second commit

This commit is contained in:
Александр Геннадьевич Сальный
2022-10-15 21:01:12 +03:00
commit 7caeeaaff5
1329 changed files with 489315 additions and 0 deletions

View File

@@ -0,0 +1,92 @@
"""
This namespace represents the core functionality that has to be built-in
and deal with private internal data structures. Things in this namespace
are publicly available in either trio, trio.lowlevel, or trio.testing.
"""
import sys
from ._exceptions import (
TrioInternalError,
RunFinishedError,
WouldBlock,
Cancelled,
BusyResourceError,
ClosedResourceError,
BrokenResourceError,
EndOfChannel,
)
from ._multierror import MultiError
from ._ki import (
enable_ki_protection,
disable_ki_protection,
currently_ki_protected,
)
# Imports that always exist
from ._run import (
Task,
CancelScope,
run,
open_nursery,
checkpoint,
current_task,
current_effective_deadline,
checkpoint_if_cancelled,
TASK_STATUS_IGNORED,
current_statistics,
current_trio_token,
reschedule,
remove_instrument,
add_instrument,
current_clock,
current_root_task,
spawn_system_task,
current_time,
wait_all_tasks_blocked,
wait_readable,
wait_writable,
notify_closing,
Nursery,
start_guest_run,
)
# Has to come after _run to resolve a circular import
from ._traps import (
cancel_shielded_checkpoint,
Abort,
wait_task_rescheduled,
temporarily_detach_coroutine_object,
permanently_detach_coroutine_object,
reattach_detached_coroutine_object,
)
from ._entry_queue import TrioToken
from ._parking_lot import ParkingLot
from ._unbounded_queue import UnboundedQueue
from ._local import RunVar
from ._thread_cache import start_thread_soon
from ._mock_clock import MockClock
# Windows imports
if sys.platform == "win32":
from ._run import (
monitor_completion_key,
current_iocp,
register_with_iocp,
wait_overlapped,
write_overlapped,
readinto_overlapped,
)
# Kqueue imports
elif sys.platform != "linux" and sys.platform != "win32":
from ._run import current_kqueue, monitor_kevent, wait_kevent
del sys # It would be better to import sys as _sys, but mypy does not understand it

View File

@@ -0,0 +1,193 @@
import attr
import logging
import sys
import warnings
import weakref
from .._util import name_asyncgen
from . import _run
from .. import _core
# Used to log exceptions in async generator finalizers
ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors")
@attr.s(eq=False, slots=True)
class AsyncGenerators:
# Async generators are added to this set when first iterated. Any
# left after the main task exits will be closed before trio.run()
# returns. During most of the run, this is a WeakSet so GC works.
# During shutdown, when we're finalizing all the remaining
# asyncgens after the system nursery has been closed, it's a
# regular set so we don't have to deal with GC firing at
# unexpected times.
alive = attr.ib(factory=weakref.WeakSet)
# This collects async generators that get garbage collected during
# the one-tick window between the system nursery closing and the
# init task starting end-of-run asyncgen finalization.
trailing_needs_finalize = attr.ib(factory=set)
prev_hooks = attr.ib(init=False)
def install_hooks(self, runner):
def firstiter(agen):
if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"):
self.alive.add(agen)
else:
# An async generator first iterated outside of a Trio
# task doesn't belong to Trio. Probably we're in guest
# mode and the async generator belongs to our host.
# The locals dictionary is the only good place to
# remember this fact, at least until
# https://bugs.python.org/issue40916 is implemented.
agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True
if self.prev_hooks.firstiter is not None:
self.prev_hooks.firstiter(agen)
def finalize_in_trio_context(agen, agen_name):
try:
runner.spawn_system_task(
self._finalize_one,
agen,
agen_name,
name=f"close asyncgen {agen_name} (abandoned)",
)
except RuntimeError:
# There is a one-tick window where the system nursery
# is closed but the init task hasn't yet made
# self.asyncgens a strong set to disable GC. We seem to
# have hit it.
self.trailing_needs_finalize.add(agen)
def finalizer(agen):
agen_name = name_asyncgen(agen)
try:
is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen")
except AttributeError: # pragma: no cover
is_ours = True
if is_ours:
runner.entry_queue.run_sync_soon(
finalize_in_trio_context, agen, agen_name
)
# Do this last, because it might raise an exception
# depending on the user's warnings filter. (That
# exception will be printed to the terminal and
# ignored, since we're running in GC context.)
warnings.warn(
f"Async generator {agen_name!r} was garbage collected before it "
f"had been exhausted. Surround its use in 'async with "
f"aclosing(...):' to ensure that it gets cleaned up as soon as "
f"you're done using it.",
ResourceWarning,
stacklevel=2,
source=agen,
)
else:
# Not ours -> forward to the host loop's async generator finalizer
if self.prev_hooks.finalizer is not None:
self.prev_hooks.finalizer(agen)
else:
# Host has no finalizer. Reimplement the default
# Python behavior with no hooks installed: throw in
# GeneratorExit, step once, raise RuntimeError if
# it doesn't exit.
closer = agen.aclose()
try:
# If the next thing is a yield, this will raise RuntimeError
# which we allow to propagate
closer.send(None)
except StopIteration:
pass
else:
# If the next thing is an await, we get here. Give a nicer
# error than the default "async generator ignored GeneratorExit"
raise RuntimeError(
f"Non-Trio async generator {agen_name!r} awaited something "
f"during finalization; install a finalization hook to "
f"support this, or wrap it in 'async with aclosing(...):'"
)
self.prev_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer)
async def finalize_remaining(self, runner):
# This is called from init after shutting down the system nursery.
# The only tasks running at this point are init and
# the run_sync_soon task, and since the system nursery is closed,
# there's no way for user code to spawn more.
assert _core.current_task() is runner.init_task
assert len(runner.tasks) == 2
# To make async generator finalization easier to reason
# about, we'll shut down asyncgen garbage collection by turning
# the alive WeakSet into a regular set.
self.alive = set(self.alive)
# Process all pending run_sync_soon callbacks, in case one of
# them was an asyncgen finalizer that snuck in under the wire.
runner.entry_queue.run_sync_soon(runner.reschedule, runner.init_task)
await _core.wait_task_rescheduled(
lambda _: _core.Abort.FAILED # pragma: no cover
)
self.alive.update(self.trailing_needs_finalize)
self.trailing_needs_finalize.clear()
# None of the still-living tasks use async generators, so
# every async generator must be suspended at a yield point --
# there's no one to be doing the iteration. That's good,
# because aclose() only works on an asyncgen that's suspended
# at a yield point. (If it's suspended at an event loop trap,
# because someone is in the middle of iterating it, then you
# get a RuntimeError on 3.8+, and a nasty surprise on earlier
# versions due to https://bugs.python.org/issue32526.)
#
# However, once we start aclose() of one async generator, it
# might start fetching the next value from another, thus
# preventing us from closing that other (at least until
# aclose() of the first one is complete). This constraint
# effectively requires us to finalize the remaining asyncgens
# in arbitrary order, rather than doing all of them at the
# same time. On 3.8+ we could defer any generator with
# ag_running=True to a later batch, but that only catches
# the case where our aclose() starts after the user's
# asend()/etc. If our aclose() starts first, then the
# user's asend()/etc will raise RuntimeError, since they're
# probably not checking ag_running.
#
# It might be possible to allow some parallelized cleanup if
# we can determine that a certain set of asyncgens have no
# interdependencies, using gc.get_referents() and such.
# But just doing one at a time will typically work well enough
# (since each aclose() executes in a cancelled scope) and
# is much easier to reason about.
# It's possible that that cleanup code will itself create
# more async generators, so we iterate repeatedly until
# all are gone.
while self.alive:
batch = self.alive
self.alive = set()
for agen in batch:
await self._finalize_one(agen, name_asyncgen(agen))
def close(self):
sys.set_asyncgen_hooks(*self.prev_hooks)
async def _finalize_one(self, agen, name):
try:
# This shield ensures that finalize_asyncgen never exits
# with an exception, not even a Cancelled. The inside
# is cancelled so there's no deadlock risk.
with _core.CancelScope(shield=True) as cancel_scope:
cancel_scope.cancel()
await agen.aclose()
except BaseException:
ASYNCGEN_LOGGER.exception(
"Exception ignored during finalization of async generator %r -- "
"surround your use of the generator in 'async with aclosing(...):' "
"to raise exceptions like this in the context where they're generated",
name,
)

View File

@@ -0,0 +1,195 @@
from collections import deque
import threading
import attr
from .. import _core
from .._util import NoPublicConstructor
from ._wakeup_socketpair import WakeupSocketpair
@attr.s(slots=True)
class EntryQueue:
# This used to use a queue.Queue. but that was broken, because Queues are
# implemented in Python, and not reentrant -- so it was thread-safe, but
# not signal-safe. deque is implemented in C, so each operation is atomic
# WRT threads (and this is guaranteed in the docs), AND each operation is
# atomic WRT signal delivery (signal handlers can run on either side, but
# not *during* a deque operation). dict makes similar guarantees - and
# it's even ordered!
queue = attr.ib(factory=deque)
idempotent_queue = attr.ib(factory=dict)
wakeup = attr.ib(factory=WakeupSocketpair)
done = attr.ib(default=False)
# Must be a reentrant lock, because it's acquired from signal handlers.
# RLock is signal-safe as of cpython 3.2. NB that this does mean that the
# lock is effectively *disabled* when we enter from signal context. The
# way we use the lock this is OK though, because when
# run_sync_soon is called from a signal it's atomic WRT the
# main thread -- it just might happen at some inconvenient place. But if
# you look at the one place where the main thread holds the lock, it's
# just to make 1 assignment, so that's atomic WRT a signal anyway.
lock = attr.ib(factory=threading.RLock)
async def task(self):
assert _core.currently_ki_protected()
# RLock has two implementations: a signal-safe version in _thread, and
# and signal-UNsafe version in threading. We need the signal safe
# version. Python 3.2 and later should always use this anyway, but,
# since the symptoms if this goes wrong are just "weird rare
# deadlocks", then let's make a little check.
# See:
# https://bugs.python.org/issue13697#msg237140
assert self.lock.__class__.__module__ == "_thread"
def run_cb(job):
# We run this with KI protection enabled; it's the callback's
# job to disable it if it wants it disabled. Exceptions are
# treated like system task exceptions (i.e., converted into
# TrioInternalError and cause everything to shut down).
sync_fn, args = job
try:
sync_fn(*args)
except BaseException as exc:
async def kill_everything(exc):
raise exc
try:
_core.spawn_system_task(kill_everything, exc)
except RuntimeError:
# We're quite late in the shutdown process and the
# system nursery is already closed.
# TODO(2020-06): this is a gross hack and should
# be fixed soon when we address #1607.
_core.current_task().parent_nursery.start_soon(kill_everything, exc)
return True
# This has to be carefully written to be safe in the face of new items
# being queued while we iterate, and to do a bounded amount of work on
# each pass:
def run_all_bounded():
for _ in range(len(self.queue)):
run_cb(self.queue.popleft())
for job in list(self.idempotent_queue):
del self.idempotent_queue[job]
run_cb(job)
try:
while True:
run_all_bounded()
if not self.queue and not self.idempotent_queue:
await self.wakeup.wait_woken()
else:
await _core.checkpoint()
except _core.Cancelled:
# Keep the work done with this lock held as minimal as possible,
# because it doesn't protect us against concurrent signal delivery
# (see the comment above). Notice that this code would still be
# correct if written like:
# self.done = True
# with self.lock:
# pass
# because all we want is to force run_sync_soon
# to either be completely before or completely after the write to
# done. That's why we don't need the lock to protect
# against signal handlers.
with self.lock:
self.done = True
# No more jobs will be submitted, so just clear out any residual
# ones:
run_all_bounded()
assert not self.queue
assert not self.idempotent_queue
def close(self):
self.wakeup.close()
def size(self):
return len(self.queue) + len(self.idempotent_queue)
def run_sync_soon(self, sync_fn, *args, idempotent=False):
with self.lock:
if self.done:
raise _core.RunFinishedError("run() has exited")
# We have to hold the lock all the way through here, because
# otherwise the main thread might exit *while* we're doing these
# calls, and then our queue item might not be processed, or the
# wakeup call might trigger an OSError b/c the IO manager has
# already been shut down.
if idempotent:
self.idempotent_queue[(sync_fn, args)] = None
else:
self.queue.append((sync_fn, args))
self.wakeup.wakeup_thread_and_signal_safe()
@attr.s(eq=False, hash=False, slots=True)
class TrioToken(metaclass=NoPublicConstructor):
"""An opaque object representing a single call to :func:`trio.run`.
It has no public constructor; instead, see :func:`current_trio_token`.
This object has two uses:
1. It lets you re-enter the Trio run loop from external threads or signal
handlers. This is the low-level primitive that :func:`trio.to_thread`
and `trio.from_thread` use to communicate with worker threads, that
`trio.open_signal_receiver` uses to receive notifications about
signals, and so forth.
2. Each call to :func:`trio.run` has exactly one associated
:class:`TrioToken` object, so you can use it to identify a particular
call.
"""
_reentry_queue = attr.ib()
def run_sync_soon(self, sync_fn, *args, idempotent=False):
"""Schedule a call to ``sync_fn(*args)`` to occur in the context of a
Trio task.
This is safe to call from the main thread, from other threads, and
from signal handlers. This is the fundamental primitive used to
re-enter the Trio run loop from outside of it.
The call will happen "soon", but there's no guarantee about exactly
when, and no mechanism provided for finding out when it's happened.
If you need this, you'll have to build your own.
The call is effectively run as part of a system task (see
:func:`~trio.lowlevel.spawn_system_task`). In particular this means
that:
* :exc:`KeyboardInterrupt` protection is *enabled* by default; if
you want ``sync_fn`` to be interruptible by control-C, then you
need to use :func:`~trio.lowlevel.disable_ki_protection`
explicitly.
* If ``sync_fn`` raises an exception, then it's converted into a
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. You
should be careful that ``sync_fn`` doesn't crash.
All calls with ``idempotent=False`` are processed in strict
first-in first-out order.
If ``idempotent=True``, then ``sync_fn`` and ``args`` must be
hashable, and Trio will make a best-effort attempt to discard any
call submission which is equal to an already-pending call. Trio
will process these in first-in first-out order.
Any ordering guarantees apply separately to ``idempotent=False``
and ``idempotent=True`` calls; there's no rule for how calls in the
different categories are ordered with respect to each other.
:raises trio.RunFinishedError:
if the associated call to :func:`trio.run`
has already exited. (Any call that *doesn't* raise this error
is guaranteed to be fully processed before :func:`trio.run`
exits.)
"""
self._reentry_queue.run_sync_soon(sync_fn, *args, idempotent=idempotent)

View File

@@ -0,0 +1,114 @@
import attr
from trio._util import NoPublicConstructor
class TrioInternalError(Exception):
"""Raised by :func:`run` if we encounter a bug in Trio, or (possibly) a
misuse of one of the low-level :mod:`trio.lowlevel` APIs.
This should never happen! If you get this error, please file a bug.
Unfortunately, if you get this error it also means that all bets are off
Trio doesn't know what is going on and its normal invariants may be void.
(For example, we might have "lost track" of a task. Or lost track of all
tasks.) Again, though, this shouldn't happen.
"""
class RunFinishedError(RuntimeError):
"""Raised by `trio.from_thread.run` and similar functions if the
corresponding call to :func:`trio.run` has already finished.
"""
class WouldBlock(Exception):
"""Raised by ``X_nowait`` functions if ``X`` would block."""
class Cancelled(BaseException, metaclass=NoPublicConstructor):
"""Raised by blocking calls if the surrounding scope has been cancelled.
You should let this exception propagate, to be caught by the relevant
cancel scope. To remind you of this, it inherits from :exc:`BaseException`
instead of :exc:`Exception`, just like :exc:`KeyboardInterrupt` and
:exc:`SystemExit` do. This means that if you write something like::
try:
...
except Exception:
...
then this *won't* catch a :exc:`Cancelled` exception.
You cannot raise :exc:`Cancelled` yourself. Attempting to do so
will produce a :exc:`TypeError`. Use :meth:`cancel_scope.cancel()
<trio.CancelScope.cancel>` instead.
.. note::
In the US it's also common to see this word spelled "canceled", with
only one "l". This is a `recent
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=5&smoothing=3&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
and `US-specific
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=18&smoothing=3&share=&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
innovation, and even in the US both forms are still commonly used. So
for consistency with the rest of the world and with "cancellation"
(which always has two "l"s), Trio uses the two "l" spelling
everywhere.
"""
def __str__(self):
return "Cancelled"
class BusyResourceError(Exception):
"""Raised when a task attempts to use a resource that some other task is
already using, and this would lead to bugs and nonsense.
For example, if two tasks try to send data through the same socket at the
same time, Trio will raise :class:`BusyResourceError` instead of letting
the data get scrambled.
"""
class ClosedResourceError(Exception):
"""Raised when attempting to use a resource after it has been closed.
Note that "closed" here means that *your* code closed the resource,
generally by calling a method with a name like ``close`` or ``aclose``, or
by exiting a context manager. If a problem arises elsewhere for example,
because of a network failure, or because a remote peer closed their end of
a connection then that should be indicated by a different exception
class, like :exc:`BrokenResourceError` or an :exc:`OSError` subclass.
"""
class BrokenResourceError(Exception):
"""Raised when an attempt to use a resource fails due to external
circumstances.
For example, you might get this if you try to send data on a stream where
the remote side has already closed the connection.
You *don't* get this error if *you* closed the resource in that case you
get :class:`ClosedResourceError`.
This exception's ``__cause__`` attribute will often contain more
information about the underlying error.
"""
class EndOfChannel(Exception):
"""Raised when trying to receive from a :class:`trio.abc.ReceiveChannel`
that has no more data to receive.
This is analogous to an "end-of-file" condition, but for channels.
"""

View File

@@ -0,0 +1,47 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
def add_instrument(instrument: Instrument) ->None:
"""Start instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to activate.
If ``instrument`` is already active, does nothing.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument)
except AttributeError:
raise RuntimeError("must be called from async context")
def remove_instrument(instrument: Instrument) ->None:
"""Stop instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to de-activate.
Raises:
KeyError: if the instrument is not currently active. This could
occur either because you never added it, or because you added it
and then it raised an unhandled exception and was automatically
deactivated.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument)
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,35 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
async def wait_readable(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_writable(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
def notify_closing(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,59 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
def current_kqueue():
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue()
except AttributeError:
raise RuntimeError("must be called from async context")
def monitor_kevent(ident, filter):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_kevent(ident, filter, abort_func):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_readable(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_writable(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
def notify_closing(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,83 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
async def wait_readable(sock):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_writable(sock):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock)
except AttributeError:
raise RuntimeError("must be called from async context")
def notify_closing(handle):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle)
except AttributeError:
raise RuntimeError("must be called from async context")
def register_with_iocp(handle):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_overlapped(handle, lpOverlapped):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(handle, lpOverlapped)
except AttributeError:
raise RuntimeError("must be called from async context")
async def write_overlapped(handle, data, file_offset=0):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(handle, data, file_offset)
except AttributeError:
raise RuntimeError("must be called from async context")
async def readinto_overlapped(handle, buffer, file_offset=0):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(handle, buffer, file_offset)
except AttributeError:
raise RuntimeError("must be called from async context")
def current_iocp():
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp()
except AttributeError:
raise RuntimeError("must be called from async context")
def monitor_completion_key():
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key()
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,241 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
def current_statistics():
"""Returns an object containing run-loop-level debugging information.
Currently the following fields are defined:
* ``tasks_living`` (int): The number of tasks that have been spawned
and not yet exited.
* ``tasks_runnable`` (int): The number of tasks that are currently
queued on the run queue (as opposed to blocked waiting for something
to happen).
* ``seconds_to_next_deadline`` (float): The time until the next
pending cancel scope deadline. May be negative if the deadline has
expired but we haven't yet processed cancellations. May be
:data:`~math.inf` if there are no pending deadlines.
* ``run_sync_soon_queue_size`` (int): The number of
unprocessed callbacks queued via
:meth:`trio.lowlevel.TrioToken.run_sync_soon`.
* ``io_statistics`` (object): Some statistics from Trio's I/O
backend. This always has an attribute ``backend`` which is a string
naming which operating-system-specific I/O backend is in use; the
other attributes vary between backends.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_statistics()
except AttributeError:
raise RuntimeError("must be called from async context")
def current_time():
"""Returns the current time according to Trio's internal clock.
Returns:
float: The current time.
Raises:
RuntimeError: if not inside a call to :func:`trio.run`.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_time()
except AttributeError:
raise RuntimeError("must be called from async context")
def current_clock():
"""Returns the current :class:`~trio.abc.Clock`."""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_clock()
except AttributeError:
raise RuntimeError("must be called from async context")
def current_root_task():
"""Returns the current root :class:`Task`.
This is the task that is the ultimate parent of all other tasks.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_root_task()
except AttributeError:
raise RuntimeError("must be called from async context")
def reschedule(task, next_send=_NO_SEND):
"""Reschedule the given task with the given
:class:`outcome.Outcome`.
See :func:`wait_task_rescheduled` for the gory details.
There must be exactly one call to :func:`reschedule` for every call to
:func:`wait_task_rescheduled`. (And when counting, keep in mind that
returning :data:`Abort.SUCCEEDED` from an abort callback is equivalent
to calling :func:`reschedule` once.)
Args:
task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked
in a call to :func:`wait_task_rescheduled`.
next_send (outcome.Outcome): the value (or error) to return (or
raise) from :func:`wait_task_rescheduled`.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send)
except AttributeError:
raise RuntimeError("must be called from async context")
def spawn_system_task(async_fn, *args, name=None, context=None):
"""Spawn a "system" task.
System tasks have a few differences from regular tasks:
* They don't need an explicit nursery; instead they go into the
internal "system nursery".
* If a system task raises an exception, then it's converted into a
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. If you
write a system task, you should be careful to make sure it doesn't
crash.
* System tasks are automatically cancelled when the main task exits.
* By default, system tasks have :exc:`KeyboardInterrupt` protection
*enabled*. If you want your task to be interruptible by control-C,
then you need to use :func:`disable_ki_protection` explicitly (and
come up with some plan for what to do with a
:exc:`KeyboardInterrupt`, given that system tasks aren't allowed to
raise exceptions).
* System tasks do not inherit context variables from their creator.
Towards the end of a call to :meth:`trio.run`, after the main
task and all system tasks have exited, the system nursery
becomes closed. At this point, new calls to
:func:`spawn_system_task` will raise ``RuntimeError("Nursery
is closed to new arrivals")`` instead of creating a system
task. It's possible to encounter this state either in
a ``finally`` block in an async generator, or in a callback
passed to :meth:`TrioToken.run_sync_soon` at the right moment.
Args:
async_fn: An async callable.
args: Positional arguments for ``async_fn``. If you want to pass
keyword arguments, use :func:`functools.partial`.
name: The name for this task. Only used for debugging/introspection
(e.g. ``repr(task_obj)``). If this isn't a string,
:func:`spawn_system_task` will try to make it one. A common use
case is if you're wrapping a function before spawning a new
task, you might pass the original function as the ``name=`` to
make debugging easier.
context: An optional ``contextvars.Context`` object with context variables
to use for this task. You would normally get a copy of the current
context with ``context = contextvars.copy_context()`` and then you would
pass that ``context`` object here.
Returns:
Task: the newly spawned task
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(async_fn, *args, name=name, context=context)
except AttributeError:
raise RuntimeError("must be called from async context")
def current_trio_token():
"""Retrieve the :class:`TrioToken` for the current call to
:func:`trio.run`.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_trio_token()
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_all_tasks_blocked(cushion=0.0):
"""Block until there are no runnable tasks.
This is useful in testing code when you want to give other tasks a
chance to "settle down". The calling task is blocked, and doesn't wake
up until all other tasks are also blocked for at least ``cushion``
seconds. (Setting a non-zero ``cushion`` is intended to handle cases
like two tasks talking to each other over a local socket, where we
want to ignore the potential brief moment between a send and receive
when all tasks are blocked.)
Note that ``cushion`` is measured in *real* time, not the Trio clock
time.
If there are multiple tasks blocked in :func:`wait_all_tasks_blocked`,
then the one with the shortest ``cushion`` is the one woken (and
this task becoming unblocked resets the timers for the remaining
tasks). If there are multiple tasks that have exactly the same
``cushion``, then all are woken.
You should also consider :class:`trio.testing.Sequencer`, which
provides a more explicit way to control execution ordering within a
test, and will often produce more readable tests.
Example:
Here's an example of one way to test that Trio's locks are fair: we
take the lock in the parent, start a child, wait for the child to be
blocked waiting for the lock (!), and then check that we can't
release and immediately re-acquire the lock::
async def lock_taker(lock):
await lock.acquire()
lock.release()
async def test_lock_fairness():
lock = trio.Lock()
await lock.acquire()
async with trio.open_nursery() as nursery:
nursery.start_soon(lock_taker, lock)
# child hasn't run yet, we have the lock
assert lock.locked()
assert lock._owner is trio.lowlevel.current_task()
await trio.testing.wait_all_tasks_blocked()
# now the child has run and is blocked on lock.acquire(), we
# still have the lock
assert lock.locked()
assert lock._owner is trio.lowlevel.current_task()
lock.release()
try:
# The child has a prior claim, so we can't have it
lock.acquire_nowait()
except trio.WouldBlock:
assert lock._owner is not trio.lowlevel.current_task()
print("PASS")
else:
print("FAIL")
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion)
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,108 @@
import logging
import types
import attr
from typing import Any, Callable, Dict, List, Sequence, Iterator, TypeVar
from .._abc import Instrument
# Used to log exceptions in instruments
INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument")
F = TypeVar("F", bound=Callable[..., Any])
# Decorator to mark methods public. This does nothing by itself, but
# trio/_tools/gen_exports.py looks for it.
def _public(fn: F) -> F:
return fn
class Instruments(Dict[str, Dict[Instrument, None]]):
"""A collection of `trio.abc.Instrument` organized by hook.
Instrumentation calls are rather expensive, and we don't want a
rarely-used instrument (like before_run()) to slow down hot
operations (like before_task_step()). Thus, we cache the set of
instruments to be called for each hook, and skip the instrumentation
call if there's nothing currently installed for that hook.
"""
__slots__ = ()
def __init__(self, incoming: Sequence[Instrument]):
self["_all"] = {}
for instrument in incoming:
self.add_instrument(instrument)
@_public
def add_instrument(self, instrument: Instrument) -> None:
"""Start instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to activate.
If ``instrument`` is already active, does nothing.
"""
if instrument in self["_all"]:
return
self["_all"][instrument] = None
try:
for name in dir(instrument):
if name.startswith("_"):
continue
try:
prototype = getattr(Instrument, name)
except AttributeError:
continue
impl = getattr(instrument, name)
if isinstance(impl, types.MethodType) and impl.__func__ is prototype:
# Inherited unchanged from _abc.Instrument
continue
self.setdefault(name, {})[instrument] = None
except:
self.remove_instrument(instrument)
raise
@_public
def remove_instrument(self, instrument: Instrument) -> None:
"""Stop instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to de-activate.
Raises:
KeyError: if the instrument is not currently active. This could
occur either because you never added it, or because you added it
and then it raised an unhandled exception and was automatically
deactivated.
"""
# If instrument isn't present, the KeyError propagates out
self["_all"].pop(instrument)
for hookname, instruments in list(self.items()):
if instrument in instruments:
del instruments[instrument]
if not instruments:
del self[hookname]
def call(self, hookname: str, *args: Any) -> None:
"""Call hookname(*args) on each applicable instrument.
You must first check whether there are any instruments installed for
that hook, e.g.::
if "before_task_step" in instruments:
instruments.call("before_task_step", task)
"""
for instrument in list(self[hookname]):
try:
getattr(instrument, hookname)(*args)
except:
self.remove_instrument(instrument)
INSTRUMENT_LOGGER.exception(
"Exception raised when calling %r on instrument %r. "
"Instrument has been disabled.",
hookname,
instrument,
)

View File

@@ -0,0 +1,22 @@
import copy
import outcome
from .. import _core
# Utility function shared between _io_epoll and _io_windows
def wake_all(waiters, exc):
try:
current_task = _core.current_task()
except RuntimeError:
current_task = None
raise_at_end = False
for attr_name in ["read_task", "write_task"]:
task = getattr(waiters, attr_name)
if task is not None:
if task is current_task:
raise_at_end = True
else:
_core.reschedule(task, outcome.Error(copy.copy(exc)))
setattr(waiters, attr_name, None)
if raise_at_end:
raise exc

View File

@@ -0,0 +1,317 @@
import select
import sys
import attr
from collections import defaultdict
from typing import Dict, TYPE_CHECKING
from .. import _core
from ._run import _public
from ._io_common import wake_all
from ._wakeup_socketpair import WakeupSocketpair
assert not TYPE_CHECKING or sys.platform == "linux"
@attr.s(slots=True, eq=False, frozen=True)
class _EpollStatistics:
tasks_waiting_read = attr.ib()
tasks_waiting_write = attr.ib()
backend = attr.ib(default="epoll")
# Some facts about epoll
# ----------------------
#
# Internally, an epoll object is sort of like a WeakKeyDictionary where the
# keys are tuples of (fd number, file object). When you call epoll_ctl, you
# pass in an fd; that gets converted to an (fd number, file object) tuple by
# looking up the fd in the process's fd table at the time of the call. When an
# event happens on the file object, epoll_wait drops the file object part, and
# just returns the fd number in its event. So from the outside it looks like
# it's keeping a table of fds, but really it's a bit more complicated. This
# has some subtle consequences.
#
# In general, file objects inside the kernel are reference counted. Each entry
# in a process's fd table holds a strong reference to the corresponding file
# object, and most operations that use file objects take a temporary strong
# reference while they're working. So when you call close() on an fd, that
# might or might not cause the file object to be deallocated -- it depends on
# whether there are any other references to that file object. Some common ways
# this can happen:
#
# - after calling dup(), you have two fds in the same process referring to the
# same file object. Even if you close one fd (= remove that entry from the
# fd table), the file object will be kept alive by the other fd.
# - when calling fork(), the child inherits a copy of the parent's fd table,
# so all the file objects get another reference. (But if the fork() is
# followed by exec(), then all of the child's fds that have the CLOEXEC flag
# set will be closed at that point.)
# - most syscalls that work on fds take a strong reference to the underlying
# file object while they're using it. So there's one thread blocked in
# read(fd), and then another thread calls close() on the last fd referring
# to that object, the underlying file won't actually be closed until
# after read() returns.
#
# However, epoll does *not* take a reference to any of the file objects in its
# interest set (that's what makes it similar to a WeakKeyDictionary). File
# objects inside an epoll interest set will be deallocated if all *other*
# references to them are closed. And when that happens, the epoll object will
# automatically deregister that file object and stop reporting events on it.
# So that's quite handy.
#
# But, what happens if we do this?
#
# fd1 = open(...)
# epoll_ctl(EPOLL_CTL_ADD, fd1, ...)
# fd2 = dup(fd1)
# close(fd1)
#
# In this case, the dup() keeps the underlying file object alive, so it
# remains registered in the epoll object's interest set, as the tuple (fd1,
# file object). But, fd1 no longer refers to this file object! You might think
# there was some magic to handle this, but unfortunately no; the consequences
# are totally predictable from what I said above:
#
# If any events occur on the file object, then epoll will report them as
# happening on fd1, even though that doesn't make sense.
#
# Perhaps we would like to deregister fd1 to stop getting nonsensical events.
# But how? When we call epoll_ctl, we have to pass an fd number, which will
# get expanded to an (fd number, file object) tuple. We can't pass fd1,
# because when epoll_ctl tries to look it up, it won't find our file object.
# And we can't pass fd2, because that will get expanded to (fd2, file object),
# which is a different lookup key. In fact, it's *impossible* to de-register
# this fd!
#
# We could even have fd1 get assigned to another file object, and then we can
# have multiple keys registered simultaneously using the same fd number, like:
# (fd1, file object 1), (fd1, file object 2). And if events happen on either
# file object, then epoll will happily report that something happened to
# "fd1".
#
# Now here's what makes this especially nasty: suppose the old file object
# becomes, say, readable. That means that every time we call epoll_wait, it
# will return immediately to tell us that "fd1" is readable. Normally, we
# would handle this by de-registering fd1, waking up the corresponding call to
# wait_readable, then the user will call read() or recv() or something, and
# we're fine. But if this happens on a stale fd where we can't remove the
# registration, then we might get stuck in a state where epoll_wait *always*
# returns immediately, so our event loop becomes unable to sleep, and now our
# program is burning 100% of the CPU doing nothing, with no way out.
#
#
# What does this mean for Trio?
# -----------------------------
#
# Since we don't control the user's code, we have no way to guarantee that we
# don't get stuck with stale fd's in our epoll interest set. For example, a
# user could call wait_readable(fd) in one task, and then while that's
# running, they might close(fd) from another task. In this situation, they're
# *supposed* to call notify_closing(fd) to let us know what's happening, so we
# can interrupt the wait_readable() call and avoid getting into this mess. And
# that's the only thing that can possibly work correctly in all cases. But
# sometimes user code has bugs. So if this does happen, we'd like to degrade
# gracefully, and survive without corrupting Trio's internal state or
# otherwise causing the whole program to explode messily.
#
# Our solution: we always use EPOLLONESHOT. This way, we might get *one*
# spurious event on a stale fd, but then epoll will automatically silence it
# until we explicitly say that we want more events... and if we have a stale
# fd, then we actually can't re-enable it! So we can't get stuck in an
# infinite busy-loop. If there's a stale fd hanging around, then it might
# cause a spurious `BusyResourceError`, or cause one wait_* call to return
# before it should have... but in general, the wait_* functions are allowed to
# have some spurious wakeups; the user code will just attempt the operation,
# get EWOULDBLOCK, and call wait_* again. And the program as a whole will
# survive, any exceptions will propagate, etc.
#
# As a bonus, EPOLLONESHOT also saves us having to explicitly deregister fds
# on the normal wakeup path, so it's a bit more efficient in general.
#
# However, EPOLLONESHOT has a few trade-offs to consider:
#
# First, you can't combine EPOLLONESHOT with EPOLLEXCLUSIVE. This is a bit sad
# in one somewhat rare case: if you have a multi-process server where a group
# of processes all share the same listening socket, then EPOLLEXCLUSIVE can be
# used to avoid "thundering herd" problems when a new connection comes in. But
# this isn't too bad. It's not clear if EPOLLEXCLUSIVE even works for us
# anyway:
#
# https://stackoverflow.com/questions/41582560/how-does-epolls-epollexclusive-mode-interact-with-level-triggering
#
# And it's not clear that EPOLLEXCLUSIVE is a great approach either:
#
# https://blog.cloudflare.com/the-sad-state-of-linux-socket-balancing/
#
# And if we do need to support this, we could always add support through some
# more-specialized API in the future. So this isn't a blocker to using
# EPOLLONESHOT.
#
# Second, EPOLLONESHOT does not actually *deregister* the fd after delivering
# an event (EPOLL_CTL_DEL). Instead, it keeps the fd registered, but
# effectively does an EPOLL_CTL_MOD to set the fd's interest flags to
# all-zeros. So we could still end up with an fd hanging around in the
# interest set for a long time, even if we're not using it.
#
# Fortunately, this isn't a problem, because it's only a weak reference if
# we have a stale fd that's been silenced by EPOLLONESHOT, then it wastes a
# tiny bit of kernel memory remembering this fd that can never be revived, but
# when the underlying file object is eventually closed, that memory will be
# reclaimed. So that's OK.
#
# The other issue is that when someone calls wait_*, using EPOLLONESHOT means
# that if we have ever waited for this fd before, we have to use EPOLL_CTL_MOD
# to re-enable it; but if it's a new fd, we have to use EPOLL_CTL_ADD. How do
# we know which one to use? There's no reasonable way to track which fds are
# currently registered -- remember, we're assuming the user might have gone
# and rearranged their fds without telling us!
#
# Fortunately, this also has a simple solution: if we wait on a socket or
# other fd once, then we'll probably wait on it lots of times. And the epoll
# object itself knows which fds it already has registered. So when an fd comes
# in, we optimistically assume that it's been waited on before, and try doing
# EPOLL_CTL_MOD. And if that fails with an ENOENT error, then we try again
# with EPOLL_CTL_ADD.
#
# So that's why this code is the way it is. And now you know more than you
# wanted to about how epoll works.
@attr.s(slots=True, eq=False)
class EpollWaiters:
read_task = attr.ib(default=None)
write_task = attr.ib(default=None)
current_flags = attr.ib(default=0)
@attr.s(slots=True, eq=False, hash=False)
class EpollIOManager:
_epoll = attr.ib(factory=select.epoll)
# {fd: EpollWaiters}
_registered = attr.ib(
factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters]
)
_force_wakeup = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd = attr.ib(default=None)
def __attrs_post_init__(self):
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
def statistics(self):
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._registered.values():
if waiter.read_task is not None:
tasks_waiting_read += 1
if waiter.write_task is not None:
tasks_waiting_write += 1
return _EpollStatistics(
tasks_waiting_read=tasks_waiting_read,
tasks_waiting_write=tasks_waiting_write,
)
def close(self):
self._epoll.close()
self._force_wakeup.close()
def force_wakeup(self):
self._force_wakeup.wakeup_thread_and_signal_safe()
# Return value must be False-y IFF the timeout expired, NOT if any I/O
# happened or force_wakeup was called. Otherwise it can be anything; gets
# passed straight through to process_events.
def get_events(self, timeout):
# max_events must be > 0 or epoll gets cranky
# accessing self._registered from a thread looks dangerous, but it's
# OK because it doesn't matter if our value is a little bit off.
max_events = max(1, len(self._registered))
return self._epoll.poll(timeout, max_events)
def process_events(self, events):
for fd, flags in events:
if fd == self._force_wakeup_fd:
self._force_wakeup.drain()
continue
waiters = self._registered[fd]
# EPOLLONESHOT always clears the flags when an event is delivered
waiters.current_flags = 0
# Clever hack stolen from selectors.EpollSelector: an event
# with EPOLLHUP or EPOLLERR flags wakes both readers and
# writers.
if flags & ~select.EPOLLIN and waiters.write_task is not None:
_core.reschedule(waiters.write_task)
waiters.write_task = None
if flags & ~select.EPOLLOUT and waiters.read_task is not None:
_core.reschedule(waiters.read_task)
waiters.read_task = None
self._update_registrations(fd)
def _update_registrations(self, fd):
waiters = self._registered[fd]
wanted_flags = 0
if waiters.read_task is not None:
wanted_flags |= select.EPOLLIN
if waiters.write_task is not None:
wanted_flags |= select.EPOLLOUT
if wanted_flags != waiters.current_flags:
try:
try:
# First try EPOLL_CTL_MOD
self._epoll.modify(fd, wanted_flags | select.EPOLLONESHOT)
except OSError:
# If that fails, it might be a new fd; try EPOLL_CTL_ADD
self._epoll.register(fd, wanted_flags | select.EPOLLONESHOT)
waiters.current_flags = wanted_flags
except OSError as exc:
# If everything fails, probably it's a bad fd, e.g. because
# the fd was closed behind our back. In this case we don't
# want to try to unregister the fd, because that will probably
# fail too. Just clear our state and wake everyone up.
del self._registered[fd]
# This could raise (in case we're calling this inside one of
# the to-be-woken tasks), so we have to do it last.
wake_all(waiters, exc)
return
if not wanted_flags:
del self._registered[fd]
async def _epoll_wait(self, fd, attr_name):
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
if getattr(waiters, attr_name) is not None:
raise _core.BusyResourceError(
"another task is already reading / writing this fd"
)
setattr(waiters, attr_name, _core.current_task())
self._update_registrations(fd)
def abort(_):
setattr(waiters, attr_name, None)
self._update_registrations(fd)
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort)
@_public
async def wait_readable(self, fd):
await self._epoll_wait(fd, "read_task")
@_public
async def wait_writable(self, fd):
await self._epoll_wait(fd, "write_task")
@_public
def notify_closing(self, fd):
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
self._registered[fd],
_core.ClosedResourceError("another task closed this fd"),
)
del self._registered[fd]
try:
self._epoll.unregister(fd)
except (OSError, ValueError):
pass

View File

@@ -0,0 +1,196 @@
import select
import sys
from typing import TYPE_CHECKING
import outcome
from contextlib import contextmanager
import attr
import errno
from .. import _core
from ._run import _public
from ._wakeup_socketpair import WakeupSocketpair
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
@attr.s(slots=True, eq=False, frozen=True)
class _KqueueStatistics:
tasks_waiting = attr.ib()
monitors = attr.ib()
backend = attr.ib(default="kqueue")
@attr.s(slots=True, eq=False)
class KqueueIOManager:
_kqueue = attr.ib(factory=select.kqueue)
# {(ident, filter): Task or UnboundedQueue}
_registered = attr.ib(factory=dict)
_force_wakeup = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd = attr.ib(default=None)
def __attrs_post_init__(self):
force_wakeup_event = select.kevent(
self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD
)
self._kqueue.control([force_wakeup_event], 0)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
def statistics(self):
tasks_waiting = 0
monitors = 0
for receiver in self._registered.values():
if type(receiver) is _core.Task:
tasks_waiting += 1
else:
monitors += 1
return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors)
def close(self):
self._kqueue.close()
self._force_wakeup.close()
def force_wakeup(self):
self._force_wakeup.wakeup_thread_and_signal_safe()
def get_events(self, timeout):
# max_events must be > 0 or kqueue gets cranky
# and we generally want this to be strictly larger than the actual
# number of events we get, so that we can tell that we've gotten
# all the events in just 1 call.
max_events = len(self._registered) + 1
events = []
while True:
batch = self._kqueue.control([], max_events, timeout)
events += batch
if len(batch) < max_events:
break
else:
timeout = 0
# and loop back to the start
return events
def process_events(self, events):
for event in events:
key = (event.ident, event.filter)
if event.ident == self._force_wakeup_fd:
self._force_wakeup.drain()
continue
receiver = self._registered[key]
if event.flags & select.KQ_EV_ONESHOT:
del self._registered[key]
if type(receiver) is _core.Task:
_core.reschedule(receiver, outcome.Value(event))
else:
receiver.put_nowait(event)
# kevent registration is complicated -- e.g. aio submission can
# implicitly perform a EV_ADD, and EVFILT_PROC with NOTE_TRACK will
# automatically register filters for child processes. So our lowlevel
# API is *very* low-level: we expose the kqueue itself for adding
# events or sticking into AIO submission structs, and split waiting
# off into separate methods. It's your responsibility to make sure
# that handle_io never receives an event without a corresponding
# registration! This may be challenging if you want to be careful
# about e.g. KeyboardInterrupt. Possibly this API could be improved to
# be more ergonomic...
@_public
def current_kqueue(self):
return self._kqueue
@contextmanager
@_public
def monitor_kevent(self, ident, filter):
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair"
)
q = _core.UnboundedQueue()
self._registered[key] = q
try:
yield q
finally:
del self._registered[key]
@_public
async def wait_kevent(self, ident, filter, abort_func):
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair"
)
self._registered[key] = _core.current_task()
def abort(raise_cancel):
r = abort_func(raise_cancel)
if r is _core.Abort.SUCCEEDED:
del self._registered[key]
return r
return await _core.wait_task_rescheduled(abort)
async def _wait_common(self, fd, filter):
if not isinstance(fd, int):
fd = fd.fileno()
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
event = select.kevent(fd, filter, flags)
self._kqueue.control([event], 0)
def abort(_):
event = select.kevent(fd, filter, select.KQ_EV_DELETE)
try:
self._kqueue.control([event], 0)
except OSError as exc:
# kqueue tracks individual fds (*not* the underlying file
# object, see _io_epoll.py for a long discussion of why this
# distinction matters), and automatically deregisters an event
# if the fd is closed. So if kqueue.control says that it
# doesn't know about this event, then probably it's because
# the fd was closed behind our backs. (Too bad we can't ask it
# to wake us up when this happens, versus discovering it after
# the fact... oh well, you can't have everything.)
#
# FreeBSD reports this using EBADF. macOS uses ENOENT.
if exc.errno in (errno.EBADF, errno.ENOENT): # pragma: no branch
pass
else: # pragma: no cover
# As far as we know, this branch can't happen.
raise
return _core.Abort.SUCCEEDED
await self.wait_kevent(fd, filter, abort)
@_public
async def wait_readable(self, fd):
await self._wait_common(fd, select.KQ_FILTER_READ)
@_public
async def wait_writable(self, fd):
await self._wait_common(fd, select.KQ_FILTER_WRITE)
@_public
def notify_closing(self, fd):
if not isinstance(fd, int):
fd = fd.fileno()
for filter in [select.KQ_FILTER_READ, select.KQ_FILTER_WRITE]:
key = (fd, filter)
receiver = self._registered.get(key)
if receiver is None:
continue
if type(receiver) is _core.Task:
event = select.kevent(fd, filter, select.KQ_EV_DELETE)
self._kqueue.control([event], 0)
exc = _core.ClosedResourceError("another task closed this fd")
_core.reschedule(receiver, outcome.Error(exc))
del self._registered[key]
else:
# XX this is an interesting example of a case where being able
# to close a queue would be useful...
raise NotImplementedError(
"can't close an fd that monitor_kevent is using"
)

View File

@@ -0,0 +1,868 @@
import itertools
from contextlib import contextmanager
import enum
import socket
import sys
from typing import TYPE_CHECKING
import attr
from outcome import Value
from .. import _core
from ._run import _public
from ._io_common import wake_all
from ._windows_cffi import (
ffi,
kernel32,
ntdll,
ws2_32,
INVALID_HANDLE_VALUE,
raise_winerror,
_handle,
ErrorCodes,
FileFlags,
AFDPollFlags,
WSAIoctls,
CompletionModes,
IoControlCodes,
)
assert not TYPE_CHECKING or sys.platform == "win32"
# There's a lot to be said about the overall design of a Windows event
# loop. See
#
# https://github.com/python-trio/trio/issues/52
#
# for discussion. This now just has some lower-level notes:
#
# How IOCP fits together:
#
# The general model is that you call some function like ReadFile or WriteFile
# to tell the kernel that you want it to perform some operation, and the
# kernel goes off and does that in the background, then at some point later it
# sends you a notification that the operation is complete. There are some more
# exotic APIs that don't quite fit this pattern, but most APIs do.
#
# Each background operation is tracked using an OVERLAPPED struct, that
# uniquely identifies that particular operation.
#
# An "IOCP" (or "I/O completion port") is an object that lets the kernel send
# us these notifications -- basically it's just a kernel->userspace queue.
#
# Each IOCP notification is represented by an OVERLAPPED_ENTRY struct, which
# contains 3 fields:
# - The "completion key". This is an opaque integer that we pick, and use
# however is convenient.
# - pointer to the OVERLAPPED struct for the completed operation.
# - dwNumberOfBytesTransferred (an integer).
#
# And in addition, for regular I/O, the OVERLAPPED structure gets filled in
# with:
# - result code (named "Internal")
# - number of bytes transferred (named "InternalHigh"); usually redundant
# with dwNumberOfBytesTransferred.
#
# There are also some other entries in OVERLAPPED which only matter on input:
# - Offset and OffsetHigh which are inputs to {Read,Write}File and
# otherwise always zero
# - hEvent which is for if you aren't using IOCP; we always set it to zero.
#
# That describes the usual pattern for operations and the usual meaning of
# these struct fields, but really these are just some arbitrary chunks of
# bytes that get passed back and forth, so some operations like to overload
# them to mean something else.
#
# You can also directly queue an OVERLAPPED_ENTRY object to an IOCP by calling
# PostQueuedCompletionStatus. When you use this you get to set all the
# OVERLAPPED_ENTRY fields to arbitrary values.
#
# You can request to cancel any operation if you know which handle it was
# issued on + the OVERLAPPED struct that identifies it (via CancelIoEx). This
# request might fail because the operation has already completed, or it might
# be queued to happen in the background, so you only find out whether it
# succeeded or failed later, when we get back the notification for the
# operation being complete.
#
# There are three types of operations that we support:
#
# == Regular I/O operations on handles (e.g. files or named pipes) ==
#
# Implemented by: register_with_iocp, wait_overlapped
#
# To use these, you have to register the handle with your IOCP first. Once
# it's registered, any operations on that handle will automatically send
# completion events to that IOCP, with a completion key that you specify *when
# the handle is registered* (so you can't use different completion keys for
# different operations).
#
# We give these two dedicated completion keys: CKeys.WAIT_OVERLAPPED for
# regular operations, and CKeys.LATE_CANCEL that's used to make
# wait_overlapped cancellable even if the user forgot to call
# register_with_iocp. The problem here is that after we request the cancel,
# wait_overlapped keeps blocking until it sees the completion notification...
# but if the user forgot to register_with_iocp, then the completion will never
# come, so the cancellation will never resolve. To avoid this, whenever we try
# to cancel an I/O operation and the cancellation fails, we use
# PostQueuedCompletionStatus to send a CKeys.LATE_CANCEL notification. If this
# arrives before the real completion, we assume the user forgot to call
# register_with_iocp on their handle, and raise an error accordingly.
#
# == Socket state notifications ==
#
# Implemented by: wait_readable, wait_writable
#
# The public APIs that windows provides for this are all really awkward and
# don't integrate with IOCP. So we drop down to a lower level, and talk
# directly to the socket device driver in the kernel, which is called "AFD".
# Unfortunately, this is a totally undocumented internal API. Fortunately
# libuv also does this, so we can be pretty confident that MS won't break it
# on us, and there is a *little* bit of information out there if you go
# digging.
#
# Basically: we open a magic file that refers to the AFD driver, register the
# magic file with our IOCP, and then we can issue regular overlapped I/O
# operations on that handle. Specifically, the operation we use is called
# IOCTL_AFD_POLL, which lets us pass in a buffer describing which events we're
# interested in on a given socket (readable, writable, etc.). Later, when the
# operation completes, the kernel rewrites the buffer we passed in to record
# which events happened, and uses IOCP as normal to notify us that this
# operation has completed.
#
# Unfortunately, the Windows kernel seems to have bugs if you try to issue
# multiple simultaneous IOCTL_AFD_POLL operations on the same socket (see
# notes-to-self/afd-lab.py). So if a user calls wait_readable and
# wait_writable at the same time, we have to combine those into a single
# IOCTL_AFD_POLL. This means we can't just use the wait_overlapped machinery.
# Instead we have some dedicated code to handle these operations, and a
# dedicated completion key CKeys.AFD_POLL.
#
# Sources of information:
# - https://github.com/python-trio/trio/issues/52
# - Wepoll: https://github.com/piscisaureus/wepoll/
# - libuv: https://github.com/libuv/libuv/
# - ReactOS: https://github.com/reactos/reactos/
# - Ancient leaked copies of the Windows NT and Winsock source code:
# https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
# https://github.com/metoo10987/WinNT4/blob/f5c14e6b42c8f45c20fe88d14c61f9d6e0386b8e/private/ntos/afd/poll.c#L68-L707
# - The WSAEventSelect docs (this exposes a finer-grained set of events than
# select(), so if you squint you can treat it as a source of information on
# the fine-grained AFD poll types)
#
#
# == Everything else ==
#
# There are also some weirder APIs for interacting with IOCP. For example, the
# "Job" API lets you specify an IOCP handle and "completion key", and then in
# the future whenever certain events happen it sends uses IOCP to send a
# notification. These notifications don't correspond to any particular
# operation; they're just spontaneous messages you get. The
# "dwNumberOfBytesTransferred" field gets repurposed to carry an identifier
# for the message type (e.g. JOB_OBJECT_MSG_EXIT_PROCESS), and the
# "lpOverlapped" field gets repurposed to carry some arbitrary data that
# depends on the message type (e.g. the pid of the process that exited).
#
# To handle these, we have monitor_completion_key, where we hand out an
# unassigned completion key, let users set it up however they want, and then
# get any events that arrive on that key.
#
# (Note: monitor_completion_key is not documented or fully baked; expect it to
# change in the future.)
# Our completion keys
class CKeys(enum.IntEnum):
AFD_POLL = 0
WAIT_OVERLAPPED = 1
LATE_CANCEL = 2
FORCE_WAKEUP = 3
USER_DEFINED = 4 # and above
def _check(success):
if not success:
raise_winerror()
return success
def _get_underlying_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE):
if hasattr(sock, "fileno"):
sock = sock.fileno()
base_ptr = ffi.new("HANDLE *")
out_size = ffi.new("DWORD *")
failed = ws2_32.WSAIoctl(
ffi.cast("SOCKET", sock),
which,
ffi.NULL,
0,
base_ptr,
ffi.sizeof("HANDLE"),
out_size,
ffi.NULL,
ffi.NULL,
)
if failed:
code = ws2_32.WSAGetLastError()
raise_winerror(code)
return base_ptr[0]
def _get_base_socket(sock):
# There is a development kit for LSPs called Komodia Redirector.
# It does some unusual (some might say evil) things like intercepting
# SIO_BASE_HANDLE (fails) and SIO_BSP_HANDLE_SELECT (returns the same
# socket) in a misguided attempt to prevent bypassing it. It's been used
# in malware including the infamous Lenovo Superfish incident from 2015,
# but unfortunately is also used in some legitimate products such as
# parental control tools and Astrill VPN. Komodia happens to not
# block SIO_BSP_HANDLE_POLL, so we'll try SIO_BASE_HANDLE and fall back
# to SIO_BSP_HANDLE_POLL if it doesn't work.
# References:
# - https://github.com/piscisaureus/wepoll/blob/0598a791bf9cbbf480793d778930fc635b044980/wepoll.c#L2223
# - https://github.com/tokio-rs/mio/issues/1314
while True:
try:
# If this is not a Komodia-intercepted socket, we can just use
# SIO_BASE_HANDLE.
return _get_underlying_socket(sock)
except OSError as ex:
if ex.winerror == ErrorCodes.ERROR_NOT_SOCKET:
# SIO_BASE_HANDLE might fail even without LSP intervention,
# if we get something that's not a socket.
raise
if hasattr(sock, "fileno"):
sock = sock.fileno()
sock = _handle(sock)
next_sock = _get_underlying_socket(
sock, which=WSAIoctls.SIO_BSP_HANDLE_POLL
)
if next_sock == sock:
# If BSP_HANDLE_POLL returns the same socket we already had,
# then there's no layering going on and we need to fail
# to prevent an infinite loop.
raise RuntimeError(
"Unexpected network configuration detected: "
"SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't "
"return a different socket. Please file a bug at "
"https://github.com/python-trio/trio/issues/new, "
"and include the output of running: "
"netsh winsock show catalog"
)
# Otherwise we've gotten at least one layer deeper, so
# loop back around to keep digging.
sock = next_sock
def _afd_helper_handle():
# The "AFD" driver is exposed at the NT path "\Device\Afd". We're using
# the Win32 CreateFile, though, so we have to pass a Win32 path. \\.\ is
# how Win32 refers to the NT \GLOBAL??\ directory, and GLOBALROOT is a
# symlink inside that directory that points to the root of the NT path
# system. So by sticking that in front of the NT path, we get a Win32
# path. Alternatively, we could use NtCreateFile directly, since it takes
# an NT path. But we already wrap CreateFileW so this was easier.
# References:
# https://blogs.msdn.microsoft.com/jeremykuhne/2016/05/02/dos-to-nt-a-paths-journey/
# https://stackoverflow.com/a/21704022
#
# I'm actually not sure what the \Trio part at the end of the path does.
# Wepoll uses \Device\Afd\Wepoll, so I just copied them. (I'm guessing it
# might be visible in some debug tools, and is otherwise arbitrary?)
rawname = r"\\.\GLOBALROOT\Device\Afd\Trio".encode("utf-16le") + b"\0\0"
rawname_buf = ffi.from_buffer(rawname)
handle = kernel32.CreateFileW(
ffi.cast("LPCWSTR", rawname_buf),
FileFlags.SYNCHRONIZE,
FileFlags.FILE_SHARE_READ | FileFlags.FILE_SHARE_WRITE,
ffi.NULL, # no security attributes
FileFlags.OPEN_EXISTING,
FileFlags.FILE_FLAG_OVERLAPPED,
ffi.NULL, # no template file
)
if handle == INVALID_HANDLE_VALUE: # pragma: no cover
raise_winerror()
return handle
# AFD_POLL has a finer-grained set of events than other APIs. We collapse them
# down into Unix-style "readable" and "writable".
#
# Note: AFD_POLL_LOCAL_CLOSE isn't a reliable substitute for notify_closing(),
# because even if the user closes the socket *handle*, the socket *object*
# could still remain open, e.g. if the socket was dup'ed (possibly into
# another process). Explicitly calling notify_closing() guarantees that
# everyone waiting on the *handle* wakes up, which is what you'd expect.
#
# However, we can't avoid getting LOCAL_CLOSE notifications -- the kernel
# delivers them whether we ask for them or not -- so better to include them
# here for documentation, and so that when we check (delivered & requested) we
# get a match.
READABLE_FLAGS = (
AFDPollFlags.AFD_POLL_RECEIVE
| AFDPollFlags.AFD_POLL_ACCEPT
| AFDPollFlags.AFD_POLL_DISCONNECT # other side sent an EOF
| AFDPollFlags.AFD_POLL_ABORT
| AFDPollFlags.AFD_POLL_LOCAL_CLOSE
)
WRITABLE_FLAGS = (
AFDPollFlags.AFD_POLL_SEND
| AFDPollFlags.AFD_POLL_CONNECT_FAIL
| AFDPollFlags.AFD_POLL_ABORT
| AFDPollFlags.AFD_POLL_LOCAL_CLOSE
)
# Annoyingly, while the API makes it *seem* like you can happily issue as many
# independent AFD_POLL operations as you want without them interfering with
# each other, in fact if you issue two AFD_POLL operations for the same socket
# at the same time with notification going to the same IOCP port, then Windows
# gets super confused. For example, if we issue one operation from
# wait_readable, and another independent operation from wait_writable, then
# Windows may complete the wait_writable operation when the socket becomes
# readable.
#
# To avoid this, we have to coalesce all the operations on a single socket
# into one, and when the set of waiters changes we have to throw away the old
# operation and start a new one.
@attr.s(slots=True, eq=False)
class AFDWaiters:
read_task = attr.ib(default=None)
write_task = attr.ib(default=None)
current_op = attr.ib(default=None)
# We also need to bundle up all the info for a single op into a standalone
# object, because we need to keep all these objects alive until the operation
# finishes, even if we're throwing it away.
@attr.s(slots=True, eq=False, frozen=True)
class AFDPollOp:
lpOverlapped = attr.ib()
poll_info = attr.ib()
waiters = attr.ib()
afd_group = attr.ib()
# The Windows kernel has a weird issue when using AFD handles. If you have N
# instances of wait_readable/wait_writable registered with a single AFD handle,
# then cancelling any one of them takes something like O(N**2) time. So if we
# used just a single AFD handle, then cancellation would quickly become very
# expensive, e.g. a program with N active sockets would take something like
# O(N**3) time to unwind after control-C. The solution is to spread our sockets
# out over multiple AFD handles, so that N doesn't grow too large for any
# individual handle.
MAX_AFD_GROUP_SIZE = 500 # at 1000, the cubic scaling is just starting to bite
@attr.s(slots=True, eq=False)
class AFDGroup:
size = attr.ib()
handle = attr.ib()
@attr.s(slots=True, eq=False, frozen=True)
class _WindowsStatistics:
tasks_waiting_read = attr.ib()
tasks_waiting_write = attr.ib()
tasks_waiting_overlapped = attr.ib()
completion_key_monitors = attr.ib()
backend = attr.ib(default="windows")
# Maximum number of events to dequeue from the completion port on each pass
# through the run loop. Somewhat arbitrary. Should be large enough to collect
# a good set of tasks on each loop, but not so large to waste tons of memory.
# (Each WindowsIOManager holds a buffer whose size is ~32x this number.)
MAX_EVENTS = 1000
@attr.s(frozen=True)
class CompletionKeyEventInfo:
lpOverlapped = attr.ib()
dwNumberOfBytesTransferred = attr.ib()
class WindowsIOManager:
def __init__(self):
# If this method raises an exception, then __del__ could run on a
# half-initialized object. So we initialize everything that __del__
# touches to safe values up front, before we do anything that can
# fail.
self._iocp = None
self._all_afd_handles = []
self._iocp = _check(
kernel32.CreateIoCompletionPort(INVALID_HANDLE_VALUE, ffi.NULL, 0, 0)
)
self._events = ffi.new("OVERLAPPED_ENTRY[]", MAX_EVENTS)
self._vacant_afd_groups = set()
# {lpOverlapped: AFDPollOp}
self._afd_ops = {}
# {socket handle: AFDWaiters}
self._afd_waiters = {}
# {lpOverlapped: task}
self._overlapped_waiters = {}
self._posted_too_late_to_cancel = set()
self._completion_key_queues = {}
self._completion_key_counter = itertools.count(CKeys.USER_DEFINED)
with socket.socket() as s:
# We assume we're not working with any LSP that changes
# how select() is supposed to work. Validate this by
# ensuring that the result of SIO_BSP_HANDLE_SELECT (the
# LSP-hookable mechanism for "what should I use for
# select()?") matches that of SIO_BASE_HANDLE ("what is
# the real non-hooked underlying socket here?").
#
# This doesn't work for Komodia-based LSPs; see the comments
# in _get_base_socket() for details. But we have special
# logic for those, so we just skip this check if
# SIO_BASE_HANDLE fails.
# LSPs can in theory override this, but we believe that it never
# actually happens in the wild (except Komodia)
select_handle = _get_underlying_socket(
s, which=WSAIoctls.SIO_BSP_HANDLE_SELECT
)
try:
# LSPs shouldn't override this...
base_handle = _get_underlying_socket(s, which=WSAIoctls.SIO_BASE_HANDLE)
except OSError:
# But Komodia-based LSPs do anyway, in a way that causes
# a failure with WSAEFAULT. We have special handling for
# them in _get_base_socket(). Make sure it works.
_get_base_socket(s)
else:
if base_handle != select_handle:
raise RuntimeError(
"Unexpected network configuration detected: "
"SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ. "
"Please file a bug at "
"https://github.com/python-trio/trio/issues/new, "
"and include the output of running: "
"netsh winsock show catalog"
)
def close(self):
try:
if self._iocp is not None:
iocp = self._iocp
self._iocp = None
_check(kernel32.CloseHandle(iocp))
finally:
while self._all_afd_handles:
afd_handle = self._all_afd_handles.pop()
_check(kernel32.CloseHandle(afd_handle))
def __del__(self):
self.close()
def statistics(self):
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._afd_waiters.values():
if waiter.read_task is not None:
tasks_waiting_read += 1
if waiter.write_task is not None:
tasks_waiting_write += 1
return _WindowsStatistics(
tasks_waiting_read=tasks_waiting_read,
tasks_waiting_write=tasks_waiting_write,
tasks_waiting_overlapped=len(self._overlapped_waiters),
completion_key_monitors=len(self._completion_key_queues),
)
def force_wakeup(self):
_check(
kernel32.PostQueuedCompletionStatus(
self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL
)
)
def get_events(self, timeout):
received = ffi.new("PULONG")
milliseconds = round(1000 * timeout)
if timeout > 0 and milliseconds == 0:
milliseconds = 1
try:
_check(
kernel32.GetQueuedCompletionStatusEx(
self._iocp, self._events, MAX_EVENTS, received, milliseconds, 0
)
)
except OSError as exc:
if exc.winerror != ErrorCodes.WAIT_TIMEOUT: # pragma: no cover
raise
return 0
return received[0]
def process_events(self, received):
for i in range(received):
entry = self._events[i]
if entry.lpCompletionKey == CKeys.AFD_POLL:
lpo = entry.lpOverlapped
op = self._afd_ops.pop(lpo)
waiters = op.waiters
if waiters.current_op is not op:
# Stale op, nothing to do
pass
else:
waiters.current_op = None
# I don't think this can happen, so if it does let's crash
# and get a debug trace.
if lpo.Internal != 0: # pragma: no cover
code = ntdll.RtlNtStatusToDosError(lpo.Internal)
raise_winerror(code)
flags = op.poll_info.Handles[0].Events
if waiters.read_task and flags & READABLE_FLAGS:
_core.reschedule(waiters.read_task)
waiters.read_task = None
if waiters.write_task and flags & WRITABLE_FLAGS:
_core.reschedule(waiters.write_task)
waiters.write_task = None
self._refresh_afd(op.poll_info.Handles[0].Handle)
elif entry.lpCompletionKey == CKeys.WAIT_OVERLAPPED:
# Regular I/O event, dispatch on lpOverlapped
waiter = self._overlapped_waiters.pop(entry.lpOverlapped)
overlapped = entry.lpOverlapped
transferred = entry.dwNumberOfBytesTransferred
info = CompletionKeyEventInfo(
lpOverlapped=overlapped, dwNumberOfBytesTransferred=transferred
)
_core.reschedule(waiter, Value(info))
elif entry.lpCompletionKey == CKeys.LATE_CANCEL:
# Post made by a regular I/O event's abort_fn
# after it failed to cancel the I/O. If we still
# have a waiter with this lpOverlapped, we didn't
# get the regular I/O completion and almost
# certainly the user forgot to call
# register_with_iocp.
self._posted_too_late_to_cancel.remove(entry.lpOverlapped)
try:
waiter = self._overlapped_waiters.pop(entry.lpOverlapped)
except KeyError:
# Looks like the actual completion got here before this
# fallback post did -- we're in the "expected" case of
# too-late-to-cancel, where the user did nothing wrong.
# Nothing more to do.
pass
else:
exc = _core.TrioInternalError(
"Failed to cancel overlapped I/O in {} and didn't "
"receive the completion either. Did you forget to "
"call register_with_iocp()?".format(waiter.name)
)
# Raising this out of handle_io ensures that
# the user will see our message even if some
# other task is in an uncancellable wait due
# to the same underlying forgot-to-register
# issue (if their CancelIoEx succeeds, we
# have no way of noticing that their completion
# won't arrive). Unfortunately it loses the
# task traceback. If you're debugging this
# error and can't tell where it's coming from,
# try changing this line to
# _core.reschedule(waiter, outcome.Error(exc))
raise exc
elif entry.lpCompletionKey == CKeys.FORCE_WAKEUP:
pass
else:
# dispatch on lpCompletionKey
queue = self._completion_key_queues[entry.lpCompletionKey]
overlapped = int(ffi.cast("uintptr_t", entry.lpOverlapped))
transferred = entry.dwNumberOfBytesTransferred
info = CompletionKeyEventInfo(
lpOverlapped=overlapped, dwNumberOfBytesTransferred=transferred
)
queue.put_nowait(info)
def _register_with_iocp(self, handle, completion_key):
handle = _handle(handle)
_check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0))
# Supposedly this makes things slightly faster, by disabling the
# ability to do WaitForSingleObject(handle). We would never want to do
# that anyway, so might as well get the extra speed (if any).
# Ref: http://www.lenholgate.com/blog/2009/09/interesting-blog-posts-on-high-performance-servers.html
_check(
kernel32.SetFileCompletionNotificationModes(
handle, CompletionModes.FILE_SKIP_SET_EVENT_ON_HANDLE
)
)
################################################################
# AFD stuff
################################################################
def _refresh_afd(self, base_handle):
waiters = self._afd_waiters[base_handle]
if waiters.current_op is not None:
afd_group = waiters.current_op.afd_group
try:
_check(
kernel32.CancelIoEx(
afd_group.handle, waiters.current_op.lpOverlapped
)
)
except OSError as exc:
if exc.winerror != ErrorCodes.ERROR_NOT_FOUND:
# I don't think this is possible, so if it happens let's
# crash noisily.
raise # pragma: no cover
waiters.current_op = None
afd_group.size -= 1
self._vacant_afd_groups.add(afd_group)
flags = 0
if waiters.read_task is not None:
flags |= READABLE_FLAGS
if waiters.write_task is not None:
flags |= WRITABLE_FLAGS
if not flags:
del self._afd_waiters[base_handle]
else:
try:
afd_group = self._vacant_afd_groups.pop()
except KeyError:
afd_group = AFDGroup(0, _afd_helper_handle())
self._register_with_iocp(afd_group.handle, CKeys.AFD_POLL)
self._all_afd_handles.append(afd_group.handle)
self._vacant_afd_groups.add(afd_group)
lpOverlapped = ffi.new("LPOVERLAPPED")
poll_info = ffi.new("AFD_POLL_INFO *")
poll_info.Timeout = 2**63 - 1 # INT64_MAX
poll_info.NumberOfHandles = 1
poll_info.Exclusive = 0
poll_info.Handles[0].Handle = base_handle
poll_info.Handles[0].Status = 0
poll_info.Handles[0].Events = flags
try:
_check(
kernel32.DeviceIoControl(
afd_group.handle,
IoControlCodes.IOCTL_AFD_POLL,
poll_info,
ffi.sizeof("AFD_POLL_INFO"),
poll_info,
ffi.sizeof("AFD_POLL_INFO"),
ffi.NULL,
lpOverlapped,
)
)
except OSError as exc:
if exc.winerror != ErrorCodes.ERROR_IO_PENDING:
# This could happen if the socket handle got closed behind
# our back while a wait_* call was pending, and we tried
# to re-issue the call. Clear our state and wake up any
# pending calls.
del self._afd_waiters[base_handle]
# Do this last, because it could raise.
wake_all(waiters, exc)
return
op = AFDPollOp(lpOverlapped, poll_info, waiters, afd_group)
waiters.current_op = op
self._afd_ops[lpOverlapped] = op
afd_group.size += 1
if afd_group.size >= MAX_AFD_GROUP_SIZE:
self._vacant_afd_groups.remove(afd_group)
async def _afd_poll(self, sock, mode):
base_handle = _get_base_socket(sock)
waiters = self._afd_waiters.get(base_handle)
if waiters is None:
waiters = AFDWaiters()
self._afd_waiters[base_handle] = waiters
if getattr(waiters, mode) is not None:
raise _core.BusyResourceError
setattr(waiters, mode, _core.current_task())
# Could potentially raise if the handle is somehow invalid; that's OK,
# we let it escape.
self._refresh_afd(base_handle)
def abort_fn(_):
setattr(waiters, mode, None)
self._refresh_afd(base_handle)
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort_fn)
@_public
async def wait_readable(self, sock):
await self._afd_poll(sock, "read_task")
@_public
async def wait_writable(self, sock):
await self._afd_poll(sock, "write_task")
@_public
def notify_closing(self, handle):
handle = _get_base_socket(handle)
waiters = self._afd_waiters.get(handle)
if waiters is not None:
wake_all(waiters, _core.ClosedResourceError())
self._refresh_afd(handle)
################################################################
# Regular overlapped operations
################################################################
@_public
def register_with_iocp(self, handle):
self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED)
@_public
async def wait_overlapped(self, handle, lpOverlapped):
handle = _handle(handle)
if isinstance(lpOverlapped, int):
lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped)
if lpOverlapped in self._overlapped_waiters:
raise _core.BusyResourceError(
"another task is already waiting on that lpOverlapped"
)
task = _core.current_task()
self._overlapped_waiters[lpOverlapped] = task
raise_cancel = None
def abort(raise_cancel_):
nonlocal raise_cancel
raise_cancel = raise_cancel_
try:
_check(kernel32.CancelIoEx(handle, lpOverlapped))
except OSError as exc:
if exc.winerror == ErrorCodes.ERROR_NOT_FOUND:
# Too late to cancel. If this happens because the
# operation is already completed, we don't need to do
# anything; we'll get a notification of that completion
# soon. But another possibility is that the operation was
# performed on a handle that wasn't registered with our
# IOCP (ie, the user forgot to call register_with_iocp),
# in which case we're just never going to see the
# completion. To avoid an uncancellable infinite sleep in
# the latter case, we'll PostQueuedCompletionStatus here,
# and if our post arrives before the original completion
# does, we'll assume the handle wasn't registered.
_check(
kernel32.PostQueuedCompletionStatus(
self._iocp, 0, CKeys.LATE_CANCEL, lpOverlapped
)
)
# Keep the lpOverlapped referenced so its address
# doesn't get reused until our posted completion
# status has been processed. Otherwise, we can
# get confused about which completion goes with
# which I/O.
self._posted_too_late_to_cancel.add(lpOverlapped)
else: # pragma: no cover
raise _core.TrioInternalError(
"CancelIoEx failed with unexpected error"
) from exc
return _core.Abort.FAILED
info = await _core.wait_task_rescheduled(abort)
if lpOverlapped.Internal != 0:
# the lpOverlapped reports the error as an NT status code,
# which we must convert back to a Win32 error code before
# it will produce the right sorts of exceptions
code = ntdll.RtlNtStatusToDosError(lpOverlapped.Internal)
if code == ErrorCodes.ERROR_OPERATION_ABORTED:
if raise_cancel is not None:
raise_cancel()
else:
# We didn't request this cancellation, so assume
# it happened due to the underlying handle being
# closed before the operation could complete.
raise _core.ClosedResourceError("another task closed this resource")
else:
raise_winerror(code)
return info
async def _perform_overlapped(self, handle, submit_fn):
# submit_fn(lpOverlapped) submits some I/O
# it may raise an OSError with ERROR_IO_PENDING
# the handle must already be registered using
# register_with_iocp(handle)
# This always does a schedule point, but it's possible that the
# operation will not be cancellable, depending on how Windows is
# feeling today. So we need to check for cancellation manually.
await _core.checkpoint_if_cancelled()
lpOverlapped = ffi.new("LPOVERLAPPED")
try:
submit_fn(lpOverlapped)
except OSError as exc:
if exc.winerror != ErrorCodes.ERROR_IO_PENDING:
raise
await self.wait_overlapped(handle, lpOverlapped)
return lpOverlapped
@_public
async def write_overlapped(self, handle, data, file_offset=0):
with ffi.from_buffer(data) as cbuf:
def submit_write(lpOverlapped):
# yes, these are the real documented names
offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME
offset_fields.Offset = file_offset & 0xFFFFFFFF
offset_fields.OffsetHigh = file_offset >> 32
_check(
kernel32.WriteFile(
_handle(handle),
ffi.cast("LPCVOID", cbuf),
len(cbuf),
ffi.NULL,
lpOverlapped,
)
)
lpOverlapped = await self._perform_overlapped(handle, submit_write)
# this is "number of bytes transferred"
return lpOverlapped.InternalHigh
@_public
async def readinto_overlapped(self, handle, buffer, file_offset=0):
with ffi.from_buffer(buffer, require_writable=True) as cbuf:
def submit_read(lpOverlapped):
offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME
offset_fields.Offset = file_offset & 0xFFFFFFFF
offset_fields.OffsetHigh = file_offset >> 32
_check(
kernel32.ReadFile(
_handle(handle),
ffi.cast("LPVOID", cbuf),
len(cbuf),
ffi.NULL,
lpOverlapped,
)
)
lpOverlapped = await self._perform_overlapped(handle, submit_read)
return lpOverlapped.InternalHigh
################################################################
# Raw IOCP operations
################################################################
@_public
def current_iocp(self):
return int(ffi.cast("uintptr_t", self._iocp))
@contextmanager
@_public
def monitor_completion_key(self):
key = next(self._completion_key_counter)
queue = _core.UnboundedQueue()
self._completion_key_queues[key] = queue
try:
yield (key, queue)
finally:
del self._completion_key_queues[key]

View File

@@ -0,0 +1,200 @@
import inspect
import signal
import sys
from functools import wraps
import attr
import async_generator
from .._util import is_main_thread
if False:
from typing import Any, TypeVar, Callable
F = TypeVar("F", bound=Callable[..., Any])
# In ordinary single-threaded Python code, when you hit control-C, it raises
# an exception and automatically does all the regular unwinding stuff.
#
# In Trio code, we would like hitting control-C to raise an exception and
# automatically do all the regular unwinding stuff. In particular, we would
# like to maintain our invariant that all tasks always run to completion (one
# way or another), by unwinding all of them.
#
# But it's basically impossible to write the core task running code in such a
# way that it can maintain this invariant in the face of KeyboardInterrupt
# exceptions arising at arbitrary bytecode positions. Similarly, if a
# KeyboardInterrupt happened at the wrong moment inside pretty much any of our
# inter-task synchronization or I/O primitives, then the system state could
# get corrupted and prevent our being able to clean up properly.
#
# So, we need a way to defer KeyboardInterrupt processing from these critical
# sections.
#
# Things that don't work:
#
# - Listen for SIGINT and process it in a system task: works fine for
# well-behaved programs that regularly pass through the event loop, but if
# user-code goes into an infinite loop then it can't be interrupted. Which
# is unfortunate, since dealing with infinite loops is what
# KeyboardInterrupt is for!
#
# - Use pthread_sigmask to disable signal delivery during critical section:
# (a) windows has no pthread_sigmask, (b) python threads start with all
# signals unblocked, so if there are any threads around they'll receive the
# signal and then tell the main thread to run the handler, even if the main
# thread has that signal blocked.
#
# - Install a signal handler which checks a global variable to decide whether
# to raise the exception immediately (if we're in a non-critical section),
# or to schedule it on the event loop (if we're in a critical section). The
# problem here is that it's impossible to transition safely out of user code:
#
# with keyboard_interrupt_enabled:
# msg = coro.send(value)
#
# If this raises a KeyboardInterrupt, it might be because the coroutine got
# interrupted and has unwound... or it might be the KeyboardInterrupt
# arrived just *after* 'send' returned, so the coroutine is still running
# but we just lost the message it sent. (And worse, in our actual task
# runner, the send is hidden inside a utility function etc.)
#
# Solution:
#
# Mark *stack frames* as being interrupt-safe or interrupt-unsafe, and from
# the signal handler check which kind of frame we're currently in when
# deciding whether to raise or schedule the exception.
#
# There are still some cases where this can fail, like if someone hits
# control-C while the process is in the event loop, and then it immediately
# enters an infinite loop in user code. In this case the user has to hit
# control-C a second time. And of course if the user code is written so that
# it doesn't actually exit after a task crashes and everything gets cancelled,
# then there's not much to be done. (Hitting control-C repeatedly might help,
# but in general the solution is to kill the process some other way, just like
# for any Python program that's written to catch and ignore
# KeyboardInterrupt.)
# We use this special string as a unique key into the frame locals dictionary.
# The @ ensures it is not a valid identifier and can't clash with any possible
# real local name. See: https://github.com/python-trio/trio/issues/469
LOCALS_KEY_KI_PROTECTION_ENABLED = "@TRIO_KI_PROTECTION_ENABLED"
# NB: according to the signal.signal docs, 'frame' can be None on entry to
# this function:
def ki_protection_enabled(frame):
while frame is not None:
if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals:
return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED]
if frame.f_code.co_name == "__del__":
return True
frame = frame.f_back
return True
def currently_ki_protected():
r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection
enabled.
It's surprisingly easy to think that one's :exc:`KeyboardInterrupt`
protection is enabled when it isn't, or vice-versa. This function tells
you what Trio thinks of the matter, which makes it useful for ``assert``\s
and unit tests.
Returns:
bool: True if protection is enabled, and False otherwise.
"""
return ki_protection_enabled(sys._getframe())
def _ki_protection_decorator(enabled):
def decorator(fn):
# In some version of Python, isgeneratorfunction returns true for
# coroutine functions, so we have to check for coroutine functions
# first.
if inspect.iscoroutinefunction(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# See the comment for regular generators below
coro = fn(*args, **kwargs)
coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return coro
return wrapper
elif inspect.isgeneratorfunction(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# It's important that we inject this directly into the
# generator's locals, as opposed to setting it here and then
# doing 'yield from'. The reason is, if a generator is
# throw()n into, then it may magically pop to the top of the
# stack. And @contextmanager generators in particular are a
# case where we often want KI protection, and which are often
# thrown into! See:
# https://bugs.python.org/issue29590
gen = fn(*args, **kwargs)
gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return gen
return wrapper
elif async_generator.isasyncgenfunction(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# See the comment for regular generators above
agen = fn(*args, **kwargs)
agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return agen
return wrapper
else:
@wraps(fn)
def wrapper(*args, **kwargs):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return fn(*args, **kwargs)
return wrapper
return decorator
enable_ki_protection = _ki_protection_decorator(True) # type: Callable[[F], F]
enable_ki_protection.__name__ = "enable_ki_protection"
disable_ki_protection = _ki_protection_decorator(False) # type: Callable[[F], F]
disable_ki_protection.__name__ = "disable_ki_protection"
@attr.s
class KIManager:
handler = attr.ib(default=None)
def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints):
assert self.handler is None
if (
not is_main_thread()
or signal.getsignal(signal.SIGINT) != signal.default_int_handler
):
return
def handler(signum, frame):
assert signum == signal.SIGINT
protection_enabled = ki_protection_enabled(frame)
if protection_enabled or restrict_keyboard_interrupt_to_checkpoints:
deliver_cb()
else:
raise KeyboardInterrupt
self.handler = handler
signal.signal(signal.SIGINT, handler)
def close(self):
if self.handler is not None:
if signal.getsignal(signal.SIGINT) is self.handler:
signal.signal(signal.SIGINT, signal.default_int_handler)
self.handler = None

View File

@@ -0,0 +1,95 @@
# Runvar implementations
import attr
from . import _run
from .._util import Final
@attr.s(eq=False, hash=False, slots=True)
class _RunVarToken:
_no_value = object()
_var = attr.ib()
previous_value = attr.ib(default=_no_value)
redeemed = attr.ib(default=False, init=False)
@classmethod
def empty(cls, var):
return cls(var)
@attr.s(eq=False, hash=False, slots=True)
class RunVar(metaclass=Final):
"""The run-local variant of a context variable.
:class:`RunVar` objects are similar to context variable objects,
except that they are shared across a single call to :func:`trio.run`
rather than a single task.
"""
_NO_DEFAULT = object()
_name = attr.ib()
_default = attr.ib(default=_NO_DEFAULT)
def get(self, default=_NO_DEFAULT):
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
return _run.GLOBAL_RUN_CONTEXT.runner._locals[self]
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
# contextvars consistency
if default is not self._NO_DEFAULT:
return default
if self._default is not self._NO_DEFAULT:
return self._default
raise LookupError(self) from None
def set(self, value):
"""Sets the value of this :class:`RunVar` for this current run
call.
"""
try:
old_value = self.get()
except LookupError:
token = _RunVarToken.empty(self)
else:
token = _RunVarToken(self, old_value)
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
return token
def reset(self, token):
"""Resets the value of this :class:`RunVar` to what it was
previously specified by the token.
"""
if token is None:
raise TypeError("token must not be none")
if token.redeemed:
raise ValueError("token has already been used")
if token._var is not self:
raise ValueError("token is not for us")
previous = token.previous_value
try:
if previous is _RunVarToken._no_value:
_run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
else:
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context")
token.redeemed = True
def __repr__(self):
return "<RunVar name={!r}>".format(self._name)

View File

@@ -0,0 +1,165 @@
import time
from math import inf
from .. import _core
from ._run import GLOBAL_RUN_CONTEXT
from .._abc import Clock
from .._util import Final
################################################################
# The glorious MockClock
################################################################
# Prior art:
# https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html
# https://github.com/ztellman/manifold/issues/57
class MockClock(Clock, metaclass=Final):
"""A user-controllable clock suitable for writing tests.
Args:
rate (float): the initial :attr:`rate`.
autojump_threshold (float): the initial :attr:`autojump_threshold`.
.. attribute:: rate
How many seconds of clock time pass per second of real time. Default is
0.0, i.e. the clock only advances through manuals calls to :meth:`jump`
or when the :attr:`autojump_threshold` is triggered. You can assign to
this attribute to change it.
.. attribute:: autojump_threshold
The clock keeps an eye on the run loop, and if at any point it detects
that all tasks have been blocked for this many real seconds (i.e.,
according to the actual clock, not this clock), then the clock
automatically jumps ahead to the run loop's next scheduled
timeout. Default is :data:`math.inf`, i.e., to never autojump. You can
assign to this attribute to change it.
Basically the idea is that if you have code or tests that use sleeps
and timeouts, you can use this to make it run much faster, totally
automatically. (At least, as long as those sleeps/timeouts are
happening inside Trio; if your test involves talking to external
service and waiting for it to timeout then obviously we can't help you
there.)
You should set this to the smallest value that lets you reliably avoid
"false alarms" where some I/O is in flight (e.g. between two halves of
a socketpair) but the threshold gets triggered and time gets advanced
anyway. This will depend on the details of your tests and test
environment. If you aren't doing any I/O (like in our sleeping example
above) then just set it to zero, and the clock will jump whenever all
tasks are blocked.
.. note:: If you use ``autojump_threshold`` and
`wait_all_tasks_blocked` at the same time, then you might wonder how
they interact, since they both cause things to happen after the run
loop goes idle for some time. The answer is:
`wait_all_tasks_blocked` takes priority. If there's a task blocked
in `wait_all_tasks_blocked`, then the autojump feature treats that
as active task and does *not* jump the clock.
"""
def __init__(self, rate=0.0, autojump_threshold=inf):
# when the real clock said 'real_base', the virtual time was
# 'virtual_base', and since then it's advanced at 'rate' virtual
# seconds per real second.
self._real_base = 0.0
self._virtual_base = 0.0
self._rate = 0.0
self._autojump_threshold = 0.0
# kept as an attribute so that our tests can monkeypatch it
self._real_clock = time.perf_counter
# use the property update logic to set initial values
self.rate = rate
self.autojump_threshold = autojump_threshold
def __repr__(self):
return "<MockClock, time={:.7f}, rate={} @ {:#x}>".format(
self.current_time(), self._rate, id(self)
)
@property
def rate(self):
return self._rate
@rate.setter
def rate(self, new_rate):
if new_rate < 0:
raise ValueError("rate must be >= 0")
else:
real = self._real_clock()
virtual = self._real_to_virtual(real)
self._virtual_base = virtual
self._real_base = real
self._rate = float(new_rate)
@property
def autojump_threshold(self):
return self._autojump_threshold
@autojump_threshold.setter
def autojump_threshold(self, new_autojump_threshold):
self._autojump_threshold = float(new_autojump_threshold)
self._try_resync_autojump_threshold()
# runner.clock_autojump_threshold is an internal API that isn't easily
# usable by custom third-party Clock objects. If you need access to this
# functionality, let us know, and we'll figure out how to make a public
# API. Discussion:
#
# https://github.com/python-trio/trio/issues/1587
def _try_resync_autojump_threshold(self):
try:
runner = GLOBAL_RUN_CONTEXT.runner
if runner.is_guest:
runner.force_guest_tick_asap()
except AttributeError:
pass
else:
runner.clock_autojump_threshold = self._autojump_threshold
# Invoked by the run loop when runner.clock_autojump_threshold is
# exceeded.
def _autojump(self):
statistics = _core.current_statistics()
jump = statistics.seconds_to_next_deadline
if 0 < jump < inf:
self.jump(jump)
def _real_to_virtual(self, real):
real_offset = real - self._real_base
virtual_offset = self._rate * real_offset
return self._virtual_base + virtual_offset
def start_clock(self):
self._try_resync_autojump_threshold()
def current_time(self):
return self._real_to_virtual(self._real_clock())
def deadline_to_sleep_time(self, deadline):
virtual_timeout = deadline - self.current_time()
if virtual_timeout <= 0:
return 0
elif self._rate > 0:
return virtual_timeout / self._rate
else:
return 999999999
def jump(self, seconds):
"""Manually advance the clock by the given number of seconds.
Args:
seconds (float): the number of seconds to jump the clock forward.
Raises:
ValueError: if you try to pass a negative value for ``seconds``.
"""
if seconds < 0:
raise ValueError("time can't go backwards")
self._virtual_base += seconds

View File

@@ -0,0 +1,516 @@
import sys
import traceback
import textwrap
import warnings
import attr
################################################################
# MultiError
################################################################
def _filter_impl(handler, root_exc):
# We have a tree of MultiError's, like:
#
# MultiError([
# ValueError,
# MultiError([
# KeyError,
# ValueError,
# ]),
# ])
#
# or similar.
#
# We want to
# 1) apply the filter to each of the leaf exceptions -- each leaf
# might stay the same, be replaced (with the original exception
# potentially sticking around as __context__ or __cause__), or
# disappear altogether.
# 2) simplify the resulting tree -- remove empty nodes, and replace
# singleton MultiError's with their contents, e.g.:
# MultiError([KeyError]) -> KeyError
# (This can happen recursively, e.g. if the two ValueErrors above
# get caught then we'll just be left with a bare KeyError.)
# 3) preserve sensible tracebacks
#
# It's the tracebacks that are most confusing. As a MultiError
# propagates through the stack, it accumulates traceback frames, but
# the exceptions inside it don't. Semantically, the traceback for a
# leaf exception is the concatenation the tracebacks of all the
# exceptions you see when traversing the exception tree from the root
# to that leaf. Our correctness invariant is that this concatenated
# traceback should be the same before and after.
#
# The easy way to do that would be to, at the beginning of this
# function, "push" all tracebacks down to the leafs, so all the
# MultiErrors have __traceback__=None, and all the leafs have complete
# tracebacks. But whenever possible, we'd actually prefer to keep
# tracebacks as high up in the tree as possible, because this lets us
# keep only a single copy of the common parts of these exception's
# tracebacks. This is cheaper (in memory + time -- tracebacks are
# unpleasantly quadratic-ish to work with, and this might matter if
# you have thousands of exceptions, which can happen e.g. after
# cancelling a large task pool, and no-one will ever look at their
# tracebacks!), and more importantly, factoring out redundant parts of
# the tracebacks makes them more readable if/when users do see them.
#
# So instead our strategy is:
# - first go through and construct the new tree, preserving any
# unchanged subtrees
# - then go through the original tree (!) and push tracebacks down
# until either we hit a leaf, or we hit a subtree which was
# preserved in the new tree.
# This used to also support async handler functions. But that runs into:
# https://bugs.python.org/issue29600
# which is difficult to fix on our end.
# Filters a subtree, ignoring tracebacks, while keeping a record of
# which MultiErrors were preserved unchanged
def filter_tree(exc, preserved):
if isinstance(exc, MultiError):
new_exceptions = []
changed = False
for child_exc in exc.exceptions:
new_child_exc = filter_tree(child_exc, preserved)
if new_child_exc is not child_exc:
changed = True
if new_child_exc is not None:
new_exceptions.append(new_child_exc)
if not new_exceptions:
return None
elif changed:
return MultiError(new_exceptions)
else:
preserved.add(id(exc))
return exc
else:
new_exc = handler(exc)
# Our version of implicit exception chaining
if new_exc is not None and new_exc is not exc:
new_exc.__context__ = exc
return new_exc
def push_tb_down(tb, exc, preserved):
if id(exc) in preserved:
return
new_tb = concat_tb(tb, exc.__traceback__)
if isinstance(exc, MultiError):
for child_exc in exc.exceptions:
push_tb_down(new_tb, child_exc, preserved)
exc.__traceback__ = None
else:
exc.__traceback__ = new_tb
preserved = set()
new_root_exc = filter_tree(root_exc, preserved)
push_tb_down(None, root_exc, preserved)
# Delete the local functions to avoid a reference cycle (see
# test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage)
del filter_tree, push_tb_down
return new_root_exc
# Normally I'm a big fan of (a)contextmanager, but in this case I found it
# easier to use the raw context manager protocol, because it makes it a lot
# easier to reason about how we're mutating the traceback as we go. (End
# result: if the exception gets modified, then the 'raise' here makes this
# frame show up in the traceback; otherwise, we leave no trace.)
@attr.s(frozen=True)
class MultiErrorCatcher:
_handler = attr.ib()
def __enter__(self):
pass
def __exit__(self, etype, exc, tb):
if exc is not None:
filtered_exc = MultiError.filter(self._handler, exc)
if filtered_exc is exc:
# Let the interpreter re-raise it
return False
if filtered_exc is None:
# Swallow the exception
return True
# When we raise filtered_exc, Python will unconditionally blow
# away its __context__ attribute and replace it with the original
# exc we caught. So after we raise it, we have to pause it while
# it's in flight to put the correct __context__ back.
old_context = filtered_exc.__context__
try:
raise filtered_exc
finally:
_, value, _ = sys.exc_info()
assert value is filtered_exc
value.__context__ = old_context
# delete references from locals to avoid creating cycles
# see test_MultiError_catch_doesnt_create_cyclic_garbage
del _, filtered_exc, value
class MultiError(BaseException):
"""An exception that contains other exceptions; also known as an
"inception".
It's main use is to represent the situation when multiple child tasks all
raise errors "in parallel".
Args:
exceptions (list): The exceptions
Returns:
If ``len(exceptions) == 1``, returns that exception. This means that a
call to ``MultiError(...)`` is not guaranteed to return a
:exc:`MultiError` object!
Otherwise, returns a new :exc:`MultiError` object.
Raises:
TypeError: if any of the passed in objects are not instances of
:exc:`BaseException`.
"""
def __init__(self, exceptions):
# Avoid recursion when exceptions[0] returned by __new__() happens
# to be a MultiError and subsequently __init__() is called.
if hasattr(self, "exceptions"):
# __init__ was already called on this object
assert len(exceptions) == 1 and exceptions[0] is self
return
self.exceptions = exceptions
def __new__(cls, exceptions):
exceptions = list(exceptions)
for exc in exceptions:
if not isinstance(exc, BaseException):
raise TypeError("Expected an exception object, not {!r}".format(exc))
if len(exceptions) == 1:
# If this lone object happens to itself be a MultiError, then
# Python will implicitly call our __init__ on it again. See
# special handling in __init__.
return exceptions[0]
else:
# The base class __new__() implicitly invokes our __init__, which
# is what we want.
#
# In an earlier version of the code, we didn't define __init__ and
# simply set the `exceptions` attribute directly on the new object.
# However, linters expect attributes to be initialized in __init__.
return BaseException.__new__(cls, exceptions)
def __str__(self):
return ", ".join(repr(exc) for exc in self.exceptions)
def __repr__(self):
return "<MultiError: {}>".format(self)
@classmethod
def filter(cls, handler, root_exc):
"""Apply the given ``handler`` to all the exceptions in ``root_exc``.
Args:
handler: A callable that takes an atomic (non-MultiError) exception
as input, and returns either a new exception object or None.
root_exc: An exception, often (though not necessarily) a
:exc:`MultiError`.
Returns:
A new exception object in which each component exception ``exc`` has
been replaced by the result of running ``handler(exc)`` or, if
``handler`` returned None for all the inputs, returns None.
"""
return _filter_impl(handler, root_exc)
@classmethod
def catch(cls, handler):
"""Return a context manager that catches and re-throws exceptions
after running :meth:`filter` on them.
Args:
handler: as for :meth:`filter`
"""
return MultiErrorCatcher(handler)
# Clean up exception printing:
MultiError.__module__ = "trio"
################################################################
# concat_tb
################################################################
# We need to compute a new traceback that is the concatenation of two existing
# tracebacks. This requires copying the entries in 'head' and then pointing
# the final tb_next to 'tail'.
#
# NB: 'tail' might be None, which requires some special handling in the ctypes
# version.
#
# The complication here is that Python doesn't actually support copying or
# modifying traceback objects, so we have to get creative...
#
# On CPython, we use ctypes. On PyPy, we use "transparent proxies".
#
# Jinja2 is a useful source of inspiration:
# https://github.com/pallets/jinja/blob/master/jinja2/debug.py
try:
import tputil
except ImportError:
have_tproxy = False
else:
have_tproxy = True
if have_tproxy:
# http://doc.pypy.org/en/latest/objspace-proxies.html
def copy_tb(base_tb, tb_next):
def controller(operation):
# Rationale for pragma: I looked fairly carefully and tried a few
# things, and AFAICT it's not actually possible to get any
# 'opname' that isn't __getattr__ or __getattribute__. So there's
# no missing test we could add, and no value in coverage nagging
# us about adding one.
if operation.opname in [
"__getattribute__",
"__getattr__",
]: # pragma: no cover
if operation.args[0] == "tb_next":
return tb_next
return operation.delegate()
return tputil.make_proxy(controller, type(base_tb), base_tb)
else:
# ctypes it is
import ctypes
# How to handle refcounting? I don't want to use ctypes.py_object because
# I don't understand or trust it, and I don't want to use
# ctypes.pythonapi.Py_{Inc,Dec}Ref because we might clash with user code
# that also tries to use them but with different types. So private _ctypes
# APIs it is!
import _ctypes
class CTraceback(ctypes.Structure):
_fields_ = [
("PyObject_HEAD", ctypes.c_byte * object().__sizeof__()),
("tb_next", ctypes.c_void_p),
("tb_frame", ctypes.c_void_p),
("tb_lasti", ctypes.c_int),
("tb_lineno", ctypes.c_int),
]
def copy_tb(base_tb, tb_next):
# TracebackType has no public constructor, so allocate one the hard way
try:
raise ValueError
except ValueError as exc:
new_tb = exc.__traceback__
c_new_tb = CTraceback.from_address(id(new_tb))
# At the C level, tb_next either pointer to the next traceback or is
# NULL. c_void_p and the .tb_next accessor both convert NULL to None,
# but we shouldn't DECREF None just because we assigned to a NULL
# pointer! Here we know that our new traceback has only 1 frame in it,
# so we can assume the tb_next field is NULL.
assert c_new_tb.tb_next is None
# If tb_next is None, then we want to set c_new_tb.tb_next to NULL,
# which it already is, so we're done. Otherwise, we have to actually
# do some work:
if tb_next is not None:
_ctypes.Py_INCREF(tb_next)
c_new_tb.tb_next = id(tb_next)
assert c_new_tb.tb_frame is not None
_ctypes.Py_INCREF(base_tb.tb_frame)
old_tb_frame = new_tb.tb_frame
c_new_tb.tb_frame = id(base_tb.tb_frame)
_ctypes.Py_DECREF(old_tb_frame)
c_new_tb.tb_lasti = base_tb.tb_lasti
c_new_tb.tb_lineno = base_tb.tb_lineno
try:
return new_tb
finally:
# delete references from locals to avoid creating cycles
# see test_MultiError_catch_doesnt_create_cyclic_garbage
del new_tb, old_tb_frame
def concat_tb(head, tail):
# We have to use an iterative algorithm here, because in the worst case
# this might be a RecursionError stack that is by definition too deep to
# process by recursion!
head_tbs = []
pointer = head
while pointer is not None:
head_tbs.append(pointer)
pointer = pointer.tb_next
current_head = tail
for head_tb in reversed(head_tbs):
current_head = copy_tb(head_tb, tb_next=current_head)
return current_head
################################################################
# MultiError traceback formatting
#
# What follows is terrible, terrible monkey patching of
# traceback.TracebackException to add support for handling
# MultiErrors
################################################################
traceback_exception_original_init = traceback.TracebackException.__init__
def traceback_exception_init(
self,
exc_type,
exc_value,
exc_traceback,
*,
limit=None,
lookup_lines=True,
capture_locals=False,
compact=False,
_seen=None,
):
if sys.version_info >= (3, 10):
kwargs = {"compact": compact}
else:
kwargs = {}
# Capture the original exception and its cause and context as TracebackExceptions
traceback_exception_original_init(
self,
exc_type,
exc_value,
exc_traceback,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
_seen=_seen,
**kwargs,
)
seen_was_none = _seen is None
if _seen is None:
_seen = set()
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
embedded = []
for exc in exc_value.exceptions:
if id(exc) not in _seen:
embedded.append(
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=None if seen_was_none else set(_seen),
)
)
self.embedded = embedded
else:
self.embedded = []
traceback.TracebackException.__init__ = traceback_exception_init # type: ignore
traceback_exception_original_format = traceback.TracebackException.format
def traceback_exception_format(self, *, chain=True):
yield from traceback_exception_original_format(self, chain=chain)
for i, exc in enumerate(self.embedded):
yield "\nDetails of embedded exception {}:\n\n".format(i + 1)
yield from (textwrap.indent(line, " " * 2) for line in exc.format(chain=chain))
traceback.TracebackException.format = traceback_exception_format # type: ignore
def trio_excepthook(etype, value, tb):
for chunk in traceback.format_exception(etype, value, tb):
sys.stderr.write(chunk)
monkeypatched_or_warned = False
if "IPython" in sys.modules:
import IPython
ip = IPython.get_ipython()
if ip is not None:
if ip.custom_exceptions != ():
warnings.warn(
"IPython detected, but you already have a custom exception "
"handler installed. I'll skip installing Trio's custom "
"handler, but this means MultiErrors will not show full "
"tracebacks.",
category=RuntimeWarning,
)
monkeypatched_or_warned = True
else:
def trio_show_traceback(self, etype, value, tb, tb_offset=None):
# XX it would be better to integrate with IPython's fancy
# exception formatting stuff (and not ignore tb_offset)
trio_excepthook(etype, value, tb)
ip.set_custom_exc((MultiError,), trio_show_traceback)
monkeypatched_or_warned = True
if sys.excepthook is sys.__excepthook__:
sys.excepthook = trio_excepthook
monkeypatched_or_warned = True
# Ubuntu's system Python has a sitecustomize.py file that import
# apport_python_hook and replaces sys.excepthook.
#
# The custom hook captures the error for crash reporting, and then calls
# sys.__excepthook__ to actually print the error.
#
# We don't mind it capturing the error for crash reporting, but we want to
# take over printing the error. So we monkeypatch the apport_python_hook
# module so that instead of calling sys.__excepthook__, it calls our custom
# hook.
#
# More details: https://github.com/python-trio/trio/issues/1065
if getattr(sys.excepthook, "__name__", None) == "apport_excepthook":
import apport_python_hook
assert sys.excepthook is apport_python_hook.apport_excepthook
# Give it a descriptive name as a hint for anyone who's stuck trying to
# debug this mess later.
class TrioFakeSysModuleForApport:
pass
fake_sys = TrioFakeSysModuleForApport()
fake_sys.__dict__.update(sys.__dict__)
fake_sys.__excepthook__ = trio_excepthook # type: ignore
apport_python_hook.sys = fake_sys
monkeypatched_or_warned = True
if not monkeypatched_or_warned:
warnings.warn(
"You seem to already have a custom sys.excepthook handler "
"installed. I'll skip installing Trio's custom handler, but this "
"means MultiErrors will not show full tracebacks.",
category=RuntimeWarning,
)

View File

@@ -0,0 +1,215 @@
# ParkingLot provides an abstraction for a fair waitqueue with cancellation
# and requeueing support. Inspiration:
#
# https://webkit.org/blog/6161/locking-in-webkit/
# https://amanieu.github.io/parking_lot/
#
# which were in turn heavily influenced by
#
# http://gee.cs.oswego.edu/dl/papers/aqs.pdf
#
# Compared to these, our use of cooperative scheduling allows some
# simplifications (no need for internal locking). On the other hand, the need
# to support Trio's strong cancellation semantics adds some complications
# (tasks need to know where they're queued so they can cancel). Also, in the
# above work, the ParkingLot is a global structure that holds a collection of
# waitqueues keyed by lock address, and which are opportunistically allocated
# and destroyed as contention arises; this allows the worst-case memory usage
# for all waitqueues to be O(#tasks). Here we allocate a separate wait queue
# for each synchronization object, so we're O(#objects + #tasks). This isn't
# *so* bad since compared to our synchronization objects are heavier than
# theirs and our tasks are lighter, so for us #objects is smaller and #tasks
# is larger.
#
# This is in the core because for two reasons. First, it's used by
# UnboundedQueue, and UnboundedQueue is used for a number of things in the
# core. And second, it's responsible for providing fairness to all of our
# high-level synchronization primitives (locks, queues, etc.). For now with
# our FIFO scheduler this is relatively trivial (it's just a FIFO waitqueue),
# but in the future we ever start support task priorities or fair scheduling
#
# https://github.com/python-trio/trio/issues/32
#
# then all we'll have to do is update this. (Well, full-fledged task
# priorities might also require priority inheritance, which would require more
# work.)
#
# For discussion of data structures to use here, see:
#
# https://github.com/dabeaz/curio/issues/136
#
# (and also the articles above). Currently we use a SortedDict ordered by a
# global monotonic counter that ensures FIFO ordering. The main advantage of
# this is that it's easy to implement :-). An intrusive doubly-linked list
# would also be a natural approach, so long as we only handle FIFO ordering.
#
# XX: should we switch to the shared global ParkingLot approach?
#
# XX: we should probably add support for "parking tokens" to allow for
# task-fair RWlock (basically: when parking a task needs to be able to mark
# itself as a reader or a writer, and then a task-fair wakeup policy is, wake
# the next task, and if it's a reader than keep waking tasks so long as they
# are readers). Without this I think you can implement write-biased or
# read-biased RWlocks (by using two parking lots and drawing from whichever is
# preferred), but not task-fair -- and task-fair plays much more nicely with
# WFQ. (Consider what happens in the two-lot implementation if you're
# write-biased but all the pending writers are blocked at the scheduler level
# by the WFQ logic...)
# ...alternatively, "phase-fair" RWlocks are pretty interesting:
# http://www.cs.unc.edu/~anderson/papers/ecrts09b.pdf
# Useful summary:
# https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/locks/ReadWriteLock.html
#
# XX: if we do add WFQ, then we might have to drop the current feature where
# unpark returns the tasks that were unparked. Rationale: suppose that at the
# time we call unpark, the next task is deprioritized... and then, before it
# becomes runnable, a new task parks which *is* runnable. Ideally we should
# immediately wake the new task, and leave the old task on the queue for
# later. But this means we can't commit to which task we are unparking when
# unpark is called.
#
# See: https://github.com/python-trio/trio/issues/53
import attr
from collections import OrderedDict
from .. import _core
from .._util import Final
@attr.s(frozen=True, slots=True)
class _ParkingLotStatistics:
tasks_waiting = attr.ib()
@attr.s(eq=False, hash=False, slots=True)
class ParkingLot(metaclass=Final):
"""A fair wait queue with cancellation and requeueing.
This class encapsulates the tricky parts of implementing a wait
queue. It's useful for implementing higher-level synchronization
primitives like queues and locks.
In addition to the methods below, you can use ``len(parking_lot)`` to get
the number of parked tasks, and ``if parking_lot: ...`` to check whether
there are any parked tasks.
"""
# {task: None}, we just want a deque where we can quickly delete random
# items
_parked = attr.ib(factory=OrderedDict, init=False)
def __len__(self):
"""Returns the number of parked tasks."""
return len(self._parked)
def __bool__(self):
"""True if there are parked tasks, False otherwise."""
return bool(self._parked)
# XX this currently returns None
# if we ever add the ability to repark while one's resuming place in
# line (for false wakeups), then we could have it return a ticket that
# abstracts the "place in line" concept.
@_core.enable_ki_protection
async def park(self):
"""Park the current task until woken by a call to :meth:`unpark` or
:meth:`unpark_all`.
"""
task = _core.current_task()
self._parked[task] = None
task.custom_sleep_data = self
def abort_fn(_):
del task.custom_sleep_data._parked[task]
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort_fn)
def _pop_several(self, count):
for _ in range(min(count, len(self._parked))):
task, _ = self._parked.popitem(last=False)
yield task
@_core.enable_ki_protection
def unpark(self, *, count=1):
"""Unpark one or more tasks.
This wakes up ``count`` tasks that are blocked in :meth:`park`. If
there are fewer than ``count`` tasks parked, then wakes as many tasks
are available and then returns successfully.
Args:
count (int): the number of tasks to unpark.
"""
tasks = list(self._pop_several(count))
for task in tasks:
_core.reschedule(task)
return tasks
def unpark_all(self):
"""Unpark all parked tasks."""
return self.unpark(count=len(self))
@_core.enable_ki_protection
def repark(self, new_lot, *, count=1):
"""Move parked tasks from one :class:`ParkingLot` object to another.
This dequeues ``count`` tasks from one lot, and requeues them on
another, preserving order. For example::
async def parker(lot):
print("sleeping")
await lot.park()
print("woken")
async def main():
lot1 = trio.lowlevel.ParkingLot()
lot2 = trio.lowlevel.ParkingLot()
async with trio.open_nursery() as nursery:
nursery.start_soon(parker, lot1)
await trio.testing.wait_all_tasks_blocked()
assert len(lot1) == 1
assert len(lot2) == 0
lot1.repark(lot2)
assert len(lot1) == 0
assert len(lot2) == 1
# This wakes up the task that was originally parked in lot1
lot2.unpark()
If there are fewer than ``count`` tasks parked, then reparks as many
tasks as are available and then returns successfully.
Args:
new_lot (ParkingLot): the parking lot to move tasks to.
count (int): the number of tasks to move.
"""
if not isinstance(new_lot, ParkingLot):
raise TypeError("new_lot must be a ParkingLot")
for task in self._pop_several(count):
new_lot._parked[task] = None
task.custom_sleep_data = new_lot
def repark_all(self, new_lot):
"""Move all parked tasks from one :class:`ParkingLot` object to
another.
See :meth:`repark` for details.
"""
return self.repark(new_lot, count=len(self))
def statistics(self):
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this lot's
:meth:`park` method.
"""
return _ParkingLotStatistics(tasks_waiting=len(self._parked))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,171 @@
from threading import Thread, Lock
import outcome
from itertools import count
# The "thread cache" is a simple unbounded thread pool, i.e., it automatically
# spawns as many threads as needed to handle all the requests its given. Its
# only purpose is to cache worker threads so that they don't have to be
# started from scratch every time we want to delegate some work to a thread.
# It's expected that some higher-level code will track how many threads are in
# use to avoid overwhelming the system (e.g. the limiter= argument to
# trio.to_thread.run_sync).
#
# To maximize sharing, there's only one thread cache per process, even if you
# have multiple calls to trio.run.
#
# Guarantees:
#
# It's safe to call start_thread_soon simultaneously from
# multiple threads.
#
# Idle threads are chosen in LIFO order, i.e. we *don't* spread work evenly
# over all threads. Instead we try to let some threads do most of the work
# while others sit idle as much as possible. Compared to FIFO, this has better
# memory cache behavior, and it makes it easier to detect when we have too
# many threads, so idle ones can exit.
#
# This code assumes that 'dict' has the following properties:
#
# - __setitem__, __delitem__, and popitem are all thread-safe and atomic with
# respect to each other. This is guaranteed by the GIL.
#
# - popitem returns the most-recently-added item (i.e., __setitem__ + popitem
# give you a LIFO queue). This relies on dicts being insertion-ordered, like
# they are in py36+.
# How long a thread will idle waiting for new work before gives up and exits.
# This value is pretty arbitrary; I don't think it matters too much.
IDLE_TIMEOUT = 10 # seconds
name_counter = count()
class WorkerThread:
def __init__(self, thread_cache):
self._job = None
self._thread_cache = thread_cache
# This Lock is used in an unconventional way.
#
# "Unlocked" means we have a pending job that's been assigned to us;
# "locked" means that we don't.
#
# Initially we have no job, so it starts out in locked state.
self._worker_lock = Lock()
self._worker_lock.acquire()
thread = Thread(target=self._work, daemon=True)
thread.name = f"Trio worker thread {next(name_counter)}"
thread.start()
def _handle_job(self):
# Handle job in a separate method to ensure user-created
# objects are cleaned up in a consistent manner.
fn, deliver = self._job
self._job = None
result = outcome.capture(fn)
# Tell the cache that we're available to be assigned a new
# job. We do this *before* calling 'deliver', so that if
# 'deliver' triggers a new job, it can be assigned to us
# instead of spawning a new thread.
self._thread_cache._idle_workers[self] = None
deliver(result)
def _work(self):
while True:
if self._worker_lock.acquire(timeout=IDLE_TIMEOUT):
# We got a job
self._handle_job()
else:
# Timeout acquiring lock, so we can probably exit. But,
# there's a race condition: we might be assigned a job *just*
# as we're about to exit. So we have to check.
try:
del self._thread_cache._idle_workers[self]
except KeyError:
# Someone else removed us from the idle worker queue, so
# they must be in the process of assigning us a job - loop
# around and wait for it.
continue
else:
# We successfully removed ourselves from the idle
# worker queue, so no more jobs are incoming; it's safe to
# exit.
return
class ThreadCache:
def __init__(self):
self._idle_workers = {}
def start_thread_soon(self, fn, deliver):
try:
worker, _ = self._idle_workers.popitem()
except KeyError:
worker = WorkerThread(self)
worker._job = (fn, deliver)
worker._worker_lock.release()
THREAD_CACHE = ThreadCache()
def start_thread_soon(fn, deliver):
"""Runs ``deliver(outcome.capture(fn))`` in a worker thread.
Generally ``fn`` does some blocking work, and ``deliver`` delivers the
result back to whoever is interested.
This is a low-level, no-frills interface, very similar to using
`threading.Thread` to spawn a thread directly. The main difference is
that this function tries to re-use threads when possible, so it can be
a bit faster than `threading.Thread`.
Worker threads have the `~threading.Thread.daemon` flag set, which means
that if your main thread exits, worker threads will automatically be
killed. If you want to make sure that your ``fn`` runs to completion, then
you should make sure that the main thread remains alive until ``deliver``
is called.
It is safe to call this function simultaneously from multiple threads.
Args:
fn (sync function): Performs arbitrary blocking work.
deliver (sync function): Takes the `outcome.Outcome` of ``fn``, and
delivers it. *Must not block.*
Because worker threads are cached and reused for multiple calls, neither
function should mutate thread-level state, like `threading.local` objects
or if they do, they should be careful to revert their changes before
returning.
Note:
The split between ``fn`` and ``deliver`` serves two purposes. First,
it's convenient, since most callers need something like this anyway.
Second, it avoids a small race condition that could cause too many
threads to be spawned. Consider a program that wants to run several
jobs sequentially on a thread, so the main thread submits a job, waits
for it to finish, submits another job, etc. In theory, this program
should only need one worker thread. But what could happen is:
1. Worker thread: First job finishes, and calls ``deliver``.
2. Main thread: receives notification that the job finished, and calls
``start_thread_soon``.
3. Main thread: sees that no worker threads are marked idle, so spawns
a second worker thread.
4. Original worker thread: marks itself as idle.
To avoid this, threads mark themselves as idle *before* calling
``deliver``.
Is this potential extra thread a major problem? Maybe not, but it's
easy enough to avoid, and we figure that if the user is trying to
limit how many threads they're using then it's polite to respect that.
"""
THREAD_CACHE.start_thread_soon(fn, deliver)

View File

@@ -0,0 +1,270 @@
# These are the only functions that ever yield back to the task runner.
import types
import enum
import attr
import outcome
from . import _run
# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
# function, but you can inside a generator, and if you decorate your generator
# with @types.coroutine, then it's even awaitable. However, it's still not a
# real async function: in particular, it isn't recognized by
# inspect.iscoroutinefunction, and it doesn't trigger the unawaited coroutine
# tracking machinery. Since our traps are public APIs, we make them real async
# functions, and then this helper takes care of the actual yield:
@types.coroutine
def _async_yield(obj):
return (yield obj)
# This class object is used as a singleton.
# Not exported in the trio._core namespace, but imported directly by _run.
class CancelShieldedCheckpoint:
pass
async def cancel_shielded_checkpoint():
"""Introduce a schedule point, but not a cancel point.
This is *not* a :ref:`checkpoint <checkpoints>`, but it is half of a
checkpoint, and when combined with :func:`checkpoint_if_cancelled` it can
make a full checkpoint.
Equivalent to (but potentially more efficient than)::
with trio.CancelScope(shield=True):
await trio.lowlevel.checkpoint()
"""
return (await _async_yield(CancelShieldedCheckpoint)).unwrap()
# Return values for abort functions
class Abort(enum.Enum):
""":class:`enum.Enum` used as the return value from abort functions.
See :func:`wait_task_rescheduled` for details.
.. data:: SUCCEEDED
FAILED
"""
SUCCEEDED = 1
FAILED = 2
# Not exported in the trio._core namespace, but imported directly by _run.
@attr.s(frozen=True)
class WaitTaskRescheduled:
abort_func = attr.ib()
async def wait_task_rescheduled(abort_func):
"""Put the current task to sleep, with cancellation support.
This is the lowest-level API for blocking in Trio. Every time a
:class:`~trio.lowlevel.Task` blocks, it does so by calling this function
(usually indirectly via some higher-level API).
This is a tricky interface with no guard rails. If you can use
:class:`ParkingLot` or the built-in I/O wait functions instead, then you
should.
Generally the way it works is that before calling this function, you make
arrangements for "someone" to call :func:`reschedule` on the current task
at some later point.
Then you call :func:`wait_task_rescheduled`, passing in ``abort_func``, an
"abort callback".
(Terminology: in Trio, "aborting" is the process of attempting to
interrupt a blocked task to deliver a cancellation.)
There are two possibilities for what happens next:
1. "Someone" calls :func:`reschedule` on the current task, and
:func:`wait_task_rescheduled` returns or raises whatever value or error
was passed to :func:`reschedule`.
2. The call's context transitions to a cancelled state (e.g. due to a
timeout expiring). When this happens, the ``abort_func`` is called. Its
interface looks like::
def abort_func(raise_cancel):
...
return trio.lowlevel.Abort.SUCCEEDED # or FAILED
It should attempt to clean up any state associated with this call, and
in particular, arrange that :func:`reschedule` will *not* be called
later. If (and only if!) it is successful, then it should return
:data:`Abort.SUCCEEDED`, in which case the task will automatically be
rescheduled with an appropriate :exc:`~trio.Cancelled` error.
Otherwise, it should return :data:`Abort.FAILED`. This means that the
task can't be cancelled at this time, and still has to make sure that
"someone" eventually calls :func:`reschedule`.
At that point there are again two possibilities. You can simply ignore
the cancellation altogether: wait for the operation to complete and
then reschedule and continue as normal. (For example, this is what
:func:`trio.to_thread.run_sync` does if cancellation is disabled.)
The other possibility is that the ``abort_func`` does succeed in
cancelling the operation, but for some reason isn't able to report that
right away. (Example: on Windows, it's possible to request that an
async ("overlapped") I/O operation be cancelled, but this request is
*also* asynchronous you don't find out until later whether the
operation was actually cancelled or not.) To report a delayed
cancellation, then you should reschedule the task yourself, and call
the ``raise_cancel`` callback passed to ``abort_func`` to raise a
:exc:`~trio.Cancelled` (or possibly :exc:`KeyboardInterrupt`) exception
into this task. Either of the approaches sketched below can work::
# Option 1:
# Catch the exception from raise_cancel and inject it into the task.
# (This is what Trio does automatically for you if you return
# Abort.SUCCEEDED.)
trio.lowlevel.reschedule(task, outcome.capture(raise_cancel))
# Option 2:
# wait to be woken by "someone", and then decide whether to raise
# the error from inside the task.
outer_raise_cancel = None
def abort(inner_raise_cancel):
nonlocal outer_raise_cancel
outer_raise_cancel = inner_raise_cancel
TRY_TO_CANCEL_OPERATION()
return trio.lowlevel.Abort.FAILED
await wait_task_rescheduled(abort)
if OPERATION_WAS_SUCCESSFULLY_CANCELLED:
# raises the error
outer_raise_cancel()
In any case it's guaranteed that we only call the ``abort_func`` at most
once per call to :func:`wait_task_rescheduled`.
Sometimes, it's useful to be able to share some mutable sleep-related data
between the sleeping task, the abort function, and the waking task. You
can use the sleeping task's :data:`~Task.custom_sleep_data` attribute to
store this data, and Trio won't touch it, except to make sure that it gets
cleared when the task is rescheduled.
.. warning::
If your ``abort_func`` raises an error, or returns any value other than
:data:`Abort.SUCCEEDED` or :data:`Abort.FAILED`, then Trio will crash
violently. Be careful! Similarly, it is entirely possible to deadlock a
Trio program by failing to reschedule a blocked task, or cause havoc by
calling :func:`reschedule` too many times. Remember what we said up
above about how you should use a higher-level API if at all possible?
"""
return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()
# Not exported in the trio._core namespace, but imported directly by _run.
@attr.s(frozen=True)
class PermanentlyDetachCoroutineObject:
final_outcome = attr.ib()
async def permanently_detach_coroutine_object(final_outcome):
"""Permanently detach the current task from the Trio scheduler.
Normally, a Trio task doesn't exit until its coroutine object exits. When
you call this function, Trio acts like the coroutine object just exited
and the task terminates with the given outcome. This is useful if you want
to permanently switch the coroutine object over to a different coroutine
runner.
When the calling coroutine enters this function it's running under Trio,
and when the function returns it's running under the foreign coroutine
runner.
You should make sure that the coroutine object has released any
Trio-specific resources it has acquired (e.g. nurseries).
Args:
final_outcome (outcome.Outcome): Trio acts as if the current task exited
with the given return value or exception.
Returns or raises whatever value or exception the new coroutine runner
uses to resume the coroutine.
"""
if _run.current_task().child_nurseries:
raise RuntimeError(
"can't permanently detach a coroutine object with open nurseries"
)
return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome))
async def temporarily_detach_coroutine_object(abort_func):
"""Temporarily detach the current coroutine object from the Trio
scheduler.
When the calling coroutine enters this function it's running under Trio,
and when the function returns it's running under the foreign coroutine
runner.
The Trio :class:`Task` will continue to exist, but will be suspended until
you use :func:`reattach_detached_coroutine_object` to resume it. In the
mean time, you can use another coroutine runner to schedule the coroutine
object. In fact, you have to the function doesn't return until the
coroutine is advanced from outside.
Note that you'll need to save the current :class:`Task` object to later
resume; you can retrieve it with :func:`current_task`. You can also use
this :class:`Task` object to retrieve the coroutine object see
:data:`Task.coro`.
Args:
abort_func: Same as for :func:`wait_task_rescheduled`, except that it
must return :data:`Abort.FAILED`. (If it returned
:data:`Abort.SUCCEEDED`, then Trio would attempt to reschedule the
detached task directly without going through
:func:`reattach_detached_coroutine_object`, which would be bad.)
Your ``abort_func`` should still arrange for whatever the coroutine
object is doing to be cancelled, and then reattach to Trio and call
the ``raise_cancel`` callback, if possible.
Returns or raises whatever value or exception the new coroutine runner
uses to resume the coroutine.
"""
return await _async_yield(WaitTaskRescheduled(abort_func))
async def reattach_detached_coroutine_object(task, yield_value):
"""Reattach a coroutine object that was detached using
:func:`temporarily_detach_coroutine_object`.
When the calling coroutine enters this function it's running under the
foreign coroutine runner, and when the function returns it's running under
Trio.
This must be called from inside the coroutine being resumed, and yields
whatever value you pass in. (Presumably you'll pass a value that will
cause the current coroutine runner to stop scheduling this task.) Then the
coroutine is resumed by the Trio scheduler at the next opportunity.
Args:
task (Task): The Trio task object that the current coroutine was
detached from.
yield_value (object): The object to yield to the current coroutine
runner.
"""
# This is a kind of crude check in particular, it can fail if the
# passed-in task is where the coroutine *runner* is running. But this is
# an experts-only interface, and there's no easy way to do a more accurate
# check, so I guess that's OK.
if not task.coro.cr_running:
raise RuntimeError("given task does not match calling coroutine")
_run.reschedule(task, outcome.Value("reattaching"))
value = await _async_yield(yield_value)
assert value == outcome.Value("reattaching")

View File

@@ -0,0 +1,149 @@
import attr
from .. import _core
from .._deprecate import deprecated
from .._util import Final
@attr.s(frozen=True)
class _UnboundedQueueStats:
qsize = attr.ib()
tasks_waiting = attr.ib()
class UnboundedQueue(metaclass=Final):
"""An unbounded queue suitable for certain unusual forms of inter-task
communication.
This class is designed for use as a queue in cases where the producer for
some reason cannot be subjected to back-pressure, i.e., :meth:`put_nowait`
has to always succeed. In order to prevent the queue backlog from actually
growing without bound, the consumer API is modified to dequeue items in
"batches". If a consumer task processes each batch without yielding, then
this helps achieve (but does not guarantee) an effective bound on the
queue's memory use, at the cost of potentially increasing system latencies
in general. You should generally prefer to use a memory channel
instead if you can.
Currently each batch completely empties the queue, but `this may change in
the future <https://github.com/python-trio/trio/issues/51>`__.
A :class:`UnboundedQueue` object can be used as an asynchronous iterator,
where each iteration returns a new batch of items. I.e., these two loops
are equivalent::
async for batch in queue:
...
while True:
obj = await queue.get_batch()
...
"""
@deprecated(
"0.9.0",
issue=497,
thing="trio.lowlevel.UnboundedQueue",
instead="trio.open_memory_channel(math.inf)",
)
def __init__(self):
self._lot = _core.ParkingLot()
self._data = []
# used to allow handoff from put to the first task in the lot
self._can_get = False
def __repr__(self):
return "<UnboundedQueue holding {} items>".format(len(self._data))
def qsize(self):
"""Returns the number of items currently in the queue."""
return len(self._data)
def empty(self):
"""Returns True if the queue is empty, False otherwise.
There is some subtlety to interpreting this method's return value: see
`issue #63 <https://github.com/python-trio/trio/issues/63>`__.
"""
return not self._data
@_core.enable_ki_protection
def put_nowait(self, obj):
"""Put an object into the queue, without blocking.
This always succeeds, because the queue is unbounded. We don't provide
a blocking ``put`` method, because it would never need to block.
Args:
obj (object): The object to enqueue.
"""
if not self._data:
assert not self._can_get
if self._lot:
self._lot.unpark(count=1)
else:
self._can_get = True
self._data.append(obj)
def _get_batch_protected(self):
data = self._data.copy()
self._data.clear()
self._can_get = False
return data
def get_batch_nowait(self):
"""Attempt to get the next batch from the queue, without blocking.
Returns:
list: A list of dequeued items, in order. On a successful call this
list is always non-empty; if it would be empty we raise
:exc:`~trio.WouldBlock` instead.
Raises:
~trio.WouldBlock: if the queue is empty.
"""
if not self._can_get:
raise _core.WouldBlock
return self._get_batch_protected()
async def get_batch(self):
"""Get the next batch from the queue, blocking as necessary.
Returns:
list: A list of dequeued items, in order. This list is always
non-empty.
"""
await _core.checkpoint_if_cancelled()
if not self._can_get:
await self._lot.park()
return self._get_batch_protected()
else:
try:
return self._get_batch_protected()
finally:
await _core.cancel_shielded_checkpoint()
def statistics(self):
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``qsize``: The number of items currently in the queue.
* ``tasks_waiting``: The number of tasks blocked on this queue's
:meth:`get_batch` method.
"""
return _UnboundedQueueStats(
qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting
)
def __aiter__(self):
return self
async def __anext__(self):
return await self.get_batch()

View File

@@ -0,0 +1,71 @@
import socket
import signal
import warnings
from .. import _core
from .._util import is_main_thread
class WakeupSocketpair:
def __init__(self):
self.wakeup_sock, self.write_sock = socket.socketpair()
self.wakeup_sock.setblocking(False)
self.write_sock.setblocking(False)
# This somewhat reduces the amount of memory wasted queueing up data
# for wakeups. With these settings, maximum number of 1-byte sends
# before getting BlockingIOError:
# Linux 4.8: 6
# macOS (darwin 15.5): 1
# Windows 10: 525347
# Windows you're weird. (And on Windows setting SNDBUF to 0 makes send
# blocking, even on non-blocking sockets, so don't do that.)
self.wakeup_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
self.write_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1)
# On Windows this is a TCP socket so this might matter. On other
# platforms this fails b/c AF_UNIX sockets aren't actually TCP.
try:
self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except OSError:
pass
self.old_wakeup_fd = None
def wakeup_thread_and_signal_safe(self):
try:
self.write_sock.send(b"\x00")
except BlockingIOError:
pass
async def wait_woken(self):
await _core.wait_readable(self.wakeup_sock)
self.drain()
def drain(self):
try:
while True:
self.wakeup_sock.recv(2**16)
except BlockingIOError:
pass
def wakeup_on_signals(self):
assert self.old_wakeup_fd is None
if not is_main_thread():
return
fd = self.write_sock.fileno()
self.old_wakeup_fd = signal.set_wakeup_fd(fd, warn_on_full_buffer=False)
if self.old_wakeup_fd != -1:
warnings.warn(
RuntimeWarning(
"It looks like Trio's signal handling code might have "
"collided with another library you're using. If you're "
"running Trio in guest mode, then this might mean you "
"should set host_uses_signal_set_wakeup_fd=True. "
"Otherwise, file a bug on Trio and we'll help you figure "
"out what's going on."
)
)
def close(self):
self.wakeup_sock.close()
self.write_sock.close()
if self.old_wakeup_fd is not None:
signal.set_wakeup_fd(self.old_wakeup_fd)

View File

@@ -0,0 +1,323 @@
import cffi
import re
import enum
################################################################
# Functions and types
################################################################
LIB = """
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa383751(v=vs.85).aspx
typedef int BOOL;
typedef unsigned char BYTE;
typedef BYTE BOOLEAN;
typedef void* PVOID;
typedef PVOID HANDLE;
typedef unsigned long DWORD;
typedef unsigned long ULONG;
typedef unsigned int NTSTATUS;
typedef unsigned long u_long;
typedef ULONG *PULONG;
typedef const void *LPCVOID;
typedef void *LPVOID;
typedef const wchar_t *LPCWSTR;
typedef uintptr_t ULONG_PTR;
typedef uintptr_t UINT_PTR;
typedef UINT_PTR SOCKET;
typedef struct _OVERLAPPED {
ULONG_PTR Internal;
ULONG_PTR InternalHigh;
union {
struct {
DWORD Offset;
DWORD OffsetHigh;
} DUMMYSTRUCTNAME;
PVOID Pointer;
} DUMMYUNIONNAME;
HANDLE hEvent;
} OVERLAPPED, *LPOVERLAPPED;
typedef OVERLAPPED WSAOVERLAPPED;
typedef LPOVERLAPPED LPWSAOVERLAPPED;
typedef PVOID LPSECURITY_ATTRIBUTES;
typedef PVOID LPCSTR;
typedef struct _OVERLAPPED_ENTRY {
ULONG_PTR lpCompletionKey;
LPOVERLAPPED lpOverlapped;
ULONG_PTR Internal;
DWORD dwNumberOfBytesTransferred;
} OVERLAPPED_ENTRY, *LPOVERLAPPED_ENTRY;
// kernel32.dll
HANDLE WINAPI CreateIoCompletionPort(
_In_ HANDLE FileHandle,
_In_opt_ HANDLE ExistingCompletionPort,
_In_ ULONG_PTR CompletionKey,
_In_ DWORD NumberOfConcurrentThreads
);
BOOL SetFileCompletionNotificationModes(
HANDLE FileHandle,
UCHAR Flags
);
HANDLE CreateFileW(
LPCWSTR lpFileName,
DWORD dwDesiredAccess,
DWORD dwShareMode,
LPSECURITY_ATTRIBUTES lpSecurityAttributes,
DWORD dwCreationDisposition,
DWORD dwFlagsAndAttributes,
HANDLE hTemplateFile
);
BOOL WINAPI CloseHandle(
_In_ HANDLE hObject
);
BOOL WINAPI PostQueuedCompletionStatus(
_In_ HANDLE CompletionPort,
_In_ DWORD dwNumberOfBytesTransferred,
_In_ ULONG_PTR dwCompletionKey,
_In_opt_ LPOVERLAPPED lpOverlapped
);
BOOL WINAPI GetQueuedCompletionStatusEx(
_In_ HANDLE CompletionPort,
_Out_ LPOVERLAPPED_ENTRY lpCompletionPortEntries,
_In_ ULONG ulCount,
_Out_ PULONG ulNumEntriesRemoved,
_In_ DWORD dwMilliseconds,
_In_ BOOL fAlertable
);
BOOL WINAPI CancelIoEx(
_In_ HANDLE hFile,
_In_opt_ LPOVERLAPPED lpOverlapped
);
BOOL WriteFile(
HANDLE hFile,
LPCVOID lpBuffer,
DWORD nNumberOfBytesToWrite,
LPDWORD lpNumberOfBytesWritten,
LPOVERLAPPED lpOverlapped
);
BOOL ReadFile(
HANDLE hFile,
LPVOID lpBuffer,
DWORD nNumberOfBytesToRead,
LPDWORD lpNumberOfBytesRead,
LPOVERLAPPED lpOverlapped
);
BOOL WINAPI SetConsoleCtrlHandler(
_In_opt_ void* HandlerRoutine,
_In_ BOOL Add
);
HANDLE CreateEventA(
LPSECURITY_ATTRIBUTES lpEventAttributes,
BOOL bManualReset,
BOOL bInitialState,
LPCSTR lpName
);
BOOL SetEvent(
HANDLE hEvent
);
BOOL ResetEvent(
HANDLE hEvent
);
DWORD WaitForSingleObject(
HANDLE hHandle,
DWORD dwMilliseconds
);
DWORD WaitForMultipleObjects(
DWORD nCount,
HANDLE *lpHandles,
BOOL bWaitAll,
DWORD dwMilliseconds
);
ULONG RtlNtStatusToDosError(
NTSTATUS Status
);
int WSAIoctl(
SOCKET s,
DWORD dwIoControlCode,
LPVOID lpvInBuffer,
DWORD cbInBuffer,
LPVOID lpvOutBuffer,
DWORD cbOutBuffer,
LPDWORD lpcbBytesReturned,
LPWSAOVERLAPPED lpOverlapped,
// actually LPWSAOVERLAPPED_COMPLETION_ROUTINE
void* lpCompletionRoutine
);
int WSAGetLastError();
BOOL DeviceIoControl(
HANDLE hDevice,
DWORD dwIoControlCode,
LPVOID lpInBuffer,
DWORD nInBufferSize,
LPVOID lpOutBuffer,
DWORD nOutBufferSize,
LPDWORD lpBytesReturned,
LPOVERLAPPED lpOverlapped
);
// From https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
typedef struct _AFD_POLL_HANDLE_INFO {
HANDLE Handle;
ULONG Events;
NTSTATUS Status;
} AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO;
// This is really defined as a messy union to allow stuff like
// i.DUMMYSTRUCTNAME.LowPart, but we don't need those complications.
// Under all that it's just an int64.
typedef int64_t LARGE_INTEGER;
typedef struct _AFD_POLL_INFO {
LARGE_INTEGER Timeout;
ULONG NumberOfHandles;
ULONG Exclusive;
AFD_POLL_HANDLE_INFO Handles[1];
} AFD_POLL_INFO, *PAFD_POLL_INFO;
"""
# cribbed from pywincffi
# programmatically strips out those annotations MSDN likes, like _In_
REGEX_SAL_ANNOTATION = re.compile(
r"\b(_In_|_Inout_|_Out_|_Outptr_|_Reserved_)(opt_)?\b"
)
LIB = REGEX_SAL_ANNOTATION.sub(" ", LIB)
# Other fixups:
# - get rid of FAR, cffi doesn't like it
LIB = re.sub(r"\bFAR\b", " ", LIB)
# - PASCAL is apparently an alias for __stdcall (on modern compilers - modern
# being _MSC_VER >= 800)
LIB = re.sub(r"\bPASCAL\b", "__stdcall", LIB)
ffi = cffi.FFI()
ffi.cdef(LIB)
kernel32 = ffi.dlopen("kernel32.dll")
ntdll = ffi.dlopen("ntdll.dll")
ws2_32 = ffi.dlopen("ws2_32.dll")
################################################################
# Magic numbers
################################################################
# Here's a great resource for looking these up:
# https://www.magnumdb.com
# (Tip: check the box to see "Hex value")
INVALID_HANDLE_VALUE = ffi.cast("HANDLE", -1)
class ErrorCodes(enum.IntEnum):
STATUS_TIMEOUT = 0x102
WAIT_TIMEOUT = 0x102
WAIT_ABANDONED = 0x80
WAIT_OBJECT_0 = 0x00 # object is signaled
WAIT_FAILED = 0xFFFFFFFF
ERROR_IO_PENDING = 997
ERROR_OPERATION_ABORTED = 995
ERROR_ABANDONED_WAIT_0 = 735
ERROR_INVALID_HANDLE = 6
ERROR_INVALID_PARMETER = 87
ERROR_NOT_FOUND = 1168
ERROR_NOT_SOCKET = 10038
class FileFlags(enum.IntEnum):
GENERIC_READ = 0x80000000
SYNCHRONIZE = 0x00100000
FILE_FLAG_OVERLAPPED = 0x40000000
FILE_SHARE_READ = 1
FILE_SHARE_WRITE = 2
FILE_SHARE_DELETE = 4
CREATE_NEW = 1
CREATE_ALWAYS = 2
OPEN_EXISTING = 3
OPEN_ALWAYS = 4
TRUNCATE_EXISTING = 5
class AFDPollFlags(enum.IntFlag):
# These are drawn from a combination of:
# https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
# https://github.com/reactos/reactos/blob/master/sdk/include/reactos/drivers/afd/shared.h
AFD_POLL_RECEIVE = 0x0001
AFD_POLL_RECEIVE_EXPEDITED = 0x0002 # OOB/urgent data
AFD_POLL_SEND = 0x0004
AFD_POLL_DISCONNECT = 0x0008 # received EOF (FIN)
AFD_POLL_ABORT = 0x0010 # received RST
AFD_POLL_LOCAL_CLOSE = 0x0020 # local socket object closed
AFD_POLL_CONNECT = 0x0040 # socket is successfully connected
AFD_POLL_ACCEPT = 0x0080 # you can call accept on this socket
AFD_POLL_CONNECT_FAIL = 0x0100 # connect() terminated unsuccessfully
# See WSAEventSelect docs for more details on these four:
AFD_POLL_QOS = 0x0200
AFD_POLL_GROUP_QOS = 0x0400
AFD_POLL_ROUTING_INTERFACE_CHANGE = 0x0800
AFD_POLL_EVENT_ADDRESS_LIST_CHANGE = 0x1000
class WSAIoctls(enum.IntEnum):
SIO_BASE_HANDLE = 0x48000022
SIO_BSP_HANDLE_SELECT = 0x4800001C
SIO_BSP_HANDLE_POLL = 0x4800001D
class CompletionModes(enum.IntFlag):
FILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 0x1
FILE_SKIP_SET_EVENT_ON_HANDLE = 0x2
class IoControlCodes(enum.IntEnum):
IOCTL_AFD_POLL = 0x00012024
################################################################
# Generic helpers
################################################################
def _handle(obj):
# For now, represent handles as either cffi HANDLEs or as ints. If you
# try to pass in a file descriptor instead, it's not going to work
# out. (For that msvcrt.get_osfhandle does the trick, but I don't know if
# we'll actually need that for anything...) For sockets this doesn't
# matter, Python never allocates an fd. So let's wait until we actually
# encounter the problem before worrying about it.
if type(obj) is int:
return ffi.cast("HANDLE", obj)
else:
return obj
def raise_winerror(winerror=None, *, filename=None, filename2=None):
if winerror is None:
winerror, msg = ffi.getwinerror()
else:
_, msg = ffi.getwinerror(winerror)
# https://docs.python.org/3/library/exceptions.html#OSError
raise OSError(0, msg, filename, winerror, filename2)

View File

@@ -0,0 +1,25 @@
import pytest
import inspect
# XX this should move into a global something
from ...testing import MockClock, trio_test
@pytest.fixture
def mock_clock():
return MockClock()
@pytest.fixture
def autojump_clock():
return MockClock(autojump_threshold=0)
# FIXME: split off into a package (or just make part of Trio's public
# interface?), with config file to enable? and I guess a mark option too; I
# guess it's useful with the class- and file-level marking machinery (where
# the raw @trio_test decorator isn't enough).
@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem):
if inspect.iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj = trio_test(pyfuncitem.obj)

View File

@@ -0,0 +1,320 @@
import sys
import weakref
import pytest
from math import inf
from functools import partial
from async_generator import aclosing
from ... import _core
from .tutil import gc_collect_harder, buggy_pypy_asyncgens, restore_unraisablehook
def test_asyncgen_basics():
collected = []
async def example(cause):
try:
try:
yield 42
except GeneratorExit:
pass
await _core.checkpoint()
except _core.Cancelled:
assert "exhausted" not in cause
task_name = _core.current_task().name
assert cause in task_name or task_name == "<init>"
assert _core.current_effective_deadline() == -inf
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
collected.append(cause)
else:
assert "async_main" in _core.current_task().name
assert "exhausted" in cause
assert _core.current_effective_deadline() == inf
await _core.checkpoint()
collected.append(cause)
saved = []
async def async_main():
# GC'ed before exhausted
with pytest.warns(
ResourceWarning, match="Async generator.*collected before.*exhausted"
):
assert 42 == await example("abandoned").asend(None)
gc_collect_harder()
await _core.wait_all_tasks_blocked()
assert collected.pop() == "abandoned"
# aclosing() ensures it's cleaned up at point of use
async with aclosing(example("exhausted 1")) as aiter:
assert 42 == await aiter.asend(None)
assert collected.pop() == "exhausted 1"
# Also fine if you exhaust it at point of use
async for val in example("exhausted 2"):
assert val == 42
assert collected.pop() == "exhausted 2"
gc_collect_harder()
# No problems saving the geniter when using either of these patterns
async with aclosing(example("exhausted 3")) as aiter:
saved.append(aiter)
assert 42 == await aiter.asend(None)
assert collected.pop() == "exhausted 3"
# Also fine if you exhaust it at point of use
saved.append(example("exhausted 4"))
async for val in saved[-1]:
assert val == 42
assert collected.pop() == "exhausted 4"
# Leave one referenced-but-unexhausted and make sure it gets cleaned up
if buggy_pypy_asyncgens:
collected.append("outlived run")
else:
saved.append(example("outlived run"))
assert 42 == await saved[-1].asend(None)
assert collected == []
_core.run(async_main)
assert collected.pop() == "outlived run"
for agen in saved:
assert agen.ag_frame is None # all should now be exhausted
async def test_asyncgen_throws_during_finalization(caplog):
record = []
async def agen():
try:
yield 1
finally:
await _core.cancel_shielded_checkpoint()
record.append("crashing")
raise ValueError("oops")
with restore_unraisablehook():
await agen().asend(None)
gc_collect_harder()
await _core.wait_all_tasks_blocked()
assert record == ["crashing"]
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
assert exc_type is ValueError
assert str(exc_value) == "oops"
assert "during finalization of async generator" in caplog.records[0].message
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
def test_firstiter_after_closing():
saved = []
record = []
async def funky_agen():
try:
yield 1
except GeneratorExit:
record.append("cleanup 1")
raise
try:
yield 2
finally:
record.append("cleanup 2")
await funky_agen().asend(None)
async def async_main():
aiter = funky_agen()
saved.append(aiter)
assert 1 == await aiter.asend(None)
assert 2 == await aiter.asend(None)
_core.run(async_main)
assert record == ["cleanup 2", "cleanup 1"]
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
def test_interdependent_asyncgen_cleanup_order():
saved = []
record = []
async def innermost():
try:
yield 1
finally:
await _core.cancel_shielded_checkpoint()
record.append("innermost")
async def agen(label, inner):
try:
yield await inner.asend(None)
finally:
# Either `inner` has already been cleaned up, or
# we're about to exhaust it. Either way, we wind
# up with `record` containing the labels in
# innermost-to-outermost order.
with pytest.raises(StopAsyncIteration):
await inner.asend(None)
record.append(label)
async def async_main():
# This makes a chain of 101 interdependent asyncgens:
# agen(99)'s cleanup will iterate agen(98)'s will iterate
# ... agen(0)'s will iterate innermost()'s
ag_chain = innermost()
for idx in range(100):
ag_chain = agen(idx, ag_chain)
saved.append(ag_chain)
assert 1 == await ag_chain.asend(None)
assert record == []
_core.run(async_main)
assert record == ["innermost"] + list(range(100))
@restore_unraisablehook()
def test_last_minute_gc_edge_case():
saved = []
record = []
needs_retry = True
async def agen():
try:
yield 1
finally:
record.append("cleaned up")
def collect_at_opportune_moment(token):
runner = _core._run.GLOBAL_RUN_CONTEXT.runner
if runner.system_nursery._closed and isinstance(
runner.asyncgens.alive, weakref.WeakSet
):
saved.clear()
record.append("final collection")
gc_collect_harder()
record.append("done")
else:
try:
token.run_sync_soon(collect_at_opportune_moment, token)
except _core.RunFinishedError: # pragma: no cover
nonlocal needs_retry
needs_retry = True
async def async_main():
token = _core.current_trio_token()
token.run_sync_soon(collect_at_opportune_moment, token)
saved.append(agen())
await saved[-1].asend(None)
# Actually running into the edge case requires that the run_sync_soon task
# execute in between the system nursery's closure and the strong-ification
# of runner.asyncgens. There's about a 25% chance that it doesn't
# (if the run_sync_soon task runs before init on one tick and after init
# on the next tick); if we try enough times, we can make the chance of
# failure as small as we want.
for attempt in range(50):
needs_retry = False
del record[:]
del saved[:]
_core.run(async_main)
if needs_retry: # pragma: no cover
if not buggy_pypy_asyncgens:
assert record == ["cleaned up"]
else:
assert record == ["final collection", "done", "cleaned up"]
break
else: # pragma: no cover
pytest.fail(
f"Didn't manage to hit the trailing_finalizer_asyncgens case "
f"despite trying {attempt} times"
)
async def step_outside_async_context(aiter):
# abort_fns run outside of task context, at least if they're
# triggered by a deadline expiry rather than a direct
# cancellation. Thus, an asyncgen first iterated inside one
# will appear non-Trio, and since no other hooks were installed,
# will use the last-ditch fallback handling (that tries to mimic
# CPython's behavior with no hooks).
#
# NB: the strangeness with aiter being an attribute of abort_fn is
# to make it as easy as possible to ensure we don't hang onto a
# reference to aiter inside the guts of the run loop.
def abort_fn(_):
with pytest.raises(StopIteration, match="42"):
abort_fn.aiter.asend(None).send(None)
del abort_fn.aiter
return _core.Abort.SUCCEEDED
abort_fn.aiter = aiter
async with _core.open_nursery() as nursery:
nursery.start_soon(_core.wait_task_rescheduled, abort_fn)
await _core.wait_all_tasks_blocked()
nursery.cancel_scope.deadline = _core.current_time()
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
async def test_fallback_when_no_hook_claims_it(capsys):
async def well_behaved():
yield 42
async def yields_after_yield():
with pytest.raises(GeneratorExit):
yield 42
yield 100
async def awaits_after_yield():
with pytest.raises(GeneratorExit):
yield 42
await _core.cancel_shielded_checkpoint()
with restore_unraisablehook():
await step_outside_async_context(well_behaved())
gc_collect_harder()
assert capsys.readouterr().err == ""
await step_outside_async_context(yields_after_yield())
gc_collect_harder()
assert "ignored GeneratorExit" in capsys.readouterr().err
await step_outside_async_context(awaits_after_yield())
gc_collect_harder()
assert "awaited something during finalization" in capsys.readouterr().err
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
def test_delegation_to_existing_hooks():
record = []
def my_firstiter(agen):
record.append("firstiter " + agen.ag_frame.f_locals["arg"])
def my_finalizer(agen):
record.append("finalizer " + agen.ag_frame.f_locals["arg"])
async def example(arg):
try:
yield 42
finally:
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
record.append("trio collected " + arg)
async def async_main():
await step_outside_async_context(example("theirs"))
assert 42 == await example("ours").asend(None)
gc_collect_harder()
assert record == ["firstiter theirs", "finalizer theirs"]
record[:] = []
await _core.wait_all_tasks_blocked()
assert record == ["trio collected ours"]
with restore_unraisablehook():
old_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(my_firstiter, my_finalizer)
try:
_core.run(async_main)
finally:
assert sys.get_asyncgen_hooks() == (my_firstiter, my_finalizer)
sys.set_asyncgen_hooks(*old_hooks)

View File

@@ -0,0 +1,546 @@
import pytest
import asyncio
import contextvars
import sys
import traceback
import queue
from functools import partial
from math import inf
import signal
import socket
import threading
import time
import warnings
import trio
import trio.testing
from .tutil import gc_collect_harder, buggy_pypy_asyncgens, restore_unraisablehook
from ..._util import signal_raise
# The simplest possible "host" loop.
# Nice features:
# - we can run code "outside" of trio using the schedule function passed to
# our main
# - final result is returned
# - any unhandled exceptions cause an immediate crash
def trivial_guest_run(trio_fn, **start_guest_run_kwargs):
todo = queue.Queue()
host_thread = threading.current_thread()
def run_sync_soon_threadsafe(fn):
if host_thread is threading.current_thread(): # pragma: no cover
crash = partial(
pytest.fail, "run_sync_soon_threadsafe called from host thread"
)
todo.put(("run", crash))
todo.put(("run", fn))
def run_sync_soon_not_threadsafe(fn):
if host_thread is not threading.current_thread(): # pragma: no cover
crash = partial(
pytest.fail, "run_sync_soon_not_threadsafe called from worker thread"
)
todo.put(("run", crash))
todo.put(("run", fn))
def done_callback(outcome):
todo.put(("unwrap", outcome))
trio.lowlevel.start_guest_run(
trio_fn,
run_sync_soon_not_threadsafe,
run_sync_soon_threadsafe=run_sync_soon_threadsafe,
run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe,
done_callback=done_callback,
**start_guest_run_kwargs,
)
try:
while True:
op, obj = todo.get()
if op == "run":
obj()
elif op == "unwrap":
return obj.unwrap()
else: # pragma: no cover
assert False
finally:
# Make sure that exceptions raised here don't capture these, so that
# if an exception does cause us to abandon a run then the Trio state
# has a chance to be GC'ed and warn about it.
del todo, run_sync_soon_threadsafe, done_callback
def test_guest_trivial():
async def trio_return(in_host):
await trio.sleep(0)
return "ok"
assert trivial_guest_run(trio_return) == "ok"
async def trio_fail(in_host):
raise KeyError("whoopsiedaisy")
with pytest.raises(KeyError, match="whoopsiedaisy"):
trivial_guest_run(trio_fail)
def test_guest_can_do_io():
async def trio_main(in_host):
record = []
a, b = trio.socket.socketpair()
with a, b:
async with trio.open_nursery() as nursery:
async def do_receive():
record.append(await a.recv(1))
nursery.start_soon(do_receive)
await trio.testing.wait_all_tasks_blocked()
await b.send(b"x")
assert record == [b"x"]
trivial_guest_run(trio_main)
def test_host_can_directly_wake_trio_task():
async def trio_main(in_host):
ev = trio.Event()
in_host(ev.set)
await ev.wait()
return "ok"
assert trivial_guest_run(trio_main) == "ok"
def test_host_altering_deadlines_wakes_trio_up():
def set_deadline(cscope, new_deadline):
cscope.deadline = new_deadline
async def trio_main(in_host):
with trio.CancelScope() as cscope:
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep_forever()
assert cscope.cancelled_caught
with trio.CancelScope() as cscope:
# also do a change that doesn't affect the next deadline, just to
# exercise that path
in_host(lambda: set_deadline(cscope, 1e6))
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep(999)
assert cscope.cancelled_caught
return "ok"
assert trivial_guest_run(trio_main) == "ok"
def test_warn_set_wakeup_fd_overwrite():
assert signal.set_wakeup_fd(-1) == -1
async def trio_main(in_host):
return "ok"
a, b = socket.socketpair()
with a, b:
a.setblocking(False)
# Warn if there's already a wakeup fd
signal.set_wakeup_fd(a.fileno())
try:
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
assert trivial_guest_run(trio_main) == "ok"
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
signal.set_wakeup_fd(a.fileno())
try:
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
assert (
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=False)
== "ok"
)
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
# Don't warn if there isn't already a wakeup fd
with warnings.catch_warnings():
warnings.simplefilter("error")
assert trivial_guest_run(trio_main) == "ok"
with warnings.catch_warnings():
warnings.simplefilter("error")
assert (
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=True)
== "ok"
)
# If there's already a wakeup fd, but we've been told to trust it,
# then it's left alone and there's no warning
signal.set_wakeup_fd(a.fileno())
try:
async def trio_check_wakeup_fd_unaltered(in_host):
fd = signal.set_wakeup_fd(-1)
assert fd == a.fileno()
signal.set_wakeup_fd(fd)
return "ok"
with warnings.catch_warnings():
warnings.simplefilter("error")
assert (
trivial_guest_run(
trio_check_wakeup_fd_unaltered,
host_uses_signal_set_wakeup_fd=True,
)
== "ok"
)
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked():
# This is designed to hit the branch in unrolled_run where:
# idle_primed=True
# runner.runq is empty
# events is Truth-y
# ...and confirm that in this case, wait_all_tasks_blocked does not get
# triggered.
def set_deadline(cscope, new_deadline):
print(f"setting deadline {new_deadline}")
cscope.deadline = new_deadline
async def trio_main(in_host):
async def sit_in_wait_all_tasks_blocked(watb_cscope):
with watb_cscope:
# Overall point of this test is that this
# wait_all_tasks_blocked should *not* return normally, but
# only by cancellation.
await trio.testing.wait_all_tasks_blocked(cushion=9999)
assert False # pragma: no cover
assert watb_cscope.cancelled_caught
async def get_woken_by_host_deadline(watb_cscope):
with trio.CancelScope() as cscope:
print("scheduling stuff to happen")
# Altering the deadline from the host, to something in the
# future, will cause the run loop to wake up, but then
# discover that there is nothing to do and go back to sleep.
# This should *not* trigger wait_all_tasks_blocked.
#
# So the 'before_io_wait' here will wait until we're blocking
# with the wait_all_tasks_blocked primed, and then schedule a
# deadline change. The critical test is that this should *not*
# wake up 'sit_in_wait_all_tasks_blocked'.
#
# The after we've had a chance to wake up
# 'sit_in_wait_all_tasks_blocked', we want the test to
# actually end. So in after_io_wait we schedule a second host
# call to tear things down.
class InstrumentHelper:
def __init__(self):
self.primed = False
def before_io_wait(self, timeout):
print(f"before_io_wait({timeout})")
if timeout == 9999: # pragma: no branch
assert not self.primed
in_host(lambda: set_deadline(cscope, 1e9))
self.primed = True
def after_io_wait(self, timeout):
if self.primed: # pragma: no branch
print("instrument triggered")
in_host(lambda: cscope.cancel())
trio.lowlevel.remove_instrument(self)
trio.lowlevel.add_instrument(InstrumentHelper())
await trio.sleep_forever()
assert cscope.cancelled_caught
watb_cscope.cancel()
async with trio.open_nursery() as nursery:
watb_cscope = trio.CancelScope()
nursery.start_soon(sit_in_wait_all_tasks_blocked, watb_cscope)
await trio.testing.wait_all_tasks_blocked()
nursery.start_soon(get_woken_by_host_deadline, watb_cscope)
return "ok"
assert trivial_guest_run(trio_main) == "ok"
@restore_unraisablehook()
def test_guest_warns_if_abandoned():
# This warning is emitted from the garbage collector. So we have to make
# sure that our abandoned run is garbage. The easiest way to do this is to
# put it into a function, so that we're sure all the local state,
# traceback frames, etc. are garbage once it returns.
def do_abandoned_guest_run():
async def abandoned_main(in_host):
in_host(lambda: 1 / 0)
while True:
await trio.sleep(0)
with pytest.raises(ZeroDivisionError):
trivial_guest_run(abandoned_main)
with pytest.warns(RuntimeWarning, match="Trio guest run got abandoned"):
do_abandoned_guest_run()
gc_collect_harder()
# If you have problems some day figuring out what's holding onto a
# reference to the unrolled_run generator and making this test fail,
# then this might be useful to help track it down. (It assumes you
# also hack start_guest_run so that it does 'global W; W =
# weakref(unrolled_run_gen)'.)
#
# import gc
# print(trio._core._run.W)
# targets = [trio._core._run.W()]
# for i in range(15):
# new_targets = []
# for target in targets:
# new_targets += gc.get_referrers(target)
# new_targets.remove(targets)
# print("#####################")
# print(f"depth {i}: {len(new_targets)}")
# print(new_targets)
# targets = new_targets
with pytest.raises(RuntimeError):
trio.current_time()
def aiotrio_run(trio_fn, *, pass_not_threadsafe=True, **start_guest_run_kwargs):
loop = asyncio.new_event_loop()
async def aio_main():
trio_done_fut = loop.create_future()
def trio_done_callback(main_outcome):
print(f"trio_fn finished: {main_outcome!r}")
trio_done_fut.set_result(main_outcome)
if pass_not_threadsafe:
start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon
trio.lowlevel.start_guest_run(
trio_fn,
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
done_callback=trio_done_callback,
**start_guest_run_kwargs,
)
return (await trio_done_fut).unwrap()
try:
return loop.run_until_complete(aio_main())
finally:
loop.close()
def test_guest_mode_on_asyncio():
async def trio_main():
print("trio_main!")
to_trio, from_aio = trio.open_memory_channel(float("inf"))
from_trio = asyncio.Queue()
aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio))
# Make sure we have at least one tick where we don't need to go into
# the thread
await trio.sleep(0)
from_trio.put_nowait(0)
async for n in from_aio:
print(f"trio got: {n}")
from_trio.put_nowait(n + 1)
if n >= 10:
aio_task.cancel()
return "trio-main-done"
async def aio_pingpong(from_trio, to_trio):
print("aio_pingpong!")
try:
while True:
n = await from_trio.get()
print(f"aio got: {n}")
to_trio.send_nowait(n + 1)
except asyncio.CancelledError:
raise
except: # pragma: no cover
traceback.print_exc()
raise
assert (
aiotrio_run(
trio_main,
# Not all versions of asyncio we test on can actually be trusted,
# but this test doesn't care about signal handling, and it's
# easier to just avoid the warnings.
host_uses_signal_set_wakeup_fd=True,
)
== "trio-main-done"
)
assert (
aiotrio_run(
trio_main,
# Also check that passing only call_soon_threadsafe works, via the
# fallback path where we use it for everything.
pass_not_threadsafe=False,
host_uses_signal_set_wakeup_fd=True,
)
== "trio-main-done"
)
def test_guest_mode_internal_errors(monkeypatch, recwarn):
with monkeypatch.context() as m:
async def crash_in_run_loop(in_host):
m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI")
await trio.sleep(1)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_run_loop)
with monkeypatch.context() as m:
async def crash_in_io(in_host):
m.setattr("trio._core._run.TheIOManager.get_events", None)
await trio.sleep(0)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_io)
with monkeypatch.context() as m:
async def crash_in_worker_thread_io(in_host):
t = threading.current_thread()
old_get_events = trio._core._run.TheIOManager.get_events
def bad_get_events(*args):
if threading.current_thread() is not t:
raise ValueError("oh no!")
else:
return old_get_events(*args)
m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events)
await trio.sleep(1)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_worker_thread_io)
gc_collect_harder()
def test_guest_mode_ki():
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Check SIGINT in Trio func and in host func
async def trio_main(in_host):
with pytest.raises(KeyboardInterrupt):
signal_raise(signal.SIGINT)
# Host SIGINT should get injected into Trio
in_host(partial(signal_raise, signal.SIGINT))
await trio.sleep(10)
with pytest.raises(KeyboardInterrupt) as excinfo:
trivial_guest_run(trio_main)
assert excinfo.value.__context__ is None
# Signal handler should be restored properly on exit
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Also check chaining in the case where KI is injected after main exits
final_exc = KeyError("whoa")
async def trio_main_raising(in_host):
in_host(partial(signal_raise, signal.SIGINT))
raise final_exc
with pytest.raises(KeyboardInterrupt) as excinfo:
trivial_guest_run(trio_main_raising)
assert excinfo.value.__context__ is final_exc
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
def test_guest_mode_autojump_clock_threshold_changing():
# This is super obscure and probably no-one will ever notice, but
# technically mutating the MockClock.autojump_threshold from the host
# should wake up the guest, so let's test it.
clock = trio.testing.MockClock()
DURATION = 120
async def trio_main(in_host):
assert trio.current_time() == 0
in_host(lambda: setattr(clock, "autojump_threshold", 0))
await trio.sleep(DURATION)
assert trio.current_time() == DURATION
start = time.monotonic()
trivial_guest_run(trio_main, clock=clock)
end = time.monotonic()
# Should be basically instantaneous, but we'll leave a generous buffer to
# account for any CI weirdness
assert end - start < DURATION / 2
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy")
@pytest.mark.xfail(
sys.implementation.name == "pypy",
reason="async generator issue under investigation",
)
@restore_unraisablehook()
def test_guest_mode_asyncgens():
import sniffio
record = set()
async def agen(label):
assert sniffio.current_async_library() == label
try:
yield 1
finally:
library = sniffio.current_async_library()
try:
await sys.modules[library].sleep(0)
except trio.Cancelled:
pass
record.add((label, library))
async def iterate_in_aio():
# "trio" gets inherited from our Trio caller if we don't set this
sniffio.current_async_library_cvar.set("asyncio")
await agen("asyncio").asend(None)
async def trio_main():
task = asyncio.ensure_future(iterate_in_aio())
done_evt = trio.Event()
task.add_done_callback(lambda _: done_evt.set())
with trio.fail_after(1):
await done_evt.wait()
await agen("trio").asend(None)
gc_collect_harder()
# Ensure we don't pollute the thread-level context if run under
# an asyncio without contextvars support (3.6)
context = contextvars.copy_context()
context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)
assert record == {("asyncio", "asyncio"), ("trio", "trio")}

View File

@@ -0,0 +1,253 @@
import attr
import pytest
from ... import _core, _abc
from .tutil import check_sequence_matches
@attr.s(eq=False, hash=False)
class TaskRecorder:
record = attr.ib(factory=list)
def before_run(self):
self.record.append(("before_run",))
def task_scheduled(self, task):
self.record.append(("schedule", task))
def before_task_step(self, task):
assert task is _core.current_task()
self.record.append(("before", task))
def after_task_step(self, task):
assert task is _core.current_task()
self.record.append(("after", task))
def after_run(self):
self.record.append(("after_run",))
def filter_tasks(self, tasks):
for item in self.record:
if item[0] in ("schedule", "before", "after") and item[1] in tasks:
yield item
if item[0] in ("before_run", "after_run"):
yield item
def test_instruments(recwarn):
r1 = TaskRecorder()
r2 = TaskRecorder()
r3 = TaskRecorder()
task = None
# We use a child task for this, because the main task does some extra
# bookkeeping stuff that can leak into the instrument results, and we
# don't want to deal with it.
async def task_fn():
nonlocal task
task = _core.current_task()
for _ in range(4):
await _core.checkpoint()
# replace r2 with r3, to test that we can manipulate them as we go
_core.remove_instrument(r2)
with pytest.raises(KeyError):
_core.remove_instrument(r2)
# add is idempotent
_core.add_instrument(r3)
_core.add_instrument(r3)
for _ in range(1):
await _core.checkpoint()
async def main():
async with _core.open_nursery() as nursery:
nursery.start_soon(task_fn)
_core.run(main, instruments=[r1, r2])
# It sleeps 5 times, so it runs 6 times. Note that checkpoint()
# reschedules the task immediately upon yielding, before the
# after_task_step event fires.
expected = (
[("before_run",), ("schedule", task)]
+ [("before", task), ("schedule", task), ("after", task)] * 5
+ [("before", task), ("after", task), ("after_run",)]
)
assert r1.record == r2.record + r3.record
assert list(r1.filter_tasks([task])) == expected
def test_instruments_interleave():
tasks = {}
async def two_step1():
tasks["t1"] = _core.current_task()
await _core.checkpoint()
async def two_step2():
tasks["t2"] = _core.current_task()
await _core.checkpoint()
async def main():
async with _core.open_nursery() as nursery:
nursery.start_soon(two_step1)
nursery.start_soon(two_step2)
r = TaskRecorder()
_core.run(main, instruments=[r])
expected = [
("before_run",),
("schedule", tasks["t1"]),
("schedule", tasks["t2"]),
{
("before", tasks["t1"]),
("schedule", tasks["t1"]),
("after", tasks["t1"]),
("before", tasks["t2"]),
("schedule", tasks["t2"]),
("after", tasks["t2"]),
},
{
("before", tasks["t1"]),
("after", tasks["t1"]),
("before", tasks["t2"]),
("after", tasks["t2"]),
},
("after_run",),
]
print(list(r.filter_tasks(tasks.values())))
check_sequence_matches(list(r.filter_tasks(tasks.values())), expected)
def test_null_instrument():
# undefined instrument methods are skipped
class NullInstrument:
def something_unrelated(self):
pass # pragma: no cover
async def main():
await _core.checkpoint()
_core.run(main, instruments=[NullInstrument()])
def test_instrument_before_after_run():
record = []
class BeforeAfterRun:
def before_run(self):
record.append("before_run")
def after_run(self):
record.append("after_run")
async def main():
pass
_core.run(main, instruments=[BeforeAfterRun()])
assert record == ["before_run", "after_run"]
def test_instrument_task_spawn_exit():
record = []
class SpawnExitRecorder:
def task_spawned(self, task):
record.append(("spawned", task))
def task_exited(self, task):
record.append(("exited", task))
async def main():
return _core.current_task()
main_task = _core.run(main, instruments=[SpawnExitRecorder()])
assert ("spawned", main_task) in record
assert ("exited", main_task) in record
# This test also tests having a crash before the initial task is even spawned,
# which is very difficult to handle.
def test_instruments_crash(caplog):
record = []
class BrokenInstrument:
def task_scheduled(self, task):
record.append("scheduled")
raise ValueError("oops")
def close(self):
# Shouldn't be called -- tests that the instrument disabling logic
# works right.
record.append("closed") # pragma: no cover
async def main():
record.append("main ran")
return _core.current_task()
r = TaskRecorder()
main_task = _core.run(main, instruments=[r, BrokenInstrument()])
assert record == ["scheduled", "main ran"]
# the TaskRecorder kept going throughout, even though the BrokenInstrument
# was disabled
assert ("after", main_task) in r.record
assert ("after_run",) in r.record
# And we got a log message
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
assert exc_type is ValueError
assert str(exc_value) == "oops"
assert "Instrument has been disabled" in caplog.records[0].message
def test_instruments_monkeypatch():
class NullInstrument(_abc.Instrument):
pass
instrument = NullInstrument()
async def main():
record = []
# Changing the set of hooks implemented by an instrument after
# it's installed doesn't make them start being called right away
instrument.before_task_step = record.append
await _core.checkpoint()
await _core.checkpoint()
assert len(record) == 0
# But if we remove and re-add the instrument, the new hooks are
# picked up
_core.remove_instrument(instrument)
_core.add_instrument(instrument)
await _core.checkpoint()
await _core.checkpoint()
assert record.count(_core.current_task()) == 2
_core.remove_instrument(instrument)
await _core.checkpoint()
await _core.checkpoint()
assert record.count(_core.current_task()) == 2
_core.run(main, instruments=[instrument])
def test_instrument_that_raises_on_getattr():
class EvilInstrument:
def task_exited(self, task):
assert False # pragma: no cover
@property
def after_run(self):
raise ValueError("oops")
async def main():
with pytest.raises(ValueError):
_core.add_instrument(EvilInstrument())
# Make sure the instrument is fully removed from the per-method lists
runner = _core.current_task()._runner
assert "after_run" not in runner.instruments
assert "task_exited" not in runner.instruments
_core.run(main)

View File

@@ -0,0 +1,447 @@
import pytest
import socket as stdlib_socket
import select
import random
import errno
from contextlib import suppress
from ... import _core
from ...testing import wait_all_tasks_blocked, Sequencer, assert_checkpoints
import trio
# Cross-platform tests for IO handling
def fill_socket(sock):
try:
while True:
sock.send(b"x" * 65536)
except BlockingIOError:
pass
def drain_socket(sock):
try:
while True:
sock.recv(65536)
except BlockingIOError:
pass
@pytest.fixture
def socketpair():
pair = stdlib_socket.socketpair()
for sock in pair:
sock.setblocking(False)
yield pair
for sock in pair:
sock.close()
def using_fileno(fn):
def fileno_wrapper(fileobj):
return fn(fileobj.fileno())
name = "<{} on fileno>".format(fn.__name__)
fileno_wrapper.__name__ = fileno_wrapper.__qualname__ = name
return fileno_wrapper
wait_readable_options = [trio.lowlevel.wait_readable]
wait_writable_options = [trio.lowlevel.wait_writable]
notify_closing_options = [trio.lowlevel.notify_closing]
for options_list in [
wait_readable_options,
wait_writable_options,
notify_closing_options,
]:
options_list += [using_fileno(f) for f in options_list]
# Decorators that feed in different settings for wait_readable / wait_writable
# / notify_closing.
# Note that if you use all three decorators on the same test, it will run all
# N**3 *combinations*
read_socket_test = pytest.mark.parametrize(
"wait_readable", wait_readable_options, ids=lambda fn: fn.__name__
)
write_socket_test = pytest.mark.parametrize(
"wait_writable", wait_writable_options, ids=lambda fn: fn.__name__
)
notify_closing_test = pytest.mark.parametrize(
"notify_closing", notify_closing_options, ids=lambda fn: fn.__name__
)
# XX These tests are all a bit dicey because they can't distinguish between
# wait_on_{read,writ}able blocking the way it should, versus blocking
# momentarily and then immediately resuming.
@read_socket_test
@write_socket_test
async def test_wait_basic(socketpair, wait_readable, wait_writable):
a, b = socketpair
# They start out writable()
with assert_checkpoints():
await wait_writable(a)
# But readable() blocks until data arrives
record = []
async def block_on_read():
try:
with assert_checkpoints():
await wait_readable(a)
except _core.Cancelled:
record.append("cancelled")
else:
record.append("readable")
assert a.recv(10) == b"x"
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_read)
await wait_all_tasks_blocked()
assert record == []
b.send(b"x")
fill_socket(a)
# Now writable will block, but readable won't
with assert_checkpoints():
await wait_readable(b)
record = []
async def block_on_write():
try:
with assert_checkpoints():
await wait_writable(a)
except _core.Cancelled:
record.append("cancelled")
else:
record.append("writable")
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_write)
await wait_all_tasks_blocked()
assert record == []
drain_socket(b)
# check cancellation
record = []
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_read)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == ["cancelled"]
fill_socket(a)
record = []
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_write)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == ["cancelled"]
@read_socket_test
async def test_double_read(socketpair, wait_readable):
a, b = socketpair
# You can't have two tasks trying to read from a socket at the same time
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_readable, a)
await wait_all_tasks_blocked()
with pytest.raises(_core.BusyResourceError):
await wait_readable(a)
nursery.cancel_scope.cancel()
@write_socket_test
async def test_double_write(socketpair, wait_writable):
a, b = socketpair
# You can't have two tasks trying to write to a socket at the same time
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_writable, a)
await wait_all_tasks_blocked()
with pytest.raises(_core.BusyResourceError):
await wait_writable(a)
nursery.cancel_scope.cancel()
@read_socket_test
@write_socket_test
@notify_closing_test
async def test_interrupted_by_close(
socketpair, wait_readable, wait_writable, notify_closing
):
a, b = socketpair
async def reader():
with pytest.raises(_core.ClosedResourceError):
await wait_readable(a)
async def writer():
with pytest.raises(_core.ClosedResourceError):
await wait_writable(a)
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(reader)
nursery.start_soon(writer)
await wait_all_tasks_blocked()
notify_closing(a)
@read_socket_test
@write_socket_test
async def test_socket_simultaneous_read_write(socketpair, wait_readable, wait_writable):
record = []
async def r_task(sock):
await wait_readable(sock)
record.append("r_task")
async def w_task(sock):
await wait_writable(sock)
record.append("w_task")
a, b = socketpair
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(r_task, a)
nursery.start_soon(w_task, a)
await wait_all_tasks_blocked()
assert record == []
b.send(b"x")
await wait_all_tasks_blocked()
assert record == ["r_task"]
drain_socket(b)
await wait_all_tasks_blocked()
assert record == ["r_task", "w_task"]
@read_socket_test
@write_socket_test
async def test_socket_actual_streaming(socketpair, wait_readable, wait_writable):
a, b = socketpair
# Use a small send buffer on one of the sockets to increase the chance of
# getting partial writes
a.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_SNDBUF, 10000)
N = 1000000 # 1 megabyte
MAX_CHUNK = 65536
results = {}
async def sender(sock, seed, key):
r = random.Random(seed)
sent = 0
while sent < N:
print("sent", sent)
chunk = bytearray(r.randrange(MAX_CHUNK))
while chunk:
with assert_checkpoints():
await wait_writable(sock)
this_chunk_size = sock.send(chunk)
sent += this_chunk_size
del chunk[:this_chunk_size]
sock.shutdown(stdlib_socket.SHUT_WR)
results[key] = sent
async def receiver(sock, key):
received = 0
while True:
print("received", received)
with assert_checkpoints():
await wait_readable(sock)
this_chunk_size = len(sock.recv(MAX_CHUNK))
if not this_chunk_size:
break
received += this_chunk_size
results[key] = received
async with _core.open_nursery() as nursery:
nursery.start_soon(sender, a, 0, "send_a")
nursery.start_soon(sender, b, 1, "send_b")
nursery.start_soon(receiver, a, "recv_a")
nursery.start_soon(receiver, b, "recv_b")
assert results["send_a"] == results["recv_b"]
assert results["send_b"] == results["recv_a"]
async def test_notify_closing_on_invalid_object():
# It should either be a no-op (generally on Unix, where we don't know
# which fds are valid), or an OSError (on Windows, where we currently only
# support sockets, so we have to do some validation to figure out whether
# it's a socket or a regular handle).
got_oserror = False
got_no_error = False
try:
trio.lowlevel.notify_closing(-1)
except OSError:
got_oserror = True
else:
got_no_error = True
assert got_oserror or got_no_error
async def test_wait_on_invalid_object():
# We definitely want to raise an error everywhere if you pass in an
# invalid fd to wait_*
for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]:
with stdlib_socket.socket() as s:
fileno = s.fileno()
# We just closed the socket and don't do anything else in between, so
# we can be confident that the fileno hasn't be reassigned.
with pytest.raises(OSError):
await wait(fileno)
async def test_io_manager_statistics():
def check(*, expected_readers, expected_writers):
statistics = _core.current_statistics()
print(statistics)
iostats = statistics.io_statistics
if iostats.backend in ["epoll", "windows"]:
assert iostats.tasks_waiting_read == expected_readers
assert iostats.tasks_waiting_write == expected_writers
else:
assert iostats.backend == "kqueue"
assert iostats.tasks_waiting == expected_readers + expected_writers
a1, b1 = stdlib_socket.socketpair()
a2, b2 = stdlib_socket.socketpair()
a3, b3 = stdlib_socket.socketpair()
for sock in [a1, b1, a2, b2, a3, b3]:
sock.setblocking(False)
with a1, b1, a2, b2, a3, b3:
# let the call_soon_task settle down
await wait_all_tasks_blocked()
# 1 for call_soon_task
check(expected_readers=1, expected_writers=0)
# We want:
# - one socket with a writer blocked
# - two sockets with a reader blocked
# - a socket with both blocked
fill_socket(a1)
fill_socket(a3)
async with _core.open_nursery() as nursery:
nursery.start_soon(_core.wait_writable, a1)
nursery.start_soon(_core.wait_readable, a2)
nursery.start_soon(_core.wait_readable, b2)
nursery.start_soon(_core.wait_writable, a3)
nursery.start_soon(_core.wait_readable, a3)
await wait_all_tasks_blocked()
# +1 for call_soon_task
check(expected_readers=3 + 1, expected_writers=2)
nursery.cancel_scope.cancel()
# 1 for call_soon_task
check(expected_readers=1, expected_writers=0)
async def test_can_survive_unnotified_close():
# An "unnotified" close is when the user closes an fd/socket/handle
# directly, without calling notify_closing first. This should never happen
# -- users should call notify_closing before closing things. But, just in
# case they don't, we would still like to avoid exploding.
#
# Acceptable behaviors:
# - wait_* never return, but can be cancelled cleanly
# - wait_* exit cleanly
# - wait_* raise an OSError
#
# Not acceptable:
# - getting stuck in an uncancellable state
# - TrioInternalError blowing up the whole run
#
# This test exercises some tricky "unnotified close" scenarios, to make
# sure we get the "acceptable" behaviors.
async def allow_OSError(async_func, *args):
with suppress(OSError):
await async_func(*args)
with stdlib_socket.socket() as s:
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
await wait_all_tasks_blocked()
s.close()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# We hit different paths on Windows depending on whether we close the last
# handle to the object (which produces a LOCAL_CLOSE notification and
# wakes up wait_readable), or only close one of the handles (which leaves
# wait_readable pending until cancelled).
with stdlib_socket.socket() as s, s.dup() as s2: # noqa: F841
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
await wait_all_tasks_blocked()
s.close()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# A more elaborate case, with two tasks waiting. On windows and epoll,
# the two tasks get muxed together onto a single underlying wait
# operation. So when they're cancelled, there's a brief moment where one
# of the tasks is cancelled but the other isn't, so we try to re-issue the
# underlying wait operation. But here, the handle we were going to use to
# do that has been pulled out from under our feet... so test that we can
# survive this.
a, b = stdlib_socket.socketpair()
with a, b, a.dup() as a2: # noqa: F841
a.setblocking(False)
b.setblocking(False)
fill_socket(a)
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
await wait_all_tasks_blocked()
a.close()
nursery.cancel_scope.cancel()
# A similar case, but now the single-task-wakeup happens due to I/O
# arriving, not a cancellation, so the operation gets re-issued from
# handle_io context rather than abort context.
a, b = stdlib_socket.socketpair()
with a, b, a.dup() as a2: # noqa: F841
print("a={}, b={}, a2={}".format(a.fileno(), b.fileno(), a2.fileno()))
a.setblocking(False)
b.setblocking(False)
fill_socket(a)
e = trio.Event()
# We want to wait for the kernel to process the wakeup on 'a', if any.
# But depending on the platform, we might not get a wakeup on 'a'. So
# we put one task to sleep waiting on 'a', and we put a second task to
# sleep waiting on 'a2', with the idea that the 'a2' notification will
# definitely arrive, and when it does then we can assume that whatever
# notification was going to arrive for 'a' has also arrived.
async def wait_readable_a2_then_set():
await trio.lowlevel.wait_readable(a2)
e.set()
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
nursery.start_soon(wait_readable_a2_then_set)
await wait_all_tasks_blocked()
a.close()
b.send(b"x")
# Make sure that the wakeup has been received and everything has
# settled before cancelling the wait_writable.
await e.wait()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()

View File

@@ -0,0 +1,501 @@
import outcome
import pytest
import sys
import os
import signal
import threading
import contextlib
import time
from async_generator import (
async_generator,
yield_,
isasyncgenfunction,
asynccontextmanager,
)
from ... import _core
from ...testing import wait_all_tasks_blocked
from ..._util import signal_raise, is_main_thread
from ..._timeouts import sleep
from .tutil import slow
def ki_self():
signal_raise(signal.SIGINT)
def test_ki_self():
with pytest.raises(KeyboardInterrupt):
ki_self()
async def test_ki_enabled():
# Regular tasks aren't KI-protected
assert not _core.currently_ki_protected()
# Low-level call-soon callbacks are KI-protected
token = _core.current_trio_token()
record = []
def check():
record.append(_core.currently_ki_protected())
token.run_sync_soon(check)
await wait_all_tasks_blocked()
assert record == [True]
@_core.enable_ki_protection
def protected():
assert _core.currently_ki_protected()
unprotected()
@_core.disable_ki_protection
def unprotected():
assert not _core.currently_ki_protected()
protected()
@_core.enable_ki_protection
async def aprotected():
assert _core.currently_ki_protected()
await aunprotected()
@_core.disable_ki_protection
async def aunprotected():
assert not _core.currently_ki_protected()
await aprotected()
# make sure that the decorator here overrides the automatic manipulation
# that start_soon() does:
async with _core.open_nursery() as nursery:
nursery.start_soon(aprotected)
nursery.start_soon(aunprotected)
@_core.enable_ki_protection
def gen_protected():
assert _core.currently_ki_protected()
yield
for _ in gen_protected():
pass
@_core.disable_ki_protection
def gen_unprotected():
assert not _core.currently_ki_protected()
yield
for _ in gen_unprotected():
pass
# This used to be broken due to
#
# https://bugs.python.org/issue29590
#
# Specifically, after a coroutine is resumed with .throw(), then the stack
# makes it look like the immediate caller is the function that called
# .throw(), not the actual caller. So child() here would have a caller deep in
# the guts of the run loop, and always be protected, even when it shouldn't
# have been. (Solution: we don't use .throw() anymore.)
async def test_ki_enabled_after_yield_briefly():
@_core.enable_ki_protection
async def protected():
await child(True)
@_core.disable_ki_protection
async def unprotected():
await child(False)
async def child(expected):
import traceback
traceback.print_stack()
assert _core.currently_ki_protected() == expected
await _core.checkpoint()
traceback.print_stack()
assert _core.currently_ki_protected() == expected
await protected()
await unprotected()
# This also used to be broken due to
# https://bugs.python.org/issue29590
async def test_generator_based_context_manager_throw():
@contextlib.contextmanager
@_core.enable_ki_protection
def protected_manager():
assert _core.currently_ki_protected()
try:
yield
finally:
assert _core.currently_ki_protected()
with protected_manager():
assert not _core.currently_ki_protected()
with pytest.raises(KeyError):
# This is the one that used to fail
with protected_manager():
raise KeyError
async def test_agen_protection():
@_core.enable_ki_protection
@async_generator
async def agen_protected1():
assert _core.currently_ki_protected()
try:
await yield_()
finally:
assert _core.currently_ki_protected()
@_core.disable_ki_protection
@async_generator
async def agen_unprotected1():
assert not _core.currently_ki_protected()
try:
await yield_()
finally:
assert not _core.currently_ki_protected()
# Swap the order of the decorators:
@async_generator
@_core.enable_ki_protection
async def agen_protected2():
assert _core.currently_ki_protected()
try:
await yield_()
finally:
assert _core.currently_ki_protected()
@async_generator
@_core.disable_ki_protection
async def agen_unprotected2():
assert not _core.currently_ki_protected()
try:
await yield_()
finally:
assert not _core.currently_ki_protected()
# Native async generators
@_core.enable_ki_protection
async def agen_protected3():
assert _core.currently_ki_protected()
try:
yield
finally:
assert _core.currently_ki_protected()
@_core.disable_ki_protection
async def agen_unprotected3():
assert not _core.currently_ki_protected()
try:
yield
finally:
assert not _core.currently_ki_protected()
for agen_fn in [
agen_protected1,
agen_protected2,
agen_protected3,
agen_unprotected1,
agen_unprotected2,
agen_unprotected3,
]:
async for _ in agen_fn(): # noqa
assert not _core.currently_ki_protected()
# asynccontextmanager insists that the function passed must itself be an
# async gen function, not a wrapper around one
if isasyncgenfunction(agen_fn):
async with asynccontextmanager(agen_fn)():
assert not _core.currently_ki_protected()
# Another case that's tricky due to:
# https://bugs.python.org/issue29590
with pytest.raises(KeyError):
async with asynccontextmanager(agen_fn)():
raise KeyError
# Test the case where there's no magic local anywhere in the call stack
def test_ki_disabled_out_of_context():
assert _core.currently_ki_protected()
def test_ki_disabled_in_del():
def nestedfunction():
return _core.currently_ki_protected()
def __del__():
assert _core.currently_ki_protected()
assert nestedfunction()
@_core.disable_ki_protection
def outerfunction():
assert not _core.currently_ki_protected()
assert not nestedfunction()
__del__()
__del__()
outerfunction()
assert nestedfunction()
def test_ki_protection_works():
async def sleeper(name, record):
try:
while True:
await _core.checkpoint()
except _core.Cancelled:
record.add(name + " ok")
async def raiser(name, record):
try:
# os.kill runs signal handlers before returning, so we don't need
# to worry that the handler will be delayed
print("killing, protection =", _core.currently_ki_protected())
ki_self()
except KeyboardInterrupt:
print("raised!")
# Make sure we aren't getting cancelled as well as siginted
await _core.checkpoint()
record.add(name + " raise ok")
raise
else:
print("didn't raise!")
# If we didn't raise (b/c protected), then we *should* get
# cancelled at the next opportunity
try:
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
except _core.Cancelled:
record.add(name + " cancel ok")
# simulated control-C during raiser, which is *unprotected*
print("check 1")
record = set()
async def check_unprotected_kill():
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record)
nursery.start_soon(sleeper, "s2", record)
nursery.start_soon(raiser, "r1", record)
with pytest.raises(KeyboardInterrupt):
_core.run(check_unprotected_kill)
assert record == {"s1 ok", "s2 ok", "r1 raise ok"}
# simulated control-C during raiser, which is *protected*, so the KI gets
# delivered to the main task instead
print("check 2")
record = set()
async def check_protected_kill():
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record)
nursery.start_soon(sleeper, "s2", record)
nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record)
# __aexit__ blocks, and then receives the KI
with pytest.raises(KeyboardInterrupt):
_core.run(check_protected_kill)
assert record == {"s1 ok", "s2 ok", "r1 cancel ok"}
# kill at last moment still raises (run_sync_soon until it raises an
# error, then kill)
print("check 3")
async def check_kill_during_shutdown():
token = _core.current_trio_token()
def kill_during_shutdown():
assert _core.currently_ki_protected()
try:
token.run_sync_soon(kill_during_shutdown)
except _core.RunFinishedError:
# it's too late for regular handling! handle this!
print("kill! kill!")
ki_self()
token.run_sync_soon(kill_during_shutdown)
with pytest.raises(KeyboardInterrupt):
_core.run(check_kill_during_shutdown)
# KI arrives very early, before main is even spawned
print("check 4")
class InstrumentOfDeath:
def before_run(self):
ki_self()
async def main():
await _core.checkpoint()
with pytest.raises(KeyboardInterrupt):
_core.run(main, instruments=[InstrumentOfDeath()])
# checkpoint_if_cancelled notices pending KI
print("check 5")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint_if_cancelled()
_core.run(main)
# KI arrives while main task is not abortable, b/c already scheduled
print("check 6")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
ki_self()
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
_core.run(main)
# KI arrives while main task is not abortable, b/c refuses to be aborted
print("check 7")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
def abort(_):
_core.reschedule(task, outcome.Value(1))
return _core.Abort.FAILED
assert await _core.wait_task_rescheduled(abort) == 1
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
_core.run(main)
# KI delivered via slow abort
print("check 8")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
def abort(raise_cancel):
result = outcome.capture(raise_cancel)
_core.reschedule(task, result)
return _core.Abort.FAILED
with pytest.raises(KeyboardInterrupt):
assert await _core.wait_task_rescheduled(abort)
await _core.checkpoint()
_core.run(main)
# KI arrives just before main task exits, so the run_sync_soon machinery
# is still functioning and will accept the callback to deliver the KI, but
# by the time the callback is actually run, main has exited and can't be
# aborted.
print("check 9")
@_core.enable_ki_protection
async def main():
ki_self()
with pytest.raises(KeyboardInterrupt):
_core.run(main)
print("check 10")
# KI in unprotected code, with
# restrict_keyboard_interrupt_to_checkpoints=True
record = []
async def main():
# We're not KI protected...
assert not _core.currently_ki_protected()
ki_self()
# ...but even after the KI, we keep running uninterrupted...
record.append("ok")
# ...until we hit a checkpoint:
with pytest.raises(KeyboardInterrupt):
await sleep(10)
_core.run(main, restrict_keyboard_interrupt_to_checkpoints=True)
assert record == ["ok"]
record = []
# Exact same code raises KI early if we leave off the argument, doesn't
# even reach the record.append call:
with pytest.raises(KeyboardInterrupt):
_core.run(main)
assert record == []
# KI arrives while main task is inside a cancelled cancellation scope
# the KeyboardInterrupt should take priority
print("check 11")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
with _core.CancelScope() as cancel_scope:
cancel_scope.cancel()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
_core.run(main)
def test_ki_is_good_neighbor():
# in the unlikely event someone overwrites our signal handler, we leave
# the overwritten one be
try:
orig = signal.getsignal(signal.SIGINT)
def my_handler(signum, frame): # pragma: no cover
pass
async def main():
signal.signal(signal.SIGINT, my_handler)
_core.run(main)
assert signal.getsignal(signal.SIGINT) is my_handler
finally:
signal.signal(signal.SIGINT, orig)
# Regression test for #461
def test_ki_with_broken_threads():
thread = threading.main_thread()
# scary!
original = threading._active[thread.ident]
# put this in a try finally so we don't have a chance of cascading a
# breakage down to everything else
try:
del threading._active[thread.ident]
@_core.enable_ki_protection
async def inner():
assert signal.getsignal(signal.SIGINT) != signal.default_int_handler
_core.run(inner)
finally:
threading._active[thread.ident] = original

View File

@@ -0,0 +1,115 @@
import pytest
from ... import _core
# scary runvar tests
def test_runvar_smoketest():
t1 = _core.RunVar("test1")
t2 = _core.RunVar("test2", default="catfish")
assert "RunVar" in repr(t1)
async def first_check():
with pytest.raises(LookupError):
t1.get()
t1.set("swordfish")
assert t1.get() == "swordfish"
assert t2.get() == "catfish"
assert t2.get(default="eel") == "eel"
t2.set("goldfish")
assert t2.get() == "goldfish"
assert t2.get(default="tuna") == "goldfish"
async def second_check():
with pytest.raises(LookupError):
t1.get()
assert t2.get() == "catfish"
_core.run(first_check)
_core.run(second_check)
def test_runvar_resetting():
t1 = _core.RunVar("test1")
t2 = _core.RunVar("test2", default="dogfish")
t3 = _core.RunVar("test3")
async def reset_check():
token = t1.set("moonfish")
assert t1.get() == "moonfish"
t1.reset(token)
with pytest.raises(TypeError):
t1.reset(None)
with pytest.raises(LookupError):
t1.get()
token2 = t2.set("catdogfish")
assert t2.get() == "catdogfish"
t2.reset(token2)
assert t2.get() == "dogfish"
with pytest.raises(ValueError):
t2.reset(token2)
token3 = t3.set("basculin")
assert t3.get() == "basculin"
with pytest.raises(ValueError):
t1.reset(token3)
_core.run(reset_check)
def test_runvar_sync():
t1 = _core.RunVar("test1")
async def sync_check():
async def task1():
t1.set("plaice")
assert t1.get() == "plaice"
async def task2(tok):
t1.reset(token)
with pytest.raises(LookupError):
t1.get()
t1.set("cod")
async with _core.open_nursery() as n:
token = t1.set("cod")
assert t1.get() == "cod"
n.start_soon(task1)
await _core.wait_all_tasks_blocked()
assert t1.get() == "plaice"
n.start_soon(task2, token)
await _core.wait_all_tasks_blocked()
assert t1.get() == "cod"
_core.run(sync_check)
def test_accessing_runvar_outside_run_call_fails():
t1 = _core.RunVar("test1")
with pytest.raises(RuntimeError):
t1.set("asdf")
with pytest.raises(RuntimeError):
t1.get()
async def get_token():
return t1.set("ok")
token = _core.run(get_token)
with pytest.raises(RuntimeError):
t1.reset(token)

View File

@@ -0,0 +1,170 @@
from math import inf
import time
import pytest
from trio import sleep
from ... import _core
from .. import wait_all_tasks_blocked
from .._mock_clock import MockClock
from .tutil import slow
def test_mock_clock():
REAL_NOW = 123.0
c = MockClock()
c._real_clock = lambda: REAL_NOW
repr(c) # smoke test
assert c.rate == 0
assert c.current_time() == 0
c.jump(1.2)
assert c.current_time() == 1.2
with pytest.raises(ValueError):
c.jump(-1)
assert c.current_time() == 1.2
assert c.deadline_to_sleep_time(1.1) == 0
assert c.deadline_to_sleep_time(1.2) == 0
assert c.deadline_to_sleep_time(1.3) > 999999
with pytest.raises(ValueError):
c.rate = -1
assert c.rate == 0
c.rate = 2
assert c.current_time() == 1.2
REAL_NOW += 1
assert c.current_time() == 3.2
assert c.deadline_to_sleep_time(3.1) == 0
assert c.deadline_to_sleep_time(3.2) == 0
assert c.deadline_to_sleep_time(4.2) == 0.5
c.rate = 0.5
assert c.current_time() == 3.2
assert c.deadline_to_sleep_time(3.1) == 0
assert c.deadline_to_sleep_time(3.2) == 0
assert c.deadline_to_sleep_time(4.2) == 2.0
c.jump(0.8)
assert c.current_time() == 4.0
REAL_NOW += 1
assert c.current_time() == 4.5
c2 = MockClock(rate=3)
assert c2.rate == 3
assert c2.current_time() < 10
async def test_mock_clock_autojump(mock_clock):
assert mock_clock.autojump_threshold == inf
mock_clock.autojump_threshold = 0
assert mock_clock.autojump_threshold == 0
real_start = time.perf_counter()
virtual_start = _core.current_time()
for i in range(10):
print("sleeping {} seconds".format(10 * i))
await sleep(10 * i)
print("woke up!")
assert virtual_start + 10 * i == _core.current_time()
virtual_start = _core.current_time()
real_duration = time.perf_counter() - real_start
print("Slept {} seconds in {} seconds".format(10 * sum(range(10)), real_duration))
assert real_duration < 1
mock_clock.autojump_threshold = 0.02
t = _core.current_time()
# this should wake up before the autojump threshold triggers, so time
# shouldn't change
await wait_all_tasks_blocked()
assert t == _core.current_time()
# this should too
await wait_all_tasks_blocked(0.01)
assert t == _core.current_time()
# set up a situation where the autojump task is blocked for a long long
# time, to make sure that cancel-and-adjust-threshold logic is working
mock_clock.autojump_threshold = 10000
await wait_all_tasks_blocked()
mock_clock.autojump_threshold = 0
# if the above line didn't take affect immediately, then this would be
# bad:
await sleep(100000)
async def test_mock_clock_autojump_interference(mock_clock):
mock_clock.autojump_threshold = 0.02
mock_clock2 = MockClock()
# messing with the autojump threshold of a clock that isn't actually
# installed in the run loop shouldn't do anything.
mock_clock2.autojump_threshold = 0.01
# if the autojump_threshold of 0.01 were in effect, then the next line
# would block forever, as the autojump task kept waking up to try to
# jump the clock.
await wait_all_tasks_blocked(0.015)
# but the 0.02 limit does apply
await sleep(100000)
def test_mock_clock_autojump_preset():
# Check that we can set the autojump_threshold before the clock is
# actually in use, and it gets picked up
mock_clock = MockClock(autojump_threshold=0.1)
mock_clock.autojump_threshold = 0.01
real_start = time.perf_counter()
_core.run(sleep, 10000, clock=mock_clock)
assert time.perf_counter() - real_start < 1
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock):
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with the default cushion=0.
mock_clock.autojump_threshold = 0
record = []
async def sleeper():
await sleep(100)
record.append("yawn")
async def waiter():
await wait_all_tasks_blocked()
record.append("waiter woke")
await sleep(1000)
record.append("waiter done")
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper)
nursery.start_soon(waiter)
assert record == ["waiter woke", "yawn", "waiter done"]
@slow
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clock):
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with a non-zero cushion.
mock_clock.autojump_threshold = 0
record = []
async def sleeper():
await sleep(100)
record.append("yawn")
async def waiter():
await wait_all_tasks_blocked(1)
record.append("waiter done")
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper)
nursery.start_soon(waiter)
assert record == ["waiter done", "yawn"]

View File

@@ -0,0 +1,772 @@
import gc
import logging
import pytest
from traceback import (
extract_tb,
print_exception,
format_exception,
)
from traceback import _cause_message # type: ignore
import sys
import os
import re
from pathlib import Path
import subprocess
from .tutil import slow
from .._multierror import MultiError, concat_tb
from ..._core import open_nursery
class NotHashableException(Exception):
code = None
def __init__(self, code):
super().__init__()
self.code = code
def __eq__(self, other):
if not isinstance(other, NotHashableException):
return False
return self.code == other.code
async def raise_nothashable(code):
raise NotHashableException(code)
def raiser1():
raiser1_2()
def raiser1_2():
raiser1_3()
def raiser1_3():
raise ValueError("raiser1_string")
def raiser2():
raiser2_2()
def raiser2_2():
raise KeyError("raiser2_string")
def raiser3():
raise NameError
def get_exc(raiser):
try:
raiser()
except Exception as exc:
return exc
def get_tb(raiser):
return get_exc(raiser).__traceback__
def einfo(exc):
return (type(exc), exc, exc.__traceback__)
def test_concat_tb():
tb1 = get_tb(raiser1)
tb2 = get_tb(raiser2)
# These return a list of (filename, lineno, fn name, text) tuples
# https://docs.python.org/3/library/traceback.html#traceback.extract_tb
entries1 = extract_tb(tb1)
entries2 = extract_tb(tb2)
tb12 = concat_tb(tb1, tb2)
assert extract_tb(tb12) == entries1 + entries2
tb21 = concat_tb(tb2, tb1)
assert extract_tb(tb21) == entries2 + entries1
# Check degenerate cases
assert extract_tb(concat_tb(None, tb1)) == entries1
assert extract_tb(concat_tb(tb1, None)) == entries1
assert concat_tb(None, None) is None
# Make sure the original tracebacks didn't get mutated by mistake
assert extract_tb(get_tb(raiser1)) == entries1
assert extract_tb(get_tb(raiser2)) == entries2
def test_MultiError():
exc1 = get_exc(raiser1)
exc2 = get_exc(raiser2)
assert MultiError([exc1]) is exc1
m = MultiError([exc1, exc2])
assert m.exceptions == [exc1, exc2]
assert "ValueError" in str(m)
assert "ValueError" in repr(m)
with pytest.raises(TypeError):
MultiError(object())
with pytest.raises(TypeError):
MultiError([KeyError(), ValueError])
def test_MultiErrorOfSingleMultiError():
# For MultiError([MultiError]), ensure there is no bad recursion by the
# constructor where __init__ is called if __new__ returns a bare MultiError.
exceptions = [KeyError(), ValueError()]
a = MultiError(exceptions)
b = MultiError([a])
assert b == a
assert b.exceptions == exceptions
async def test_MultiErrorNotHashable():
exc1 = NotHashableException(42)
exc2 = NotHashableException(4242)
exc3 = ValueError()
assert exc1 != exc2
assert exc1 != exc3
with pytest.raises(MultiError):
async with open_nursery() as nursery:
nursery.start_soon(raise_nothashable, 42)
nursery.start_soon(raise_nothashable, 4242)
def test_MultiError_filter_NotHashable():
excs = MultiError([NotHashableException(42), ValueError()])
def handle_ValueError(exc):
if isinstance(exc, ValueError):
return None
else:
return exc
filtered_excs = MultiError.filter(handle_ValueError, excs)
assert isinstance(filtered_excs, NotHashableException)
def test_traceback_recursion():
exc1 = RuntimeError()
exc2 = KeyError()
exc3 = NotHashableException(42)
# Note how this creates a loop, where exc1 refers to exc1
# This could trigger an infinite recursion; the 'seen' set is supposed to prevent
# this.
exc1.__cause__ = MultiError([exc1, exc2, exc3])
format_exception(*einfo(exc1))
def make_tree():
# Returns an object like:
# MultiError([
# MultiError([
# ValueError,
# KeyError,
# ]),
# NameError,
# ])
# where all exceptions except the root have a non-trivial traceback.
exc1 = get_exc(raiser1)
exc2 = get_exc(raiser2)
exc3 = get_exc(raiser3)
# Give m12 a non-trivial traceback
try:
raise MultiError([exc1, exc2])
except BaseException as m12:
return MultiError([m12, exc3])
def assert_tree_eq(m1, m2):
if m1 is None or m2 is None:
assert m1 is m2
return
assert type(m1) is type(m2)
assert extract_tb(m1.__traceback__) == extract_tb(m2.__traceback__)
assert_tree_eq(m1.__cause__, m2.__cause__)
assert_tree_eq(m1.__context__, m2.__context__)
if isinstance(m1, MultiError):
assert len(m1.exceptions) == len(m2.exceptions)
for e1, e2 in zip(m1.exceptions, m2.exceptions):
assert_tree_eq(e1, e2)
def test_MultiError_filter():
def null_handler(exc):
return exc
m = make_tree()
assert_tree_eq(m, m)
assert MultiError.filter(null_handler, m) is m
assert_tree_eq(m, make_tree())
# Make sure we don't pick up any detritus if run in a context where
# implicit exception chaining would like to kick in
m = make_tree()
try:
raise ValueError
except ValueError:
assert MultiError.filter(null_handler, m) is m
assert_tree_eq(m, make_tree())
def simple_filter(exc):
if isinstance(exc, ValueError):
return None
if isinstance(exc, KeyError):
return RuntimeError()
return exc
new_m = MultiError.filter(simple_filter, make_tree())
assert isinstance(new_m, MultiError)
assert len(new_m.exceptions) == 2
# was: [[ValueError, KeyError], NameError]
# ValueError disappeared & KeyError became RuntimeError, so now:
assert isinstance(new_m.exceptions[0], RuntimeError)
assert isinstance(new_m.exceptions[1], NameError)
# implicit chaining:
assert isinstance(new_m.exceptions[0].__context__, KeyError)
# also, the traceback on the KeyError incorporates what used to be the
# traceback on its parent MultiError
orig = make_tree()
# make sure we have the right path
assert isinstance(orig.exceptions[0].exceptions[1], KeyError)
# get original traceback summary
orig_extracted = (
extract_tb(orig.__traceback__)
+ extract_tb(orig.exceptions[0].__traceback__)
+ extract_tb(orig.exceptions[0].exceptions[1].__traceback__)
)
def p(exc):
print_exception(type(exc), exc, exc.__traceback__)
p(orig)
p(orig.exceptions[0])
p(orig.exceptions[0].exceptions[1])
p(new_m.exceptions[0].__context__)
# compare to the new path
assert new_m.__traceback__ is None
new_extracted = extract_tb(new_m.exceptions[0].__context__.__traceback__)
assert orig_extracted == new_extracted
# check preserving partial tree
def filter_NameError(exc):
if isinstance(exc, NameError):
return None
return exc
m = make_tree()
new_m = MultiError.filter(filter_NameError, m)
# with the NameError gone, the other branch gets promoted
assert new_m is m.exceptions[0]
# check fully handling everything
def filter_all(exc):
return None
assert MultiError.filter(filter_all, make_tree()) is None
def test_MultiError_catch():
# No exception to catch
def noop(_):
pass # pragma: no cover
with MultiError.catch(noop):
pass
# Simple pass-through of all exceptions
m = make_tree()
with pytest.raises(MultiError) as excinfo:
with MultiError.catch(lambda exc: exc):
raise m
assert excinfo.value is m
# Should be unchanged, except that we added a traceback frame by raising
# it here
assert m.__traceback__ is not None
assert m.__traceback__.tb_frame.f_code.co_name == "test_MultiError_catch"
assert m.__traceback__.tb_next is None
m.__traceback__ = None
assert_tree_eq(m, make_tree())
# Swallows everything
with MultiError.catch(lambda _: None):
raise make_tree()
def simple_filter(exc):
if isinstance(exc, ValueError):
return None
if isinstance(exc, KeyError):
return RuntimeError()
return exc
with pytest.raises(MultiError) as excinfo:
with MultiError.catch(simple_filter):
raise make_tree()
new_m = excinfo.value
assert isinstance(new_m, MultiError)
assert len(new_m.exceptions) == 2
# was: [[ValueError, KeyError], NameError]
# ValueError disappeared & KeyError became RuntimeError, so now:
assert isinstance(new_m.exceptions[0], RuntimeError)
assert isinstance(new_m.exceptions[1], NameError)
# Make sure that Python did not successfully attach the old MultiError to
# our new MultiError's __context__
assert not new_m.__suppress_context__
assert new_m.__context__ is None
# check preservation of __cause__ and __context__
v = ValueError()
v.__cause__ = KeyError()
with pytest.raises(ValueError) as excinfo:
with MultiError.catch(lambda exc: exc):
raise v
assert isinstance(excinfo.value.__cause__, KeyError)
v = ValueError()
context = KeyError()
v.__context__ = context
with pytest.raises(ValueError) as excinfo:
with MultiError.catch(lambda exc: exc):
raise v
assert excinfo.value.__context__ is context
assert not excinfo.value.__suppress_context__
for suppress_context in [True, False]:
v = ValueError()
context = KeyError()
v.__context__ = context
v.__suppress_context__ = suppress_context
distractor = RuntimeError()
with pytest.raises(ValueError) as excinfo:
def catch_RuntimeError(exc):
if isinstance(exc, RuntimeError):
return None
else:
return exc
with MultiError.catch(catch_RuntimeError):
raise MultiError([v, distractor])
assert excinfo.value.__context__ is context
assert excinfo.value.__suppress_context__ == suppress_context
@pytest.mark.skipif(
sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC"
)
def test_MultiError_catch_doesnt_create_cyclic_garbage():
# https://github.com/python-trio/trio/pull/2063
gc.collect()
old_flags = gc.get_debug()
def make_multi():
# make_tree creates cycles itself, so a simple
raise MultiError([get_exc(raiser1), get_exc(raiser2)])
def simple_filter(exc):
if isinstance(exc, ValueError):
return Exception()
if isinstance(exc, KeyError):
return RuntimeError()
assert False, "only ValueError and KeyError should exist" # pragma: no cover
try:
gc.set_debug(gc.DEBUG_SAVEALL)
with pytest.raises(MultiError):
# covers MultiErrorCatcher.__exit__ and _multierror.copy_tb
with MultiError.catch(simple_filter):
raise make_multi()
gc.collect()
assert not gc.garbage
finally:
gc.set_debug(old_flags)
gc.garbage.clear()
def assert_match_in_seq(pattern_list, string):
offset = 0
print("looking for pattern matches...")
for pattern in pattern_list:
print("checking pattern:", pattern)
reobj = re.compile(pattern)
match = reobj.search(string, offset)
assert match is not None
offset = match.end()
def test_assert_match_in_seq():
assert_match_in_seq(["a", "b"], "xx a xx b xx")
assert_match_in_seq(["b", "a"], "xx b xx a xx")
with pytest.raises(AssertionError):
assert_match_in_seq(["a", "b"], "xx b xx a xx")
def test_format_exception():
exc = get_exc(raiser1)
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" not in formatted
assert "in raiser2_2" not in formatted
assert "direct cause" not in formatted
assert "During handling" not in formatted
exc = get_exc(raiser1)
exc.__cause__ = get_exc(raiser2)
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" in formatted
assert "in raiser2_2" in formatted
assert "direct cause" in formatted
assert "During handling" not in formatted
# ensure cause included
assert _cause_message in formatted
exc = get_exc(raiser1)
exc.__context__ = get_exc(raiser2)
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" in formatted
assert "in raiser2_2" in formatted
assert "direct cause" not in formatted
assert "During handling" in formatted
exc.__suppress_context__ = True
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" not in formatted
assert "in raiser2_2" not in formatted
assert "direct cause" not in formatted
assert "During handling" not in formatted
# chain=False
exc = get_exc(raiser1)
exc.__context__ = get_exc(raiser2)
formatted = "".join(format_exception(*einfo(exc), chain=False))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" not in formatted
assert "in raiser2_2" not in formatted
assert "direct cause" not in formatted
assert "During handling" not in formatted
# limit
exc = get_exc(raiser1)
exc.__context__ = get_exc(raiser2)
# get_exc adds a frame that counts against the limit, so limit=2 means we
# get 1 deep into the raiser stack
formatted = "".join(format_exception(*einfo(exc), limit=2))
print(formatted)
assert "raiser1_string" in formatted
assert "in raiser1" in formatted
assert "in raiser1_2" not in formatted
assert "raiser2_string" in formatted
assert "in raiser2" in formatted
assert "in raiser2_2" not in formatted
exc = get_exc(raiser1)
exc.__context__ = get_exc(raiser2)
formatted = "".join(format_exception(*einfo(exc), limit=1))
print(formatted)
assert "raiser1_string" in formatted
assert "in raiser1" not in formatted
assert "raiser2_string" in formatted
assert "in raiser2" not in formatted
# handles loops
exc = get_exc(raiser1)
exc.__cause__ = exc
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" not in formatted
assert "in raiser2_2" not in formatted
# ensure duplicate exception is not included as cause
assert _cause_message not in formatted
# MultiError
formatted = "".join(format_exception(*einfo(make_tree())))
print(formatted)
assert_match_in_seq(
[
# Outer exception is MultiError
r"MultiError:",
# First embedded exception is the embedded MultiError
r"\nDetails of embedded exception 1",
# Which has a single stack frame from make_tree raising it
r"in make_tree",
# Then it has two embedded exceptions
r" Details of embedded exception 1",
r"in raiser1_2",
# for some reason ValueError has no quotes
r"ValueError: raiser1_string",
r" Details of embedded exception 2",
r"in raiser2_2",
# But KeyError does have quotes
r"KeyError: 'raiser2_string'",
# And finally the NameError, which is a sibling of the embedded
# MultiError
r"\nDetails of embedded exception 2:",
r"in raiser3",
r"NameError",
],
formatted,
)
# Prints duplicate exceptions in sub-exceptions
exc1 = get_exc(raiser1)
def raise1_raiser1():
try:
raise exc1
except:
raise ValueError("foo")
def raise2_raiser1():
try:
raise exc1
except:
raise KeyError("bar")
exc2 = get_exc(raise1_raiser1)
exc3 = get_exc(raise2_raiser1)
try:
raise MultiError([exc2, exc3])
except MultiError as e:
exc = e
formatted = "".join(format_exception(*einfo(exc)))
print(formatted)
assert_match_in_seq(
[
r"Traceback",
# Outer exception is MultiError
r"MultiError:",
# First embedded exception is the embedded ValueError with cause of raiser1
r"\nDetails of embedded exception 1",
# Print details of exc1
r" Traceback",
r"in get_exc",
r"in raiser1",
r"ValueError: raiser1_string",
# Print details of exc2
r"\n During handling of the above exception, another exception occurred:",
r" Traceback",
r"in get_exc",
r"in raise1_raiser1",
r" ValueError: foo",
# Second embedded exception is the embedded KeyError with cause of raiser1
r"\nDetails of embedded exception 2",
# Print details of exc1 again
r" Traceback",
r"in get_exc",
r"in raiser1",
r"ValueError: raiser1_string",
# Print details of exc3
r"\n During handling of the above exception, another exception occurred:",
r" Traceback",
r"in get_exc",
r"in raise2_raiser1",
r" KeyError: 'bar'",
],
formatted,
)
def test_logging(caplog):
exc1 = get_exc(raiser1)
exc2 = get_exc(raiser2)
m = MultiError([exc1, exc2])
message = "test test test"
try:
raise m
except MultiError as exc:
logging.getLogger().exception(message)
# Join lines together
formatted = "".join(format_exception(type(exc), exc, exc.__traceback__))
assert message in caplog.text
assert formatted in caplog.text
def run_script(name, use_ipython=False):
import trio
trio_path = Path(trio.__file__).parent.parent
script_path = Path(__file__).parent / "test_multierror_scripts" / name
env = dict(os.environ)
print("parent PYTHONPATH:", env.get("PYTHONPATH"))
if "PYTHONPATH" in env: # pragma: no cover
pp = env["PYTHONPATH"].split(os.pathsep)
else:
pp = []
pp.insert(0, str(trio_path))
pp.insert(0, str(script_path.parent))
env["PYTHONPATH"] = os.pathsep.join(pp)
print("subprocess PYTHONPATH:", env.get("PYTHONPATH"))
if use_ipython:
lines = [script_path.read_text(), "exit()"]
cmd = [
sys.executable,
"-u",
"-m",
"IPython",
# no startup files
"--quick",
"--TerminalIPythonApp.code_to_run=" + "\n".join(lines),
]
else:
cmd = [sys.executable, "-u", str(script_path)]
print("running:", cmd)
completed = subprocess.run(
cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
print("process output:")
print(completed.stdout.decode("utf-8"))
return completed
def check_simple_excepthook(completed):
assert_match_in_seq(
[
"in <module>",
"MultiError",
"Details of embedded exception 1",
"in exc1_fn",
"ValueError",
"Details of embedded exception 2",
"in exc2_fn",
"KeyError",
],
completed.stdout.decode("utf-8"),
)
def test_simple_excepthook():
completed = run_script("simple_excepthook.py")
check_simple_excepthook(completed)
def test_custom_excepthook():
# Check that user-defined excepthooks aren't overridden
completed = run_script("custom_excepthook.py")
assert_match_in_seq(
[
# The warning
"RuntimeWarning",
"already have a custom",
# The message printed by the custom hook, proving we didn't
# override it
"custom running!",
# The MultiError
"MultiError:",
],
completed.stdout.decode("utf-8"),
)
# This warning is triggered by ipython 7.5.0 on python 3.8
import warnings
warnings.filterwarnings(
"ignore",
message='.*"@coroutine" decorator is deprecated',
category=DeprecationWarning,
module="IPython.*",
)
try:
import IPython
except ImportError: # pragma: no cover
have_ipython = False
else:
have_ipython = True
need_ipython = pytest.mark.skipif(not have_ipython, reason="need IPython")
@slow
@need_ipython
def test_ipython_exc_handler():
completed = run_script("simple_excepthook.py", use_ipython=True)
check_simple_excepthook(completed)
@slow
@need_ipython
def test_ipython_imported_but_unused():
completed = run_script("simple_excepthook_IPython.py")
check_simple_excepthook(completed)
@slow
def test_partial_imported_but_unused():
# Check that a functools.partial as sys.excepthook doesn't cause an exception when
# importing trio. This was a problem due to the lack of a .__name__ attribute and
# happens when inside a pytest-qt test case for example.
completed = run_script("simple_excepthook_partial.py")
completed.check_returncode()
@slow
@need_ipython
def test_ipython_custom_exc_handler():
# Check we get a nice warning (but only one!) if the user is using IPython
# and already has some other set_custom_exc handler installed.
completed = run_script("ipython_custom_exc.py", use_ipython=True)
assert_match_in_seq(
[
# The warning
"RuntimeWarning",
"IPython detected",
"skip installing Trio",
# The MultiError
"MultiError",
"ValueError",
"KeyError",
],
completed.stdout.decode("utf-8"),
)
# Make sure our other warning doesn't show up
assert "custom sys.excepthook" not in completed.stdout.decode("utf-8")
@slow
@pytest.mark.skipif(
not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(),
reason="need Ubuntu with python3-apport installed",
)
def test_apport_excepthook_monkeypatch_interaction():
completed = run_script("apport_excepthook.py")
stdout = completed.stdout.decode("utf-8")
# No warning
assert "custom sys.excepthook" not in stdout
# Proper traceback
assert_match_in_seq(
["Details of embedded", "KeyError", "Details of embedded", "ValueError"],
stdout,
)

View File

@@ -0,0 +1,2 @@
# This isn't really a package, everything in here is a standalone script. This
# __init__.py is just to fool setup.py into actually installing the things.

View File

@@ -0,0 +1,7 @@
# https://coverage.readthedocs.io/en/latest/subprocess.html
try:
import coverage
except ImportError: # pragma: no cover
pass
else:
coverage.process_startup()

View File

@@ -0,0 +1,13 @@
# The apport_python_hook package is only installed as part of Ubuntu's system
# python, and not available in venvs. So before we can import it we have to
# make sure it's on sys.path.
import sys
sys.path.append("/usr/lib/python3/dist-packages")
import apport_python_hook
apport_python_hook.install()
import trio
raise trio.MultiError([KeyError("key_error"), ValueError("value_error")])

View File

@@ -0,0 +1,18 @@
import _common
import sys
def custom_excepthook(*args):
print("custom running!")
return sys.__excepthook__(*args)
sys.excepthook = custom_excepthook
# Should warn that we'll get kinda-broken tracebacks
import trio
# The custom excepthook should run, because Trio was polite and didn't
# override it
raise trio.MultiError([ValueError(), KeyError()])

View File

@@ -0,0 +1,36 @@
import _common
# Override the regular excepthook too -- it doesn't change anything either way
# because ipython doesn't use it, but we want to make sure Trio doesn't warn
# about it.
import sys
def custom_excepthook(*args):
print("custom running!")
return sys.__excepthook__(*args)
sys.excepthook = custom_excepthook
import IPython
ip = IPython.get_ipython()
# Set this to some random nonsense
class SomeError(Exception):
pass
def custom_exc_hook(etype, value, tb, tb_offset=None):
ip.showtraceback()
ip.set_custom_exc((SomeError,), custom_exc_hook)
import trio
# The custom excepthook should run, because Trio was polite and didn't
# override it
raise trio.MultiError([ValueError(), KeyError()])

View File

@@ -0,0 +1,21 @@
import _common
import trio
def exc1_fn():
try:
raise ValueError
except Exception as exc:
return exc
def exc2_fn():
try:
raise KeyError
except Exception as exc:
return exc
# This should be printed nicely, because Trio overrode sys.excepthook
raise trio.MultiError([exc1_fn(), exc2_fn()])

View File

@@ -0,0 +1,7 @@
import _common
# To tickle the "is IPython loaded?" logic, make sure that Trio tolerates
# IPython loaded but not actually in use
import IPython
import simple_excepthook

View File

@@ -0,0 +1,13 @@
import functools
import sys
import _common
# just making sure importing Trio doesn't fail if sys.excepthook doesn't have a
# .__name__ attribute
sys.excepthook = functools.partial(sys.excepthook)
assert not hasattr(sys.excepthook, "__name__")
import trio

View File

@@ -0,0 +1,197 @@
import pytest
from ... import _core
from ...testing import wait_all_tasks_blocked
from .._parking_lot import ParkingLot
from .tutil import check_sequence_matches
async def test_parking_lot_basic():
record = []
async def waiter(i, lot):
record.append("sleep {}".format(i))
await lot.park()
record.append("wake {}".format(i))
async with _core.open_nursery() as nursery:
lot = ParkingLot()
assert not lot
assert len(lot) == 0
assert lot.statistics().tasks_waiting == 0
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
assert len(record) == 3
assert bool(lot)
assert len(lot) == 3
assert lot.statistics().tasks_waiting == 3
lot.unpark_all()
assert lot.statistics().tasks_waiting == 0
await wait_all_tasks_blocked()
assert len(record) == 6
check_sequence_matches(
record, [{"sleep 0", "sleep 1", "sleep 2"}, {"wake 0", "wake 1", "wake 2"}]
)
async with _core.open_nursery() as nursery:
record = []
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
assert len(record) == 3
for i in range(3):
lot.unpark()
await wait_all_tasks_blocked()
# 1-by-1 wakeups are strict FIFO
assert record == [
"sleep 0",
"sleep 1",
"sleep 2",
"wake 0",
"wake 1",
"wake 2",
]
# It's legal (but a no-op) to try and unpark while there's nothing parked
lot.unpark()
lot.unpark(count=1)
lot.unpark(count=100)
# Check unpark with count
async with _core.open_nursery() as nursery:
record = []
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
lot.unpark(count=2)
await wait_all_tasks_blocked()
check_sequence_matches(
record, ["sleep 0", "sleep 1", "sleep 2", {"wake 0", "wake 1"}]
)
lot.unpark_all()
async def cancellable_waiter(name, lot, scopes, record):
with _core.CancelScope() as scope:
scopes[name] = scope
record.append("sleep {}".format(name))
try:
await lot.park()
except _core.Cancelled:
record.append("cancelled {}".format(name))
else:
record.append("wake {}".format(name))
async def test_parking_lot_cancel():
record = []
scopes = {}
async with _core.open_nursery() as nursery:
lot = ParkingLot()
nursery.start_soon(cancellable_waiter, 1, lot, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
scopes[2].cancel()
await wait_all_tasks_blocked()
assert len(record) == 4
lot.unpark_all()
await wait_all_tasks_blocked()
assert len(record) == 6
check_sequence_matches(
record, ["sleep 1", "sleep 2", "sleep 3", "cancelled 2", {"wake 1", "wake 3"}]
)
async def test_parking_lot_repark():
record = []
scopes = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
with pytest.raises(TypeError):
lot1.repark([])
async with _core.open_nursery() as nursery:
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
assert len(lot1) == 3
lot1.repark(lot2)
assert len(lot1) == 2
assert len(lot2) == 1
lot2.unpark_all()
await wait_all_tasks_blocked()
assert len(record) == 4
assert record == ["sleep 1", "sleep 2", "sleep 3", "wake 1"]
lot1.repark_all(lot2)
assert len(lot1) == 0
assert len(lot2) == 2
scopes[2].cancel()
await wait_all_tasks_blocked()
assert len(lot2) == 1
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"cancelled 2",
]
lot2.unpark_all()
await wait_all_tasks_blocked()
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"cancelled 2",
"wake 3",
]
async def test_parking_lot_repark_with_count():
record = []
scopes = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
async with _core.open_nursery() as nursery:
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
assert len(lot1) == 3
assert len(lot2) == 0
lot1.repark(lot2, count=2)
assert len(lot1) == 1
assert len(lot2) == 2
while lot2:
lot2.unpark()
await wait_all_tasks_blocked()
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"wake 2",
]
lot1.unpark_all()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,164 @@
import pytest
import threading
from queue import Queue
import time
import sys
from contextlib import contextmanager
from .tutil import slow, gc_collect_harder, disable_threading_excepthook
from .. import _thread_cache
from .._thread_cache import start_thread_soon, ThreadCache
def test_thread_cache_basics():
q = Queue()
def fn():
raise RuntimeError("hi")
def deliver(outcome):
q.put(outcome)
start_thread_soon(fn, deliver)
outcome = q.get()
with pytest.raises(RuntimeError, match="hi"):
outcome.unwrap()
def test_thread_cache_deref():
res = [False]
class del_me:
def __call__(self):
return 42
def __del__(self):
res[0] = True
q = Queue()
def deliver(outcome):
q.put(outcome)
start_thread_soon(del_me(), deliver)
outcome = q.get()
assert outcome.unwrap() == 42
gc_collect_harder()
assert res[0]
@slow
def test_spawning_new_thread_from_deliver_reuses_starting_thread():
# We know that no-one else is using the thread cache, so if we keep
# submitting new jobs the instant the previous one is finished, we should
# keep getting the same thread over and over. This tests both that the
# thread cache is LIFO, and that threads can be assigned new work *before*
# deliver exits.
# Make sure there are a few threads running, so if we weren't LIFO then we
# could grab the wrong one.
q = Queue()
COUNT = 5
for _ in range(COUNT):
start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result))
for _ in range(COUNT):
q.get().unwrap()
seen_threads = set()
done = threading.Event()
def deliver(n, _):
print(n)
seen_threads.add(threading.current_thread())
if n == 0:
done.set()
else:
start_thread_soon(lambda: None, lambda _: deliver(n - 1, _))
start_thread_soon(lambda: None, lambda _: deliver(5, _))
done.wait()
assert len(seen_threads) == 1
@slow
def test_idle_threads_exit(monkeypatch):
# Temporarily set the idle timeout to something tiny, to speed up the
# test. (But non-zero, so that the worker loop will at least yield the
# CPU.)
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
q = Queue()
start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread()))
seen_thread = q.get()
# Since the idle timeout is 0, after sleeping for 1 second, the thread
# should have exited
time.sleep(1)
assert not seen_thread.is_alive()
@contextmanager
def _join_started_threads():
before = frozenset(threading.enumerate())
try:
yield
finally:
for thread in threading.enumerate():
if thread not in before:
thread.join()
def test_race_between_idle_exit_and_job_assignment(monkeypatch):
# This is a lock where the first few times you try to acquire it with a
# timeout, it waits until the lock is available and then pretends to time
# out. Using this in our thread cache implementation causes the following
# sequence:
#
# 1. start_thread_soon grabs the worker thread, assigns it a job, and
# releases its lock.
# 2. The worker thread wakes up (because the lock has been released), but
# the JankyLock lies to it and tells it that the lock timed out. So the
# worker thread tries to exit.
# 3. The worker thread checks for the race between exiting and being
# assigned a job, and discovers that it *is* in the process of being
# assigned a job, so it loops around and tries to acquire the lock
# again.
# 4. Eventually the JankyLock admits that the lock is available, and
# everything proceeds as normal.
class JankyLock:
def __init__(self):
self._lock = threading.Lock()
self._counter = 3
def acquire(self, timeout=None):
self._lock.acquire()
if timeout is None:
return True
else:
if self._counter > 0:
self._counter -= 1
self._lock.release()
return False
return True
def release(self):
self._lock.release()
monkeypatch.setattr(_thread_cache, "Lock", JankyLock)
with disable_threading_excepthook(), _join_started_threads():
tc = ThreadCache()
done = threading.Event()
tc.start_thread_soon(lambda: None, lambda _: done.set())
done.wait()
# Let's kill the thread we started, so it doesn't hang around until the
# test suite finishes. Doesn't really do any harm, but it can be confusing
# to see it in debug output. This is hacky, and leaves our ThreadCache
# object in an inconsistent state... but it doesn't matter, because we're
# not going to use it again anyway.
tc.start_thread_soon(lambda: None, lambda _: sys.exit())

View File

@@ -0,0 +1,13 @@
import pytest
from .tutil import check_sequence_matches
def test_check_sequence_matches():
check_sequence_matches([1, 2, 3], [1, 2, 3])
with pytest.raises(AssertionError):
check_sequence_matches([1, 3, 2], [1, 2, 3])
check_sequence_matches([1, 2, 3, 4], [1, {2, 3}, 4])
check_sequence_matches([1, 3, 2, 4], [1, {2, 3}, 4])
with pytest.raises(AssertionError):
check_sequence_matches([1, 2, 4, 3], [1, {2, 3}, 4])

View File

@@ -0,0 +1,152 @@
import itertools
import pytest
from ... import _core
from ...testing import assert_checkpoints, wait_all_tasks_blocked
pytestmark = pytest.mark.filterwarnings(
"ignore:.*UnboundedQueue:trio.TrioDeprecationWarning"
)
async def test_UnboundedQueue_basic():
q = _core.UnboundedQueue()
q.put_nowait("hi")
assert await q.get_batch() == ["hi"]
with pytest.raises(_core.WouldBlock):
q.get_batch_nowait()
q.put_nowait(1)
q.put_nowait(2)
q.put_nowait(3)
assert q.get_batch_nowait() == [1, 2, 3]
assert q.empty()
assert q.qsize() == 0
q.put_nowait(None)
assert not q.empty()
assert q.qsize() == 1
stats = q.statistics()
assert stats.qsize == 1
assert stats.tasks_waiting == 0
# smoke test
repr(q)
async def test_UnboundedQueue_blocking():
record = []
q = _core.UnboundedQueue()
async def get_batch_consumer():
while True:
batch = await q.get_batch()
assert batch
record.append(batch)
async def aiter_consumer():
async for batch in q:
assert batch
record.append(batch)
for consumer in (get_batch_consumer, aiter_consumer):
record.clear()
async with _core.open_nursery() as nursery:
nursery.start_soon(consumer)
await _core.wait_all_tasks_blocked()
stats = q.statistics()
assert stats.qsize == 0
assert stats.tasks_waiting == 1
q.put_nowait(10)
q.put_nowait(11)
await _core.wait_all_tasks_blocked()
q.put_nowait(12)
await _core.wait_all_tasks_blocked()
assert record == [[10, 11], [12]]
nursery.cancel_scope.cancel()
async def test_UnboundedQueue_fairness():
q = _core.UnboundedQueue()
# If there's no-one else around, we can put stuff in and take it out
# again, no problem
q.put_nowait(1)
q.put_nowait(2)
assert q.get_batch_nowait() == [1, 2]
result = None
async def get_batch(q):
nonlocal result
result = await q.get_batch()
# But if someone else is waiting to read, then they get dibs
async with _core.open_nursery() as nursery:
nursery.start_soon(get_batch, q)
await _core.wait_all_tasks_blocked()
q.put_nowait(3)
q.put_nowait(4)
with pytest.raises(_core.WouldBlock):
q.get_batch_nowait()
assert result == [3, 4]
# If two tasks are trying to read, they alternate
record = []
async def reader(name):
while True:
record.append((name, await q.get_batch()))
async with _core.open_nursery() as nursery:
nursery.start_soon(reader, "a")
await _core.wait_all_tasks_blocked()
nursery.start_soon(reader, "b")
await _core.wait_all_tasks_blocked()
for i in range(20):
q.put_nowait(i)
await _core.wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == list(zip(itertools.cycle("ab"), [[i] for i in range(20)]))
async def test_UnboundedQueue_trivial_yields():
q = _core.UnboundedQueue()
q.put_nowait(None)
with assert_checkpoints():
await q.get_batch()
q.put_nowait(None)
with assert_checkpoints():
async for _ in q: # noqa # pragma: no branch
break
async def test_UnboundedQueue_no_spurious_wakeups():
# If we have two tasks waiting, and put two items into the queue... then
# only one task wakes up
record = []
async def getter(q, i):
got = await q.get_batch()
record.append((i, got))
async with _core.open_nursery() as nursery:
q = _core.UnboundedQueue()
nursery.start_soon(getter, q, 1)
await wait_all_tasks_blocked()
nursery.start_soon(getter, q, 2)
await wait_all_tasks_blocked()
for i in range(10):
q.put_nowait(i)
await wait_all_tasks_blocked()
assert record == [(1, list(range(10)))]
nursery.cancel_scope.cancel()

View File

@@ -0,0 +1 @@
import pytest

View File

@@ -0,0 +1,218 @@
import os
import tempfile
from contextlib import contextmanager
import pytest
on_windows = os.name == "nt"
# Mark all the tests in this file as being windows-only
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
from .tutil import slow, gc_collect_harder, restore_unraisablehook
from ... import _core, sleep, move_on_after
from ...testing import wait_all_tasks_blocked
if on_windows:
from .._windows_cffi import (
ffi,
kernel32,
INVALID_HANDLE_VALUE,
raise_winerror,
FileFlags,
)
# The undocumented API that this is testing should be changed to stop using
# UnboundedQueue (or just removed until we have time to redo it), but until
# then we filter out the warning.
@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning")
async def test_completion_key_listen():
async def post(key):
iocp = ffi.cast("HANDLE", _core.current_iocp())
for i in range(10):
print("post", i)
if i % 3 == 0:
await _core.checkpoint()
success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL)
assert success
with _core.monitor_completion_key() as (key, queue):
async with _core.open_nursery() as nursery:
nursery.start_soon(post, key)
i = 0
print("loop")
async for batch in queue: # pragma: no branch
print("got some", batch)
for info in batch:
assert info.lpOverlapped == 0
assert info.dwNumberOfBytesTransferred == i
i += 1
if i == 10:
break
print("end loop")
async def test_readinto_overlapped():
data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024
buffer = bytearray(len(data))
with tempfile.TemporaryDirectory() as tdir:
tfile = os.path.join(tdir, "numbers.txt")
with open(tfile, "wb") as fp:
fp.write(data)
fp.flush()
rawname = tfile.encode("utf-16le") + b"\0\0"
rawname_buf = ffi.from_buffer(rawname)
handle = kernel32.CreateFileW(
ffi.cast("LPCWSTR", rawname_buf),
FileFlags.GENERIC_READ,
FileFlags.FILE_SHARE_READ,
ffi.NULL, # no security attributes
FileFlags.OPEN_EXISTING,
FileFlags.FILE_FLAG_OVERLAPPED,
ffi.NULL, # no template file
)
if handle == INVALID_HANDLE_VALUE: # pragma: no cover
raise_winerror()
try:
with memoryview(buffer) as buffer_view:
async def read_region(start, end):
await _core.readinto_overlapped(
handle, buffer_view[start:end], start
)
_core.register_with_iocp(handle)
async with _core.open_nursery() as nursery:
for start in range(0, 4096, 512):
nursery.start_soon(read_region, start, start + 512)
assert buffer == data
with pytest.raises(BufferError):
await _core.readinto_overlapped(handle, b"immutable")
finally:
kernel32.CloseHandle(handle)
@contextmanager
def pipe_with_overlapped_read():
from asyncio.windows_utils import pipe
import msvcrt
read_handle, write_handle = pipe(overlapped=(True, False))
try:
write_fd = msvcrt.open_osfhandle(write_handle, 0)
yield os.fdopen(write_fd, "wb", closefd=False), read_handle
finally:
kernel32.CloseHandle(ffi.cast("HANDLE", read_handle))
kernel32.CloseHandle(ffi.cast("HANDLE", write_handle))
@restore_unraisablehook()
def test_forgot_to_register_with_iocp():
with pipe_with_overlapped_read() as (write_fp, read_handle):
with write_fp:
write_fp.write(b"test\n")
left_run_yet = False
async def main():
target = bytearray(1)
try:
async with _core.open_nursery() as nursery:
nursery.start_soon(
_core.readinto_overlapped, read_handle, target, name="xyz"
)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
finally:
# Run loop is exited without unwinding running tasks, so
# we don't get here until the main() coroutine is GC'ed
assert left_run_yet
with pytest.raises(_core.TrioInternalError) as exc_info:
_core.run(main)
left_run_yet = True
assert "Failed to cancel overlapped I/O in xyz " in str(exc_info.value)
assert "forget to call register_with_iocp()?" in str(exc_info.value)
# Make sure the Nursery.__del__ assertion about dangling children
# gets put with the correct test
del exc_info
gc_collect_harder()
@slow
async def test_too_late_to_cancel():
import time
with pipe_with_overlapped_read() as (write_fp, read_handle):
_core.register_with_iocp(read_handle)
target = bytearray(6)
async with _core.open_nursery() as nursery:
# Start an async read in the background
nursery.start_soon(_core.readinto_overlapped, read_handle, target)
await wait_all_tasks_blocked()
# Synchronous write to the other end of the pipe
with write_fp:
write_fp.write(b"test1\ntest2\n")
# Note: not trio.sleep! We're making sure the OS level
# ReadFile completes, before Trio has a chance to execute
# another checkpoint and notice it completed.
time.sleep(1)
nursery.cancel_scope.cancel()
assert target[:6] == b"test1\n"
# Do another I/O to make sure we've actually processed the
# fallback completion that was posted when CancelIoEx failed.
assert await _core.readinto_overlapped(read_handle, target) == 6
assert target[:6] == b"test2\n"
def test_lsp_that_hooks_select_gives_good_error(monkeypatch):
from .._windows_cffi import WSAIoctls, _handle
from .. import _io_windows
def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE):
if hasattr(sock, "fileno"): # pragma: no branch
sock = sock.fileno()
if which == WSAIoctls.SIO_BSP_HANDLE_SELECT:
return _handle(sock + 1)
else:
return _handle(sock)
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
with pytest.raises(
RuntimeError, match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ"
):
_core.run(sleep, 0)
def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch):
# This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns
# self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns
# self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to
# make sure we get an error rather than an infinite loop.
from .._windows_cffi import WSAIoctls, _handle
from .. import _io_windows
def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE):
if hasattr(sock, "fileno"): # pragma: no branch
sock = sock.fileno()
if which == WSAIoctls.SIO_BASE_HANDLE:
raise OSError("nope")
else:
return _handle(sock)
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
with pytest.raises(
RuntimeError,
match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff",
):
_core.run(sleep, 0)

View File

@@ -0,0 +1,146 @@
# Utilities for testing
import asyncio
import socket as stdlib_socket
import threading
import os
import sys
from typing import TYPE_CHECKING
import pytest
import warnings
from contextlib import contextmanager, closing
import gc
# See trio/tests/conftest.py for the other half of this
from trio.tests.conftest import RUN_SLOW
slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests")
# PyPy 7.2 was released with a bug that just never called the async
# generator 'firstiter' hook at all. This impacts tests of end-of-run
# finalization (nothing gets added to runner.asyncgens) and tests of
# "foreign" async generator behavior (since the firstiter hook is what
# marks the asyncgen as foreign), but most tests of GC-mediated
# finalization still work.
buggy_pypy_asyncgens = (
not TYPE_CHECKING
and sys.implementation.name == "pypy"
and sys.pypy_version_info < (7, 3)
)
try:
s = stdlib_socket.socket(stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0)
except OSError: # pragma: no cover
# Some systems don't even support creating an IPv6 socket, let alone
# binding it. (ex: Linux with 'ipv6.disable=1' in the kernel command line)
# We don't have any of those in our CI, and there's nothing that gets
# tested _only_ if can_create_ipv6 = False, so we'll just no-cover this.
can_create_ipv6 = False
can_bind_ipv6 = False
else:
can_create_ipv6 = True
with s:
try:
s.bind(("::1", 0))
except OSError:
can_bind_ipv6 = False
else:
can_bind_ipv6 = True
creates_ipv6 = pytest.mark.skipif(not can_create_ipv6, reason="need IPv6")
binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6")
def gc_collect_harder():
# In the test suite we sometimes want to call gc.collect() to make sure
# that any objects with noisy __del__ methods (e.g. unawaited coroutines)
# get collected before we continue, so their noise doesn't leak into
# unrelated tests.
#
# On PyPy, coroutine objects (for example) can survive at least 1 round of
# garbage collection, because executing their __del__ method to print the
# warning can cause them to be resurrected. So we call collect a few times
# to make sure.
for _ in range(4):
gc.collect()
# Some of our tests need to leak coroutines, and thus trigger the
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
# manager should be used anywhere this happens to hide those messages, because
# when expected they're clutter.
@contextmanager
def ignore_coroutine_never_awaited_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited")
try:
yield
finally:
# Make sure to trigger any coroutine __del__ methods now, before
# we leave the context manager.
gc_collect_harder()
def _noop(*args, **kwargs):
pass
if sys.version_info >= (3, 8):
@contextmanager
def restore_unraisablehook():
sys.unraisablehook, prev = sys.__unraisablehook__, sys.unraisablehook
try:
yield
finally:
sys.unraisablehook = prev
@contextmanager
def disable_threading_excepthook():
if sys.version_info >= (3, 10):
threading.excepthook, prev = threading.__excepthook__, threading.excepthook
else:
threading.excepthook, prev = _noop, threading.excepthook
try:
yield
finally:
threading.excepthook = prev
else:
@contextmanager
def restore_unraisablehook(): # pragma: no cover
yield
@contextmanager
def disable_threading_excepthook(): # pragma: no cover
yield
# template is like:
# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3]
def check_sequence_matches(seq, template):
i = 0
for pattern in template:
if not isinstance(pattern, set):
pattern = {pattern}
got = set(seq[i : i + len(pattern)])
assert got == pattern
i += len(got)
# https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350
skip_if_fbsd_pipes_broken = pytest.mark.skipif(
sys.platform != "win32" # prevent mypy from complaining about missing uname
and hasattr(os, "uname")
and os.uname().sysname == "FreeBSD"
and os.uname().release[:4] < "12.2",
reason="hangs on FreeBSD 12.1 and earlier, due to FreeBSD bug #246350",
)
def create_asyncio_future_in_new_loop():
with closing(asyncio.new_event_loop()) as loop:
return loop.create_future()