second commit

This commit is contained in:
Александр Геннадьевич Сальный
2022-10-15 21:01:12 +03:00
commit 7caeeaaff5
1329 changed files with 489315 additions and 0 deletions

View File

@@ -0,0 +1,132 @@
"""Trio - A friendly Python library for async concurrency and I/O
"""
# General layout:
#
# trio/_core/... is the self-contained core library. It does various
# shenanigans to export a consistent "core API", but parts of the core API are
# too low-level to be recommended for regular use.
#
# trio/*.py define a set of more usable tools on top of this. They import from
# trio._core and from each other.
#
# This file pulls together the friendly public API, by re-exporting the more
# innocuous bits of the _core API + the higher-level tools from trio/*.py.
from ._version import __version__
from ._core import (
TrioInternalError,
RunFinishedError,
WouldBlock,
Cancelled,
BusyResourceError,
ClosedResourceError,
MultiError,
run,
open_nursery,
CancelScope,
current_effective_deadline,
TASK_STATUS_IGNORED,
current_time,
BrokenResourceError,
EndOfChannel,
Nursery,
)
from ._timeouts import (
move_on_at,
move_on_after,
sleep_forever,
sleep_until,
sleep,
fail_at,
fail_after,
TooSlowError,
)
from ._sync import (
Event,
CapacityLimiter,
Semaphore,
Lock,
StrictFIFOLock,
Condition,
)
from ._highlevel_generic import aclose_forcefully, StapledStream
from ._channel import (
open_memory_channel,
MemorySendChannel,
MemoryReceiveChannel,
)
from ._signals import open_signal_receiver
from ._highlevel_socket import SocketStream, SocketListener
from ._file_io import open_file, wrap_file
from ._path import Path
from ._subprocess import Process, run_process
from ._ssl import SSLStream, SSLListener, NeedHandshakeError
from ._highlevel_serve_listeners import serve_listeners
from ._highlevel_open_tcp_stream import open_tcp_stream
from ._highlevel_open_tcp_listeners import open_tcp_listeners, serve_tcp
from ._highlevel_open_unix_stream import open_unix_socket
from ._highlevel_ssl_helpers import (
open_ssl_over_tcp_stream,
open_ssl_over_tcp_listeners,
serve_ssl_over_tcp,
)
from ._deprecate import TrioDeprecationWarning
# Submodules imported by default
from . import lowlevel
from . import socket
from . import abc
from . import from_thread
from . import to_thread
# Not imported by default, but mentioned here so static analysis tools like
# pylint will know that it exists.
if False:
from . import testing
from . import _deprecate
_deprecate.enable_attribute_deprecations(__name__)
__deprecated_attributes__ = {
"open_process": _deprecate.DeprecatedAttribute(
value=lowlevel.open_process,
version="0.20.0",
issue=1104,
instead="trio.lowlevel.open_process",
),
}
# Having the public path in .__module__ attributes is important for:
# - exception names in printed tracebacks
# - sphinx :show-inheritance:
# - deprecation warnings
# - pickle
# - probably other stuff
from ._util import fixup_module_metadata
fixup_module_metadata(__name__, globals())
fixup_module_metadata(lowlevel.__name__, lowlevel.__dict__)
fixup_module_metadata(socket.__name__, socket.__dict__)
fixup_module_metadata(abc.__name__, abc.__dict__)
fixup_module_metadata(from_thread.__name__, from_thread.__dict__)
fixup_module_metadata(to_thread.__name__, to_thread.__dict__)
del fixup_module_metadata

View File

@@ -0,0 +1,652 @@
# coding: utf-8
from abc import ABCMeta, abstractmethod
from typing import Generic, TypeVar
import trio
# We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a
# __dict__ onto subclasses.
class Clock(metaclass=ABCMeta):
"""The interface for custom run loop clocks."""
__slots__ = ()
@abstractmethod
def start_clock(self):
"""Do any setup this clock might need.
Called at the beginning of the run.
"""
@abstractmethod
def current_time(self):
"""Return the current time, according to this clock.
This is used to implement functions like :func:`trio.current_time` and
:func:`trio.move_on_after`.
Returns:
float: The current time.
"""
@abstractmethod
def deadline_to_sleep_time(self, deadline):
"""Compute the real time until the given deadline.
This is called before we enter a system-specific wait function like
:func:`select.select`, to get the timeout to pass.
For a clock using wall-time, this should be something like::
return deadline - self.current_time()
but of course it may be different if you're implementing some kind of
virtual clock.
Args:
deadline (float): The absolute time of the next deadline,
according to this clock.
Returns:
float: The number of real seconds to sleep until the given
deadline. May be :data:`math.inf`.
"""
class Instrument(metaclass=ABCMeta):
"""The interface for run loop instrumentation.
Instruments don't have to inherit from this abstract base class, and all
of these methods are optional. This class serves mostly as documentation.
"""
__slots__ = ()
def before_run(self):
"""Called at the beginning of :func:`trio.run`."""
def after_run(self):
"""Called just before :func:`trio.run` returns."""
def task_spawned(self, task):
"""Called when the given task is created.
Args:
task (trio.lowlevel.Task): The new task.
"""
def task_scheduled(self, task):
"""Called when the given task becomes runnable.
It may still be some time before it actually runs, if there are other
runnable tasks ahead of it.
Args:
task (trio.lowlevel.Task): The task that became runnable.
"""
def before_task_step(self, task):
"""Called immediately before we resume running the given task.
Args:
task (trio.lowlevel.Task): The task that is about to run.
"""
def after_task_step(self, task):
"""Called when we return to the main run loop after a task has yielded.
Args:
task (trio.lowlevel.Task): The task that just ran.
"""
def task_exited(self, task):
"""Called when the given task exits.
Args:
task (trio.lowlevel.Task): The finished task.
"""
def before_io_wait(self, timeout):
"""Called before blocking to wait for I/O readiness.
Args:
timeout (float): The number of seconds we are willing to wait.
"""
def after_io_wait(self, timeout):
"""Called after handling pending I/O.
Args:
timeout (float): The number of seconds we were willing to
wait. This much time may or may not have elapsed, depending on
whether any I/O was ready.
"""
class HostnameResolver(metaclass=ABCMeta):
"""If you have a custom hostname resolver, then implementing
:class:`HostnameResolver` allows you to register this to be used by Trio.
See :func:`trio.socket.set_custom_hostname_resolver`.
"""
__slots__ = ()
@abstractmethod
async def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0):
"""A custom implementation of :func:`~trio.socket.getaddrinfo`.
Called by :func:`trio.socket.getaddrinfo`.
If ``host`` is given as a numeric IP address, then
:func:`~trio.socket.getaddrinfo` may handle the request itself rather
than calling this method.
Any required IDNA encoding is handled before calling this function;
your implementation can assume that it will never see U-labels like
``"café.com"``, and only needs to handle A-labels like
``b"xn--caf-dma.com"``.
"""
@abstractmethod
async def getnameinfo(self, sockaddr, flags):
"""A custom implementation of :func:`~trio.socket.getnameinfo`.
Called by :func:`trio.socket.getnameinfo`.
"""
class SocketFactory(metaclass=ABCMeta):
"""If you write a custom class implementing the Trio socket interface,
then you can use a :class:`SocketFactory` to get Trio to use it.
See :func:`trio.socket.set_custom_socket_factory`.
"""
@abstractmethod
def socket(self, family=None, type=None, proto=None):
"""Create and return a socket object.
Your socket object must inherit from :class:`trio.socket.SocketType`,
which is an empty class whose only purpose is to "mark" which classes
should be considered valid Trio sockets.
Called by :func:`trio.socket.socket`.
Note that unlike :func:`trio.socket.socket`, this does not take a
``fileno=`` argument. If a ``fileno=`` is specified, then
:func:`trio.socket.socket` returns a regular Trio socket object
instead of calling this method.
"""
class AsyncResource(metaclass=ABCMeta):
"""A standard interface for resources that needs to be cleaned up, and
where that cleanup may require blocking operations.
This class distinguishes between "graceful" closes, which may perform I/O
and thus block, and a "forceful" close, which cannot. For example, cleanly
shutting down a TLS-encrypted connection requires sending a "goodbye"
message; but if a peer has become non-responsive, then sending this
message might block forever, so we may want to just drop the connection
instead. Therefore the :meth:`aclose` method is unusual in that it
should always close the connection (or at least make its best attempt)
*even if it fails*; failure indicates a failure to achieve grace, not a
failure to close the connection.
Objects that implement this interface can be used as async context
managers, i.e., you can write::
async with create_resource() as some_async_resource:
...
Entering the context manager is synchronous (not a checkpoint); exiting it
calls :meth:`aclose`. The default implementations of
``__aenter__`` and ``__aexit__`` should be adequate for all subclasses.
"""
__slots__ = ()
@abstractmethod
async def aclose(self):
"""Close this resource, possibly blocking.
IMPORTANT: This method may block in order to perform a "graceful"
shutdown. But, if this fails, then it still *must* close any
underlying resources before returning. An error from this method
indicates a failure to achieve grace, *not* a failure to close the
connection.
For example, suppose we call :meth:`aclose` on a TLS-encrypted
connection. This requires sending a "goodbye" message; but if the peer
has become non-responsive, then our attempt to send this message might
block forever, and eventually time out and be cancelled. In this case
the :meth:`aclose` method on :class:`~trio.SSLStream` will
immediately close the underlying transport stream using
:func:`trio.aclose_forcefully` before raising :exc:`~trio.Cancelled`.
If the resource is already closed, then this method should silently
succeed.
Once this method completes, any other pending or future operations on
this resource should generally raise :exc:`~trio.ClosedResourceError`,
unless there's a good reason to do otherwise.
See also: :func:`trio.aclose_forcefully`.
"""
async def __aenter__(self):
return self
async def __aexit__(self, *args):
await self.aclose()
class SendStream(AsyncResource):
"""A standard interface for sending data on a byte stream.
The underlying stream may be unidirectional, or bidirectional. If it's
bidirectional, then you probably want to also implement
:class:`ReceiveStream`, which makes your object a :class:`Stream`.
:class:`SendStream` objects also implement the :class:`AsyncResource`
interface, so they can be closed by calling :meth:`~AsyncResource.aclose`
or using an ``async with`` block.
If you want to send Python objects rather than raw bytes, see
:class:`SendChannel`.
"""
__slots__ = ()
@abstractmethod
async def send_all(self, data):
"""Sends the given data through the stream, blocking if necessary.
Args:
data (bytes, bytearray, or memoryview): The data to send.
Raises:
trio.BusyResourceError: if another task is already executing a
:meth:`send_all`, :meth:`wait_send_all_might_not_block`, or
:meth:`HalfCloseableStream.send_eof` on this stream.
trio.BrokenResourceError: if something has gone wrong, and the stream
is broken.
trio.ClosedResourceError: if you previously closed this stream
object, or if another task closes this stream object while
:meth:`send_all` is running.
Most low-level operations in Trio provide a guarantee: if they raise
:exc:`trio.Cancelled`, this means that they had no effect, so the
system remains in a known state. This is **not true** for
:meth:`send_all`. If this operation raises :exc:`trio.Cancelled` (or
any other exception for that matter), then it may have sent some, all,
or none of the requested data, and there is no way to know which.
"""
@abstractmethod
async def wait_send_all_might_not_block(self):
"""Block until it's possible that :meth:`send_all` might not block.
This method may return early: it's possible that after it returns,
:meth:`send_all` will still block. (In the worst case, if no better
implementation is available, then it might always return immediately
without blocking. It's nice to do better than that when possible,
though.)
This method **must not** return *late*: if it's possible for
:meth:`send_all` to complete without blocking, then it must
return. When implementing it, err on the side of returning early.
Raises:
trio.BusyResourceError: if another task is already executing a
:meth:`send_all`, :meth:`wait_send_all_might_not_block`, or
:meth:`HalfCloseableStream.send_eof` on this stream.
trio.BrokenResourceError: if something has gone wrong, and the stream
is broken.
trio.ClosedResourceError: if you previously closed this stream
object, or if another task closes this stream object while
:meth:`wait_send_all_might_not_block` is running.
Note:
This method is intended to aid in implementing protocols that want
to delay choosing which data to send until the last moment. E.g.,
suppose you're working on an implementation of a remote display server
like `VNC
<https://en.wikipedia.org/wiki/Virtual_Network_Computing>`__, and
the network connection is currently backed up so that if you call
:meth:`send_all` now then it will sit for 0.5 seconds before actually
sending anything. In this case it doesn't make sense to take a
screenshot, then wait 0.5 seconds, and then send it, because the
screen will keep changing while you wait; it's better to wait 0.5
seconds, then take the screenshot, and then send it, because this
way the data you deliver will be more
up-to-date. Using :meth:`wait_send_all_might_not_block` makes it
possible to implement the better strategy.
If you use this method, you might also want to read up on
``TCP_NOTSENT_LOWAT``.
Further reading:
* `Prioritization Only Works When There's Pending Data to Prioritize
<https://insouciant.org/tech/prioritization-only-works-when-theres-pending-data-to-prioritize/>`__
* WWDC 2015: Your App and Next Generation Networks: `slides
<http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1>`__,
`video and transcript
<https://developer.apple.com/videos/play/wwdc2015/719/>`__
"""
class ReceiveStream(AsyncResource):
"""A standard interface for receiving data on a byte stream.
The underlying stream may be unidirectional, or bidirectional. If it's
bidirectional, then you probably want to also implement
:class:`SendStream`, which makes your object a :class:`Stream`.
:class:`ReceiveStream` objects also implement the :class:`AsyncResource`
interface, so they can be closed by calling :meth:`~AsyncResource.aclose`
or using an ``async with`` block.
If you want to receive Python objects rather than raw bytes, see
:class:`ReceiveChannel`.
`ReceiveStream` objects can be used in ``async for`` loops. Each iteration
will produce an arbitrary sized chunk of bytes, like calling
`receive_some` with no arguments. Every chunk will contain at least one
byte, and the loop automatically exits when reaching end-of-file.
"""
__slots__ = ()
@abstractmethod
async def receive_some(self, max_bytes=None):
"""Wait until there is data available on this stream, and then return
some of it.
A return value of ``b""`` (an empty bytestring) indicates that the
stream has reached end-of-file. Implementations should be careful that
they return ``b""`` if, and only if, the stream has reached
end-of-file!
Args:
max_bytes (int): The maximum number of bytes to return. Must be
greater than zero. Optional; if omitted, then the stream object
is free to pick a reasonable default.
Returns:
bytes or bytearray: The data received.
Raises:
trio.BusyResourceError: if two tasks attempt to call
:meth:`receive_some` on the same stream at the same time.
trio.BrokenResourceError: if something has gone wrong, and the stream
is broken.
trio.ClosedResourceError: if you previously closed this stream
object, or if another task closes this stream object while
:meth:`receive_some` is running.
"""
def __aiter__(self):
return self
async def __anext__(self):
data = await self.receive_some()
if not data:
raise StopAsyncIteration
return data
class Stream(SendStream, ReceiveStream):
"""A standard interface for interacting with bidirectional byte streams.
A :class:`Stream` is an object that implements both the
:class:`SendStream` and :class:`ReceiveStream` interfaces.
If implementing this interface, you should consider whether you can go one
step further and implement :class:`HalfCloseableStream`.
"""
__slots__ = ()
class HalfCloseableStream(Stream):
"""This interface extends :class:`Stream` to also allow closing the send
part of the stream without closing the receive part.
"""
__slots__ = ()
@abstractmethod
async def send_eof(self):
"""Send an end-of-file indication on this stream, if possible.
The difference between :meth:`send_eof` and
:meth:`~AsyncResource.aclose` is that :meth:`send_eof` is a
*unidirectional* end-of-file indication. After you call this method,
you shouldn't try sending any more data on this stream, and your
remote peer should receive an end-of-file indication (eventually,
after receiving all the data you sent before that). But, they may
continue to send data to you, and you can continue to receive it by
calling :meth:`~ReceiveStream.receive_some`. You can think of it as
calling :meth:`~AsyncResource.aclose` on just the
:class:`SendStream` "half" of the stream object (and in fact that's
literally how :class:`trio.StapledStream` implements it).
Examples:
* On a socket, this corresponds to ``shutdown(..., SHUT_WR)`` (`man
page <https://linux.die.net/man/2/shutdown>`__).
* The SSH protocol provides the ability to multiplex bidirectional
"channels" on top of a single encrypted connection. A Trio
implementation of SSH could expose these channels as
:class:`HalfCloseableStream` objects, and calling :meth:`send_eof`
would send an ``SSH_MSG_CHANNEL_EOF`` request (see `RFC 4254 §5.3
<https://tools.ietf.org/html/rfc4254#section-5.3>`__).
* On an SSL/TLS-encrypted connection, the protocol doesn't provide any
way to do a unidirectional shutdown without closing the connection
entirely, so :class:`~trio.SSLStream` implements
:class:`Stream`, not :class:`HalfCloseableStream`.
If an EOF has already been sent, then this method should silently
succeed.
Raises:
trio.BusyResourceError: if another task is already executing a
:meth:`~SendStream.send_all`,
:meth:`~SendStream.wait_send_all_might_not_block`, or
:meth:`send_eof` on this stream.
trio.BrokenResourceError: if something has gone wrong, and the stream
is broken.
trio.ClosedResourceError: if you previously closed this stream
object, or if another task closes this stream object while
:meth:`send_eof` is running.
"""
# A regular invariant generic type
T = TypeVar("T")
# The type of object produced by a ReceiveChannel (covariant because
# ReceiveChannel[Derived] can be passed to someone expecting
# ReceiveChannel[Base])
ReceiveType = TypeVar("ReceiveType", covariant=True)
# The type of object accepted by a SendChannel (contravariant because
# SendChannel[Base] can be passed to someone expecting
# SendChannel[Derived])
SendType = TypeVar("SendType", contravariant=True)
# The type of object produced by a Listener (covariant plus must be
# an AsyncResource)
T_resource = TypeVar("T_resource", bound=AsyncResource, covariant=True)
class Listener(AsyncResource, Generic[T_resource]):
"""A standard interface for listening for incoming connections.
:class:`Listener` objects also implement the :class:`AsyncResource`
interface, so they can be closed by calling :meth:`~AsyncResource.aclose`
or using an ``async with`` block.
"""
__slots__ = ()
@abstractmethod
async def accept(self):
"""Wait until an incoming connection arrives, and then return it.
Returns:
AsyncResource: An object representing the incoming connection. In
practice this is generally some kind of :class:`Stream`,
but in principle you could also define a :class:`Listener` that
returned, say, channel objects.
Raises:
trio.BusyResourceError: if two tasks attempt to call
:meth:`accept` on the same listener at the same time.
trio.ClosedResourceError: if you previously closed this listener
object, or if another task closes this listener object while
:meth:`accept` is running.
Listeners don't generally raise :exc:`~trio.BrokenResourceError`,
because for listeners there is no general condition of "the
network/remote peer broke the connection" that can be handled in a
generic way, like there is for streams. Other errors *can* occur and
be raised from :meth:`accept` for example, if you run out of file
descriptors then you might get an :class:`OSError` with its errno set
to ``EMFILE``.
"""
class SendChannel(AsyncResource, Generic[SendType]):
"""A standard interface for sending Python objects to some receiver.
`SendChannel` objects also implement the `AsyncResource` interface, so
they can be closed by calling `~AsyncResource.aclose` or using an ``async
with`` block.
If you want to send raw bytes rather than Python objects, see
`SendStream`.
"""
__slots__ = ()
@abstractmethod
async def send(self, value: SendType) -> None:
"""Attempt to send an object through the channel, blocking if necessary.
Args:
value (object): The object to send.
Raises:
trio.BrokenResourceError: if something has gone wrong, and the
channel is broken. For example, you may get this if the receiver
has already been closed.
trio.ClosedResourceError: if you previously closed this
:class:`SendChannel` object, or if another task closes it while
:meth:`send` is running.
trio.BusyResourceError: some channels allow multiple tasks to call
`send` at the same time, but others don't. If you try to call
`send` simultaneously from multiple tasks on a channel that
doesn't support it, then you can get `~trio.BusyResourceError`.
"""
class ReceiveChannel(AsyncResource, Generic[ReceiveType]):
"""A standard interface for receiving Python objects from some sender.
You can iterate over a :class:`ReceiveChannel` using an ``async for``
loop::
async for value in receive_channel:
...
This is equivalent to calling :meth:`receive` repeatedly. The loop exits
without error when `receive` raises `~trio.EndOfChannel`.
`ReceiveChannel` objects also implement the `AsyncResource` interface, so
they can be closed by calling `~AsyncResource.aclose` or using an ``async
with`` block.
If you want to receive raw bytes rather than Python objects, see
`ReceiveStream`.
"""
__slots__ = ()
@abstractmethod
async def receive(self) -> ReceiveType:
"""Attempt to receive an incoming object, blocking if necessary.
Returns:
object: Whatever object was received.
Raises:
trio.EndOfChannel: if the sender has been closed cleanly, and no
more objects are coming. This is not an error condition.
trio.ClosedResourceError: if you previously closed this
:class:`ReceiveChannel` object.
trio.BrokenResourceError: if something has gone wrong, and the
channel is broken.
trio.BusyResourceError: some channels allow multiple tasks to call
`receive` at the same time, but others don't. If you try to call
`receive` simultaneously from multiple tasks on a channel that
doesn't support it, then you can get `~trio.BusyResourceError`.
"""
def __aiter__(self):
return self
async def __anext__(self) -> ReceiveType:
try:
return await self.receive()
except trio.EndOfChannel:
raise StopAsyncIteration
class Channel(SendChannel[T], ReceiveChannel[T]):
"""A standard interface for interacting with bidirectional channels.
A `Channel` is an object that implements both the `SendChannel` and
`ReceiveChannel` interfaces, so you can both send and receive objects.
"""

View File

@@ -0,0 +1,386 @@
from collections import deque, OrderedDict
from math import inf
import attr
from outcome import Error, Value
from .abc import SendChannel, ReceiveChannel, Channel
from ._util import generic_function, NoPublicConstructor
import trio
from ._core import enable_ki_protection
@generic_function
def open_memory_channel(max_buffer_size):
"""Open a channel for passing objects between tasks within a process.
Memory channels are lightweight, cheap to allocate, and entirely
in-memory. They don't involve any operating-system resources, or any kind
of serialization. They just pass Python objects directly between tasks
(with a possible stop in an internal buffer along the way).
Channel objects can be closed by calling `~trio.abc.AsyncResource.aclose`
or using ``async with``. They are *not* automatically closed when garbage
collected. Closing memory channels isn't mandatory, but it is generally a
good idea, because it helps avoid situations where tasks get stuck waiting
on a channel when there's no-one on the other side. See
:ref:`channel-shutdown` for details.
Memory channel operations are all atomic with respect to
cancellation, either `~trio.abc.ReceiveChannel.receive` will
successfully return an object, or it will raise :exc:`Cancelled`
while leaving the channel unchanged.
Args:
max_buffer_size (int or math.inf): The maximum number of items that can
be buffered in the channel before :meth:`~trio.abc.SendChannel.send`
blocks. Choosing a sensible value here is important to ensure that
backpressure is communicated promptly and avoid unnecessary latency;
see :ref:`channel-buffering` for more details. If in doubt, use 0.
Returns:
A pair ``(send_channel, receive_channel)``. If you have
trouble remembering which order these go in, remember: data
flows from left → right.
In addition to the standard channel methods, all memory channel objects
provide a ``statistics()`` method, which returns an object with the
following fields:
* ``current_buffer_used``: The number of items currently stored in the
channel buffer.
* ``max_buffer_size``: The maximum number of items allowed in the buffer,
as passed to :func:`open_memory_channel`.
* ``open_send_channels``: The number of open
:class:`MemorySendChannel` endpoints pointing to this channel.
Initially 1, but can be increased by
:meth:`MemorySendChannel.clone`.
* ``open_receive_channels``: Likewise, but for open
:class:`MemoryReceiveChannel` endpoints.
* ``tasks_waiting_send``: The number of tasks blocked in ``send`` on this
channel (summing over all clones).
* ``tasks_waiting_receive``: The number of tasks blocked in ``receive`` on
this channel (summing over all clones).
"""
if max_buffer_size != inf and not isinstance(max_buffer_size, int):
raise TypeError("max_buffer_size must be an integer or math.inf")
if max_buffer_size < 0:
raise ValueError("max_buffer_size must be >= 0")
state = MemoryChannelState(max_buffer_size)
return (
MemorySendChannel._create(state),
MemoryReceiveChannel._create(state),
)
@attr.s(frozen=True, slots=True)
class MemoryChannelStats:
current_buffer_used = attr.ib()
max_buffer_size = attr.ib()
open_send_channels = attr.ib()
open_receive_channels = attr.ib()
tasks_waiting_send = attr.ib()
tasks_waiting_receive = attr.ib()
@attr.s(slots=True)
class MemoryChannelState:
max_buffer_size = attr.ib()
data = attr.ib(factory=deque)
# Counts of open endpoints using this state
open_send_channels = attr.ib(default=0)
open_receive_channels = attr.ib(default=0)
# {task: value}
send_tasks = attr.ib(factory=OrderedDict)
# {task: None}
receive_tasks = attr.ib(factory=OrderedDict)
def statistics(self):
return MemoryChannelStats(
current_buffer_used=len(self.data),
max_buffer_size=self.max_buffer_size,
open_send_channels=self.open_send_channels,
open_receive_channels=self.open_receive_channels,
tasks_waiting_send=len(self.send_tasks),
tasks_waiting_receive=len(self.receive_tasks),
)
@attr.s(eq=False, repr=False)
class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
# This is just the tasks waiting on *this* object. As compared to
# self._state.send_tasks, which includes tasks from this object and
# all clones.
_tasks = attr.ib(factory=set)
def __attrs_post_init__(self):
self._state.open_send_channels += 1
def __repr__(self):
return "<send channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)
def statistics(self):
# XX should we also report statistics specific to this object?
return self._state.statistics()
@enable_ki_protection
def send_nowait(self, value):
"""Like `~trio.abc.SendChannel.send`, but if the channel's buffer is
full, raises `WouldBlock` instead of blocking.
"""
if self._closed:
raise trio.ClosedResourceError
if self._state.open_receive_channels == 0:
raise trio.BrokenResourceError
if self._state.receive_tasks:
assert not self._state.data
task, _ = self._state.receive_tasks.popitem(last=False)
task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task, Value(value))
elif len(self._state.data) < self._state.max_buffer_size:
self._state.data.append(value)
else:
raise trio.WouldBlock
@enable_ki_protection
async def send(self, value):
"""See `SendChannel.send <trio.abc.SendChannel.send>`.
Memory channels allow multiple tasks to call `send` at the same time.
"""
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.send_nowait(value)
except trio.WouldBlock:
pass
else:
await trio.lowlevel.cancel_shielded_checkpoint()
return
task = trio.lowlevel.current_task()
self._tasks.add(task)
self._state.send_tasks[task] = value
task.custom_sleep_data = self
def abort_fn(_):
self._tasks.remove(task)
del self._state.send_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED
await trio.lowlevel.wait_task_rescheduled(abort_fn)
@enable_ki_protection
def clone(self):
"""Clone this send channel object.
This returns a new `MemorySendChannel` object, which acts as a
duplicate of the original: sending on the new object does exactly the
same thing as sending on the old object. (If you're familiar with
`os.dup`, then this is a similar idea.)
However, closing one of the objects does not close the other, and
receivers don't get `EndOfChannel` until *all* clones have been
closed.
This is useful for communication patterns that involve multiple
producers all sending objects to the same destination. If you give
each producer its own clone of the `MemorySendChannel`, and then make
sure to close each `MemorySendChannel` when it's finished, receivers
will automatically get notified when all producers are finished. See
:ref:`channel-mpmc` for examples.
Raises:
trio.ClosedResourceError: if you already closed this
`MemorySendChannel` object.
"""
if self._closed:
raise trio.ClosedResourceError
return MemorySendChannel._create(self._state)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
@enable_ki_protection
def close(self):
"""Close this send channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Memory channels can also be closed synchronously. This has the same
effect on the channel and other tasks using it, but `close` is not a
trio checkpoint. This simplifies cleaning up in cancelled tasks.
Using ``with send_channel:`` will close the channel object on leaving
the with block.
"""
if self._closed:
return
self._closed = True
for task in self._tasks:
trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError()))
del self._state.send_tasks[task]
self._tasks.clear()
self._state.open_send_channels -= 1
if self._state.open_send_channels == 0:
assert not self._state.send_tasks
for task in self._state.receive_tasks:
task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task, Error(trio.EndOfChannel()))
self._state.receive_tasks.clear()
@enable_ki_protection
async def aclose(self):
self.close()
await trio.lowlevel.checkpoint()
@attr.s(eq=False, repr=False)
class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor):
_state = attr.ib()
_closed = attr.ib(default=False)
_tasks = attr.ib(factory=set)
def __attrs_post_init__(self):
self._state.open_receive_channels += 1
def statistics(self):
return self._state.statistics()
def __repr__(self):
return "<receive channel at {:#x}, using buffer at {:#x}>".format(
id(self), id(self._state)
)
@enable_ki_protection
def receive_nowait(self):
"""Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing
ready to receive, raises `WouldBlock` instead of blocking.
"""
if self._closed:
raise trio.ClosedResourceError
if self._state.send_tasks:
task, value = self._state.send_tasks.popitem(last=False)
task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task)
self._state.data.append(value)
# Fall through
if self._state.data:
return self._state.data.popleft()
if not self._state.open_send_channels:
raise trio.EndOfChannel
raise trio.WouldBlock
@enable_ki_protection
async def receive(self):
"""See `ReceiveChannel.receive <trio.abc.ReceiveChannel.receive>`.
Memory channels allow multiple tasks to call `receive` at the same
time. The first task will get the first item sent, the second task
will get the second item sent, and so on.
"""
await trio.lowlevel.checkpoint_if_cancelled()
try:
value = self.receive_nowait()
except trio.WouldBlock:
pass
else:
await trio.lowlevel.cancel_shielded_checkpoint()
return value
task = trio.lowlevel.current_task()
self._tasks.add(task)
self._state.receive_tasks[task] = None
task.custom_sleep_data = self
def abort_fn(_):
self._tasks.remove(task)
del self._state.receive_tasks[task]
return trio.lowlevel.Abort.SUCCEEDED
return await trio.lowlevel.wait_task_rescheduled(abort_fn)
@enable_ki_protection
def clone(self):
"""Clone this receive channel object.
This returns a new `MemoryReceiveChannel` object, which acts as a
duplicate of the original: receiving on the new object does exactly
the same thing as receiving on the old object.
However, closing one of the objects does not close the other, and the
underlying channel is not closed until all clones are closed. (If
you're familiar with `os.dup`, then this is a similar idea.)
This is useful for communication patterns that involve multiple
consumers all receiving objects from the same underlying channel. See
:ref:`channel-mpmc` for examples.
.. warning:: The clones all share the same underlying channel.
Whenever a clone :meth:`receive`\\s a value, it is removed from the
channel and the other clones do *not* receive that value. If you
want to send multiple copies of the same stream of values to
multiple destinations, like :func:`itertools.tee`, then you need to
find some other solution; this method does *not* do that.
Raises:
trio.ClosedResourceError: if you already closed this
`MemoryReceiveChannel` object.
"""
if self._closed:
raise trio.ClosedResourceError
return MemoryReceiveChannel._create(self._state)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
@enable_ki_protection
def close(self):
"""Close this receive channel object synchronously.
All channel objects have an asynchronous `~.AsyncResource.aclose` method.
Memory channels can also be closed synchronously. This has the same
effect on the channel and other tasks using it, but `close` is not a
trio checkpoint. This simplifies cleaning up in cancelled tasks.
Using ``with receive_channel:`` will close the channel object on
leaving the with block.
"""
if self._closed:
return
self._closed = True
for task in self._tasks:
trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError()))
del self._state.receive_tasks[task]
self._tasks.clear()
self._state.open_receive_channels -= 1
if self._state.open_receive_channels == 0:
assert not self._state.receive_tasks
for task in self._state.send_tasks:
task.custom_sleep_data._tasks.remove(task)
trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError()))
self._state.send_tasks.clear()
self._state.data.clear()
@enable_ki_protection
async def aclose(self):
self.close()
await trio.lowlevel.checkpoint()

View File

@@ -0,0 +1,92 @@
"""
This namespace represents the core functionality that has to be built-in
and deal with private internal data structures. Things in this namespace
are publicly available in either trio, trio.lowlevel, or trio.testing.
"""
import sys
from ._exceptions import (
TrioInternalError,
RunFinishedError,
WouldBlock,
Cancelled,
BusyResourceError,
ClosedResourceError,
BrokenResourceError,
EndOfChannel,
)
from ._multierror import MultiError
from ._ki import (
enable_ki_protection,
disable_ki_protection,
currently_ki_protected,
)
# Imports that always exist
from ._run import (
Task,
CancelScope,
run,
open_nursery,
checkpoint,
current_task,
current_effective_deadline,
checkpoint_if_cancelled,
TASK_STATUS_IGNORED,
current_statistics,
current_trio_token,
reschedule,
remove_instrument,
add_instrument,
current_clock,
current_root_task,
spawn_system_task,
current_time,
wait_all_tasks_blocked,
wait_readable,
wait_writable,
notify_closing,
Nursery,
start_guest_run,
)
# Has to come after _run to resolve a circular import
from ._traps import (
cancel_shielded_checkpoint,
Abort,
wait_task_rescheduled,
temporarily_detach_coroutine_object,
permanently_detach_coroutine_object,
reattach_detached_coroutine_object,
)
from ._entry_queue import TrioToken
from ._parking_lot import ParkingLot
from ._unbounded_queue import UnboundedQueue
from ._local import RunVar
from ._thread_cache import start_thread_soon
from ._mock_clock import MockClock
# Windows imports
if sys.platform == "win32":
from ._run import (
monitor_completion_key,
current_iocp,
register_with_iocp,
wait_overlapped,
write_overlapped,
readinto_overlapped,
)
# Kqueue imports
elif sys.platform != "linux" and sys.platform != "win32":
from ._run import current_kqueue, monitor_kevent, wait_kevent
del sys # It would be better to import sys as _sys, but mypy does not understand it

View File

@@ -0,0 +1,193 @@
import attr
import logging
import sys
import warnings
import weakref
from .._util import name_asyncgen
from . import _run
from .. import _core
# Used to log exceptions in async generator finalizers
ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors")
@attr.s(eq=False, slots=True)
class AsyncGenerators:
# Async generators are added to this set when first iterated. Any
# left after the main task exits will be closed before trio.run()
# returns. During most of the run, this is a WeakSet so GC works.
# During shutdown, when we're finalizing all the remaining
# asyncgens after the system nursery has been closed, it's a
# regular set so we don't have to deal with GC firing at
# unexpected times.
alive = attr.ib(factory=weakref.WeakSet)
# This collects async generators that get garbage collected during
# the one-tick window between the system nursery closing and the
# init task starting end-of-run asyncgen finalization.
trailing_needs_finalize = attr.ib(factory=set)
prev_hooks = attr.ib(init=False)
def install_hooks(self, runner):
def firstiter(agen):
if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"):
self.alive.add(agen)
else:
# An async generator first iterated outside of a Trio
# task doesn't belong to Trio. Probably we're in guest
# mode and the async generator belongs to our host.
# The locals dictionary is the only good place to
# remember this fact, at least until
# https://bugs.python.org/issue40916 is implemented.
agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True
if self.prev_hooks.firstiter is not None:
self.prev_hooks.firstiter(agen)
def finalize_in_trio_context(agen, agen_name):
try:
runner.spawn_system_task(
self._finalize_one,
agen,
agen_name,
name=f"close asyncgen {agen_name} (abandoned)",
)
except RuntimeError:
# There is a one-tick window where the system nursery
# is closed but the init task hasn't yet made
# self.asyncgens a strong set to disable GC. We seem to
# have hit it.
self.trailing_needs_finalize.add(agen)
def finalizer(agen):
agen_name = name_asyncgen(agen)
try:
is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen")
except AttributeError: # pragma: no cover
is_ours = True
if is_ours:
runner.entry_queue.run_sync_soon(
finalize_in_trio_context, agen, agen_name
)
# Do this last, because it might raise an exception
# depending on the user's warnings filter. (That
# exception will be printed to the terminal and
# ignored, since we're running in GC context.)
warnings.warn(
f"Async generator {agen_name!r} was garbage collected before it "
f"had been exhausted. Surround its use in 'async with "
f"aclosing(...):' to ensure that it gets cleaned up as soon as "
f"you're done using it.",
ResourceWarning,
stacklevel=2,
source=agen,
)
else:
# Not ours -> forward to the host loop's async generator finalizer
if self.prev_hooks.finalizer is not None:
self.prev_hooks.finalizer(agen)
else:
# Host has no finalizer. Reimplement the default
# Python behavior with no hooks installed: throw in
# GeneratorExit, step once, raise RuntimeError if
# it doesn't exit.
closer = agen.aclose()
try:
# If the next thing is a yield, this will raise RuntimeError
# which we allow to propagate
closer.send(None)
except StopIteration:
pass
else:
# If the next thing is an await, we get here. Give a nicer
# error than the default "async generator ignored GeneratorExit"
raise RuntimeError(
f"Non-Trio async generator {agen_name!r} awaited something "
f"during finalization; install a finalization hook to "
f"support this, or wrap it in 'async with aclosing(...):'"
)
self.prev_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer)
async def finalize_remaining(self, runner):
# This is called from init after shutting down the system nursery.
# The only tasks running at this point are init and
# the run_sync_soon task, and since the system nursery is closed,
# there's no way for user code to spawn more.
assert _core.current_task() is runner.init_task
assert len(runner.tasks) == 2
# To make async generator finalization easier to reason
# about, we'll shut down asyncgen garbage collection by turning
# the alive WeakSet into a regular set.
self.alive = set(self.alive)
# Process all pending run_sync_soon callbacks, in case one of
# them was an asyncgen finalizer that snuck in under the wire.
runner.entry_queue.run_sync_soon(runner.reschedule, runner.init_task)
await _core.wait_task_rescheduled(
lambda _: _core.Abort.FAILED # pragma: no cover
)
self.alive.update(self.trailing_needs_finalize)
self.trailing_needs_finalize.clear()
# None of the still-living tasks use async generators, so
# every async generator must be suspended at a yield point --
# there's no one to be doing the iteration. That's good,
# because aclose() only works on an asyncgen that's suspended
# at a yield point. (If it's suspended at an event loop trap,
# because someone is in the middle of iterating it, then you
# get a RuntimeError on 3.8+, and a nasty surprise on earlier
# versions due to https://bugs.python.org/issue32526.)
#
# However, once we start aclose() of one async generator, it
# might start fetching the next value from another, thus
# preventing us from closing that other (at least until
# aclose() of the first one is complete). This constraint
# effectively requires us to finalize the remaining asyncgens
# in arbitrary order, rather than doing all of them at the
# same time. On 3.8+ we could defer any generator with
# ag_running=True to a later batch, but that only catches
# the case where our aclose() starts after the user's
# asend()/etc. If our aclose() starts first, then the
# user's asend()/etc will raise RuntimeError, since they're
# probably not checking ag_running.
#
# It might be possible to allow some parallelized cleanup if
# we can determine that a certain set of asyncgens have no
# interdependencies, using gc.get_referents() and such.
# But just doing one at a time will typically work well enough
# (since each aclose() executes in a cancelled scope) and
# is much easier to reason about.
# It's possible that that cleanup code will itself create
# more async generators, so we iterate repeatedly until
# all are gone.
while self.alive:
batch = self.alive
self.alive = set()
for agen in batch:
await self._finalize_one(agen, name_asyncgen(agen))
def close(self):
sys.set_asyncgen_hooks(*self.prev_hooks)
async def _finalize_one(self, agen, name):
try:
# This shield ensures that finalize_asyncgen never exits
# with an exception, not even a Cancelled. The inside
# is cancelled so there's no deadlock risk.
with _core.CancelScope(shield=True) as cancel_scope:
cancel_scope.cancel()
await agen.aclose()
except BaseException:
ASYNCGEN_LOGGER.exception(
"Exception ignored during finalization of async generator %r -- "
"surround your use of the generator in 'async with aclosing(...):' "
"to raise exceptions like this in the context where they're generated",
name,
)

View File

@@ -0,0 +1,195 @@
from collections import deque
import threading
import attr
from .. import _core
from .._util import NoPublicConstructor
from ._wakeup_socketpair import WakeupSocketpair
@attr.s(slots=True)
class EntryQueue:
# This used to use a queue.Queue. but that was broken, because Queues are
# implemented in Python, and not reentrant -- so it was thread-safe, but
# not signal-safe. deque is implemented in C, so each operation is atomic
# WRT threads (and this is guaranteed in the docs), AND each operation is
# atomic WRT signal delivery (signal handlers can run on either side, but
# not *during* a deque operation). dict makes similar guarantees - and
# it's even ordered!
queue = attr.ib(factory=deque)
idempotent_queue = attr.ib(factory=dict)
wakeup = attr.ib(factory=WakeupSocketpair)
done = attr.ib(default=False)
# Must be a reentrant lock, because it's acquired from signal handlers.
# RLock is signal-safe as of cpython 3.2. NB that this does mean that the
# lock is effectively *disabled* when we enter from signal context. The
# way we use the lock this is OK though, because when
# run_sync_soon is called from a signal it's atomic WRT the
# main thread -- it just might happen at some inconvenient place. But if
# you look at the one place where the main thread holds the lock, it's
# just to make 1 assignment, so that's atomic WRT a signal anyway.
lock = attr.ib(factory=threading.RLock)
async def task(self):
assert _core.currently_ki_protected()
# RLock has two implementations: a signal-safe version in _thread, and
# and signal-UNsafe version in threading. We need the signal safe
# version. Python 3.2 and later should always use this anyway, but,
# since the symptoms if this goes wrong are just "weird rare
# deadlocks", then let's make a little check.
# See:
# https://bugs.python.org/issue13697#msg237140
assert self.lock.__class__.__module__ == "_thread"
def run_cb(job):
# We run this with KI protection enabled; it's the callback's
# job to disable it if it wants it disabled. Exceptions are
# treated like system task exceptions (i.e., converted into
# TrioInternalError and cause everything to shut down).
sync_fn, args = job
try:
sync_fn(*args)
except BaseException as exc:
async def kill_everything(exc):
raise exc
try:
_core.spawn_system_task(kill_everything, exc)
except RuntimeError:
# We're quite late in the shutdown process and the
# system nursery is already closed.
# TODO(2020-06): this is a gross hack and should
# be fixed soon when we address #1607.
_core.current_task().parent_nursery.start_soon(kill_everything, exc)
return True
# This has to be carefully written to be safe in the face of new items
# being queued while we iterate, and to do a bounded amount of work on
# each pass:
def run_all_bounded():
for _ in range(len(self.queue)):
run_cb(self.queue.popleft())
for job in list(self.idempotent_queue):
del self.idempotent_queue[job]
run_cb(job)
try:
while True:
run_all_bounded()
if not self.queue and not self.idempotent_queue:
await self.wakeup.wait_woken()
else:
await _core.checkpoint()
except _core.Cancelled:
# Keep the work done with this lock held as minimal as possible,
# because it doesn't protect us against concurrent signal delivery
# (see the comment above). Notice that this code would still be
# correct if written like:
# self.done = True
# with self.lock:
# pass
# because all we want is to force run_sync_soon
# to either be completely before or completely after the write to
# done. That's why we don't need the lock to protect
# against signal handlers.
with self.lock:
self.done = True
# No more jobs will be submitted, so just clear out any residual
# ones:
run_all_bounded()
assert not self.queue
assert not self.idempotent_queue
def close(self):
self.wakeup.close()
def size(self):
return len(self.queue) + len(self.idempotent_queue)
def run_sync_soon(self, sync_fn, *args, idempotent=False):
with self.lock:
if self.done:
raise _core.RunFinishedError("run() has exited")
# We have to hold the lock all the way through here, because
# otherwise the main thread might exit *while* we're doing these
# calls, and then our queue item might not be processed, or the
# wakeup call might trigger an OSError b/c the IO manager has
# already been shut down.
if idempotent:
self.idempotent_queue[(sync_fn, args)] = None
else:
self.queue.append((sync_fn, args))
self.wakeup.wakeup_thread_and_signal_safe()
@attr.s(eq=False, hash=False, slots=True)
class TrioToken(metaclass=NoPublicConstructor):
"""An opaque object representing a single call to :func:`trio.run`.
It has no public constructor; instead, see :func:`current_trio_token`.
This object has two uses:
1. It lets you re-enter the Trio run loop from external threads or signal
handlers. This is the low-level primitive that :func:`trio.to_thread`
and `trio.from_thread` use to communicate with worker threads, that
`trio.open_signal_receiver` uses to receive notifications about
signals, and so forth.
2. Each call to :func:`trio.run` has exactly one associated
:class:`TrioToken` object, so you can use it to identify a particular
call.
"""
_reentry_queue = attr.ib()
def run_sync_soon(self, sync_fn, *args, idempotent=False):
"""Schedule a call to ``sync_fn(*args)`` to occur in the context of a
Trio task.
This is safe to call from the main thread, from other threads, and
from signal handlers. This is the fundamental primitive used to
re-enter the Trio run loop from outside of it.
The call will happen "soon", but there's no guarantee about exactly
when, and no mechanism provided for finding out when it's happened.
If you need this, you'll have to build your own.
The call is effectively run as part of a system task (see
:func:`~trio.lowlevel.spawn_system_task`). In particular this means
that:
* :exc:`KeyboardInterrupt` protection is *enabled* by default; if
you want ``sync_fn`` to be interruptible by control-C, then you
need to use :func:`~trio.lowlevel.disable_ki_protection`
explicitly.
* If ``sync_fn`` raises an exception, then it's converted into a
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. You
should be careful that ``sync_fn`` doesn't crash.
All calls with ``idempotent=False`` are processed in strict
first-in first-out order.
If ``idempotent=True``, then ``sync_fn`` and ``args`` must be
hashable, and Trio will make a best-effort attempt to discard any
call submission which is equal to an already-pending call. Trio
will process these in first-in first-out order.
Any ordering guarantees apply separately to ``idempotent=False``
and ``idempotent=True`` calls; there's no rule for how calls in the
different categories are ordered with respect to each other.
:raises trio.RunFinishedError:
if the associated call to :func:`trio.run`
has already exited. (Any call that *doesn't* raise this error
is guaranteed to be fully processed before :func:`trio.run`
exits.)
"""
self._reentry_queue.run_sync_soon(sync_fn, *args, idempotent=idempotent)

View File

@@ -0,0 +1,114 @@
import attr
from trio._util import NoPublicConstructor
class TrioInternalError(Exception):
"""Raised by :func:`run` if we encounter a bug in Trio, or (possibly) a
misuse of one of the low-level :mod:`trio.lowlevel` APIs.
This should never happen! If you get this error, please file a bug.
Unfortunately, if you get this error it also means that all bets are off
Trio doesn't know what is going on and its normal invariants may be void.
(For example, we might have "lost track" of a task. Or lost track of all
tasks.) Again, though, this shouldn't happen.
"""
class RunFinishedError(RuntimeError):
"""Raised by `trio.from_thread.run` and similar functions if the
corresponding call to :func:`trio.run` has already finished.
"""
class WouldBlock(Exception):
"""Raised by ``X_nowait`` functions if ``X`` would block."""
class Cancelled(BaseException, metaclass=NoPublicConstructor):
"""Raised by blocking calls if the surrounding scope has been cancelled.
You should let this exception propagate, to be caught by the relevant
cancel scope. To remind you of this, it inherits from :exc:`BaseException`
instead of :exc:`Exception`, just like :exc:`KeyboardInterrupt` and
:exc:`SystemExit` do. This means that if you write something like::
try:
...
except Exception:
...
then this *won't* catch a :exc:`Cancelled` exception.
You cannot raise :exc:`Cancelled` yourself. Attempting to do so
will produce a :exc:`TypeError`. Use :meth:`cancel_scope.cancel()
<trio.CancelScope.cancel>` instead.
.. note::
In the US it's also common to see this word spelled "canceled", with
only one "l". This is a `recent
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=5&smoothing=3&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
and `US-specific
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=18&smoothing=3&share=&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
innovation, and even in the US both forms are still commonly used. So
for consistency with the rest of the world and with "cancellation"
(which always has two "l"s), Trio uses the two "l" spelling
everywhere.
"""
def __str__(self):
return "Cancelled"
class BusyResourceError(Exception):
"""Raised when a task attempts to use a resource that some other task is
already using, and this would lead to bugs and nonsense.
For example, if two tasks try to send data through the same socket at the
same time, Trio will raise :class:`BusyResourceError` instead of letting
the data get scrambled.
"""
class ClosedResourceError(Exception):
"""Raised when attempting to use a resource after it has been closed.
Note that "closed" here means that *your* code closed the resource,
generally by calling a method with a name like ``close`` or ``aclose``, or
by exiting a context manager. If a problem arises elsewhere for example,
because of a network failure, or because a remote peer closed their end of
a connection then that should be indicated by a different exception
class, like :exc:`BrokenResourceError` or an :exc:`OSError` subclass.
"""
class BrokenResourceError(Exception):
"""Raised when an attempt to use a resource fails due to external
circumstances.
For example, you might get this if you try to send data on a stream where
the remote side has already closed the connection.
You *don't* get this error if *you* closed the resource in that case you
get :class:`ClosedResourceError`.
This exception's ``__cause__`` attribute will often contain more
information about the underlying error.
"""
class EndOfChannel(Exception):
"""Raised when trying to receive from a :class:`trio.abc.ReceiveChannel`
that has no more data to receive.
This is analogous to an "end-of-file" condition, but for channels.
"""

View File

@@ -0,0 +1,47 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
def add_instrument(instrument: Instrument) ->None:
"""Start instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to activate.
If ``instrument`` is already active, does nothing.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument)
except AttributeError:
raise RuntimeError("must be called from async context")
def remove_instrument(instrument: Instrument) ->None:
"""Stop instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to de-activate.
Raises:
KeyError: if the instrument is not currently active. This could
occur either because you never added it, or because you added it
and then it raised an unhandled exception and was automatically
deactivated.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument)
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,35 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
async def wait_readable(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_writable(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
def notify_closing(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,59 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
def current_kqueue():
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue()
except AttributeError:
raise RuntimeError("must be called from async context")
def monitor_kevent(ident, filter):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_kevent(ident, filter, abort_func):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_readable(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_writable(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
def notify_closing(fd):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,83 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
async def wait_readable(sock):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_writable(sock):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock)
except AttributeError:
raise RuntimeError("must be called from async context")
def notify_closing(handle):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle)
except AttributeError:
raise RuntimeError("must be called from async context")
def register_with_iocp(handle):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle)
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_overlapped(handle, lpOverlapped):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(handle, lpOverlapped)
except AttributeError:
raise RuntimeError("must be called from async context")
async def write_overlapped(handle, data, file_offset=0):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(handle, data, file_offset)
except AttributeError:
raise RuntimeError("must be called from async context")
async def readinto_overlapped(handle, buffer, file_offset=0):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(handle, buffer, file_offset)
except AttributeError:
raise RuntimeError("must be called from async context")
def current_iocp():
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp()
except AttributeError:
raise RuntimeError("must be called from async context")
def monitor_completion_key():
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key()
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,241 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
def current_statistics():
"""Returns an object containing run-loop-level debugging information.
Currently the following fields are defined:
* ``tasks_living`` (int): The number of tasks that have been spawned
and not yet exited.
* ``tasks_runnable`` (int): The number of tasks that are currently
queued on the run queue (as opposed to blocked waiting for something
to happen).
* ``seconds_to_next_deadline`` (float): The time until the next
pending cancel scope deadline. May be negative if the deadline has
expired but we haven't yet processed cancellations. May be
:data:`~math.inf` if there are no pending deadlines.
* ``run_sync_soon_queue_size`` (int): The number of
unprocessed callbacks queued via
:meth:`trio.lowlevel.TrioToken.run_sync_soon`.
* ``io_statistics`` (object): Some statistics from Trio's I/O
backend. This always has an attribute ``backend`` which is a string
naming which operating-system-specific I/O backend is in use; the
other attributes vary between backends.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_statistics()
except AttributeError:
raise RuntimeError("must be called from async context")
def current_time():
"""Returns the current time according to Trio's internal clock.
Returns:
float: The current time.
Raises:
RuntimeError: if not inside a call to :func:`trio.run`.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_time()
except AttributeError:
raise RuntimeError("must be called from async context")
def current_clock():
"""Returns the current :class:`~trio.abc.Clock`."""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_clock()
except AttributeError:
raise RuntimeError("must be called from async context")
def current_root_task():
"""Returns the current root :class:`Task`.
This is the task that is the ultimate parent of all other tasks.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_root_task()
except AttributeError:
raise RuntimeError("must be called from async context")
def reschedule(task, next_send=_NO_SEND):
"""Reschedule the given task with the given
:class:`outcome.Outcome`.
See :func:`wait_task_rescheduled` for the gory details.
There must be exactly one call to :func:`reschedule` for every call to
:func:`wait_task_rescheduled`. (And when counting, keep in mind that
returning :data:`Abort.SUCCEEDED` from an abort callback is equivalent
to calling :func:`reschedule` once.)
Args:
task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked
in a call to :func:`wait_task_rescheduled`.
next_send (outcome.Outcome): the value (or error) to return (or
raise) from :func:`wait_task_rescheduled`.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send)
except AttributeError:
raise RuntimeError("must be called from async context")
def spawn_system_task(async_fn, *args, name=None, context=None):
"""Spawn a "system" task.
System tasks have a few differences from regular tasks:
* They don't need an explicit nursery; instead they go into the
internal "system nursery".
* If a system task raises an exception, then it's converted into a
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. If you
write a system task, you should be careful to make sure it doesn't
crash.
* System tasks are automatically cancelled when the main task exits.
* By default, system tasks have :exc:`KeyboardInterrupt` protection
*enabled*. If you want your task to be interruptible by control-C,
then you need to use :func:`disable_ki_protection` explicitly (and
come up with some plan for what to do with a
:exc:`KeyboardInterrupt`, given that system tasks aren't allowed to
raise exceptions).
* System tasks do not inherit context variables from their creator.
Towards the end of a call to :meth:`trio.run`, after the main
task and all system tasks have exited, the system nursery
becomes closed. At this point, new calls to
:func:`spawn_system_task` will raise ``RuntimeError("Nursery
is closed to new arrivals")`` instead of creating a system
task. It's possible to encounter this state either in
a ``finally`` block in an async generator, or in a callback
passed to :meth:`TrioToken.run_sync_soon` at the right moment.
Args:
async_fn: An async callable.
args: Positional arguments for ``async_fn``. If you want to pass
keyword arguments, use :func:`functools.partial`.
name: The name for this task. Only used for debugging/introspection
(e.g. ``repr(task_obj)``). If this isn't a string,
:func:`spawn_system_task` will try to make it one. A common use
case is if you're wrapping a function before spawning a new
task, you might pass the original function as the ``name=`` to
make debugging easier.
context: An optional ``contextvars.Context`` object with context variables
to use for this task. You would normally get a copy of the current
context with ``context = contextvars.copy_context()`` and then you would
pass that ``context`` object here.
Returns:
Task: the newly spawned task
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(async_fn, *args, name=name, context=context)
except AttributeError:
raise RuntimeError("must be called from async context")
def current_trio_token():
"""Retrieve the :class:`TrioToken` for the current call to
:func:`trio.run`.
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_trio_token()
except AttributeError:
raise RuntimeError("must be called from async context")
async def wait_all_tasks_blocked(cushion=0.0):
"""Block until there are no runnable tasks.
This is useful in testing code when you want to give other tasks a
chance to "settle down". The calling task is blocked, and doesn't wake
up until all other tasks are also blocked for at least ``cushion``
seconds. (Setting a non-zero ``cushion`` is intended to handle cases
like two tasks talking to each other over a local socket, where we
want to ignore the potential brief moment between a send and receive
when all tasks are blocked.)
Note that ``cushion`` is measured in *real* time, not the Trio clock
time.
If there are multiple tasks blocked in :func:`wait_all_tasks_blocked`,
then the one with the shortest ``cushion`` is the one woken (and
this task becoming unblocked resets the timers for the remaining
tasks). If there are multiple tasks that have exactly the same
``cushion``, then all are woken.
You should also consider :class:`trio.testing.Sequencer`, which
provides a more explicit way to control execution ordering within a
test, and will often produce more readable tests.
Example:
Here's an example of one way to test that Trio's locks are fair: we
take the lock in the parent, start a child, wait for the child to be
blocked waiting for the lock (!), and then check that we can't
release and immediately re-acquire the lock::
async def lock_taker(lock):
await lock.acquire()
lock.release()
async def test_lock_fairness():
lock = trio.Lock()
await lock.acquire()
async with trio.open_nursery() as nursery:
nursery.start_soon(lock_taker, lock)
# child hasn't run yet, we have the lock
assert lock.locked()
assert lock._owner is trio.lowlevel.current_task()
await trio.testing.wait_all_tasks_blocked()
# now the child has run and is blocked on lock.acquire(), we
# still have the lock
assert lock.locked()
assert lock._owner is trio.lowlevel.current_task()
lock.release()
try:
# The child has a prior claim, so we can't have it
lock.acquire_nowait()
except trio.WouldBlock:
assert lock._owner is not trio.lowlevel.current_task()
print("PASS")
else:
print("FAIL")
"""
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion)
except AttributeError:
raise RuntimeError("must be called from async context")
# fmt: on

View File

@@ -0,0 +1,108 @@
import logging
import types
import attr
from typing import Any, Callable, Dict, List, Sequence, Iterator, TypeVar
from .._abc import Instrument
# Used to log exceptions in instruments
INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument")
F = TypeVar("F", bound=Callable[..., Any])
# Decorator to mark methods public. This does nothing by itself, but
# trio/_tools/gen_exports.py looks for it.
def _public(fn: F) -> F:
return fn
class Instruments(Dict[str, Dict[Instrument, None]]):
"""A collection of `trio.abc.Instrument` organized by hook.
Instrumentation calls are rather expensive, and we don't want a
rarely-used instrument (like before_run()) to slow down hot
operations (like before_task_step()). Thus, we cache the set of
instruments to be called for each hook, and skip the instrumentation
call if there's nothing currently installed for that hook.
"""
__slots__ = ()
def __init__(self, incoming: Sequence[Instrument]):
self["_all"] = {}
for instrument in incoming:
self.add_instrument(instrument)
@_public
def add_instrument(self, instrument: Instrument) -> None:
"""Start instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to activate.
If ``instrument`` is already active, does nothing.
"""
if instrument in self["_all"]:
return
self["_all"][instrument] = None
try:
for name in dir(instrument):
if name.startswith("_"):
continue
try:
prototype = getattr(Instrument, name)
except AttributeError:
continue
impl = getattr(instrument, name)
if isinstance(impl, types.MethodType) and impl.__func__ is prototype:
# Inherited unchanged from _abc.Instrument
continue
self.setdefault(name, {})[instrument] = None
except:
self.remove_instrument(instrument)
raise
@_public
def remove_instrument(self, instrument: Instrument) -> None:
"""Stop instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to de-activate.
Raises:
KeyError: if the instrument is not currently active. This could
occur either because you never added it, or because you added it
and then it raised an unhandled exception and was automatically
deactivated.
"""
# If instrument isn't present, the KeyError propagates out
self["_all"].pop(instrument)
for hookname, instruments in list(self.items()):
if instrument in instruments:
del instruments[instrument]
if not instruments:
del self[hookname]
def call(self, hookname: str, *args: Any) -> None:
"""Call hookname(*args) on each applicable instrument.
You must first check whether there are any instruments installed for
that hook, e.g.::
if "before_task_step" in instruments:
instruments.call("before_task_step", task)
"""
for instrument in list(self[hookname]):
try:
getattr(instrument, hookname)(*args)
except:
self.remove_instrument(instrument)
INSTRUMENT_LOGGER.exception(
"Exception raised when calling %r on instrument %r. "
"Instrument has been disabled.",
hookname,
instrument,
)

View File

@@ -0,0 +1,22 @@
import copy
import outcome
from .. import _core
# Utility function shared between _io_epoll and _io_windows
def wake_all(waiters, exc):
try:
current_task = _core.current_task()
except RuntimeError:
current_task = None
raise_at_end = False
for attr_name in ["read_task", "write_task"]:
task = getattr(waiters, attr_name)
if task is not None:
if task is current_task:
raise_at_end = True
else:
_core.reschedule(task, outcome.Error(copy.copy(exc)))
setattr(waiters, attr_name, None)
if raise_at_end:
raise exc

View File

@@ -0,0 +1,317 @@
import select
import sys
import attr
from collections import defaultdict
from typing import Dict, TYPE_CHECKING
from .. import _core
from ._run import _public
from ._io_common import wake_all
from ._wakeup_socketpair import WakeupSocketpair
assert not TYPE_CHECKING or sys.platform == "linux"
@attr.s(slots=True, eq=False, frozen=True)
class _EpollStatistics:
tasks_waiting_read = attr.ib()
tasks_waiting_write = attr.ib()
backend = attr.ib(default="epoll")
# Some facts about epoll
# ----------------------
#
# Internally, an epoll object is sort of like a WeakKeyDictionary where the
# keys are tuples of (fd number, file object). When you call epoll_ctl, you
# pass in an fd; that gets converted to an (fd number, file object) tuple by
# looking up the fd in the process's fd table at the time of the call. When an
# event happens on the file object, epoll_wait drops the file object part, and
# just returns the fd number in its event. So from the outside it looks like
# it's keeping a table of fds, but really it's a bit more complicated. This
# has some subtle consequences.
#
# In general, file objects inside the kernel are reference counted. Each entry
# in a process's fd table holds a strong reference to the corresponding file
# object, and most operations that use file objects take a temporary strong
# reference while they're working. So when you call close() on an fd, that
# might or might not cause the file object to be deallocated -- it depends on
# whether there are any other references to that file object. Some common ways
# this can happen:
#
# - after calling dup(), you have two fds in the same process referring to the
# same file object. Even if you close one fd (= remove that entry from the
# fd table), the file object will be kept alive by the other fd.
# - when calling fork(), the child inherits a copy of the parent's fd table,
# so all the file objects get another reference. (But if the fork() is
# followed by exec(), then all of the child's fds that have the CLOEXEC flag
# set will be closed at that point.)
# - most syscalls that work on fds take a strong reference to the underlying
# file object while they're using it. So there's one thread blocked in
# read(fd), and then another thread calls close() on the last fd referring
# to that object, the underlying file won't actually be closed until
# after read() returns.
#
# However, epoll does *not* take a reference to any of the file objects in its
# interest set (that's what makes it similar to a WeakKeyDictionary). File
# objects inside an epoll interest set will be deallocated if all *other*
# references to them are closed. And when that happens, the epoll object will
# automatically deregister that file object and stop reporting events on it.
# So that's quite handy.
#
# But, what happens if we do this?
#
# fd1 = open(...)
# epoll_ctl(EPOLL_CTL_ADD, fd1, ...)
# fd2 = dup(fd1)
# close(fd1)
#
# In this case, the dup() keeps the underlying file object alive, so it
# remains registered in the epoll object's interest set, as the tuple (fd1,
# file object). But, fd1 no longer refers to this file object! You might think
# there was some magic to handle this, but unfortunately no; the consequences
# are totally predictable from what I said above:
#
# If any events occur on the file object, then epoll will report them as
# happening on fd1, even though that doesn't make sense.
#
# Perhaps we would like to deregister fd1 to stop getting nonsensical events.
# But how? When we call epoll_ctl, we have to pass an fd number, which will
# get expanded to an (fd number, file object) tuple. We can't pass fd1,
# because when epoll_ctl tries to look it up, it won't find our file object.
# And we can't pass fd2, because that will get expanded to (fd2, file object),
# which is a different lookup key. In fact, it's *impossible* to de-register
# this fd!
#
# We could even have fd1 get assigned to another file object, and then we can
# have multiple keys registered simultaneously using the same fd number, like:
# (fd1, file object 1), (fd1, file object 2). And if events happen on either
# file object, then epoll will happily report that something happened to
# "fd1".
#
# Now here's what makes this especially nasty: suppose the old file object
# becomes, say, readable. That means that every time we call epoll_wait, it
# will return immediately to tell us that "fd1" is readable. Normally, we
# would handle this by de-registering fd1, waking up the corresponding call to
# wait_readable, then the user will call read() or recv() or something, and
# we're fine. But if this happens on a stale fd where we can't remove the
# registration, then we might get stuck in a state where epoll_wait *always*
# returns immediately, so our event loop becomes unable to sleep, and now our
# program is burning 100% of the CPU doing nothing, with no way out.
#
#
# What does this mean for Trio?
# -----------------------------
#
# Since we don't control the user's code, we have no way to guarantee that we
# don't get stuck with stale fd's in our epoll interest set. For example, a
# user could call wait_readable(fd) in one task, and then while that's
# running, they might close(fd) from another task. In this situation, they're
# *supposed* to call notify_closing(fd) to let us know what's happening, so we
# can interrupt the wait_readable() call and avoid getting into this mess. And
# that's the only thing that can possibly work correctly in all cases. But
# sometimes user code has bugs. So if this does happen, we'd like to degrade
# gracefully, and survive without corrupting Trio's internal state or
# otherwise causing the whole program to explode messily.
#
# Our solution: we always use EPOLLONESHOT. This way, we might get *one*
# spurious event on a stale fd, but then epoll will automatically silence it
# until we explicitly say that we want more events... and if we have a stale
# fd, then we actually can't re-enable it! So we can't get stuck in an
# infinite busy-loop. If there's a stale fd hanging around, then it might
# cause a spurious `BusyResourceError`, or cause one wait_* call to return
# before it should have... but in general, the wait_* functions are allowed to
# have some spurious wakeups; the user code will just attempt the operation,
# get EWOULDBLOCK, and call wait_* again. And the program as a whole will
# survive, any exceptions will propagate, etc.
#
# As a bonus, EPOLLONESHOT also saves us having to explicitly deregister fds
# on the normal wakeup path, so it's a bit more efficient in general.
#
# However, EPOLLONESHOT has a few trade-offs to consider:
#
# First, you can't combine EPOLLONESHOT with EPOLLEXCLUSIVE. This is a bit sad
# in one somewhat rare case: if you have a multi-process server where a group
# of processes all share the same listening socket, then EPOLLEXCLUSIVE can be
# used to avoid "thundering herd" problems when a new connection comes in. But
# this isn't too bad. It's not clear if EPOLLEXCLUSIVE even works for us
# anyway:
#
# https://stackoverflow.com/questions/41582560/how-does-epolls-epollexclusive-mode-interact-with-level-triggering
#
# And it's not clear that EPOLLEXCLUSIVE is a great approach either:
#
# https://blog.cloudflare.com/the-sad-state-of-linux-socket-balancing/
#
# And if we do need to support this, we could always add support through some
# more-specialized API in the future. So this isn't a blocker to using
# EPOLLONESHOT.
#
# Second, EPOLLONESHOT does not actually *deregister* the fd after delivering
# an event (EPOLL_CTL_DEL). Instead, it keeps the fd registered, but
# effectively does an EPOLL_CTL_MOD to set the fd's interest flags to
# all-zeros. So we could still end up with an fd hanging around in the
# interest set for a long time, even if we're not using it.
#
# Fortunately, this isn't a problem, because it's only a weak reference if
# we have a stale fd that's been silenced by EPOLLONESHOT, then it wastes a
# tiny bit of kernel memory remembering this fd that can never be revived, but
# when the underlying file object is eventually closed, that memory will be
# reclaimed. So that's OK.
#
# The other issue is that when someone calls wait_*, using EPOLLONESHOT means
# that if we have ever waited for this fd before, we have to use EPOLL_CTL_MOD
# to re-enable it; but if it's a new fd, we have to use EPOLL_CTL_ADD. How do
# we know which one to use? There's no reasonable way to track which fds are
# currently registered -- remember, we're assuming the user might have gone
# and rearranged their fds without telling us!
#
# Fortunately, this also has a simple solution: if we wait on a socket or
# other fd once, then we'll probably wait on it lots of times. And the epoll
# object itself knows which fds it already has registered. So when an fd comes
# in, we optimistically assume that it's been waited on before, and try doing
# EPOLL_CTL_MOD. And if that fails with an ENOENT error, then we try again
# with EPOLL_CTL_ADD.
#
# So that's why this code is the way it is. And now you know more than you
# wanted to about how epoll works.
@attr.s(slots=True, eq=False)
class EpollWaiters:
read_task = attr.ib(default=None)
write_task = attr.ib(default=None)
current_flags = attr.ib(default=0)
@attr.s(slots=True, eq=False, hash=False)
class EpollIOManager:
_epoll = attr.ib(factory=select.epoll)
# {fd: EpollWaiters}
_registered = attr.ib(
factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters]
)
_force_wakeup = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd = attr.ib(default=None)
def __attrs_post_init__(self):
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
def statistics(self):
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._registered.values():
if waiter.read_task is not None:
tasks_waiting_read += 1
if waiter.write_task is not None:
tasks_waiting_write += 1
return _EpollStatistics(
tasks_waiting_read=tasks_waiting_read,
tasks_waiting_write=tasks_waiting_write,
)
def close(self):
self._epoll.close()
self._force_wakeup.close()
def force_wakeup(self):
self._force_wakeup.wakeup_thread_and_signal_safe()
# Return value must be False-y IFF the timeout expired, NOT if any I/O
# happened or force_wakeup was called. Otherwise it can be anything; gets
# passed straight through to process_events.
def get_events(self, timeout):
# max_events must be > 0 or epoll gets cranky
# accessing self._registered from a thread looks dangerous, but it's
# OK because it doesn't matter if our value is a little bit off.
max_events = max(1, len(self._registered))
return self._epoll.poll(timeout, max_events)
def process_events(self, events):
for fd, flags in events:
if fd == self._force_wakeup_fd:
self._force_wakeup.drain()
continue
waiters = self._registered[fd]
# EPOLLONESHOT always clears the flags when an event is delivered
waiters.current_flags = 0
# Clever hack stolen from selectors.EpollSelector: an event
# with EPOLLHUP or EPOLLERR flags wakes both readers and
# writers.
if flags & ~select.EPOLLIN and waiters.write_task is not None:
_core.reschedule(waiters.write_task)
waiters.write_task = None
if flags & ~select.EPOLLOUT and waiters.read_task is not None:
_core.reschedule(waiters.read_task)
waiters.read_task = None
self._update_registrations(fd)
def _update_registrations(self, fd):
waiters = self._registered[fd]
wanted_flags = 0
if waiters.read_task is not None:
wanted_flags |= select.EPOLLIN
if waiters.write_task is not None:
wanted_flags |= select.EPOLLOUT
if wanted_flags != waiters.current_flags:
try:
try:
# First try EPOLL_CTL_MOD
self._epoll.modify(fd, wanted_flags | select.EPOLLONESHOT)
except OSError:
# If that fails, it might be a new fd; try EPOLL_CTL_ADD
self._epoll.register(fd, wanted_flags | select.EPOLLONESHOT)
waiters.current_flags = wanted_flags
except OSError as exc:
# If everything fails, probably it's a bad fd, e.g. because
# the fd was closed behind our back. In this case we don't
# want to try to unregister the fd, because that will probably
# fail too. Just clear our state and wake everyone up.
del self._registered[fd]
# This could raise (in case we're calling this inside one of
# the to-be-woken tasks), so we have to do it last.
wake_all(waiters, exc)
return
if not wanted_flags:
del self._registered[fd]
async def _epoll_wait(self, fd, attr_name):
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
if getattr(waiters, attr_name) is not None:
raise _core.BusyResourceError(
"another task is already reading / writing this fd"
)
setattr(waiters, attr_name, _core.current_task())
self._update_registrations(fd)
def abort(_):
setattr(waiters, attr_name, None)
self._update_registrations(fd)
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort)
@_public
async def wait_readable(self, fd):
await self._epoll_wait(fd, "read_task")
@_public
async def wait_writable(self, fd):
await self._epoll_wait(fd, "write_task")
@_public
def notify_closing(self, fd):
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
self._registered[fd],
_core.ClosedResourceError("another task closed this fd"),
)
del self._registered[fd]
try:
self._epoll.unregister(fd)
except (OSError, ValueError):
pass

View File

@@ -0,0 +1,196 @@
import select
import sys
from typing import TYPE_CHECKING
import outcome
from contextlib import contextmanager
import attr
import errno
from .. import _core
from ._run import _public
from ._wakeup_socketpair import WakeupSocketpair
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
@attr.s(slots=True, eq=False, frozen=True)
class _KqueueStatistics:
tasks_waiting = attr.ib()
monitors = attr.ib()
backend = attr.ib(default="kqueue")
@attr.s(slots=True, eq=False)
class KqueueIOManager:
_kqueue = attr.ib(factory=select.kqueue)
# {(ident, filter): Task or UnboundedQueue}
_registered = attr.ib(factory=dict)
_force_wakeup = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd = attr.ib(default=None)
def __attrs_post_init__(self):
force_wakeup_event = select.kevent(
self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD
)
self._kqueue.control([force_wakeup_event], 0)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
def statistics(self):
tasks_waiting = 0
monitors = 0
for receiver in self._registered.values():
if type(receiver) is _core.Task:
tasks_waiting += 1
else:
monitors += 1
return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors)
def close(self):
self._kqueue.close()
self._force_wakeup.close()
def force_wakeup(self):
self._force_wakeup.wakeup_thread_and_signal_safe()
def get_events(self, timeout):
# max_events must be > 0 or kqueue gets cranky
# and we generally want this to be strictly larger than the actual
# number of events we get, so that we can tell that we've gotten
# all the events in just 1 call.
max_events = len(self._registered) + 1
events = []
while True:
batch = self._kqueue.control([], max_events, timeout)
events += batch
if len(batch) < max_events:
break
else:
timeout = 0
# and loop back to the start
return events
def process_events(self, events):
for event in events:
key = (event.ident, event.filter)
if event.ident == self._force_wakeup_fd:
self._force_wakeup.drain()
continue
receiver = self._registered[key]
if event.flags & select.KQ_EV_ONESHOT:
del self._registered[key]
if type(receiver) is _core.Task:
_core.reschedule(receiver, outcome.Value(event))
else:
receiver.put_nowait(event)
# kevent registration is complicated -- e.g. aio submission can
# implicitly perform a EV_ADD, and EVFILT_PROC with NOTE_TRACK will
# automatically register filters for child processes. So our lowlevel
# API is *very* low-level: we expose the kqueue itself for adding
# events or sticking into AIO submission structs, and split waiting
# off into separate methods. It's your responsibility to make sure
# that handle_io never receives an event without a corresponding
# registration! This may be challenging if you want to be careful
# about e.g. KeyboardInterrupt. Possibly this API could be improved to
# be more ergonomic...
@_public
def current_kqueue(self):
return self._kqueue
@contextmanager
@_public
def monitor_kevent(self, ident, filter):
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair"
)
q = _core.UnboundedQueue()
self._registered[key] = q
try:
yield q
finally:
del self._registered[key]
@_public
async def wait_kevent(self, ident, filter, abort_func):
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair"
)
self._registered[key] = _core.current_task()
def abort(raise_cancel):
r = abort_func(raise_cancel)
if r is _core.Abort.SUCCEEDED:
del self._registered[key]
return r
return await _core.wait_task_rescheduled(abort)
async def _wait_common(self, fd, filter):
if not isinstance(fd, int):
fd = fd.fileno()
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
event = select.kevent(fd, filter, flags)
self._kqueue.control([event], 0)
def abort(_):
event = select.kevent(fd, filter, select.KQ_EV_DELETE)
try:
self._kqueue.control([event], 0)
except OSError as exc:
# kqueue tracks individual fds (*not* the underlying file
# object, see _io_epoll.py for a long discussion of why this
# distinction matters), and automatically deregisters an event
# if the fd is closed. So if kqueue.control says that it
# doesn't know about this event, then probably it's because
# the fd was closed behind our backs. (Too bad we can't ask it
# to wake us up when this happens, versus discovering it after
# the fact... oh well, you can't have everything.)
#
# FreeBSD reports this using EBADF. macOS uses ENOENT.
if exc.errno in (errno.EBADF, errno.ENOENT): # pragma: no branch
pass
else: # pragma: no cover
# As far as we know, this branch can't happen.
raise
return _core.Abort.SUCCEEDED
await self.wait_kevent(fd, filter, abort)
@_public
async def wait_readable(self, fd):
await self._wait_common(fd, select.KQ_FILTER_READ)
@_public
async def wait_writable(self, fd):
await self._wait_common(fd, select.KQ_FILTER_WRITE)
@_public
def notify_closing(self, fd):
if not isinstance(fd, int):
fd = fd.fileno()
for filter in [select.KQ_FILTER_READ, select.KQ_FILTER_WRITE]:
key = (fd, filter)
receiver = self._registered.get(key)
if receiver is None:
continue
if type(receiver) is _core.Task:
event = select.kevent(fd, filter, select.KQ_EV_DELETE)
self._kqueue.control([event], 0)
exc = _core.ClosedResourceError("another task closed this fd")
_core.reschedule(receiver, outcome.Error(exc))
del self._registered[key]
else:
# XX this is an interesting example of a case where being able
# to close a queue would be useful...
raise NotImplementedError(
"can't close an fd that monitor_kevent is using"
)

View File

@@ -0,0 +1,868 @@
import itertools
from contextlib import contextmanager
import enum
import socket
import sys
from typing import TYPE_CHECKING
import attr
from outcome import Value
from .. import _core
from ._run import _public
from ._io_common import wake_all
from ._windows_cffi import (
ffi,
kernel32,
ntdll,
ws2_32,
INVALID_HANDLE_VALUE,
raise_winerror,
_handle,
ErrorCodes,
FileFlags,
AFDPollFlags,
WSAIoctls,
CompletionModes,
IoControlCodes,
)
assert not TYPE_CHECKING or sys.platform == "win32"
# There's a lot to be said about the overall design of a Windows event
# loop. See
#
# https://github.com/python-trio/trio/issues/52
#
# for discussion. This now just has some lower-level notes:
#
# How IOCP fits together:
#
# The general model is that you call some function like ReadFile or WriteFile
# to tell the kernel that you want it to perform some operation, and the
# kernel goes off and does that in the background, then at some point later it
# sends you a notification that the operation is complete. There are some more
# exotic APIs that don't quite fit this pattern, but most APIs do.
#
# Each background operation is tracked using an OVERLAPPED struct, that
# uniquely identifies that particular operation.
#
# An "IOCP" (or "I/O completion port") is an object that lets the kernel send
# us these notifications -- basically it's just a kernel->userspace queue.
#
# Each IOCP notification is represented by an OVERLAPPED_ENTRY struct, which
# contains 3 fields:
# - The "completion key". This is an opaque integer that we pick, and use
# however is convenient.
# - pointer to the OVERLAPPED struct for the completed operation.
# - dwNumberOfBytesTransferred (an integer).
#
# And in addition, for regular I/O, the OVERLAPPED structure gets filled in
# with:
# - result code (named "Internal")
# - number of bytes transferred (named "InternalHigh"); usually redundant
# with dwNumberOfBytesTransferred.
#
# There are also some other entries in OVERLAPPED which only matter on input:
# - Offset and OffsetHigh which are inputs to {Read,Write}File and
# otherwise always zero
# - hEvent which is for if you aren't using IOCP; we always set it to zero.
#
# That describes the usual pattern for operations and the usual meaning of
# these struct fields, but really these are just some arbitrary chunks of
# bytes that get passed back and forth, so some operations like to overload
# them to mean something else.
#
# You can also directly queue an OVERLAPPED_ENTRY object to an IOCP by calling
# PostQueuedCompletionStatus. When you use this you get to set all the
# OVERLAPPED_ENTRY fields to arbitrary values.
#
# You can request to cancel any operation if you know which handle it was
# issued on + the OVERLAPPED struct that identifies it (via CancelIoEx). This
# request might fail because the operation has already completed, or it might
# be queued to happen in the background, so you only find out whether it
# succeeded or failed later, when we get back the notification for the
# operation being complete.
#
# There are three types of operations that we support:
#
# == Regular I/O operations on handles (e.g. files or named pipes) ==
#
# Implemented by: register_with_iocp, wait_overlapped
#
# To use these, you have to register the handle with your IOCP first. Once
# it's registered, any operations on that handle will automatically send
# completion events to that IOCP, with a completion key that you specify *when
# the handle is registered* (so you can't use different completion keys for
# different operations).
#
# We give these two dedicated completion keys: CKeys.WAIT_OVERLAPPED for
# regular operations, and CKeys.LATE_CANCEL that's used to make
# wait_overlapped cancellable even if the user forgot to call
# register_with_iocp. The problem here is that after we request the cancel,
# wait_overlapped keeps blocking until it sees the completion notification...
# but if the user forgot to register_with_iocp, then the completion will never
# come, so the cancellation will never resolve. To avoid this, whenever we try
# to cancel an I/O operation and the cancellation fails, we use
# PostQueuedCompletionStatus to send a CKeys.LATE_CANCEL notification. If this
# arrives before the real completion, we assume the user forgot to call
# register_with_iocp on their handle, and raise an error accordingly.
#
# == Socket state notifications ==
#
# Implemented by: wait_readable, wait_writable
#
# The public APIs that windows provides for this are all really awkward and
# don't integrate with IOCP. So we drop down to a lower level, and talk
# directly to the socket device driver in the kernel, which is called "AFD".
# Unfortunately, this is a totally undocumented internal API. Fortunately
# libuv also does this, so we can be pretty confident that MS won't break it
# on us, and there is a *little* bit of information out there if you go
# digging.
#
# Basically: we open a magic file that refers to the AFD driver, register the
# magic file with our IOCP, and then we can issue regular overlapped I/O
# operations on that handle. Specifically, the operation we use is called
# IOCTL_AFD_POLL, which lets us pass in a buffer describing which events we're
# interested in on a given socket (readable, writable, etc.). Later, when the
# operation completes, the kernel rewrites the buffer we passed in to record
# which events happened, and uses IOCP as normal to notify us that this
# operation has completed.
#
# Unfortunately, the Windows kernel seems to have bugs if you try to issue
# multiple simultaneous IOCTL_AFD_POLL operations on the same socket (see
# notes-to-self/afd-lab.py). So if a user calls wait_readable and
# wait_writable at the same time, we have to combine those into a single
# IOCTL_AFD_POLL. This means we can't just use the wait_overlapped machinery.
# Instead we have some dedicated code to handle these operations, and a
# dedicated completion key CKeys.AFD_POLL.
#
# Sources of information:
# - https://github.com/python-trio/trio/issues/52
# - Wepoll: https://github.com/piscisaureus/wepoll/
# - libuv: https://github.com/libuv/libuv/
# - ReactOS: https://github.com/reactos/reactos/
# - Ancient leaked copies of the Windows NT and Winsock source code:
# https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
# https://github.com/metoo10987/WinNT4/blob/f5c14e6b42c8f45c20fe88d14c61f9d6e0386b8e/private/ntos/afd/poll.c#L68-L707
# - The WSAEventSelect docs (this exposes a finer-grained set of events than
# select(), so if you squint you can treat it as a source of information on
# the fine-grained AFD poll types)
#
#
# == Everything else ==
#
# There are also some weirder APIs for interacting with IOCP. For example, the
# "Job" API lets you specify an IOCP handle and "completion key", and then in
# the future whenever certain events happen it sends uses IOCP to send a
# notification. These notifications don't correspond to any particular
# operation; they're just spontaneous messages you get. The
# "dwNumberOfBytesTransferred" field gets repurposed to carry an identifier
# for the message type (e.g. JOB_OBJECT_MSG_EXIT_PROCESS), and the
# "lpOverlapped" field gets repurposed to carry some arbitrary data that
# depends on the message type (e.g. the pid of the process that exited).
#
# To handle these, we have monitor_completion_key, where we hand out an
# unassigned completion key, let users set it up however they want, and then
# get any events that arrive on that key.
#
# (Note: monitor_completion_key is not documented or fully baked; expect it to
# change in the future.)
# Our completion keys
class CKeys(enum.IntEnum):
AFD_POLL = 0
WAIT_OVERLAPPED = 1
LATE_CANCEL = 2
FORCE_WAKEUP = 3
USER_DEFINED = 4 # and above
def _check(success):
if not success:
raise_winerror()
return success
def _get_underlying_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE):
if hasattr(sock, "fileno"):
sock = sock.fileno()
base_ptr = ffi.new("HANDLE *")
out_size = ffi.new("DWORD *")
failed = ws2_32.WSAIoctl(
ffi.cast("SOCKET", sock),
which,
ffi.NULL,
0,
base_ptr,
ffi.sizeof("HANDLE"),
out_size,
ffi.NULL,
ffi.NULL,
)
if failed:
code = ws2_32.WSAGetLastError()
raise_winerror(code)
return base_ptr[0]
def _get_base_socket(sock):
# There is a development kit for LSPs called Komodia Redirector.
# It does some unusual (some might say evil) things like intercepting
# SIO_BASE_HANDLE (fails) and SIO_BSP_HANDLE_SELECT (returns the same
# socket) in a misguided attempt to prevent bypassing it. It's been used
# in malware including the infamous Lenovo Superfish incident from 2015,
# but unfortunately is also used in some legitimate products such as
# parental control tools and Astrill VPN. Komodia happens to not
# block SIO_BSP_HANDLE_POLL, so we'll try SIO_BASE_HANDLE and fall back
# to SIO_BSP_HANDLE_POLL if it doesn't work.
# References:
# - https://github.com/piscisaureus/wepoll/blob/0598a791bf9cbbf480793d778930fc635b044980/wepoll.c#L2223
# - https://github.com/tokio-rs/mio/issues/1314
while True:
try:
# If this is not a Komodia-intercepted socket, we can just use
# SIO_BASE_HANDLE.
return _get_underlying_socket(sock)
except OSError as ex:
if ex.winerror == ErrorCodes.ERROR_NOT_SOCKET:
# SIO_BASE_HANDLE might fail even without LSP intervention,
# if we get something that's not a socket.
raise
if hasattr(sock, "fileno"):
sock = sock.fileno()
sock = _handle(sock)
next_sock = _get_underlying_socket(
sock, which=WSAIoctls.SIO_BSP_HANDLE_POLL
)
if next_sock == sock:
# If BSP_HANDLE_POLL returns the same socket we already had,
# then there's no layering going on and we need to fail
# to prevent an infinite loop.
raise RuntimeError(
"Unexpected network configuration detected: "
"SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't "
"return a different socket. Please file a bug at "
"https://github.com/python-trio/trio/issues/new, "
"and include the output of running: "
"netsh winsock show catalog"
)
# Otherwise we've gotten at least one layer deeper, so
# loop back around to keep digging.
sock = next_sock
def _afd_helper_handle():
# The "AFD" driver is exposed at the NT path "\Device\Afd". We're using
# the Win32 CreateFile, though, so we have to pass a Win32 path. \\.\ is
# how Win32 refers to the NT \GLOBAL??\ directory, and GLOBALROOT is a
# symlink inside that directory that points to the root of the NT path
# system. So by sticking that in front of the NT path, we get a Win32
# path. Alternatively, we could use NtCreateFile directly, since it takes
# an NT path. But we already wrap CreateFileW so this was easier.
# References:
# https://blogs.msdn.microsoft.com/jeremykuhne/2016/05/02/dos-to-nt-a-paths-journey/
# https://stackoverflow.com/a/21704022
#
# I'm actually not sure what the \Trio part at the end of the path does.
# Wepoll uses \Device\Afd\Wepoll, so I just copied them. (I'm guessing it
# might be visible in some debug tools, and is otherwise arbitrary?)
rawname = r"\\.\GLOBALROOT\Device\Afd\Trio".encode("utf-16le") + b"\0\0"
rawname_buf = ffi.from_buffer(rawname)
handle = kernel32.CreateFileW(
ffi.cast("LPCWSTR", rawname_buf),
FileFlags.SYNCHRONIZE,
FileFlags.FILE_SHARE_READ | FileFlags.FILE_SHARE_WRITE,
ffi.NULL, # no security attributes
FileFlags.OPEN_EXISTING,
FileFlags.FILE_FLAG_OVERLAPPED,
ffi.NULL, # no template file
)
if handle == INVALID_HANDLE_VALUE: # pragma: no cover
raise_winerror()
return handle
# AFD_POLL has a finer-grained set of events than other APIs. We collapse them
# down into Unix-style "readable" and "writable".
#
# Note: AFD_POLL_LOCAL_CLOSE isn't a reliable substitute for notify_closing(),
# because even if the user closes the socket *handle*, the socket *object*
# could still remain open, e.g. if the socket was dup'ed (possibly into
# another process). Explicitly calling notify_closing() guarantees that
# everyone waiting on the *handle* wakes up, which is what you'd expect.
#
# However, we can't avoid getting LOCAL_CLOSE notifications -- the kernel
# delivers them whether we ask for them or not -- so better to include them
# here for documentation, and so that when we check (delivered & requested) we
# get a match.
READABLE_FLAGS = (
AFDPollFlags.AFD_POLL_RECEIVE
| AFDPollFlags.AFD_POLL_ACCEPT
| AFDPollFlags.AFD_POLL_DISCONNECT # other side sent an EOF
| AFDPollFlags.AFD_POLL_ABORT
| AFDPollFlags.AFD_POLL_LOCAL_CLOSE
)
WRITABLE_FLAGS = (
AFDPollFlags.AFD_POLL_SEND
| AFDPollFlags.AFD_POLL_CONNECT_FAIL
| AFDPollFlags.AFD_POLL_ABORT
| AFDPollFlags.AFD_POLL_LOCAL_CLOSE
)
# Annoyingly, while the API makes it *seem* like you can happily issue as many
# independent AFD_POLL operations as you want without them interfering with
# each other, in fact if you issue two AFD_POLL operations for the same socket
# at the same time with notification going to the same IOCP port, then Windows
# gets super confused. For example, if we issue one operation from
# wait_readable, and another independent operation from wait_writable, then
# Windows may complete the wait_writable operation when the socket becomes
# readable.
#
# To avoid this, we have to coalesce all the operations on a single socket
# into one, and when the set of waiters changes we have to throw away the old
# operation and start a new one.
@attr.s(slots=True, eq=False)
class AFDWaiters:
read_task = attr.ib(default=None)
write_task = attr.ib(default=None)
current_op = attr.ib(default=None)
# We also need to bundle up all the info for a single op into a standalone
# object, because we need to keep all these objects alive until the operation
# finishes, even if we're throwing it away.
@attr.s(slots=True, eq=False, frozen=True)
class AFDPollOp:
lpOverlapped = attr.ib()
poll_info = attr.ib()
waiters = attr.ib()
afd_group = attr.ib()
# The Windows kernel has a weird issue when using AFD handles. If you have N
# instances of wait_readable/wait_writable registered with a single AFD handle,
# then cancelling any one of them takes something like O(N**2) time. So if we
# used just a single AFD handle, then cancellation would quickly become very
# expensive, e.g. a program with N active sockets would take something like
# O(N**3) time to unwind after control-C. The solution is to spread our sockets
# out over multiple AFD handles, so that N doesn't grow too large for any
# individual handle.
MAX_AFD_GROUP_SIZE = 500 # at 1000, the cubic scaling is just starting to bite
@attr.s(slots=True, eq=False)
class AFDGroup:
size = attr.ib()
handle = attr.ib()
@attr.s(slots=True, eq=False, frozen=True)
class _WindowsStatistics:
tasks_waiting_read = attr.ib()
tasks_waiting_write = attr.ib()
tasks_waiting_overlapped = attr.ib()
completion_key_monitors = attr.ib()
backend = attr.ib(default="windows")
# Maximum number of events to dequeue from the completion port on each pass
# through the run loop. Somewhat arbitrary. Should be large enough to collect
# a good set of tasks on each loop, but not so large to waste tons of memory.
# (Each WindowsIOManager holds a buffer whose size is ~32x this number.)
MAX_EVENTS = 1000
@attr.s(frozen=True)
class CompletionKeyEventInfo:
lpOverlapped = attr.ib()
dwNumberOfBytesTransferred = attr.ib()
class WindowsIOManager:
def __init__(self):
# If this method raises an exception, then __del__ could run on a
# half-initialized object. So we initialize everything that __del__
# touches to safe values up front, before we do anything that can
# fail.
self._iocp = None
self._all_afd_handles = []
self._iocp = _check(
kernel32.CreateIoCompletionPort(INVALID_HANDLE_VALUE, ffi.NULL, 0, 0)
)
self._events = ffi.new("OVERLAPPED_ENTRY[]", MAX_EVENTS)
self._vacant_afd_groups = set()
# {lpOverlapped: AFDPollOp}
self._afd_ops = {}
# {socket handle: AFDWaiters}
self._afd_waiters = {}
# {lpOverlapped: task}
self._overlapped_waiters = {}
self._posted_too_late_to_cancel = set()
self._completion_key_queues = {}
self._completion_key_counter = itertools.count(CKeys.USER_DEFINED)
with socket.socket() as s:
# We assume we're not working with any LSP that changes
# how select() is supposed to work. Validate this by
# ensuring that the result of SIO_BSP_HANDLE_SELECT (the
# LSP-hookable mechanism for "what should I use for
# select()?") matches that of SIO_BASE_HANDLE ("what is
# the real non-hooked underlying socket here?").
#
# This doesn't work for Komodia-based LSPs; see the comments
# in _get_base_socket() for details. But we have special
# logic for those, so we just skip this check if
# SIO_BASE_HANDLE fails.
# LSPs can in theory override this, but we believe that it never
# actually happens in the wild (except Komodia)
select_handle = _get_underlying_socket(
s, which=WSAIoctls.SIO_BSP_HANDLE_SELECT
)
try:
# LSPs shouldn't override this...
base_handle = _get_underlying_socket(s, which=WSAIoctls.SIO_BASE_HANDLE)
except OSError:
# But Komodia-based LSPs do anyway, in a way that causes
# a failure with WSAEFAULT. We have special handling for
# them in _get_base_socket(). Make sure it works.
_get_base_socket(s)
else:
if base_handle != select_handle:
raise RuntimeError(
"Unexpected network configuration detected: "
"SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ. "
"Please file a bug at "
"https://github.com/python-trio/trio/issues/new, "
"and include the output of running: "
"netsh winsock show catalog"
)
def close(self):
try:
if self._iocp is not None:
iocp = self._iocp
self._iocp = None
_check(kernel32.CloseHandle(iocp))
finally:
while self._all_afd_handles:
afd_handle = self._all_afd_handles.pop()
_check(kernel32.CloseHandle(afd_handle))
def __del__(self):
self.close()
def statistics(self):
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._afd_waiters.values():
if waiter.read_task is not None:
tasks_waiting_read += 1
if waiter.write_task is not None:
tasks_waiting_write += 1
return _WindowsStatistics(
tasks_waiting_read=tasks_waiting_read,
tasks_waiting_write=tasks_waiting_write,
tasks_waiting_overlapped=len(self._overlapped_waiters),
completion_key_monitors=len(self._completion_key_queues),
)
def force_wakeup(self):
_check(
kernel32.PostQueuedCompletionStatus(
self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL
)
)
def get_events(self, timeout):
received = ffi.new("PULONG")
milliseconds = round(1000 * timeout)
if timeout > 0 and milliseconds == 0:
milliseconds = 1
try:
_check(
kernel32.GetQueuedCompletionStatusEx(
self._iocp, self._events, MAX_EVENTS, received, milliseconds, 0
)
)
except OSError as exc:
if exc.winerror != ErrorCodes.WAIT_TIMEOUT: # pragma: no cover
raise
return 0
return received[0]
def process_events(self, received):
for i in range(received):
entry = self._events[i]
if entry.lpCompletionKey == CKeys.AFD_POLL:
lpo = entry.lpOverlapped
op = self._afd_ops.pop(lpo)
waiters = op.waiters
if waiters.current_op is not op:
# Stale op, nothing to do
pass
else:
waiters.current_op = None
# I don't think this can happen, so if it does let's crash
# and get a debug trace.
if lpo.Internal != 0: # pragma: no cover
code = ntdll.RtlNtStatusToDosError(lpo.Internal)
raise_winerror(code)
flags = op.poll_info.Handles[0].Events
if waiters.read_task and flags & READABLE_FLAGS:
_core.reschedule(waiters.read_task)
waiters.read_task = None
if waiters.write_task and flags & WRITABLE_FLAGS:
_core.reschedule(waiters.write_task)
waiters.write_task = None
self._refresh_afd(op.poll_info.Handles[0].Handle)
elif entry.lpCompletionKey == CKeys.WAIT_OVERLAPPED:
# Regular I/O event, dispatch on lpOverlapped
waiter = self._overlapped_waiters.pop(entry.lpOverlapped)
overlapped = entry.lpOverlapped
transferred = entry.dwNumberOfBytesTransferred
info = CompletionKeyEventInfo(
lpOverlapped=overlapped, dwNumberOfBytesTransferred=transferred
)
_core.reschedule(waiter, Value(info))
elif entry.lpCompletionKey == CKeys.LATE_CANCEL:
# Post made by a regular I/O event's abort_fn
# after it failed to cancel the I/O. If we still
# have a waiter with this lpOverlapped, we didn't
# get the regular I/O completion and almost
# certainly the user forgot to call
# register_with_iocp.
self._posted_too_late_to_cancel.remove(entry.lpOverlapped)
try:
waiter = self._overlapped_waiters.pop(entry.lpOverlapped)
except KeyError:
# Looks like the actual completion got here before this
# fallback post did -- we're in the "expected" case of
# too-late-to-cancel, where the user did nothing wrong.
# Nothing more to do.
pass
else:
exc = _core.TrioInternalError(
"Failed to cancel overlapped I/O in {} and didn't "
"receive the completion either. Did you forget to "
"call register_with_iocp()?".format(waiter.name)
)
# Raising this out of handle_io ensures that
# the user will see our message even if some
# other task is in an uncancellable wait due
# to the same underlying forgot-to-register
# issue (if their CancelIoEx succeeds, we
# have no way of noticing that their completion
# won't arrive). Unfortunately it loses the
# task traceback. If you're debugging this
# error and can't tell where it's coming from,
# try changing this line to
# _core.reschedule(waiter, outcome.Error(exc))
raise exc
elif entry.lpCompletionKey == CKeys.FORCE_WAKEUP:
pass
else:
# dispatch on lpCompletionKey
queue = self._completion_key_queues[entry.lpCompletionKey]
overlapped = int(ffi.cast("uintptr_t", entry.lpOverlapped))
transferred = entry.dwNumberOfBytesTransferred
info = CompletionKeyEventInfo(
lpOverlapped=overlapped, dwNumberOfBytesTransferred=transferred
)
queue.put_nowait(info)
def _register_with_iocp(self, handle, completion_key):
handle = _handle(handle)
_check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0))
# Supposedly this makes things slightly faster, by disabling the
# ability to do WaitForSingleObject(handle). We would never want to do
# that anyway, so might as well get the extra speed (if any).
# Ref: http://www.lenholgate.com/blog/2009/09/interesting-blog-posts-on-high-performance-servers.html
_check(
kernel32.SetFileCompletionNotificationModes(
handle, CompletionModes.FILE_SKIP_SET_EVENT_ON_HANDLE
)
)
################################################################
# AFD stuff
################################################################
def _refresh_afd(self, base_handle):
waiters = self._afd_waiters[base_handle]
if waiters.current_op is not None:
afd_group = waiters.current_op.afd_group
try:
_check(
kernel32.CancelIoEx(
afd_group.handle, waiters.current_op.lpOverlapped
)
)
except OSError as exc:
if exc.winerror != ErrorCodes.ERROR_NOT_FOUND:
# I don't think this is possible, so if it happens let's
# crash noisily.
raise # pragma: no cover
waiters.current_op = None
afd_group.size -= 1
self._vacant_afd_groups.add(afd_group)
flags = 0
if waiters.read_task is not None:
flags |= READABLE_FLAGS
if waiters.write_task is not None:
flags |= WRITABLE_FLAGS
if not flags:
del self._afd_waiters[base_handle]
else:
try:
afd_group = self._vacant_afd_groups.pop()
except KeyError:
afd_group = AFDGroup(0, _afd_helper_handle())
self._register_with_iocp(afd_group.handle, CKeys.AFD_POLL)
self._all_afd_handles.append(afd_group.handle)
self._vacant_afd_groups.add(afd_group)
lpOverlapped = ffi.new("LPOVERLAPPED")
poll_info = ffi.new("AFD_POLL_INFO *")
poll_info.Timeout = 2**63 - 1 # INT64_MAX
poll_info.NumberOfHandles = 1
poll_info.Exclusive = 0
poll_info.Handles[0].Handle = base_handle
poll_info.Handles[0].Status = 0
poll_info.Handles[0].Events = flags
try:
_check(
kernel32.DeviceIoControl(
afd_group.handle,
IoControlCodes.IOCTL_AFD_POLL,
poll_info,
ffi.sizeof("AFD_POLL_INFO"),
poll_info,
ffi.sizeof("AFD_POLL_INFO"),
ffi.NULL,
lpOverlapped,
)
)
except OSError as exc:
if exc.winerror != ErrorCodes.ERROR_IO_PENDING:
# This could happen if the socket handle got closed behind
# our back while a wait_* call was pending, and we tried
# to re-issue the call. Clear our state and wake up any
# pending calls.
del self._afd_waiters[base_handle]
# Do this last, because it could raise.
wake_all(waiters, exc)
return
op = AFDPollOp(lpOverlapped, poll_info, waiters, afd_group)
waiters.current_op = op
self._afd_ops[lpOverlapped] = op
afd_group.size += 1
if afd_group.size >= MAX_AFD_GROUP_SIZE:
self._vacant_afd_groups.remove(afd_group)
async def _afd_poll(self, sock, mode):
base_handle = _get_base_socket(sock)
waiters = self._afd_waiters.get(base_handle)
if waiters is None:
waiters = AFDWaiters()
self._afd_waiters[base_handle] = waiters
if getattr(waiters, mode) is not None:
raise _core.BusyResourceError
setattr(waiters, mode, _core.current_task())
# Could potentially raise if the handle is somehow invalid; that's OK,
# we let it escape.
self._refresh_afd(base_handle)
def abort_fn(_):
setattr(waiters, mode, None)
self._refresh_afd(base_handle)
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort_fn)
@_public
async def wait_readable(self, sock):
await self._afd_poll(sock, "read_task")
@_public
async def wait_writable(self, sock):
await self._afd_poll(sock, "write_task")
@_public
def notify_closing(self, handle):
handle = _get_base_socket(handle)
waiters = self._afd_waiters.get(handle)
if waiters is not None:
wake_all(waiters, _core.ClosedResourceError())
self._refresh_afd(handle)
################################################################
# Regular overlapped operations
################################################################
@_public
def register_with_iocp(self, handle):
self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED)
@_public
async def wait_overlapped(self, handle, lpOverlapped):
handle = _handle(handle)
if isinstance(lpOverlapped, int):
lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped)
if lpOverlapped in self._overlapped_waiters:
raise _core.BusyResourceError(
"another task is already waiting on that lpOverlapped"
)
task = _core.current_task()
self._overlapped_waiters[lpOverlapped] = task
raise_cancel = None
def abort(raise_cancel_):
nonlocal raise_cancel
raise_cancel = raise_cancel_
try:
_check(kernel32.CancelIoEx(handle, lpOverlapped))
except OSError as exc:
if exc.winerror == ErrorCodes.ERROR_NOT_FOUND:
# Too late to cancel. If this happens because the
# operation is already completed, we don't need to do
# anything; we'll get a notification of that completion
# soon. But another possibility is that the operation was
# performed on a handle that wasn't registered with our
# IOCP (ie, the user forgot to call register_with_iocp),
# in which case we're just never going to see the
# completion. To avoid an uncancellable infinite sleep in
# the latter case, we'll PostQueuedCompletionStatus here,
# and if our post arrives before the original completion
# does, we'll assume the handle wasn't registered.
_check(
kernel32.PostQueuedCompletionStatus(
self._iocp, 0, CKeys.LATE_CANCEL, lpOverlapped
)
)
# Keep the lpOverlapped referenced so its address
# doesn't get reused until our posted completion
# status has been processed. Otherwise, we can
# get confused about which completion goes with
# which I/O.
self._posted_too_late_to_cancel.add(lpOverlapped)
else: # pragma: no cover
raise _core.TrioInternalError(
"CancelIoEx failed with unexpected error"
) from exc
return _core.Abort.FAILED
info = await _core.wait_task_rescheduled(abort)
if lpOverlapped.Internal != 0:
# the lpOverlapped reports the error as an NT status code,
# which we must convert back to a Win32 error code before
# it will produce the right sorts of exceptions
code = ntdll.RtlNtStatusToDosError(lpOverlapped.Internal)
if code == ErrorCodes.ERROR_OPERATION_ABORTED:
if raise_cancel is not None:
raise_cancel()
else:
# We didn't request this cancellation, so assume
# it happened due to the underlying handle being
# closed before the operation could complete.
raise _core.ClosedResourceError("another task closed this resource")
else:
raise_winerror(code)
return info
async def _perform_overlapped(self, handle, submit_fn):
# submit_fn(lpOverlapped) submits some I/O
# it may raise an OSError with ERROR_IO_PENDING
# the handle must already be registered using
# register_with_iocp(handle)
# This always does a schedule point, but it's possible that the
# operation will not be cancellable, depending on how Windows is
# feeling today. So we need to check for cancellation manually.
await _core.checkpoint_if_cancelled()
lpOverlapped = ffi.new("LPOVERLAPPED")
try:
submit_fn(lpOverlapped)
except OSError as exc:
if exc.winerror != ErrorCodes.ERROR_IO_PENDING:
raise
await self.wait_overlapped(handle, lpOverlapped)
return lpOverlapped
@_public
async def write_overlapped(self, handle, data, file_offset=0):
with ffi.from_buffer(data) as cbuf:
def submit_write(lpOverlapped):
# yes, these are the real documented names
offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME
offset_fields.Offset = file_offset & 0xFFFFFFFF
offset_fields.OffsetHigh = file_offset >> 32
_check(
kernel32.WriteFile(
_handle(handle),
ffi.cast("LPCVOID", cbuf),
len(cbuf),
ffi.NULL,
lpOverlapped,
)
)
lpOverlapped = await self._perform_overlapped(handle, submit_write)
# this is "number of bytes transferred"
return lpOverlapped.InternalHigh
@_public
async def readinto_overlapped(self, handle, buffer, file_offset=0):
with ffi.from_buffer(buffer, require_writable=True) as cbuf:
def submit_read(lpOverlapped):
offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME
offset_fields.Offset = file_offset & 0xFFFFFFFF
offset_fields.OffsetHigh = file_offset >> 32
_check(
kernel32.ReadFile(
_handle(handle),
ffi.cast("LPVOID", cbuf),
len(cbuf),
ffi.NULL,
lpOverlapped,
)
)
lpOverlapped = await self._perform_overlapped(handle, submit_read)
return lpOverlapped.InternalHigh
################################################################
# Raw IOCP operations
################################################################
@_public
def current_iocp(self):
return int(ffi.cast("uintptr_t", self._iocp))
@contextmanager
@_public
def monitor_completion_key(self):
key = next(self._completion_key_counter)
queue = _core.UnboundedQueue()
self._completion_key_queues[key] = queue
try:
yield (key, queue)
finally:
del self._completion_key_queues[key]

View File

@@ -0,0 +1,200 @@
import inspect
import signal
import sys
from functools import wraps
import attr
import async_generator
from .._util import is_main_thread
if False:
from typing import Any, TypeVar, Callable
F = TypeVar("F", bound=Callable[..., Any])
# In ordinary single-threaded Python code, when you hit control-C, it raises
# an exception and automatically does all the regular unwinding stuff.
#
# In Trio code, we would like hitting control-C to raise an exception and
# automatically do all the regular unwinding stuff. In particular, we would
# like to maintain our invariant that all tasks always run to completion (one
# way or another), by unwinding all of them.
#
# But it's basically impossible to write the core task running code in such a
# way that it can maintain this invariant in the face of KeyboardInterrupt
# exceptions arising at arbitrary bytecode positions. Similarly, if a
# KeyboardInterrupt happened at the wrong moment inside pretty much any of our
# inter-task synchronization or I/O primitives, then the system state could
# get corrupted and prevent our being able to clean up properly.
#
# So, we need a way to defer KeyboardInterrupt processing from these critical
# sections.
#
# Things that don't work:
#
# - Listen for SIGINT and process it in a system task: works fine for
# well-behaved programs that regularly pass through the event loop, but if
# user-code goes into an infinite loop then it can't be interrupted. Which
# is unfortunate, since dealing with infinite loops is what
# KeyboardInterrupt is for!
#
# - Use pthread_sigmask to disable signal delivery during critical section:
# (a) windows has no pthread_sigmask, (b) python threads start with all
# signals unblocked, so if there are any threads around they'll receive the
# signal and then tell the main thread to run the handler, even if the main
# thread has that signal blocked.
#
# - Install a signal handler which checks a global variable to decide whether
# to raise the exception immediately (if we're in a non-critical section),
# or to schedule it on the event loop (if we're in a critical section). The
# problem here is that it's impossible to transition safely out of user code:
#
# with keyboard_interrupt_enabled:
# msg = coro.send(value)
#
# If this raises a KeyboardInterrupt, it might be because the coroutine got
# interrupted and has unwound... or it might be the KeyboardInterrupt
# arrived just *after* 'send' returned, so the coroutine is still running
# but we just lost the message it sent. (And worse, in our actual task
# runner, the send is hidden inside a utility function etc.)
#
# Solution:
#
# Mark *stack frames* as being interrupt-safe or interrupt-unsafe, and from
# the signal handler check which kind of frame we're currently in when
# deciding whether to raise or schedule the exception.
#
# There are still some cases where this can fail, like if someone hits
# control-C while the process is in the event loop, and then it immediately
# enters an infinite loop in user code. In this case the user has to hit
# control-C a second time. And of course if the user code is written so that
# it doesn't actually exit after a task crashes and everything gets cancelled,
# then there's not much to be done. (Hitting control-C repeatedly might help,
# but in general the solution is to kill the process some other way, just like
# for any Python program that's written to catch and ignore
# KeyboardInterrupt.)
# We use this special string as a unique key into the frame locals dictionary.
# The @ ensures it is not a valid identifier and can't clash with any possible
# real local name. See: https://github.com/python-trio/trio/issues/469
LOCALS_KEY_KI_PROTECTION_ENABLED = "@TRIO_KI_PROTECTION_ENABLED"
# NB: according to the signal.signal docs, 'frame' can be None on entry to
# this function:
def ki_protection_enabled(frame):
while frame is not None:
if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals:
return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED]
if frame.f_code.co_name == "__del__":
return True
frame = frame.f_back
return True
def currently_ki_protected():
r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection
enabled.
It's surprisingly easy to think that one's :exc:`KeyboardInterrupt`
protection is enabled when it isn't, or vice-versa. This function tells
you what Trio thinks of the matter, which makes it useful for ``assert``\s
and unit tests.
Returns:
bool: True if protection is enabled, and False otherwise.
"""
return ki_protection_enabled(sys._getframe())
def _ki_protection_decorator(enabled):
def decorator(fn):
# In some version of Python, isgeneratorfunction returns true for
# coroutine functions, so we have to check for coroutine functions
# first.
if inspect.iscoroutinefunction(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# See the comment for regular generators below
coro = fn(*args, **kwargs)
coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return coro
return wrapper
elif inspect.isgeneratorfunction(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# It's important that we inject this directly into the
# generator's locals, as opposed to setting it here and then
# doing 'yield from'. The reason is, if a generator is
# throw()n into, then it may magically pop to the top of the
# stack. And @contextmanager generators in particular are a
# case where we often want KI protection, and which are often
# thrown into! See:
# https://bugs.python.org/issue29590
gen = fn(*args, **kwargs)
gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return gen
return wrapper
elif async_generator.isasyncgenfunction(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# See the comment for regular generators above
agen = fn(*args, **kwargs)
agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return agen
return wrapper
else:
@wraps(fn)
def wrapper(*args, **kwargs):
locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return fn(*args, **kwargs)
return wrapper
return decorator
enable_ki_protection = _ki_protection_decorator(True) # type: Callable[[F], F]
enable_ki_protection.__name__ = "enable_ki_protection"
disable_ki_protection = _ki_protection_decorator(False) # type: Callable[[F], F]
disable_ki_protection.__name__ = "disable_ki_protection"
@attr.s
class KIManager:
handler = attr.ib(default=None)
def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints):
assert self.handler is None
if (
not is_main_thread()
or signal.getsignal(signal.SIGINT) != signal.default_int_handler
):
return
def handler(signum, frame):
assert signum == signal.SIGINT
protection_enabled = ki_protection_enabled(frame)
if protection_enabled or restrict_keyboard_interrupt_to_checkpoints:
deliver_cb()
else:
raise KeyboardInterrupt
self.handler = handler
signal.signal(signal.SIGINT, handler)
def close(self):
if self.handler is not None:
if signal.getsignal(signal.SIGINT) is self.handler:
signal.signal(signal.SIGINT, signal.default_int_handler)
self.handler = None

View File

@@ -0,0 +1,95 @@
# Runvar implementations
import attr
from . import _run
from .._util import Final
@attr.s(eq=False, hash=False, slots=True)
class _RunVarToken:
_no_value = object()
_var = attr.ib()
previous_value = attr.ib(default=_no_value)
redeemed = attr.ib(default=False, init=False)
@classmethod
def empty(cls, var):
return cls(var)
@attr.s(eq=False, hash=False, slots=True)
class RunVar(metaclass=Final):
"""The run-local variant of a context variable.
:class:`RunVar` objects are similar to context variable objects,
except that they are shared across a single call to :func:`trio.run`
rather than a single task.
"""
_NO_DEFAULT = object()
_name = attr.ib()
_default = attr.ib(default=_NO_DEFAULT)
def get(self, default=_NO_DEFAULT):
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
return _run.GLOBAL_RUN_CONTEXT.runner._locals[self]
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
# contextvars consistency
if default is not self._NO_DEFAULT:
return default
if self._default is not self._NO_DEFAULT:
return self._default
raise LookupError(self) from None
def set(self, value):
"""Sets the value of this :class:`RunVar` for this current run
call.
"""
try:
old_value = self.get()
except LookupError:
token = _RunVarToken.empty(self)
else:
token = _RunVarToken(self, old_value)
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
return token
def reset(self, token):
"""Resets the value of this :class:`RunVar` to what it was
previously specified by the token.
"""
if token is None:
raise TypeError("token must not be none")
if token.redeemed:
raise ValueError("token has already been used")
if token._var is not self:
raise ValueError("token is not for us")
previous = token.previous_value
try:
if previous is _RunVarToken._no_value:
_run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
else:
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context")
token.redeemed = True
def __repr__(self):
return "<RunVar name={!r}>".format(self._name)

View File

@@ -0,0 +1,165 @@
import time
from math import inf
from .. import _core
from ._run import GLOBAL_RUN_CONTEXT
from .._abc import Clock
from .._util import Final
################################################################
# The glorious MockClock
################################################################
# Prior art:
# https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html
# https://github.com/ztellman/manifold/issues/57
class MockClock(Clock, metaclass=Final):
"""A user-controllable clock suitable for writing tests.
Args:
rate (float): the initial :attr:`rate`.
autojump_threshold (float): the initial :attr:`autojump_threshold`.
.. attribute:: rate
How many seconds of clock time pass per second of real time. Default is
0.0, i.e. the clock only advances through manuals calls to :meth:`jump`
or when the :attr:`autojump_threshold` is triggered. You can assign to
this attribute to change it.
.. attribute:: autojump_threshold
The clock keeps an eye on the run loop, and if at any point it detects
that all tasks have been blocked for this many real seconds (i.e.,
according to the actual clock, not this clock), then the clock
automatically jumps ahead to the run loop's next scheduled
timeout. Default is :data:`math.inf`, i.e., to never autojump. You can
assign to this attribute to change it.
Basically the idea is that if you have code or tests that use sleeps
and timeouts, you can use this to make it run much faster, totally
automatically. (At least, as long as those sleeps/timeouts are
happening inside Trio; if your test involves talking to external
service and waiting for it to timeout then obviously we can't help you
there.)
You should set this to the smallest value that lets you reliably avoid
"false alarms" where some I/O is in flight (e.g. between two halves of
a socketpair) but the threshold gets triggered and time gets advanced
anyway. This will depend on the details of your tests and test
environment. If you aren't doing any I/O (like in our sleeping example
above) then just set it to zero, and the clock will jump whenever all
tasks are blocked.
.. note:: If you use ``autojump_threshold`` and
`wait_all_tasks_blocked` at the same time, then you might wonder how
they interact, since they both cause things to happen after the run
loop goes idle for some time. The answer is:
`wait_all_tasks_blocked` takes priority. If there's a task blocked
in `wait_all_tasks_blocked`, then the autojump feature treats that
as active task and does *not* jump the clock.
"""
def __init__(self, rate=0.0, autojump_threshold=inf):
# when the real clock said 'real_base', the virtual time was
# 'virtual_base', and since then it's advanced at 'rate' virtual
# seconds per real second.
self._real_base = 0.0
self._virtual_base = 0.0
self._rate = 0.0
self._autojump_threshold = 0.0
# kept as an attribute so that our tests can monkeypatch it
self._real_clock = time.perf_counter
# use the property update logic to set initial values
self.rate = rate
self.autojump_threshold = autojump_threshold
def __repr__(self):
return "<MockClock, time={:.7f}, rate={} @ {:#x}>".format(
self.current_time(), self._rate, id(self)
)
@property
def rate(self):
return self._rate
@rate.setter
def rate(self, new_rate):
if new_rate < 0:
raise ValueError("rate must be >= 0")
else:
real = self._real_clock()
virtual = self._real_to_virtual(real)
self._virtual_base = virtual
self._real_base = real
self._rate = float(new_rate)
@property
def autojump_threshold(self):
return self._autojump_threshold
@autojump_threshold.setter
def autojump_threshold(self, new_autojump_threshold):
self._autojump_threshold = float(new_autojump_threshold)
self._try_resync_autojump_threshold()
# runner.clock_autojump_threshold is an internal API that isn't easily
# usable by custom third-party Clock objects. If you need access to this
# functionality, let us know, and we'll figure out how to make a public
# API. Discussion:
#
# https://github.com/python-trio/trio/issues/1587
def _try_resync_autojump_threshold(self):
try:
runner = GLOBAL_RUN_CONTEXT.runner
if runner.is_guest:
runner.force_guest_tick_asap()
except AttributeError:
pass
else:
runner.clock_autojump_threshold = self._autojump_threshold
# Invoked by the run loop when runner.clock_autojump_threshold is
# exceeded.
def _autojump(self):
statistics = _core.current_statistics()
jump = statistics.seconds_to_next_deadline
if 0 < jump < inf:
self.jump(jump)
def _real_to_virtual(self, real):
real_offset = real - self._real_base
virtual_offset = self._rate * real_offset
return self._virtual_base + virtual_offset
def start_clock(self):
self._try_resync_autojump_threshold()
def current_time(self):
return self._real_to_virtual(self._real_clock())
def deadline_to_sleep_time(self, deadline):
virtual_timeout = deadline - self.current_time()
if virtual_timeout <= 0:
return 0
elif self._rate > 0:
return virtual_timeout / self._rate
else:
return 999999999
def jump(self, seconds):
"""Manually advance the clock by the given number of seconds.
Args:
seconds (float): the number of seconds to jump the clock forward.
Raises:
ValueError: if you try to pass a negative value for ``seconds``.
"""
if seconds < 0:
raise ValueError("time can't go backwards")
self._virtual_base += seconds

View File

@@ -0,0 +1,516 @@
import sys
import traceback
import textwrap
import warnings
import attr
################################################################
# MultiError
################################################################
def _filter_impl(handler, root_exc):
# We have a tree of MultiError's, like:
#
# MultiError([
# ValueError,
# MultiError([
# KeyError,
# ValueError,
# ]),
# ])
#
# or similar.
#
# We want to
# 1) apply the filter to each of the leaf exceptions -- each leaf
# might stay the same, be replaced (with the original exception
# potentially sticking around as __context__ or __cause__), or
# disappear altogether.
# 2) simplify the resulting tree -- remove empty nodes, and replace
# singleton MultiError's with their contents, e.g.:
# MultiError([KeyError]) -> KeyError
# (This can happen recursively, e.g. if the two ValueErrors above
# get caught then we'll just be left with a bare KeyError.)
# 3) preserve sensible tracebacks
#
# It's the tracebacks that are most confusing. As a MultiError
# propagates through the stack, it accumulates traceback frames, but
# the exceptions inside it don't. Semantically, the traceback for a
# leaf exception is the concatenation the tracebacks of all the
# exceptions you see when traversing the exception tree from the root
# to that leaf. Our correctness invariant is that this concatenated
# traceback should be the same before and after.
#
# The easy way to do that would be to, at the beginning of this
# function, "push" all tracebacks down to the leafs, so all the
# MultiErrors have __traceback__=None, and all the leafs have complete
# tracebacks. But whenever possible, we'd actually prefer to keep
# tracebacks as high up in the tree as possible, because this lets us
# keep only a single copy of the common parts of these exception's
# tracebacks. This is cheaper (in memory + time -- tracebacks are
# unpleasantly quadratic-ish to work with, and this might matter if
# you have thousands of exceptions, which can happen e.g. after
# cancelling a large task pool, and no-one will ever look at their
# tracebacks!), and more importantly, factoring out redundant parts of
# the tracebacks makes them more readable if/when users do see them.
#
# So instead our strategy is:
# - first go through and construct the new tree, preserving any
# unchanged subtrees
# - then go through the original tree (!) and push tracebacks down
# until either we hit a leaf, or we hit a subtree which was
# preserved in the new tree.
# This used to also support async handler functions. But that runs into:
# https://bugs.python.org/issue29600
# which is difficult to fix on our end.
# Filters a subtree, ignoring tracebacks, while keeping a record of
# which MultiErrors were preserved unchanged
def filter_tree(exc, preserved):
if isinstance(exc, MultiError):
new_exceptions = []
changed = False
for child_exc in exc.exceptions:
new_child_exc = filter_tree(child_exc, preserved)
if new_child_exc is not child_exc:
changed = True
if new_child_exc is not None:
new_exceptions.append(new_child_exc)
if not new_exceptions:
return None
elif changed:
return MultiError(new_exceptions)
else:
preserved.add(id(exc))
return exc
else:
new_exc = handler(exc)
# Our version of implicit exception chaining
if new_exc is not None and new_exc is not exc:
new_exc.__context__ = exc
return new_exc
def push_tb_down(tb, exc, preserved):
if id(exc) in preserved:
return
new_tb = concat_tb(tb, exc.__traceback__)
if isinstance(exc, MultiError):
for child_exc in exc.exceptions:
push_tb_down(new_tb, child_exc, preserved)
exc.__traceback__ = None
else:
exc.__traceback__ = new_tb
preserved = set()
new_root_exc = filter_tree(root_exc, preserved)
push_tb_down(None, root_exc, preserved)
# Delete the local functions to avoid a reference cycle (see
# test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage)
del filter_tree, push_tb_down
return new_root_exc
# Normally I'm a big fan of (a)contextmanager, but in this case I found it
# easier to use the raw context manager protocol, because it makes it a lot
# easier to reason about how we're mutating the traceback as we go. (End
# result: if the exception gets modified, then the 'raise' here makes this
# frame show up in the traceback; otherwise, we leave no trace.)
@attr.s(frozen=True)
class MultiErrorCatcher:
_handler = attr.ib()
def __enter__(self):
pass
def __exit__(self, etype, exc, tb):
if exc is not None:
filtered_exc = MultiError.filter(self._handler, exc)
if filtered_exc is exc:
# Let the interpreter re-raise it
return False
if filtered_exc is None:
# Swallow the exception
return True
# When we raise filtered_exc, Python will unconditionally blow
# away its __context__ attribute and replace it with the original
# exc we caught. So after we raise it, we have to pause it while
# it's in flight to put the correct __context__ back.
old_context = filtered_exc.__context__
try:
raise filtered_exc
finally:
_, value, _ = sys.exc_info()
assert value is filtered_exc
value.__context__ = old_context
# delete references from locals to avoid creating cycles
# see test_MultiError_catch_doesnt_create_cyclic_garbage
del _, filtered_exc, value
class MultiError(BaseException):
"""An exception that contains other exceptions; also known as an
"inception".
It's main use is to represent the situation when multiple child tasks all
raise errors "in parallel".
Args:
exceptions (list): The exceptions
Returns:
If ``len(exceptions) == 1``, returns that exception. This means that a
call to ``MultiError(...)`` is not guaranteed to return a
:exc:`MultiError` object!
Otherwise, returns a new :exc:`MultiError` object.
Raises:
TypeError: if any of the passed in objects are not instances of
:exc:`BaseException`.
"""
def __init__(self, exceptions):
# Avoid recursion when exceptions[0] returned by __new__() happens
# to be a MultiError and subsequently __init__() is called.
if hasattr(self, "exceptions"):
# __init__ was already called on this object
assert len(exceptions) == 1 and exceptions[0] is self
return
self.exceptions = exceptions
def __new__(cls, exceptions):
exceptions = list(exceptions)
for exc in exceptions:
if not isinstance(exc, BaseException):
raise TypeError("Expected an exception object, not {!r}".format(exc))
if len(exceptions) == 1:
# If this lone object happens to itself be a MultiError, then
# Python will implicitly call our __init__ on it again. See
# special handling in __init__.
return exceptions[0]
else:
# The base class __new__() implicitly invokes our __init__, which
# is what we want.
#
# In an earlier version of the code, we didn't define __init__ and
# simply set the `exceptions` attribute directly on the new object.
# However, linters expect attributes to be initialized in __init__.
return BaseException.__new__(cls, exceptions)
def __str__(self):
return ", ".join(repr(exc) for exc in self.exceptions)
def __repr__(self):
return "<MultiError: {}>".format(self)
@classmethod
def filter(cls, handler, root_exc):
"""Apply the given ``handler`` to all the exceptions in ``root_exc``.
Args:
handler: A callable that takes an atomic (non-MultiError) exception
as input, and returns either a new exception object or None.
root_exc: An exception, often (though not necessarily) a
:exc:`MultiError`.
Returns:
A new exception object in which each component exception ``exc`` has
been replaced by the result of running ``handler(exc)`` or, if
``handler`` returned None for all the inputs, returns None.
"""
return _filter_impl(handler, root_exc)
@classmethod
def catch(cls, handler):
"""Return a context manager that catches and re-throws exceptions
after running :meth:`filter` on them.
Args:
handler: as for :meth:`filter`
"""
return MultiErrorCatcher(handler)
# Clean up exception printing:
MultiError.__module__ = "trio"
################################################################
# concat_tb
################################################################
# We need to compute a new traceback that is the concatenation of two existing
# tracebacks. This requires copying the entries in 'head' and then pointing
# the final tb_next to 'tail'.
#
# NB: 'tail' might be None, which requires some special handling in the ctypes
# version.
#
# The complication here is that Python doesn't actually support copying or
# modifying traceback objects, so we have to get creative...
#
# On CPython, we use ctypes. On PyPy, we use "transparent proxies".
#
# Jinja2 is a useful source of inspiration:
# https://github.com/pallets/jinja/blob/master/jinja2/debug.py
try:
import tputil
except ImportError:
have_tproxy = False
else:
have_tproxy = True
if have_tproxy:
# http://doc.pypy.org/en/latest/objspace-proxies.html
def copy_tb(base_tb, tb_next):
def controller(operation):
# Rationale for pragma: I looked fairly carefully and tried a few
# things, and AFAICT it's not actually possible to get any
# 'opname' that isn't __getattr__ or __getattribute__. So there's
# no missing test we could add, and no value in coverage nagging
# us about adding one.
if operation.opname in [
"__getattribute__",
"__getattr__",
]: # pragma: no cover
if operation.args[0] == "tb_next":
return tb_next
return operation.delegate()
return tputil.make_proxy(controller, type(base_tb), base_tb)
else:
# ctypes it is
import ctypes
# How to handle refcounting? I don't want to use ctypes.py_object because
# I don't understand or trust it, and I don't want to use
# ctypes.pythonapi.Py_{Inc,Dec}Ref because we might clash with user code
# that also tries to use them but with different types. So private _ctypes
# APIs it is!
import _ctypes
class CTraceback(ctypes.Structure):
_fields_ = [
("PyObject_HEAD", ctypes.c_byte * object().__sizeof__()),
("tb_next", ctypes.c_void_p),
("tb_frame", ctypes.c_void_p),
("tb_lasti", ctypes.c_int),
("tb_lineno", ctypes.c_int),
]
def copy_tb(base_tb, tb_next):
# TracebackType has no public constructor, so allocate one the hard way
try:
raise ValueError
except ValueError as exc:
new_tb = exc.__traceback__
c_new_tb = CTraceback.from_address(id(new_tb))
# At the C level, tb_next either pointer to the next traceback or is
# NULL. c_void_p and the .tb_next accessor both convert NULL to None,
# but we shouldn't DECREF None just because we assigned to a NULL
# pointer! Here we know that our new traceback has only 1 frame in it,
# so we can assume the tb_next field is NULL.
assert c_new_tb.tb_next is None
# If tb_next is None, then we want to set c_new_tb.tb_next to NULL,
# which it already is, so we're done. Otherwise, we have to actually
# do some work:
if tb_next is not None:
_ctypes.Py_INCREF(tb_next)
c_new_tb.tb_next = id(tb_next)
assert c_new_tb.tb_frame is not None
_ctypes.Py_INCREF(base_tb.tb_frame)
old_tb_frame = new_tb.tb_frame
c_new_tb.tb_frame = id(base_tb.tb_frame)
_ctypes.Py_DECREF(old_tb_frame)
c_new_tb.tb_lasti = base_tb.tb_lasti
c_new_tb.tb_lineno = base_tb.tb_lineno
try:
return new_tb
finally:
# delete references from locals to avoid creating cycles
# see test_MultiError_catch_doesnt_create_cyclic_garbage
del new_tb, old_tb_frame
def concat_tb(head, tail):
# We have to use an iterative algorithm here, because in the worst case
# this might be a RecursionError stack that is by definition too deep to
# process by recursion!
head_tbs = []
pointer = head
while pointer is not None:
head_tbs.append(pointer)
pointer = pointer.tb_next
current_head = tail
for head_tb in reversed(head_tbs):
current_head = copy_tb(head_tb, tb_next=current_head)
return current_head
################################################################
# MultiError traceback formatting
#
# What follows is terrible, terrible monkey patching of
# traceback.TracebackException to add support for handling
# MultiErrors
################################################################
traceback_exception_original_init = traceback.TracebackException.__init__
def traceback_exception_init(
self,
exc_type,
exc_value,
exc_traceback,
*,
limit=None,
lookup_lines=True,
capture_locals=False,
compact=False,
_seen=None,
):
if sys.version_info >= (3, 10):
kwargs = {"compact": compact}
else:
kwargs = {}
# Capture the original exception and its cause and context as TracebackExceptions
traceback_exception_original_init(
self,
exc_type,
exc_value,
exc_traceback,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
_seen=_seen,
**kwargs,
)
seen_was_none = _seen is None
if _seen is None:
_seen = set()
# Capture each of the exceptions in the MultiError along with each of their causes and contexts
if isinstance(exc_value, MultiError):
embedded = []
for exc in exc_value.exceptions:
if id(exc) not in _seen:
embedded.append(
traceback.TracebackException.from_exception(
exc,
limit=limit,
lookup_lines=lookup_lines,
capture_locals=capture_locals,
# copy the set of _seen exceptions so that duplicates
# shared between sub-exceptions are not omitted
_seen=None if seen_was_none else set(_seen),
)
)
self.embedded = embedded
else:
self.embedded = []
traceback.TracebackException.__init__ = traceback_exception_init # type: ignore
traceback_exception_original_format = traceback.TracebackException.format
def traceback_exception_format(self, *, chain=True):
yield from traceback_exception_original_format(self, chain=chain)
for i, exc in enumerate(self.embedded):
yield "\nDetails of embedded exception {}:\n\n".format(i + 1)
yield from (textwrap.indent(line, " " * 2) for line in exc.format(chain=chain))
traceback.TracebackException.format = traceback_exception_format # type: ignore
def trio_excepthook(etype, value, tb):
for chunk in traceback.format_exception(etype, value, tb):
sys.stderr.write(chunk)
monkeypatched_or_warned = False
if "IPython" in sys.modules:
import IPython
ip = IPython.get_ipython()
if ip is not None:
if ip.custom_exceptions != ():
warnings.warn(
"IPython detected, but you already have a custom exception "
"handler installed. I'll skip installing Trio's custom "
"handler, but this means MultiErrors will not show full "
"tracebacks.",
category=RuntimeWarning,
)
monkeypatched_or_warned = True
else:
def trio_show_traceback(self, etype, value, tb, tb_offset=None):
# XX it would be better to integrate with IPython's fancy
# exception formatting stuff (and not ignore tb_offset)
trio_excepthook(etype, value, tb)
ip.set_custom_exc((MultiError,), trio_show_traceback)
monkeypatched_or_warned = True
if sys.excepthook is sys.__excepthook__:
sys.excepthook = trio_excepthook
monkeypatched_or_warned = True
# Ubuntu's system Python has a sitecustomize.py file that import
# apport_python_hook and replaces sys.excepthook.
#
# The custom hook captures the error for crash reporting, and then calls
# sys.__excepthook__ to actually print the error.
#
# We don't mind it capturing the error for crash reporting, but we want to
# take over printing the error. So we monkeypatch the apport_python_hook
# module so that instead of calling sys.__excepthook__, it calls our custom
# hook.
#
# More details: https://github.com/python-trio/trio/issues/1065
if getattr(sys.excepthook, "__name__", None) == "apport_excepthook":
import apport_python_hook
assert sys.excepthook is apport_python_hook.apport_excepthook
# Give it a descriptive name as a hint for anyone who's stuck trying to
# debug this mess later.
class TrioFakeSysModuleForApport:
pass
fake_sys = TrioFakeSysModuleForApport()
fake_sys.__dict__.update(sys.__dict__)
fake_sys.__excepthook__ = trio_excepthook # type: ignore
apport_python_hook.sys = fake_sys
monkeypatched_or_warned = True
if not monkeypatched_or_warned:
warnings.warn(
"You seem to already have a custom sys.excepthook handler "
"installed. I'll skip installing Trio's custom handler, but this "
"means MultiErrors will not show full tracebacks.",
category=RuntimeWarning,
)

View File

@@ -0,0 +1,215 @@
# ParkingLot provides an abstraction for a fair waitqueue with cancellation
# and requeueing support. Inspiration:
#
# https://webkit.org/blog/6161/locking-in-webkit/
# https://amanieu.github.io/parking_lot/
#
# which were in turn heavily influenced by
#
# http://gee.cs.oswego.edu/dl/papers/aqs.pdf
#
# Compared to these, our use of cooperative scheduling allows some
# simplifications (no need for internal locking). On the other hand, the need
# to support Trio's strong cancellation semantics adds some complications
# (tasks need to know where they're queued so they can cancel). Also, in the
# above work, the ParkingLot is a global structure that holds a collection of
# waitqueues keyed by lock address, and which are opportunistically allocated
# and destroyed as contention arises; this allows the worst-case memory usage
# for all waitqueues to be O(#tasks). Here we allocate a separate wait queue
# for each synchronization object, so we're O(#objects + #tasks). This isn't
# *so* bad since compared to our synchronization objects are heavier than
# theirs and our tasks are lighter, so for us #objects is smaller and #tasks
# is larger.
#
# This is in the core because for two reasons. First, it's used by
# UnboundedQueue, and UnboundedQueue is used for a number of things in the
# core. And second, it's responsible for providing fairness to all of our
# high-level synchronization primitives (locks, queues, etc.). For now with
# our FIFO scheduler this is relatively trivial (it's just a FIFO waitqueue),
# but in the future we ever start support task priorities or fair scheduling
#
# https://github.com/python-trio/trio/issues/32
#
# then all we'll have to do is update this. (Well, full-fledged task
# priorities might also require priority inheritance, which would require more
# work.)
#
# For discussion of data structures to use here, see:
#
# https://github.com/dabeaz/curio/issues/136
#
# (and also the articles above). Currently we use a SortedDict ordered by a
# global monotonic counter that ensures FIFO ordering. The main advantage of
# this is that it's easy to implement :-). An intrusive doubly-linked list
# would also be a natural approach, so long as we only handle FIFO ordering.
#
# XX: should we switch to the shared global ParkingLot approach?
#
# XX: we should probably add support for "parking tokens" to allow for
# task-fair RWlock (basically: when parking a task needs to be able to mark
# itself as a reader or a writer, and then a task-fair wakeup policy is, wake
# the next task, and if it's a reader than keep waking tasks so long as they
# are readers). Without this I think you can implement write-biased or
# read-biased RWlocks (by using two parking lots and drawing from whichever is
# preferred), but not task-fair -- and task-fair plays much more nicely with
# WFQ. (Consider what happens in the two-lot implementation if you're
# write-biased but all the pending writers are blocked at the scheduler level
# by the WFQ logic...)
# ...alternatively, "phase-fair" RWlocks are pretty interesting:
# http://www.cs.unc.edu/~anderson/papers/ecrts09b.pdf
# Useful summary:
# https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/locks/ReadWriteLock.html
#
# XX: if we do add WFQ, then we might have to drop the current feature where
# unpark returns the tasks that were unparked. Rationale: suppose that at the
# time we call unpark, the next task is deprioritized... and then, before it
# becomes runnable, a new task parks which *is* runnable. Ideally we should
# immediately wake the new task, and leave the old task on the queue for
# later. But this means we can't commit to which task we are unparking when
# unpark is called.
#
# See: https://github.com/python-trio/trio/issues/53
import attr
from collections import OrderedDict
from .. import _core
from .._util import Final
@attr.s(frozen=True, slots=True)
class _ParkingLotStatistics:
tasks_waiting = attr.ib()
@attr.s(eq=False, hash=False, slots=True)
class ParkingLot(metaclass=Final):
"""A fair wait queue with cancellation and requeueing.
This class encapsulates the tricky parts of implementing a wait
queue. It's useful for implementing higher-level synchronization
primitives like queues and locks.
In addition to the methods below, you can use ``len(parking_lot)`` to get
the number of parked tasks, and ``if parking_lot: ...`` to check whether
there are any parked tasks.
"""
# {task: None}, we just want a deque where we can quickly delete random
# items
_parked = attr.ib(factory=OrderedDict, init=False)
def __len__(self):
"""Returns the number of parked tasks."""
return len(self._parked)
def __bool__(self):
"""True if there are parked tasks, False otherwise."""
return bool(self._parked)
# XX this currently returns None
# if we ever add the ability to repark while one's resuming place in
# line (for false wakeups), then we could have it return a ticket that
# abstracts the "place in line" concept.
@_core.enable_ki_protection
async def park(self):
"""Park the current task until woken by a call to :meth:`unpark` or
:meth:`unpark_all`.
"""
task = _core.current_task()
self._parked[task] = None
task.custom_sleep_data = self
def abort_fn(_):
del task.custom_sleep_data._parked[task]
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort_fn)
def _pop_several(self, count):
for _ in range(min(count, len(self._parked))):
task, _ = self._parked.popitem(last=False)
yield task
@_core.enable_ki_protection
def unpark(self, *, count=1):
"""Unpark one or more tasks.
This wakes up ``count`` tasks that are blocked in :meth:`park`. If
there are fewer than ``count`` tasks parked, then wakes as many tasks
are available and then returns successfully.
Args:
count (int): the number of tasks to unpark.
"""
tasks = list(self._pop_several(count))
for task in tasks:
_core.reschedule(task)
return tasks
def unpark_all(self):
"""Unpark all parked tasks."""
return self.unpark(count=len(self))
@_core.enable_ki_protection
def repark(self, new_lot, *, count=1):
"""Move parked tasks from one :class:`ParkingLot` object to another.
This dequeues ``count`` tasks from one lot, and requeues them on
another, preserving order. For example::
async def parker(lot):
print("sleeping")
await lot.park()
print("woken")
async def main():
lot1 = trio.lowlevel.ParkingLot()
lot2 = trio.lowlevel.ParkingLot()
async with trio.open_nursery() as nursery:
nursery.start_soon(parker, lot1)
await trio.testing.wait_all_tasks_blocked()
assert len(lot1) == 1
assert len(lot2) == 0
lot1.repark(lot2)
assert len(lot1) == 0
assert len(lot2) == 1
# This wakes up the task that was originally parked in lot1
lot2.unpark()
If there are fewer than ``count`` tasks parked, then reparks as many
tasks as are available and then returns successfully.
Args:
new_lot (ParkingLot): the parking lot to move tasks to.
count (int): the number of tasks to move.
"""
if not isinstance(new_lot, ParkingLot):
raise TypeError("new_lot must be a ParkingLot")
for task in self._pop_several(count):
new_lot._parked[task] = None
task.custom_sleep_data = new_lot
def repark_all(self, new_lot):
"""Move all parked tasks from one :class:`ParkingLot` object to
another.
See :meth:`repark` for details.
"""
return self.repark(new_lot, count=len(self))
def statistics(self):
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this lot's
:meth:`park` method.
"""
return _ParkingLotStatistics(tasks_waiting=len(self._parked))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,171 @@
from threading import Thread, Lock
import outcome
from itertools import count
# The "thread cache" is a simple unbounded thread pool, i.e., it automatically
# spawns as many threads as needed to handle all the requests its given. Its
# only purpose is to cache worker threads so that they don't have to be
# started from scratch every time we want to delegate some work to a thread.
# It's expected that some higher-level code will track how many threads are in
# use to avoid overwhelming the system (e.g. the limiter= argument to
# trio.to_thread.run_sync).
#
# To maximize sharing, there's only one thread cache per process, even if you
# have multiple calls to trio.run.
#
# Guarantees:
#
# It's safe to call start_thread_soon simultaneously from
# multiple threads.
#
# Idle threads are chosen in LIFO order, i.e. we *don't* spread work evenly
# over all threads. Instead we try to let some threads do most of the work
# while others sit idle as much as possible. Compared to FIFO, this has better
# memory cache behavior, and it makes it easier to detect when we have too
# many threads, so idle ones can exit.
#
# This code assumes that 'dict' has the following properties:
#
# - __setitem__, __delitem__, and popitem are all thread-safe and atomic with
# respect to each other. This is guaranteed by the GIL.
#
# - popitem returns the most-recently-added item (i.e., __setitem__ + popitem
# give you a LIFO queue). This relies on dicts being insertion-ordered, like
# they are in py36+.
# How long a thread will idle waiting for new work before gives up and exits.
# This value is pretty arbitrary; I don't think it matters too much.
IDLE_TIMEOUT = 10 # seconds
name_counter = count()
class WorkerThread:
def __init__(self, thread_cache):
self._job = None
self._thread_cache = thread_cache
# This Lock is used in an unconventional way.
#
# "Unlocked" means we have a pending job that's been assigned to us;
# "locked" means that we don't.
#
# Initially we have no job, so it starts out in locked state.
self._worker_lock = Lock()
self._worker_lock.acquire()
thread = Thread(target=self._work, daemon=True)
thread.name = f"Trio worker thread {next(name_counter)}"
thread.start()
def _handle_job(self):
# Handle job in a separate method to ensure user-created
# objects are cleaned up in a consistent manner.
fn, deliver = self._job
self._job = None
result = outcome.capture(fn)
# Tell the cache that we're available to be assigned a new
# job. We do this *before* calling 'deliver', so that if
# 'deliver' triggers a new job, it can be assigned to us
# instead of spawning a new thread.
self._thread_cache._idle_workers[self] = None
deliver(result)
def _work(self):
while True:
if self._worker_lock.acquire(timeout=IDLE_TIMEOUT):
# We got a job
self._handle_job()
else:
# Timeout acquiring lock, so we can probably exit. But,
# there's a race condition: we might be assigned a job *just*
# as we're about to exit. So we have to check.
try:
del self._thread_cache._idle_workers[self]
except KeyError:
# Someone else removed us from the idle worker queue, so
# they must be in the process of assigning us a job - loop
# around and wait for it.
continue
else:
# We successfully removed ourselves from the idle
# worker queue, so no more jobs are incoming; it's safe to
# exit.
return
class ThreadCache:
def __init__(self):
self._idle_workers = {}
def start_thread_soon(self, fn, deliver):
try:
worker, _ = self._idle_workers.popitem()
except KeyError:
worker = WorkerThread(self)
worker._job = (fn, deliver)
worker._worker_lock.release()
THREAD_CACHE = ThreadCache()
def start_thread_soon(fn, deliver):
"""Runs ``deliver(outcome.capture(fn))`` in a worker thread.
Generally ``fn`` does some blocking work, and ``deliver`` delivers the
result back to whoever is interested.
This is a low-level, no-frills interface, very similar to using
`threading.Thread` to spawn a thread directly. The main difference is
that this function tries to re-use threads when possible, so it can be
a bit faster than `threading.Thread`.
Worker threads have the `~threading.Thread.daemon` flag set, which means
that if your main thread exits, worker threads will automatically be
killed. If you want to make sure that your ``fn`` runs to completion, then
you should make sure that the main thread remains alive until ``deliver``
is called.
It is safe to call this function simultaneously from multiple threads.
Args:
fn (sync function): Performs arbitrary blocking work.
deliver (sync function): Takes the `outcome.Outcome` of ``fn``, and
delivers it. *Must not block.*
Because worker threads are cached and reused for multiple calls, neither
function should mutate thread-level state, like `threading.local` objects
or if they do, they should be careful to revert their changes before
returning.
Note:
The split between ``fn`` and ``deliver`` serves two purposes. First,
it's convenient, since most callers need something like this anyway.
Second, it avoids a small race condition that could cause too many
threads to be spawned. Consider a program that wants to run several
jobs sequentially on a thread, so the main thread submits a job, waits
for it to finish, submits another job, etc. In theory, this program
should only need one worker thread. But what could happen is:
1. Worker thread: First job finishes, and calls ``deliver``.
2. Main thread: receives notification that the job finished, and calls
``start_thread_soon``.
3. Main thread: sees that no worker threads are marked idle, so spawns
a second worker thread.
4. Original worker thread: marks itself as idle.
To avoid this, threads mark themselves as idle *before* calling
``deliver``.
Is this potential extra thread a major problem? Maybe not, but it's
easy enough to avoid, and we figure that if the user is trying to
limit how many threads they're using then it's polite to respect that.
"""
THREAD_CACHE.start_thread_soon(fn, deliver)

View File

@@ -0,0 +1,270 @@
# These are the only functions that ever yield back to the task runner.
import types
import enum
import attr
import outcome
from . import _run
# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
# function, but you can inside a generator, and if you decorate your generator
# with @types.coroutine, then it's even awaitable. However, it's still not a
# real async function: in particular, it isn't recognized by
# inspect.iscoroutinefunction, and it doesn't trigger the unawaited coroutine
# tracking machinery. Since our traps are public APIs, we make them real async
# functions, and then this helper takes care of the actual yield:
@types.coroutine
def _async_yield(obj):
return (yield obj)
# This class object is used as a singleton.
# Not exported in the trio._core namespace, but imported directly by _run.
class CancelShieldedCheckpoint:
pass
async def cancel_shielded_checkpoint():
"""Introduce a schedule point, but not a cancel point.
This is *not* a :ref:`checkpoint <checkpoints>`, but it is half of a
checkpoint, and when combined with :func:`checkpoint_if_cancelled` it can
make a full checkpoint.
Equivalent to (but potentially more efficient than)::
with trio.CancelScope(shield=True):
await trio.lowlevel.checkpoint()
"""
return (await _async_yield(CancelShieldedCheckpoint)).unwrap()
# Return values for abort functions
class Abort(enum.Enum):
""":class:`enum.Enum` used as the return value from abort functions.
See :func:`wait_task_rescheduled` for details.
.. data:: SUCCEEDED
FAILED
"""
SUCCEEDED = 1
FAILED = 2
# Not exported in the trio._core namespace, but imported directly by _run.
@attr.s(frozen=True)
class WaitTaskRescheduled:
abort_func = attr.ib()
async def wait_task_rescheduled(abort_func):
"""Put the current task to sleep, with cancellation support.
This is the lowest-level API for blocking in Trio. Every time a
:class:`~trio.lowlevel.Task` blocks, it does so by calling this function
(usually indirectly via some higher-level API).
This is a tricky interface with no guard rails. If you can use
:class:`ParkingLot` or the built-in I/O wait functions instead, then you
should.
Generally the way it works is that before calling this function, you make
arrangements for "someone" to call :func:`reschedule` on the current task
at some later point.
Then you call :func:`wait_task_rescheduled`, passing in ``abort_func``, an
"abort callback".
(Terminology: in Trio, "aborting" is the process of attempting to
interrupt a blocked task to deliver a cancellation.)
There are two possibilities for what happens next:
1. "Someone" calls :func:`reschedule` on the current task, and
:func:`wait_task_rescheduled` returns or raises whatever value or error
was passed to :func:`reschedule`.
2. The call's context transitions to a cancelled state (e.g. due to a
timeout expiring). When this happens, the ``abort_func`` is called. Its
interface looks like::
def abort_func(raise_cancel):
...
return trio.lowlevel.Abort.SUCCEEDED # or FAILED
It should attempt to clean up any state associated with this call, and
in particular, arrange that :func:`reschedule` will *not* be called
later. If (and only if!) it is successful, then it should return
:data:`Abort.SUCCEEDED`, in which case the task will automatically be
rescheduled with an appropriate :exc:`~trio.Cancelled` error.
Otherwise, it should return :data:`Abort.FAILED`. This means that the
task can't be cancelled at this time, and still has to make sure that
"someone" eventually calls :func:`reschedule`.
At that point there are again two possibilities. You can simply ignore
the cancellation altogether: wait for the operation to complete and
then reschedule and continue as normal. (For example, this is what
:func:`trio.to_thread.run_sync` does if cancellation is disabled.)
The other possibility is that the ``abort_func`` does succeed in
cancelling the operation, but for some reason isn't able to report that
right away. (Example: on Windows, it's possible to request that an
async ("overlapped") I/O operation be cancelled, but this request is
*also* asynchronous you don't find out until later whether the
operation was actually cancelled or not.) To report a delayed
cancellation, then you should reschedule the task yourself, and call
the ``raise_cancel`` callback passed to ``abort_func`` to raise a
:exc:`~trio.Cancelled` (or possibly :exc:`KeyboardInterrupt`) exception
into this task. Either of the approaches sketched below can work::
# Option 1:
# Catch the exception from raise_cancel and inject it into the task.
# (This is what Trio does automatically for you if you return
# Abort.SUCCEEDED.)
trio.lowlevel.reschedule(task, outcome.capture(raise_cancel))
# Option 2:
# wait to be woken by "someone", and then decide whether to raise
# the error from inside the task.
outer_raise_cancel = None
def abort(inner_raise_cancel):
nonlocal outer_raise_cancel
outer_raise_cancel = inner_raise_cancel
TRY_TO_CANCEL_OPERATION()
return trio.lowlevel.Abort.FAILED
await wait_task_rescheduled(abort)
if OPERATION_WAS_SUCCESSFULLY_CANCELLED:
# raises the error
outer_raise_cancel()
In any case it's guaranteed that we only call the ``abort_func`` at most
once per call to :func:`wait_task_rescheduled`.
Sometimes, it's useful to be able to share some mutable sleep-related data
between the sleeping task, the abort function, and the waking task. You
can use the sleeping task's :data:`~Task.custom_sleep_data` attribute to
store this data, and Trio won't touch it, except to make sure that it gets
cleared when the task is rescheduled.
.. warning::
If your ``abort_func`` raises an error, or returns any value other than
:data:`Abort.SUCCEEDED` or :data:`Abort.FAILED`, then Trio will crash
violently. Be careful! Similarly, it is entirely possible to deadlock a
Trio program by failing to reschedule a blocked task, or cause havoc by
calling :func:`reschedule` too many times. Remember what we said up
above about how you should use a higher-level API if at all possible?
"""
return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()
# Not exported in the trio._core namespace, but imported directly by _run.
@attr.s(frozen=True)
class PermanentlyDetachCoroutineObject:
final_outcome = attr.ib()
async def permanently_detach_coroutine_object(final_outcome):
"""Permanently detach the current task from the Trio scheduler.
Normally, a Trio task doesn't exit until its coroutine object exits. When
you call this function, Trio acts like the coroutine object just exited
and the task terminates with the given outcome. This is useful if you want
to permanently switch the coroutine object over to a different coroutine
runner.
When the calling coroutine enters this function it's running under Trio,
and when the function returns it's running under the foreign coroutine
runner.
You should make sure that the coroutine object has released any
Trio-specific resources it has acquired (e.g. nurseries).
Args:
final_outcome (outcome.Outcome): Trio acts as if the current task exited
with the given return value or exception.
Returns or raises whatever value or exception the new coroutine runner
uses to resume the coroutine.
"""
if _run.current_task().child_nurseries:
raise RuntimeError(
"can't permanently detach a coroutine object with open nurseries"
)
return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome))
async def temporarily_detach_coroutine_object(abort_func):
"""Temporarily detach the current coroutine object from the Trio
scheduler.
When the calling coroutine enters this function it's running under Trio,
and when the function returns it's running under the foreign coroutine
runner.
The Trio :class:`Task` will continue to exist, but will be suspended until
you use :func:`reattach_detached_coroutine_object` to resume it. In the
mean time, you can use another coroutine runner to schedule the coroutine
object. In fact, you have to the function doesn't return until the
coroutine is advanced from outside.
Note that you'll need to save the current :class:`Task` object to later
resume; you can retrieve it with :func:`current_task`. You can also use
this :class:`Task` object to retrieve the coroutine object see
:data:`Task.coro`.
Args:
abort_func: Same as for :func:`wait_task_rescheduled`, except that it
must return :data:`Abort.FAILED`. (If it returned
:data:`Abort.SUCCEEDED`, then Trio would attempt to reschedule the
detached task directly without going through
:func:`reattach_detached_coroutine_object`, which would be bad.)
Your ``abort_func`` should still arrange for whatever the coroutine
object is doing to be cancelled, and then reattach to Trio and call
the ``raise_cancel`` callback, if possible.
Returns or raises whatever value or exception the new coroutine runner
uses to resume the coroutine.
"""
return await _async_yield(WaitTaskRescheduled(abort_func))
async def reattach_detached_coroutine_object(task, yield_value):
"""Reattach a coroutine object that was detached using
:func:`temporarily_detach_coroutine_object`.
When the calling coroutine enters this function it's running under the
foreign coroutine runner, and when the function returns it's running under
Trio.
This must be called from inside the coroutine being resumed, and yields
whatever value you pass in. (Presumably you'll pass a value that will
cause the current coroutine runner to stop scheduling this task.) Then the
coroutine is resumed by the Trio scheduler at the next opportunity.
Args:
task (Task): The Trio task object that the current coroutine was
detached from.
yield_value (object): The object to yield to the current coroutine
runner.
"""
# This is a kind of crude check in particular, it can fail if the
# passed-in task is where the coroutine *runner* is running. But this is
# an experts-only interface, and there's no easy way to do a more accurate
# check, so I guess that's OK.
if not task.coro.cr_running:
raise RuntimeError("given task does not match calling coroutine")
_run.reschedule(task, outcome.Value("reattaching"))
value = await _async_yield(yield_value)
assert value == outcome.Value("reattaching")

View File

@@ -0,0 +1,149 @@
import attr
from .. import _core
from .._deprecate import deprecated
from .._util import Final
@attr.s(frozen=True)
class _UnboundedQueueStats:
qsize = attr.ib()
tasks_waiting = attr.ib()
class UnboundedQueue(metaclass=Final):
"""An unbounded queue suitable for certain unusual forms of inter-task
communication.
This class is designed for use as a queue in cases where the producer for
some reason cannot be subjected to back-pressure, i.e., :meth:`put_nowait`
has to always succeed. In order to prevent the queue backlog from actually
growing without bound, the consumer API is modified to dequeue items in
"batches". If a consumer task processes each batch without yielding, then
this helps achieve (but does not guarantee) an effective bound on the
queue's memory use, at the cost of potentially increasing system latencies
in general. You should generally prefer to use a memory channel
instead if you can.
Currently each batch completely empties the queue, but `this may change in
the future <https://github.com/python-trio/trio/issues/51>`__.
A :class:`UnboundedQueue` object can be used as an asynchronous iterator,
where each iteration returns a new batch of items. I.e., these two loops
are equivalent::
async for batch in queue:
...
while True:
obj = await queue.get_batch()
...
"""
@deprecated(
"0.9.0",
issue=497,
thing="trio.lowlevel.UnboundedQueue",
instead="trio.open_memory_channel(math.inf)",
)
def __init__(self):
self._lot = _core.ParkingLot()
self._data = []
# used to allow handoff from put to the first task in the lot
self._can_get = False
def __repr__(self):
return "<UnboundedQueue holding {} items>".format(len(self._data))
def qsize(self):
"""Returns the number of items currently in the queue."""
return len(self._data)
def empty(self):
"""Returns True if the queue is empty, False otherwise.
There is some subtlety to interpreting this method's return value: see
`issue #63 <https://github.com/python-trio/trio/issues/63>`__.
"""
return not self._data
@_core.enable_ki_protection
def put_nowait(self, obj):
"""Put an object into the queue, without blocking.
This always succeeds, because the queue is unbounded. We don't provide
a blocking ``put`` method, because it would never need to block.
Args:
obj (object): The object to enqueue.
"""
if not self._data:
assert not self._can_get
if self._lot:
self._lot.unpark(count=1)
else:
self._can_get = True
self._data.append(obj)
def _get_batch_protected(self):
data = self._data.copy()
self._data.clear()
self._can_get = False
return data
def get_batch_nowait(self):
"""Attempt to get the next batch from the queue, without blocking.
Returns:
list: A list of dequeued items, in order. On a successful call this
list is always non-empty; if it would be empty we raise
:exc:`~trio.WouldBlock` instead.
Raises:
~trio.WouldBlock: if the queue is empty.
"""
if not self._can_get:
raise _core.WouldBlock
return self._get_batch_protected()
async def get_batch(self):
"""Get the next batch from the queue, blocking as necessary.
Returns:
list: A list of dequeued items, in order. This list is always
non-empty.
"""
await _core.checkpoint_if_cancelled()
if not self._can_get:
await self._lot.park()
return self._get_batch_protected()
else:
try:
return self._get_batch_protected()
finally:
await _core.cancel_shielded_checkpoint()
def statistics(self):
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``qsize``: The number of items currently in the queue.
* ``tasks_waiting``: The number of tasks blocked on this queue's
:meth:`get_batch` method.
"""
return _UnboundedQueueStats(
qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting
)
def __aiter__(self):
return self
async def __anext__(self):
return await self.get_batch()

View File

@@ -0,0 +1,71 @@
import socket
import signal
import warnings
from .. import _core
from .._util import is_main_thread
class WakeupSocketpair:
def __init__(self):
self.wakeup_sock, self.write_sock = socket.socketpair()
self.wakeup_sock.setblocking(False)
self.write_sock.setblocking(False)
# This somewhat reduces the amount of memory wasted queueing up data
# for wakeups. With these settings, maximum number of 1-byte sends
# before getting BlockingIOError:
# Linux 4.8: 6
# macOS (darwin 15.5): 1
# Windows 10: 525347
# Windows you're weird. (And on Windows setting SNDBUF to 0 makes send
# blocking, even on non-blocking sockets, so don't do that.)
self.wakeup_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
self.write_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1)
# On Windows this is a TCP socket so this might matter. On other
# platforms this fails b/c AF_UNIX sockets aren't actually TCP.
try:
self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except OSError:
pass
self.old_wakeup_fd = None
def wakeup_thread_and_signal_safe(self):
try:
self.write_sock.send(b"\x00")
except BlockingIOError:
pass
async def wait_woken(self):
await _core.wait_readable(self.wakeup_sock)
self.drain()
def drain(self):
try:
while True:
self.wakeup_sock.recv(2**16)
except BlockingIOError:
pass
def wakeup_on_signals(self):
assert self.old_wakeup_fd is None
if not is_main_thread():
return
fd = self.write_sock.fileno()
self.old_wakeup_fd = signal.set_wakeup_fd(fd, warn_on_full_buffer=False)
if self.old_wakeup_fd != -1:
warnings.warn(
RuntimeWarning(
"It looks like Trio's signal handling code might have "
"collided with another library you're using. If you're "
"running Trio in guest mode, then this might mean you "
"should set host_uses_signal_set_wakeup_fd=True. "
"Otherwise, file a bug on Trio and we'll help you figure "
"out what's going on."
)
)
def close(self):
self.wakeup_sock.close()
self.write_sock.close()
if self.old_wakeup_fd is not None:
signal.set_wakeup_fd(self.old_wakeup_fd)

View File

@@ -0,0 +1,323 @@
import cffi
import re
import enum
################################################################
# Functions and types
################################################################
LIB = """
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa383751(v=vs.85).aspx
typedef int BOOL;
typedef unsigned char BYTE;
typedef BYTE BOOLEAN;
typedef void* PVOID;
typedef PVOID HANDLE;
typedef unsigned long DWORD;
typedef unsigned long ULONG;
typedef unsigned int NTSTATUS;
typedef unsigned long u_long;
typedef ULONG *PULONG;
typedef const void *LPCVOID;
typedef void *LPVOID;
typedef const wchar_t *LPCWSTR;
typedef uintptr_t ULONG_PTR;
typedef uintptr_t UINT_PTR;
typedef UINT_PTR SOCKET;
typedef struct _OVERLAPPED {
ULONG_PTR Internal;
ULONG_PTR InternalHigh;
union {
struct {
DWORD Offset;
DWORD OffsetHigh;
} DUMMYSTRUCTNAME;
PVOID Pointer;
} DUMMYUNIONNAME;
HANDLE hEvent;
} OVERLAPPED, *LPOVERLAPPED;
typedef OVERLAPPED WSAOVERLAPPED;
typedef LPOVERLAPPED LPWSAOVERLAPPED;
typedef PVOID LPSECURITY_ATTRIBUTES;
typedef PVOID LPCSTR;
typedef struct _OVERLAPPED_ENTRY {
ULONG_PTR lpCompletionKey;
LPOVERLAPPED lpOverlapped;
ULONG_PTR Internal;
DWORD dwNumberOfBytesTransferred;
} OVERLAPPED_ENTRY, *LPOVERLAPPED_ENTRY;
// kernel32.dll
HANDLE WINAPI CreateIoCompletionPort(
_In_ HANDLE FileHandle,
_In_opt_ HANDLE ExistingCompletionPort,
_In_ ULONG_PTR CompletionKey,
_In_ DWORD NumberOfConcurrentThreads
);
BOOL SetFileCompletionNotificationModes(
HANDLE FileHandle,
UCHAR Flags
);
HANDLE CreateFileW(
LPCWSTR lpFileName,
DWORD dwDesiredAccess,
DWORD dwShareMode,
LPSECURITY_ATTRIBUTES lpSecurityAttributes,
DWORD dwCreationDisposition,
DWORD dwFlagsAndAttributes,
HANDLE hTemplateFile
);
BOOL WINAPI CloseHandle(
_In_ HANDLE hObject
);
BOOL WINAPI PostQueuedCompletionStatus(
_In_ HANDLE CompletionPort,
_In_ DWORD dwNumberOfBytesTransferred,
_In_ ULONG_PTR dwCompletionKey,
_In_opt_ LPOVERLAPPED lpOverlapped
);
BOOL WINAPI GetQueuedCompletionStatusEx(
_In_ HANDLE CompletionPort,
_Out_ LPOVERLAPPED_ENTRY lpCompletionPortEntries,
_In_ ULONG ulCount,
_Out_ PULONG ulNumEntriesRemoved,
_In_ DWORD dwMilliseconds,
_In_ BOOL fAlertable
);
BOOL WINAPI CancelIoEx(
_In_ HANDLE hFile,
_In_opt_ LPOVERLAPPED lpOverlapped
);
BOOL WriteFile(
HANDLE hFile,
LPCVOID lpBuffer,
DWORD nNumberOfBytesToWrite,
LPDWORD lpNumberOfBytesWritten,
LPOVERLAPPED lpOverlapped
);
BOOL ReadFile(
HANDLE hFile,
LPVOID lpBuffer,
DWORD nNumberOfBytesToRead,
LPDWORD lpNumberOfBytesRead,
LPOVERLAPPED lpOverlapped
);
BOOL WINAPI SetConsoleCtrlHandler(
_In_opt_ void* HandlerRoutine,
_In_ BOOL Add
);
HANDLE CreateEventA(
LPSECURITY_ATTRIBUTES lpEventAttributes,
BOOL bManualReset,
BOOL bInitialState,
LPCSTR lpName
);
BOOL SetEvent(
HANDLE hEvent
);
BOOL ResetEvent(
HANDLE hEvent
);
DWORD WaitForSingleObject(
HANDLE hHandle,
DWORD dwMilliseconds
);
DWORD WaitForMultipleObjects(
DWORD nCount,
HANDLE *lpHandles,
BOOL bWaitAll,
DWORD dwMilliseconds
);
ULONG RtlNtStatusToDosError(
NTSTATUS Status
);
int WSAIoctl(
SOCKET s,
DWORD dwIoControlCode,
LPVOID lpvInBuffer,
DWORD cbInBuffer,
LPVOID lpvOutBuffer,
DWORD cbOutBuffer,
LPDWORD lpcbBytesReturned,
LPWSAOVERLAPPED lpOverlapped,
// actually LPWSAOVERLAPPED_COMPLETION_ROUTINE
void* lpCompletionRoutine
);
int WSAGetLastError();
BOOL DeviceIoControl(
HANDLE hDevice,
DWORD dwIoControlCode,
LPVOID lpInBuffer,
DWORD nInBufferSize,
LPVOID lpOutBuffer,
DWORD nOutBufferSize,
LPDWORD lpBytesReturned,
LPOVERLAPPED lpOverlapped
);
// From https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
typedef struct _AFD_POLL_HANDLE_INFO {
HANDLE Handle;
ULONG Events;
NTSTATUS Status;
} AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO;
// This is really defined as a messy union to allow stuff like
// i.DUMMYSTRUCTNAME.LowPart, but we don't need those complications.
// Under all that it's just an int64.
typedef int64_t LARGE_INTEGER;
typedef struct _AFD_POLL_INFO {
LARGE_INTEGER Timeout;
ULONG NumberOfHandles;
ULONG Exclusive;
AFD_POLL_HANDLE_INFO Handles[1];
} AFD_POLL_INFO, *PAFD_POLL_INFO;
"""
# cribbed from pywincffi
# programmatically strips out those annotations MSDN likes, like _In_
REGEX_SAL_ANNOTATION = re.compile(
r"\b(_In_|_Inout_|_Out_|_Outptr_|_Reserved_)(opt_)?\b"
)
LIB = REGEX_SAL_ANNOTATION.sub(" ", LIB)
# Other fixups:
# - get rid of FAR, cffi doesn't like it
LIB = re.sub(r"\bFAR\b", " ", LIB)
# - PASCAL is apparently an alias for __stdcall (on modern compilers - modern
# being _MSC_VER >= 800)
LIB = re.sub(r"\bPASCAL\b", "__stdcall", LIB)
ffi = cffi.FFI()
ffi.cdef(LIB)
kernel32 = ffi.dlopen("kernel32.dll")
ntdll = ffi.dlopen("ntdll.dll")
ws2_32 = ffi.dlopen("ws2_32.dll")
################################################################
# Magic numbers
################################################################
# Here's a great resource for looking these up:
# https://www.magnumdb.com
# (Tip: check the box to see "Hex value")
INVALID_HANDLE_VALUE = ffi.cast("HANDLE", -1)
class ErrorCodes(enum.IntEnum):
STATUS_TIMEOUT = 0x102
WAIT_TIMEOUT = 0x102
WAIT_ABANDONED = 0x80
WAIT_OBJECT_0 = 0x00 # object is signaled
WAIT_FAILED = 0xFFFFFFFF
ERROR_IO_PENDING = 997
ERROR_OPERATION_ABORTED = 995
ERROR_ABANDONED_WAIT_0 = 735
ERROR_INVALID_HANDLE = 6
ERROR_INVALID_PARMETER = 87
ERROR_NOT_FOUND = 1168
ERROR_NOT_SOCKET = 10038
class FileFlags(enum.IntEnum):
GENERIC_READ = 0x80000000
SYNCHRONIZE = 0x00100000
FILE_FLAG_OVERLAPPED = 0x40000000
FILE_SHARE_READ = 1
FILE_SHARE_WRITE = 2
FILE_SHARE_DELETE = 4
CREATE_NEW = 1
CREATE_ALWAYS = 2
OPEN_EXISTING = 3
OPEN_ALWAYS = 4
TRUNCATE_EXISTING = 5
class AFDPollFlags(enum.IntFlag):
# These are drawn from a combination of:
# https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
# https://github.com/reactos/reactos/blob/master/sdk/include/reactos/drivers/afd/shared.h
AFD_POLL_RECEIVE = 0x0001
AFD_POLL_RECEIVE_EXPEDITED = 0x0002 # OOB/urgent data
AFD_POLL_SEND = 0x0004
AFD_POLL_DISCONNECT = 0x0008 # received EOF (FIN)
AFD_POLL_ABORT = 0x0010 # received RST
AFD_POLL_LOCAL_CLOSE = 0x0020 # local socket object closed
AFD_POLL_CONNECT = 0x0040 # socket is successfully connected
AFD_POLL_ACCEPT = 0x0080 # you can call accept on this socket
AFD_POLL_CONNECT_FAIL = 0x0100 # connect() terminated unsuccessfully
# See WSAEventSelect docs for more details on these four:
AFD_POLL_QOS = 0x0200
AFD_POLL_GROUP_QOS = 0x0400
AFD_POLL_ROUTING_INTERFACE_CHANGE = 0x0800
AFD_POLL_EVENT_ADDRESS_LIST_CHANGE = 0x1000
class WSAIoctls(enum.IntEnum):
SIO_BASE_HANDLE = 0x48000022
SIO_BSP_HANDLE_SELECT = 0x4800001C
SIO_BSP_HANDLE_POLL = 0x4800001D
class CompletionModes(enum.IntFlag):
FILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 0x1
FILE_SKIP_SET_EVENT_ON_HANDLE = 0x2
class IoControlCodes(enum.IntEnum):
IOCTL_AFD_POLL = 0x00012024
################################################################
# Generic helpers
################################################################
def _handle(obj):
# For now, represent handles as either cffi HANDLEs or as ints. If you
# try to pass in a file descriptor instead, it's not going to work
# out. (For that msvcrt.get_osfhandle does the trick, but I don't know if
# we'll actually need that for anything...) For sockets this doesn't
# matter, Python never allocates an fd. So let's wait until we actually
# encounter the problem before worrying about it.
if type(obj) is int:
return ffi.cast("HANDLE", obj)
else:
return obj
def raise_winerror(winerror=None, *, filename=None, filename2=None):
if winerror is None:
winerror, msg = ffi.getwinerror()
else:
_, msg = ffi.getwinerror(winerror)
# https://docs.python.org/3/library/exceptions.html#OSError
raise OSError(0, msg, filename, winerror, filename2)

View File

@@ -0,0 +1,25 @@
import pytest
import inspect
# XX this should move into a global something
from ...testing import MockClock, trio_test
@pytest.fixture
def mock_clock():
return MockClock()
@pytest.fixture
def autojump_clock():
return MockClock(autojump_threshold=0)
# FIXME: split off into a package (or just make part of Trio's public
# interface?), with config file to enable? and I guess a mark option too; I
# guess it's useful with the class- and file-level marking machinery (where
# the raw @trio_test decorator isn't enough).
@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem):
if inspect.iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj = trio_test(pyfuncitem.obj)

View File

@@ -0,0 +1,320 @@
import sys
import weakref
import pytest
from math import inf
from functools import partial
from async_generator import aclosing
from ... import _core
from .tutil import gc_collect_harder, buggy_pypy_asyncgens, restore_unraisablehook
def test_asyncgen_basics():
collected = []
async def example(cause):
try:
try:
yield 42
except GeneratorExit:
pass
await _core.checkpoint()
except _core.Cancelled:
assert "exhausted" not in cause
task_name = _core.current_task().name
assert cause in task_name or task_name == "<init>"
assert _core.current_effective_deadline() == -inf
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
collected.append(cause)
else:
assert "async_main" in _core.current_task().name
assert "exhausted" in cause
assert _core.current_effective_deadline() == inf
await _core.checkpoint()
collected.append(cause)
saved = []
async def async_main():
# GC'ed before exhausted
with pytest.warns(
ResourceWarning, match="Async generator.*collected before.*exhausted"
):
assert 42 == await example("abandoned").asend(None)
gc_collect_harder()
await _core.wait_all_tasks_blocked()
assert collected.pop() == "abandoned"
# aclosing() ensures it's cleaned up at point of use
async with aclosing(example("exhausted 1")) as aiter:
assert 42 == await aiter.asend(None)
assert collected.pop() == "exhausted 1"
# Also fine if you exhaust it at point of use
async for val in example("exhausted 2"):
assert val == 42
assert collected.pop() == "exhausted 2"
gc_collect_harder()
# No problems saving the geniter when using either of these patterns
async with aclosing(example("exhausted 3")) as aiter:
saved.append(aiter)
assert 42 == await aiter.asend(None)
assert collected.pop() == "exhausted 3"
# Also fine if you exhaust it at point of use
saved.append(example("exhausted 4"))
async for val in saved[-1]:
assert val == 42
assert collected.pop() == "exhausted 4"
# Leave one referenced-but-unexhausted and make sure it gets cleaned up
if buggy_pypy_asyncgens:
collected.append("outlived run")
else:
saved.append(example("outlived run"))
assert 42 == await saved[-1].asend(None)
assert collected == []
_core.run(async_main)
assert collected.pop() == "outlived run"
for agen in saved:
assert agen.ag_frame is None # all should now be exhausted
async def test_asyncgen_throws_during_finalization(caplog):
record = []
async def agen():
try:
yield 1
finally:
await _core.cancel_shielded_checkpoint()
record.append("crashing")
raise ValueError("oops")
with restore_unraisablehook():
await agen().asend(None)
gc_collect_harder()
await _core.wait_all_tasks_blocked()
assert record == ["crashing"]
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
assert exc_type is ValueError
assert str(exc_value) == "oops"
assert "during finalization of async generator" in caplog.records[0].message
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
def test_firstiter_after_closing():
saved = []
record = []
async def funky_agen():
try:
yield 1
except GeneratorExit:
record.append("cleanup 1")
raise
try:
yield 2
finally:
record.append("cleanup 2")
await funky_agen().asend(None)
async def async_main():
aiter = funky_agen()
saved.append(aiter)
assert 1 == await aiter.asend(None)
assert 2 == await aiter.asend(None)
_core.run(async_main)
assert record == ["cleanup 2", "cleanup 1"]
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
def test_interdependent_asyncgen_cleanup_order():
saved = []
record = []
async def innermost():
try:
yield 1
finally:
await _core.cancel_shielded_checkpoint()
record.append("innermost")
async def agen(label, inner):
try:
yield await inner.asend(None)
finally:
# Either `inner` has already been cleaned up, or
# we're about to exhaust it. Either way, we wind
# up with `record` containing the labels in
# innermost-to-outermost order.
with pytest.raises(StopAsyncIteration):
await inner.asend(None)
record.append(label)
async def async_main():
# This makes a chain of 101 interdependent asyncgens:
# agen(99)'s cleanup will iterate agen(98)'s will iterate
# ... agen(0)'s will iterate innermost()'s
ag_chain = innermost()
for idx in range(100):
ag_chain = agen(idx, ag_chain)
saved.append(ag_chain)
assert 1 == await ag_chain.asend(None)
assert record == []
_core.run(async_main)
assert record == ["innermost"] + list(range(100))
@restore_unraisablehook()
def test_last_minute_gc_edge_case():
saved = []
record = []
needs_retry = True
async def agen():
try:
yield 1
finally:
record.append("cleaned up")
def collect_at_opportune_moment(token):
runner = _core._run.GLOBAL_RUN_CONTEXT.runner
if runner.system_nursery._closed and isinstance(
runner.asyncgens.alive, weakref.WeakSet
):
saved.clear()
record.append("final collection")
gc_collect_harder()
record.append("done")
else:
try:
token.run_sync_soon(collect_at_opportune_moment, token)
except _core.RunFinishedError: # pragma: no cover
nonlocal needs_retry
needs_retry = True
async def async_main():
token = _core.current_trio_token()
token.run_sync_soon(collect_at_opportune_moment, token)
saved.append(agen())
await saved[-1].asend(None)
# Actually running into the edge case requires that the run_sync_soon task
# execute in between the system nursery's closure and the strong-ification
# of runner.asyncgens. There's about a 25% chance that it doesn't
# (if the run_sync_soon task runs before init on one tick and after init
# on the next tick); if we try enough times, we can make the chance of
# failure as small as we want.
for attempt in range(50):
needs_retry = False
del record[:]
del saved[:]
_core.run(async_main)
if needs_retry: # pragma: no cover
if not buggy_pypy_asyncgens:
assert record == ["cleaned up"]
else:
assert record == ["final collection", "done", "cleaned up"]
break
else: # pragma: no cover
pytest.fail(
f"Didn't manage to hit the trailing_finalizer_asyncgens case "
f"despite trying {attempt} times"
)
async def step_outside_async_context(aiter):
# abort_fns run outside of task context, at least if they're
# triggered by a deadline expiry rather than a direct
# cancellation. Thus, an asyncgen first iterated inside one
# will appear non-Trio, and since no other hooks were installed,
# will use the last-ditch fallback handling (that tries to mimic
# CPython's behavior with no hooks).
#
# NB: the strangeness with aiter being an attribute of abort_fn is
# to make it as easy as possible to ensure we don't hang onto a
# reference to aiter inside the guts of the run loop.
def abort_fn(_):
with pytest.raises(StopIteration, match="42"):
abort_fn.aiter.asend(None).send(None)
del abort_fn.aiter
return _core.Abort.SUCCEEDED
abort_fn.aiter = aiter
async with _core.open_nursery() as nursery:
nursery.start_soon(_core.wait_task_rescheduled, abort_fn)
await _core.wait_all_tasks_blocked()
nursery.cancel_scope.deadline = _core.current_time()
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
async def test_fallback_when_no_hook_claims_it(capsys):
async def well_behaved():
yield 42
async def yields_after_yield():
with pytest.raises(GeneratorExit):
yield 42
yield 100
async def awaits_after_yield():
with pytest.raises(GeneratorExit):
yield 42
await _core.cancel_shielded_checkpoint()
with restore_unraisablehook():
await step_outside_async_context(well_behaved())
gc_collect_harder()
assert capsys.readouterr().err == ""
await step_outside_async_context(yields_after_yield())
gc_collect_harder()
assert "ignored GeneratorExit" in capsys.readouterr().err
await step_outside_async_context(awaits_after_yield())
gc_collect_harder()
assert "awaited something during finalization" in capsys.readouterr().err
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
def test_delegation_to_existing_hooks():
record = []
def my_firstiter(agen):
record.append("firstiter " + agen.ag_frame.f_locals["arg"])
def my_finalizer(agen):
record.append("finalizer " + agen.ag_frame.f_locals["arg"])
async def example(arg):
try:
yield 42
finally:
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
record.append("trio collected " + arg)
async def async_main():
await step_outside_async_context(example("theirs"))
assert 42 == await example("ours").asend(None)
gc_collect_harder()
assert record == ["firstiter theirs", "finalizer theirs"]
record[:] = []
await _core.wait_all_tasks_blocked()
assert record == ["trio collected ours"]
with restore_unraisablehook():
old_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(my_firstiter, my_finalizer)
try:
_core.run(async_main)
finally:
assert sys.get_asyncgen_hooks() == (my_firstiter, my_finalizer)
sys.set_asyncgen_hooks(*old_hooks)

View File

@@ -0,0 +1,546 @@
import pytest
import asyncio
import contextvars
import sys
import traceback
import queue
from functools import partial
from math import inf
import signal
import socket
import threading
import time
import warnings
import trio
import trio.testing
from .tutil import gc_collect_harder, buggy_pypy_asyncgens, restore_unraisablehook
from ..._util import signal_raise
# The simplest possible "host" loop.
# Nice features:
# - we can run code "outside" of trio using the schedule function passed to
# our main
# - final result is returned
# - any unhandled exceptions cause an immediate crash
def trivial_guest_run(trio_fn, **start_guest_run_kwargs):
todo = queue.Queue()
host_thread = threading.current_thread()
def run_sync_soon_threadsafe(fn):
if host_thread is threading.current_thread(): # pragma: no cover
crash = partial(
pytest.fail, "run_sync_soon_threadsafe called from host thread"
)
todo.put(("run", crash))
todo.put(("run", fn))
def run_sync_soon_not_threadsafe(fn):
if host_thread is not threading.current_thread(): # pragma: no cover
crash = partial(
pytest.fail, "run_sync_soon_not_threadsafe called from worker thread"
)
todo.put(("run", crash))
todo.put(("run", fn))
def done_callback(outcome):
todo.put(("unwrap", outcome))
trio.lowlevel.start_guest_run(
trio_fn,
run_sync_soon_not_threadsafe,
run_sync_soon_threadsafe=run_sync_soon_threadsafe,
run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe,
done_callback=done_callback,
**start_guest_run_kwargs,
)
try:
while True:
op, obj = todo.get()
if op == "run":
obj()
elif op == "unwrap":
return obj.unwrap()
else: # pragma: no cover
assert False
finally:
# Make sure that exceptions raised here don't capture these, so that
# if an exception does cause us to abandon a run then the Trio state
# has a chance to be GC'ed and warn about it.
del todo, run_sync_soon_threadsafe, done_callback
def test_guest_trivial():
async def trio_return(in_host):
await trio.sleep(0)
return "ok"
assert trivial_guest_run(trio_return) == "ok"
async def trio_fail(in_host):
raise KeyError("whoopsiedaisy")
with pytest.raises(KeyError, match="whoopsiedaisy"):
trivial_guest_run(trio_fail)
def test_guest_can_do_io():
async def trio_main(in_host):
record = []
a, b = trio.socket.socketpair()
with a, b:
async with trio.open_nursery() as nursery:
async def do_receive():
record.append(await a.recv(1))
nursery.start_soon(do_receive)
await trio.testing.wait_all_tasks_blocked()
await b.send(b"x")
assert record == [b"x"]
trivial_guest_run(trio_main)
def test_host_can_directly_wake_trio_task():
async def trio_main(in_host):
ev = trio.Event()
in_host(ev.set)
await ev.wait()
return "ok"
assert trivial_guest_run(trio_main) == "ok"
def test_host_altering_deadlines_wakes_trio_up():
def set_deadline(cscope, new_deadline):
cscope.deadline = new_deadline
async def trio_main(in_host):
with trio.CancelScope() as cscope:
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep_forever()
assert cscope.cancelled_caught
with trio.CancelScope() as cscope:
# also do a change that doesn't affect the next deadline, just to
# exercise that path
in_host(lambda: set_deadline(cscope, 1e6))
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep(999)
assert cscope.cancelled_caught
return "ok"
assert trivial_guest_run(trio_main) == "ok"
def test_warn_set_wakeup_fd_overwrite():
assert signal.set_wakeup_fd(-1) == -1
async def trio_main(in_host):
return "ok"
a, b = socket.socketpair()
with a, b:
a.setblocking(False)
# Warn if there's already a wakeup fd
signal.set_wakeup_fd(a.fileno())
try:
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
assert trivial_guest_run(trio_main) == "ok"
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
signal.set_wakeup_fd(a.fileno())
try:
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
assert (
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=False)
== "ok"
)
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
# Don't warn if there isn't already a wakeup fd
with warnings.catch_warnings():
warnings.simplefilter("error")
assert trivial_guest_run(trio_main) == "ok"
with warnings.catch_warnings():
warnings.simplefilter("error")
assert (
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=True)
== "ok"
)
# If there's already a wakeup fd, but we've been told to trust it,
# then it's left alone and there's no warning
signal.set_wakeup_fd(a.fileno())
try:
async def trio_check_wakeup_fd_unaltered(in_host):
fd = signal.set_wakeup_fd(-1)
assert fd == a.fileno()
signal.set_wakeup_fd(fd)
return "ok"
with warnings.catch_warnings():
warnings.simplefilter("error")
assert (
trivial_guest_run(
trio_check_wakeup_fd_unaltered,
host_uses_signal_set_wakeup_fd=True,
)
== "ok"
)
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked():
# This is designed to hit the branch in unrolled_run where:
# idle_primed=True
# runner.runq is empty
# events is Truth-y
# ...and confirm that in this case, wait_all_tasks_blocked does not get
# triggered.
def set_deadline(cscope, new_deadline):
print(f"setting deadline {new_deadline}")
cscope.deadline = new_deadline
async def trio_main(in_host):
async def sit_in_wait_all_tasks_blocked(watb_cscope):
with watb_cscope:
# Overall point of this test is that this
# wait_all_tasks_blocked should *not* return normally, but
# only by cancellation.
await trio.testing.wait_all_tasks_blocked(cushion=9999)
assert False # pragma: no cover
assert watb_cscope.cancelled_caught
async def get_woken_by_host_deadline(watb_cscope):
with trio.CancelScope() as cscope:
print("scheduling stuff to happen")
# Altering the deadline from the host, to something in the
# future, will cause the run loop to wake up, but then
# discover that there is nothing to do and go back to sleep.
# This should *not* trigger wait_all_tasks_blocked.
#
# So the 'before_io_wait' here will wait until we're blocking
# with the wait_all_tasks_blocked primed, and then schedule a
# deadline change. The critical test is that this should *not*
# wake up 'sit_in_wait_all_tasks_blocked'.
#
# The after we've had a chance to wake up
# 'sit_in_wait_all_tasks_blocked', we want the test to
# actually end. So in after_io_wait we schedule a second host
# call to tear things down.
class InstrumentHelper:
def __init__(self):
self.primed = False
def before_io_wait(self, timeout):
print(f"before_io_wait({timeout})")
if timeout == 9999: # pragma: no branch
assert not self.primed
in_host(lambda: set_deadline(cscope, 1e9))
self.primed = True
def after_io_wait(self, timeout):
if self.primed: # pragma: no branch
print("instrument triggered")
in_host(lambda: cscope.cancel())
trio.lowlevel.remove_instrument(self)
trio.lowlevel.add_instrument(InstrumentHelper())
await trio.sleep_forever()
assert cscope.cancelled_caught
watb_cscope.cancel()
async with trio.open_nursery() as nursery:
watb_cscope = trio.CancelScope()
nursery.start_soon(sit_in_wait_all_tasks_blocked, watb_cscope)
await trio.testing.wait_all_tasks_blocked()
nursery.start_soon(get_woken_by_host_deadline, watb_cscope)
return "ok"
assert trivial_guest_run(trio_main) == "ok"
@restore_unraisablehook()
def test_guest_warns_if_abandoned():
# This warning is emitted from the garbage collector. So we have to make
# sure that our abandoned run is garbage. The easiest way to do this is to
# put it into a function, so that we're sure all the local state,
# traceback frames, etc. are garbage once it returns.
def do_abandoned_guest_run():
async def abandoned_main(in_host):
in_host(lambda: 1 / 0)
while True:
await trio.sleep(0)
with pytest.raises(ZeroDivisionError):
trivial_guest_run(abandoned_main)
with pytest.warns(RuntimeWarning, match="Trio guest run got abandoned"):
do_abandoned_guest_run()
gc_collect_harder()
# If you have problems some day figuring out what's holding onto a
# reference to the unrolled_run generator and making this test fail,
# then this might be useful to help track it down. (It assumes you
# also hack start_guest_run so that it does 'global W; W =
# weakref(unrolled_run_gen)'.)
#
# import gc
# print(trio._core._run.W)
# targets = [trio._core._run.W()]
# for i in range(15):
# new_targets = []
# for target in targets:
# new_targets += gc.get_referrers(target)
# new_targets.remove(targets)
# print("#####################")
# print(f"depth {i}: {len(new_targets)}")
# print(new_targets)
# targets = new_targets
with pytest.raises(RuntimeError):
trio.current_time()
def aiotrio_run(trio_fn, *, pass_not_threadsafe=True, **start_guest_run_kwargs):
loop = asyncio.new_event_loop()
async def aio_main():
trio_done_fut = loop.create_future()
def trio_done_callback(main_outcome):
print(f"trio_fn finished: {main_outcome!r}")
trio_done_fut.set_result(main_outcome)
if pass_not_threadsafe:
start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon
trio.lowlevel.start_guest_run(
trio_fn,
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
done_callback=trio_done_callback,
**start_guest_run_kwargs,
)
return (await trio_done_fut).unwrap()
try:
return loop.run_until_complete(aio_main())
finally:
loop.close()
def test_guest_mode_on_asyncio():
async def trio_main():
print("trio_main!")
to_trio, from_aio = trio.open_memory_channel(float("inf"))
from_trio = asyncio.Queue()
aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio))
# Make sure we have at least one tick where we don't need to go into
# the thread
await trio.sleep(0)
from_trio.put_nowait(0)
async for n in from_aio:
print(f"trio got: {n}")
from_trio.put_nowait(n + 1)
if n >= 10:
aio_task.cancel()
return "trio-main-done"
async def aio_pingpong(from_trio, to_trio):
print("aio_pingpong!")
try:
while True:
n = await from_trio.get()
print(f"aio got: {n}")
to_trio.send_nowait(n + 1)
except asyncio.CancelledError:
raise
except: # pragma: no cover
traceback.print_exc()
raise
assert (
aiotrio_run(
trio_main,
# Not all versions of asyncio we test on can actually be trusted,
# but this test doesn't care about signal handling, and it's
# easier to just avoid the warnings.
host_uses_signal_set_wakeup_fd=True,
)
== "trio-main-done"
)
assert (
aiotrio_run(
trio_main,
# Also check that passing only call_soon_threadsafe works, via the
# fallback path where we use it for everything.
pass_not_threadsafe=False,
host_uses_signal_set_wakeup_fd=True,
)
== "trio-main-done"
)
def test_guest_mode_internal_errors(monkeypatch, recwarn):
with monkeypatch.context() as m:
async def crash_in_run_loop(in_host):
m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI")
await trio.sleep(1)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_run_loop)
with monkeypatch.context() as m:
async def crash_in_io(in_host):
m.setattr("trio._core._run.TheIOManager.get_events", None)
await trio.sleep(0)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_io)
with monkeypatch.context() as m:
async def crash_in_worker_thread_io(in_host):
t = threading.current_thread()
old_get_events = trio._core._run.TheIOManager.get_events
def bad_get_events(*args):
if threading.current_thread() is not t:
raise ValueError("oh no!")
else:
return old_get_events(*args)
m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events)
await trio.sleep(1)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_worker_thread_io)
gc_collect_harder()
def test_guest_mode_ki():
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Check SIGINT in Trio func and in host func
async def trio_main(in_host):
with pytest.raises(KeyboardInterrupt):
signal_raise(signal.SIGINT)
# Host SIGINT should get injected into Trio
in_host(partial(signal_raise, signal.SIGINT))
await trio.sleep(10)
with pytest.raises(KeyboardInterrupt) as excinfo:
trivial_guest_run(trio_main)
assert excinfo.value.__context__ is None
# Signal handler should be restored properly on exit
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Also check chaining in the case where KI is injected after main exits
final_exc = KeyError("whoa")
async def trio_main_raising(in_host):
in_host(partial(signal_raise, signal.SIGINT))
raise final_exc
with pytest.raises(KeyboardInterrupt) as excinfo:
trivial_guest_run(trio_main_raising)
assert excinfo.value.__context__ is final_exc
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
def test_guest_mode_autojump_clock_threshold_changing():
# This is super obscure and probably no-one will ever notice, but
# technically mutating the MockClock.autojump_threshold from the host
# should wake up the guest, so let's test it.
clock = trio.testing.MockClock()
DURATION = 120
async def trio_main(in_host):
assert trio.current_time() == 0
in_host(lambda: setattr(clock, "autojump_threshold", 0))
await trio.sleep(DURATION)
assert trio.current_time() == DURATION
start = time.monotonic()
trivial_guest_run(trio_main, clock=clock)
end = time.monotonic()
# Should be basically instantaneous, but we'll leave a generous buffer to
# account for any CI weirdness
assert end - start < DURATION / 2
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy")
@pytest.mark.xfail(
sys.implementation.name == "pypy",
reason="async generator issue under investigation",
)
@restore_unraisablehook()
def test_guest_mode_asyncgens():
import sniffio
record = set()
async def agen(label):
assert sniffio.current_async_library() == label
try:
yield 1
finally:
library = sniffio.current_async_library()
try:
await sys.modules[library].sleep(0)
except trio.Cancelled:
pass
record.add((label, library))
async def iterate_in_aio():
# "trio" gets inherited from our Trio caller if we don't set this
sniffio.current_async_library_cvar.set("asyncio")
await agen("asyncio").asend(None)
async def trio_main():
task = asyncio.ensure_future(iterate_in_aio())
done_evt = trio.Event()
task.add_done_callback(lambda _: done_evt.set())
with trio.fail_after(1):
await done_evt.wait()
await agen("trio").asend(None)
gc_collect_harder()
# Ensure we don't pollute the thread-level context if run under
# an asyncio without contextvars support (3.6)
context = contextvars.copy_context()
context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)
assert record == {("asyncio", "asyncio"), ("trio", "trio")}

View File

@@ -0,0 +1,253 @@
import attr
import pytest
from ... import _core, _abc
from .tutil import check_sequence_matches
@attr.s(eq=False, hash=False)
class TaskRecorder:
record = attr.ib(factory=list)
def before_run(self):
self.record.append(("before_run",))
def task_scheduled(self, task):
self.record.append(("schedule", task))
def before_task_step(self, task):
assert task is _core.current_task()
self.record.append(("before", task))
def after_task_step(self, task):
assert task is _core.current_task()
self.record.append(("after", task))
def after_run(self):
self.record.append(("after_run",))
def filter_tasks(self, tasks):
for item in self.record:
if item[0] in ("schedule", "before", "after") and item[1] in tasks:
yield item
if item[0] in ("before_run", "after_run"):
yield item
def test_instruments(recwarn):
r1 = TaskRecorder()
r2 = TaskRecorder()
r3 = TaskRecorder()
task = None
# We use a child task for this, because the main task does some extra
# bookkeeping stuff that can leak into the instrument results, and we
# don't want to deal with it.
async def task_fn():
nonlocal task
task = _core.current_task()
for _ in range(4):
await _core.checkpoint()
# replace r2 with r3, to test that we can manipulate them as we go
_core.remove_instrument(r2)
with pytest.raises(KeyError):
_core.remove_instrument(r2)
# add is idempotent
_core.add_instrument(r3)
_core.add_instrument(r3)
for _ in range(1):
await _core.checkpoint()
async def main():
async with _core.open_nursery() as nursery:
nursery.start_soon(task_fn)
_core.run(main, instruments=[r1, r2])
# It sleeps 5 times, so it runs 6 times. Note that checkpoint()
# reschedules the task immediately upon yielding, before the
# after_task_step event fires.
expected = (
[("before_run",), ("schedule", task)]
+ [("before", task), ("schedule", task), ("after", task)] * 5
+ [("before", task), ("after", task), ("after_run",)]
)
assert r1.record == r2.record + r3.record
assert list(r1.filter_tasks([task])) == expected
def test_instruments_interleave():
tasks = {}
async def two_step1():
tasks["t1"] = _core.current_task()
await _core.checkpoint()
async def two_step2():
tasks["t2"] = _core.current_task()
await _core.checkpoint()
async def main():
async with _core.open_nursery() as nursery:
nursery.start_soon(two_step1)
nursery.start_soon(two_step2)
r = TaskRecorder()
_core.run(main, instruments=[r])
expected = [
("before_run",),
("schedule", tasks["t1"]),
("schedule", tasks["t2"]),
{
("before", tasks["t1"]),
("schedule", tasks["t1"]),
("after", tasks["t1"]),
("before", tasks["t2"]),
("schedule", tasks["t2"]),
("after", tasks["t2"]),
},
{
("before", tasks["t1"]),
("after", tasks["t1"]),
("before", tasks["t2"]),
("after", tasks["t2"]),
},
("after_run",),
]
print(list(r.filter_tasks(tasks.values())))
check_sequence_matches(list(r.filter_tasks(tasks.values())), expected)
def test_null_instrument():
# undefined instrument methods are skipped
class NullInstrument:
def something_unrelated(self):
pass # pragma: no cover
async def main():
await _core.checkpoint()
_core.run(main, instruments=[NullInstrument()])
def test_instrument_before_after_run():
record = []
class BeforeAfterRun:
def before_run(self):
record.append("before_run")
def after_run(self):
record.append("after_run")
async def main():
pass
_core.run(main, instruments=[BeforeAfterRun()])
assert record == ["before_run", "after_run"]
def test_instrument_task_spawn_exit():
record = []
class SpawnExitRecorder:
def task_spawned(self, task):
record.append(("spawned", task))
def task_exited(self, task):
record.append(("exited", task))
async def main():
return _core.current_task()
main_task = _core.run(main, instruments=[SpawnExitRecorder()])
assert ("spawned", main_task) in record
assert ("exited", main_task) in record
# This test also tests having a crash before the initial task is even spawned,
# which is very difficult to handle.
def test_instruments_crash(caplog):
record = []
class BrokenInstrument:
def task_scheduled(self, task):
record.append("scheduled")
raise ValueError("oops")
def close(self):
# Shouldn't be called -- tests that the instrument disabling logic
# works right.
record.append("closed") # pragma: no cover
async def main():
record.append("main ran")
return _core.current_task()
r = TaskRecorder()
main_task = _core.run(main, instruments=[r, BrokenInstrument()])
assert record == ["scheduled", "main ran"]
# the TaskRecorder kept going throughout, even though the BrokenInstrument
# was disabled
assert ("after", main_task) in r.record
assert ("after_run",) in r.record
# And we got a log message
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
assert exc_type is ValueError
assert str(exc_value) == "oops"
assert "Instrument has been disabled" in caplog.records[0].message
def test_instruments_monkeypatch():
class NullInstrument(_abc.Instrument):
pass
instrument = NullInstrument()
async def main():
record = []
# Changing the set of hooks implemented by an instrument after
# it's installed doesn't make them start being called right away
instrument.before_task_step = record.append
await _core.checkpoint()
await _core.checkpoint()
assert len(record) == 0
# But if we remove and re-add the instrument, the new hooks are
# picked up
_core.remove_instrument(instrument)
_core.add_instrument(instrument)
await _core.checkpoint()
await _core.checkpoint()
assert record.count(_core.current_task()) == 2
_core.remove_instrument(instrument)
await _core.checkpoint()
await _core.checkpoint()
assert record.count(_core.current_task()) == 2
_core.run(main, instruments=[instrument])
def test_instrument_that_raises_on_getattr():
class EvilInstrument:
def task_exited(self, task):
assert False # pragma: no cover
@property
def after_run(self):
raise ValueError("oops")
async def main():
with pytest.raises(ValueError):
_core.add_instrument(EvilInstrument())
# Make sure the instrument is fully removed from the per-method lists
runner = _core.current_task()._runner
assert "after_run" not in runner.instruments
assert "task_exited" not in runner.instruments
_core.run(main)

View File

@@ -0,0 +1,447 @@
import pytest
import socket as stdlib_socket
import select
import random
import errno
from contextlib import suppress
from ... import _core
from ...testing import wait_all_tasks_blocked, Sequencer, assert_checkpoints
import trio
# Cross-platform tests for IO handling
def fill_socket(sock):
try:
while True:
sock.send(b"x" * 65536)
except BlockingIOError:
pass
def drain_socket(sock):
try:
while True:
sock.recv(65536)
except BlockingIOError:
pass
@pytest.fixture
def socketpair():
pair = stdlib_socket.socketpair()
for sock in pair:
sock.setblocking(False)
yield pair
for sock in pair:
sock.close()
def using_fileno(fn):
def fileno_wrapper(fileobj):
return fn(fileobj.fileno())
name = "<{} on fileno>".format(fn.__name__)
fileno_wrapper.__name__ = fileno_wrapper.__qualname__ = name
return fileno_wrapper
wait_readable_options = [trio.lowlevel.wait_readable]
wait_writable_options = [trio.lowlevel.wait_writable]
notify_closing_options = [trio.lowlevel.notify_closing]
for options_list in [
wait_readable_options,
wait_writable_options,
notify_closing_options,
]:
options_list += [using_fileno(f) for f in options_list]
# Decorators that feed in different settings for wait_readable / wait_writable
# / notify_closing.
# Note that if you use all three decorators on the same test, it will run all
# N**3 *combinations*
read_socket_test = pytest.mark.parametrize(
"wait_readable", wait_readable_options, ids=lambda fn: fn.__name__
)
write_socket_test = pytest.mark.parametrize(
"wait_writable", wait_writable_options, ids=lambda fn: fn.__name__
)
notify_closing_test = pytest.mark.parametrize(
"notify_closing", notify_closing_options, ids=lambda fn: fn.__name__
)
# XX These tests are all a bit dicey because they can't distinguish between
# wait_on_{read,writ}able blocking the way it should, versus blocking
# momentarily and then immediately resuming.
@read_socket_test
@write_socket_test
async def test_wait_basic(socketpair, wait_readable, wait_writable):
a, b = socketpair
# They start out writable()
with assert_checkpoints():
await wait_writable(a)
# But readable() blocks until data arrives
record = []
async def block_on_read():
try:
with assert_checkpoints():
await wait_readable(a)
except _core.Cancelled:
record.append("cancelled")
else:
record.append("readable")
assert a.recv(10) == b"x"
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_read)
await wait_all_tasks_blocked()
assert record == []
b.send(b"x")
fill_socket(a)
# Now writable will block, but readable won't
with assert_checkpoints():
await wait_readable(b)
record = []
async def block_on_write():
try:
with assert_checkpoints():
await wait_writable(a)
except _core.Cancelled:
record.append("cancelled")
else:
record.append("writable")
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_write)
await wait_all_tasks_blocked()
assert record == []
drain_socket(b)
# check cancellation
record = []
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_read)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == ["cancelled"]
fill_socket(a)
record = []
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_write)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == ["cancelled"]
@read_socket_test
async def test_double_read(socketpair, wait_readable):
a, b = socketpair
# You can't have two tasks trying to read from a socket at the same time
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_readable, a)
await wait_all_tasks_blocked()
with pytest.raises(_core.BusyResourceError):
await wait_readable(a)
nursery.cancel_scope.cancel()
@write_socket_test
async def test_double_write(socketpair, wait_writable):
a, b = socketpair
# You can't have two tasks trying to write to a socket at the same time
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_writable, a)
await wait_all_tasks_blocked()
with pytest.raises(_core.BusyResourceError):
await wait_writable(a)
nursery.cancel_scope.cancel()
@read_socket_test
@write_socket_test
@notify_closing_test
async def test_interrupted_by_close(
socketpair, wait_readable, wait_writable, notify_closing
):
a, b = socketpair
async def reader():
with pytest.raises(_core.ClosedResourceError):
await wait_readable(a)
async def writer():
with pytest.raises(_core.ClosedResourceError):
await wait_writable(a)
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(reader)
nursery.start_soon(writer)
await wait_all_tasks_blocked()
notify_closing(a)
@read_socket_test
@write_socket_test
async def test_socket_simultaneous_read_write(socketpair, wait_readable, wait_writable):
record = []
async def r_task(sock):
await wait_readable(sock)
record.append("r_task")
async def w_task(sock):
await wait_writable(sock)
record.append("w_task")
a, b = socketpair
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(r_task, a)
nursery.start_soon(w_task, a)
await wait_all_tasks_blocked()
assert record == []
b.send(b"x")
await wait_all_tasks_blocked()
assert record == ["r_task"]
drain_socket(b)
await wait_all_tasks_blocked()
assert record == ["r_task", "w_task"]
@read_socket_test
@write_socket_test
async def test_socket_actual_streaming(socketpair, wait_readable, wait_writable):
a, b = socketpair
# Use a small send buffer on one of the sockets to increase the chance of
# getting partial writes
a.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_SNDBUF, 10000)
N = 1000000 # 1 megabyte
MAX_CHUNK = 65536
results = {}
async def sender(sock, seed, key):
r = random.Random(seed)
sent = 0
while sent < N:
print("sent", sent)
chunk = bytearray(r.randrange(MAX_CHUNK))
while chunk:
with assert_checkpoints():
await wait_writable(sock)
this_chunk_size = sock.send(chunk)
sent += this_chunk_size
del chunk[:this_chunk_size]
sock.shutdown(stdlib_socket.SHUT_WR)
results[key] = sent
async def receiver(sock, key):
received = 0
while True:
print("received", received)
with assert_checkpoints():
await wait_readable(sock)
this_chunk_size = len(sock.recv(MAX_CHUNK))
if not this_chunk_size:
break
received += this_chunk_size
results[key] = received
async with _core.open_nursery() as nursery:
nursery.start_soon(sender, a, 0, "send_a")
nursery.start_soon(sender, b, 1, "send_b")
nursery.start_soon(receiver, a, "recv_a")
nursery.start_soon(receiver, b, "recv_b")
assert results["send_a"] == results["recv_b"]
assert results["send_b"] == results["recv_a"]
async def test_notify_closing_on_invalid_object():
# It should either be a no-op (generally on Unix, where we don't know
# which fds are valid), or an OSError (on Windows, where we currently only
# support sockets, so we have to do some validation to figure out whether
# it's a socket or a regular handle).
got_oserror = False
got_no_error = False
try:
trio.lowlevel.notify_closing(-1)
except OSError:
got_oserror = True
else:
got_no_error = True
assert got_oserror or got_no_error
async def test_wait_on_invalid_object():
# We definitely want to raise an error everywhere if you pass in an
# invalid fd to wait_*
for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]:
with stdlib_socket.socket() as s:
fileno = s.fileno()
# We just closed the socket and don't do anything else in between, so
# we can be confident that the fileno hasn't be reassigned.
with pytest.raises(OSError):
await wait(fileno)
async def test_io_manager_statistics():
def check(*, expected_readers, expected_writers):
statistics = _core.current_statistics()
print(statistics)
iostats = statistics.io_statistics
if iostats.backend in ["epoll", "windows"]:
assert iostats.tasks_waiting_read == expected_readers
assert iostats.tasks_waiting_write == expected_writers
else:
assert iostats.backend == "kqueue"
assert iostats.tasks_waiting == expected_readers + expected_writers
a1, b1 = stdlib_socket.socketpair()
a2, b2 = stdlib_socket.socketpair()
a3, b3 = stdlib_socket.socketpair()
for sock in [a1, b1, a2, b2, a3, b3]:
sock.setblocking(False)
with a1, b1, a2, b2, a3, b3:
# let the call_soon_task settle down
await wait_all_tasks_blocked()
# 1 for call_soon_task
check(expected_readers=1, expected_writers=0)
# We want:
# - one socket with a writer blocked
# - two sockets with a reader blocked
# - a socket with both blocked
fill_socket(a1)
fill_socket(a3)
async with _core.open_nursery() as nursery:
nursery.start_soon(_core.wait_writable, a1)
nursery.start_soon(_core.wait_readable, a2)
nursery.start_soon(_core.wait_readable, b2)
nursery.start_soon(_core.wait_writable, a3)
nursery.start_soon(_core.wait_readable, a3)
await wait_all_tasks_blocked()
# +1 for call_soon_task
check(expected_readers=3 + 1, expected_writers=2)
nursery.cancel_scope.cancel()
# 1 for call_soon_task
check(expected_readers=1, expected_writers=0)
async def test_can_survive_unnotified_close():
# An "unnotified" close is when the user closes an fd/socket/handle
# directly, without calling notify_closing first. This should never happen
# -- users should call notify_closing before closing things. But, just in
# case they don't, we would still like to avoid exploding.
#
# Acceptable behaviors:
# - wait_* never return, but can be cancelled cleanly
# - wait_* exit cleanly
# - wait_* raise an OSError
#
# Not acceptable:
# - getting stuck in an uncancellable state
# - TrioInternalError blowing up the whole run
#
# This test exercises some tricky "unnotified close" scenarios, to make
# sure we get the "acceptable" behaviors.
async def allow_OSError(async_func, *args):
with suppress(OSError):
await async_func(*args)
with stdlib_socket.socket() as s:
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
await wait_all_tasks_blocked()
s.close()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# We hit different paths on Windows depending on whether we close the last
# handle to the object (which produces a LOCAL_CLOSE notification and
# wakes up wait_readable), or only close one of the handles (which leaves
# wait_readable pending until cancelled).
with stdlib_socket.socket() as s, s.dup() as s2: # noqa: F841
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
await wait_all_tasks_blocked()
s.close()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# A more elaborate case, with two tasks waiting. On windows and epoll,
# the two tasks get muxed together onto a single underlying wait
# operation. So when they're cancelled, there's a brief moment where one
# of the tasks is cancelled but the other isn't, so we try to re-issue the
# underlying wait operation. But here, the handle we were going to use to
# do that has been pulled out from under our feet... so test that we can
# survive this.
a, b = stdlib_socket.socketpair()
with a, b, a.dup() as a2: # noqa: F841
a.setblocking(False)
b.setblocking(False)
fill_socket(a)
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
await wait_all_tasks_blocked()
a.close()
nursery.cancel_scope.cancel()
# A similar case, but now the single-task-wakeup happens due to I/O
# arriving, not a cancellation, so the operation gets re-issued from
# handle_io context rather than abort context.
a, b = stdlib_socket.socketpair()
with a, b, a.dup() as a2: # noqa: F841
print("a={}, b={}, a2={}".format(a.fileno(), b.fileno(), a2.fileno()))
a.setblocking(False)
b.setblocking(False)
fill_socket(a)
e = trio.Event()
# We want to wait for the kernel to process the wakeup on 'a', if any.
# But depending on the platform, we might not get a wakeup on 'a'. So
# we put one task to sleep waiting on 'a', and we put a second task to
# sleep waiting on 'a2', with the idea that the 'a2' notification will
# definitely arrive, and when it does then we can assume that whatever
# notification was going to arrive for 'a' has also arrived.
async def wait_readable_a2_then_set():
await trio.lowlevel.wait_readable(a2)
e.set()
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
nursery.start_soon(wait_readable_a2_then_set)
await wait_all_tasks_blocked()
a.close()
b.send(b"x")
# Make sure that the wakeup has been received and everything has
# settled before cancelling the wait_writable.
await e.wait()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()

View File

@@ -0,0 +1,501 @@
import outcome
import pytest
import sys
import os
import signal
import threading
import contextlib
import time
from async_generator import (
async_generator,
yield_,
isasyncgenfunction,
asynccontextmanager,
)
from ... import _core
from ...testing import wait_all_tasks_blocked
from ..._util import signal_raise, is_main_thread
from ..._timeouts import sleep
from .tutil import slow
def ki_self():
signal_raise(signal.SIGINT)
def test_ki_self():
with pytest.raises(KeyboardInterrupt):
ki_self()
async def test_ki_enabled():
# Regular tasks aren't KI-protected
assert not _core.currently_ki_protected()
# Low-level call-soon callbacks are KI-protected
token = _core.current_trio_token()
record = []
def check():
record.append(_core.currently_ki_protected())
token.run_sync_soon(check)
await wait_all_tasks_blocked()
assert record == [True]
@_core.enable_ki_protection
def protected():
assert _core.currently_ki_protected()
unprotected()
@_core.disable_ki_protection
def unprotected():
assert not _core.currently_ki_protected()
protected()
@_core.enable_ki_protection
async def aprotected():
assert _core.currently_ki_protected()
await aunprotected()
@_core.disable_ki_protection
async def aunprotected():
assert not _core.currently_ki_protected()
await aprotected()
# make sure that the decorator here overrides the automatic manipulation
# that start_soon() does:
async with _core.open_nursery() as nursery:
nursery.start_soon(aprotected)
nursery.start_soon(aunprotected)
@_core.enable_ki_protection
def gen_protected():
assert _core.currently_ki_protected()
yield
for _ in gen_protected():
pass
@_core.disable_ki_protection
def gen_unprotected():
assert not _core.currently_ki_protected()
yield
for _ in gen_unprotected():
pass
# This used to be broken due to
#
# https://bugs.python.org/issue29590
#
# Specifically, after a coroutine is resumed with .throw(), then the stack
# makes it look like the immediate caller is the function that called
# .throw(), not the actual caller. So child() here would have a caller deep in
# the guts of the run loop, and always be protected, even when it shouldn't
# have been. (Solution: we don't use .throw() anymore.)
async def test_ki_enabled_after_yield_briefly():
@_core.enable_ki_protection
async def protected():
await child(True)
@_core.disable_ki_protection
async def unprotected():
await child(False)
async def child(expected):
import traceback
traceback.print_stack()
assert _core.currently_ki_protected() == expected
await _core.checkpoint()
traceback.print_stack()
assert _core.currently_ki_protected() == expected
await protected()
await unprotected()
# This also used to be broken due to
# https://bugs.python.org/issue29590
async def test_generator_based_context_manager_throw():
@contextlib.contextmanager
@_core.enable_ki_protection
def protected_manager():
assert _core.currently_ki_protected()
try:
yield
finally:
assert _core.currently_ki_protected()
with protected_manager():
assert not _core.currently_ki_protected()
with pytest.raises(KeyError):
# This is the one that used to fail
with protected_manager():
raise KeyError
async def test_agen_protection():
@_core.enable_ki_protection
@async_generator
async def agen_protected1():
assert _core.currently_ki_protected()
try:
await yield_()
finally:
assert _core.currently_ki_protected()
@_core.disable_ki_protection
@async_generator
async def agen_unprotected1():
assert not _core.currently_ki_protected()
try:
await yield_()
finally:
assert not _core.currently_ki_protected()
# Swap the order of the decorators:
@async_generator
@_core.enable_ki_protection
async def agen_protected2():
assert _core.currently_ki_protected()
try:
await yield_()
finally:
assert _core.currently_ki_protected()
@async_generator
@_core.disable_ki_protection
async def agen_unprotected2():
assert not _core.currently_ki_protected()
try:
await yield_()
finally:
assert not _core.currently_ki_protected()
# Native async generators
@_core.enable_ki_protection
async def agen_protected3():
assert _core.currently_ki_protected()
try:
yield
finally:
assert _core.currently_ki_protected()
@_core.disable_ki_protection
async def agen_unprotected3():
assert not _core.currently_ki_protected()
try:
yield
finally:
assert not _core.currently_ki_protected()
for agen_fn in [
agen_protected1,
agen_protected2,
agen_protected3,
agen_unprotected1,
agen_unprotected2,
agen_unprotected3,
]:
async for _ in agen_fn(): # noqa
assert not _core.currently_ki_protected()
# asynccontextmanager insists that the function passed must itself be an
# async gen function, not a wrapper around one
if isasyncgenfunction(agen_fn):
async with asynccontextmanager(agen_fn)():
assert not _core.currently_ki_protected()
# Another case that's tricky due to:
# https://bugs.python.org/issue29590
with pytest.raises(KeyError):
async with asynccontextmanager(agen_fn)():
raise KeyError
# Test the case where there's no magic local anywhere in the call stack
def test_ki_disabled_out_of_context():
assert _core.currently_ki_protected()
def test_ki_disabled_in_del():
def nestedfunction():
return _core.currently_ki_protected()
def __del__():
assert _core.currently_ki_protected()
assert nestedfunction()
@_core.disable_ki_protection
def outerfunction():
assert not _core.currently_ki_protected()
assert not nestedfunction()
__del__()
__del__()
outerfunction()
assert nestedfunction()
def test_ki_protection_works():
async def sleeper(name, record):
try:
while True:
await _core.checkpoint()
except _core.Cancelled:
record.add(name + " ok")
async def raiser(name, record):
try:
# os.kill runs signal handlers before returning, so we don't need
# to worry that the handler will be delayed
print("killing, protection =", _core.currently_ki_protected())
ki_self()
except KeyboardInterrupt:
print("raised!")
# Make sure we aren't getting cancelled as well as siginted
await _core.checkpoint()
record.add(name + " raise ok")
raise
else:
print("didn't raise!")
# If we didn't raise (b/c protected), then we *should* get
# cancelled at the next opportunity
try:
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
except _core.Cancelled:
record.add(name + " cancel ok")
# simulated control-C during raiser, which is *unprotected*
print("check 1")
record = set()
async def check_unprotected_kill():
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record)
nursery.start_soon(sleeper, "s2", record)
nursery.start_soon(raiser, "r1", record)
with pytest.raises(KeyboardInterrupt):
_core.run(check_unprotected_kill)
assert record == {"s1 ok", "s2 ok", "r1 raise ok"}
# simulated control-C during raiser, which is *protected*, so the KI gets
# delivered to the main task instead
print("check 2")
record = set()
async def check_protected_kill():
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record)
nursery.start_soon(sleeper, "s2", record)
nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record)
# __aexit__ blocks, and then receives the KI
with pytest.raises(KeyboardInterrupt):
_core.run(check_protected_kill)
assert record == {"s1 ok", "s2 ok", "r1 cancel ok"}
# kill at last moment still raises (run_sync_soon until it raises an
# error, then kill)
print("check 3")
async def check_kill_during_shutdown():
token = _core.current_trio_token()
def kill_during_shutdown():
assert _core.currently_ki_protected()
try:
token.run_sync_soon(kill_during_shutdown)
except _core.RunFinishedError:
# it's too late for regular handling! handle this!
print("kill! kill!")
ki_self()
token.run_sync_soon(kill_during_shutdown)
with pytest.raises(KeyboardInterrupt):
_core.run(check_kill_during_shutdown)
# KI arrives very early, before main is even spawned
print("check 4")
class InstrumentOfDeath:
def before_run(self):
ki_self()
async def main():
await _core.checkpoint()
with pytest.raises(KeyboardInterrupt):
_core.run(main, instruments=[InstrumentOfDeath()])
# checkpoint_if_cancelled notices pending KI
print("check 5")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint_if_cancelled()
_core.run(main)
# KI arrives while main task is not abortable, b/c already scheduled
print("check 6")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
ki_self()
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
_core.run(main)
# KI arrives while main task is not abortable, b/c refuses to be aborted
print("check 7")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
def abort(_):
_core.reschedule(task, outcome.Value(1))
return _core.Abort.FAILED
assert await _core.wait_task_rescheduled(abort) == 1
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
_core.run(main)
# KI delivered via slow abort
print("check 8")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
def abort(raise_cancel):
result = outcome.capture(raise_cancel)
_core.reschedule(task, result)
return _core.Abort.FAILED
with pytest.raises(KeyboardInterrupt):
assert await _core.wait_task_rescheduled(abort)
await _core.checkpoint()
_core.run(main)
# KI arrives just before main task exits, so the run_sync_soon machinery
# is still functioning and will accept the callback to deliver the KI, but
# by the time the callback is actually run, main has exited and can't be
# aborted.
print("check 9")
@_core.enable_ki_protection
async def main():
ki_self()
with pytest.raises(KeyboardInterrupt):
_core.run(main)
print("check 10")
# KI in unprotected code, with
# restrict_keyboard_interrupt_to_checkpoints=True
record = []
async def main():
# We're not KI protected...
assert not _core.currently_ki_protected()
ki_self()
# ...but even after the KI, we keep running uninterrupted...
record.append("ok")
# ...until we hit a checkpoint:
with pytest.raises(KeyboardInterrupt):
await sleep(10)
_core.run(main, restrict_keyboard_interrupt_to_checkpoints=True)
assert record == ["ok"]
record = []
# Exact same code raises KI early if we leave off the argument, doesn't
# even reach the record.append call:
with pytest.raises(KeyboardInterrupt):
_core.run(main)
assert record == []
# KI arrives while main task is inside a cancelled cancellation scope
# the KeyboardInterrupt should take priority
print("check 11")
@_core.enable_ki_protection
async def main():
assert _core.currently_ki_protected()
with _core.CancelScope() as cancel_scope:
cancel_scope.cancel()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
_core.run(main)
def test_ki_is_good_neighbor():
# in the unlikely event someone overwrites our signal handler, we leave
# the overwritten one be
try:
orig = signal.getsignal(signal.SIGINT)
def my_handler(signum, frame): # pragma: no cover
pass
async def main():
signal.signal(signal.SIGINT, my_handler)
_core.run(main)
assert signal.getsignal(signal.SIGINT) is my_handler
finally:
signal.signal(signal.SIGINT, orig)
# Regression test for #461
def test_ki_with_broken_threads():
thread = threading.main_thread()
# scary!
original = threading._active[thread.ident]
# put this in a try finally so we don't have a chance of cascading a
# breakage down to everything else
try:
del threading._active[thread.ident]
@_core.enable_ki_protection
async def inner():
assert signal.getsignal(signal.SIGINT) != signal.default_int_handler
_core.run(inner)
finally:
threading._active[thread.ident] = original

View File

@@ -0,0 +1,115 @@
import pytest
from ... import _core
# scary runvar tests
def test_runvar_smoketest():
t1 = _core.RunVar("test1")
t2 = _core.RunVar("test2", default="catfish")
assert "RunVar" in repr(t1)
async def first_check():
with pytest.raises(LookupError):
t1.get()
t1.set("swordfish")
assert t1.get() == "swordfish"
assert t2.get() == "catfish"
assert t2.get(default="eel") == "eel"
t2.set("goldfish")
assert t2.get() == "goldfish"
assert t2.get(default="tuna") == "goldfish"
async def second_check():
with pytest.raises(LookupError):
t1.get()
assert t2.get() == "catfish"
_core.run(first_check)
_core.run(second_check)
def test_runvar_resetting():
t1 = _core.RunVar("test1")
t2 = _core.RunVar("test2", default="dogfish")
t3 = _core.RunVar("test3")
async def reset_check():
token = t1.set("moonfish")
assert t1.get() == "moonfish"
t1.reset(token)
with pytest.raises(TypeError):
t1.reset(None)
with pytest.raises(LookupError):
t1.get()
token2 = t2.set("catdogfish")
assert t2.get() == "catdogfish"
t2.reset(token2)
assert t2.get() == "dogfish"
with pytest.raises(ValueError):
t2.reset(token2)
token3 = t3.set("basculin")
assert t3.get() == "basculin"
with pytest.raises(ValueError):
t1.reset(token3)
_core.run(reset_check)
def test_runvar_sync():
t1 = _core.RunVar("test1")
async def sync_check():
async def task1():
t1.set("plaice")
assert t1.get() == "plaice"
async def task2(tok):
t1.reset(token)
with pytest.raises(LookupError):
t1.get()
t1.set("cod")
async with _core.open_nursery() as n:
token = t1.set("cod")
assert t1.get() == "cod"
n.start_soon(task1)
await _core.wait_all_tasks_blocked()
assert t1.get() == "plaice"
n.start_soon(task2, token)
await _core.wait_all_tasks_blocked()
assert t1.get() == "cod"
_core.run(sync_check)
def test_accessing_runvar_outside_run_call_fails():
t1 = _core.RunVar("test1")
with pytest.raises(RuntimeError):
t1.set("asdf")
with pytest.raises(RuntimeError):
t1.get()
async def get_token():
return t1.set("ok")
token = _core.run(get_token)
with pytest.raises(RuntimeError):
t1.reset(token)

View File

@@ -0,0 +1,170 @@
from math import inf
import time
import pytest
from trio import sleep
from ... import _core
from .. import wait_all_tasks_blocked
from .._mock_clock import MockClock
from .tutil import slow
def test_mock_clock():
REAL_NOW = 123.0
c = MockClock()
c._real_clock = lambda: REAL_NOW
repr(c) # smoke test
assert c.rate == 0
assert c.current_time() == 0
c.jump(1.2)
assert c.current_time() == 1.2
with pytest.raises(ValueError):
c.jump(-1)
assert c.current_time() == 1.2
assert c.deadline_to_sleep_time(1.1) == 0
assert c.deadline_to_sleep_time(1.2) == 0
assert c.deadline_to_sleep_time(1.3) > 999999
with pytest.raises(ValueError):
c.rate = -1
assert c.rate == 0
c.rate = 2
assert c.current_time() == 1.2
REAL_NOW += 1
assert c.current_time() == 3.2
assert c.deadline_to_sleep_time(3.1) == 0
assert c.deadline_to_sleep_time(3.2) == 0
assert c.deadline_to_sleep_time(4.2) == 0.5
c.rate = 0.5
assert c.current_time() == 3.2
assert c.deadline_to_sleep_time(3.1) == 0
assert c.deadline_to_sleep_time(3.2) == 0
assert c.deadline_to_sleep_time(4.2) == 2.0
c.jump(0.8)
assert c.current_time() == 4.0
REAL_NOW += 1
assert c.current_time() == 4.5
c2 = MockClock(rate=3)
assert c2.rate == 3
assert c2.current_time() < 10
async def test_mock_clock_autojump(mock_clock):
assert mock_clock.autojump_threshold == inf
mock_clock.autojump_threshold = 0
assert mock_clock.autojump_threshold == 0
real_start = time.perf_counter()
virtual_start = _core.current_time()
for i in range(10):
print("sleeping {} seconds".format(10 * i))
await sleep(10 * i)
print("woke up!")
assert virtual_start + 10 * i == _core.current_time()
virtual_start = _core.current_time()
real_duration = time.perf_counter() - real_start
print("Slept {} seconds in {} seconds".format(10 * sum(range(10)), real_duration))
assert real_duration < 1
mock_clock.autojump_threshold = 0.02
t = _core.current_time()
# this should wake up before the autojump threshold triggers, so time
# shouldn't change
await wait_all_tasks_blocked()
assert t == _core.current_time()
# this should too
await wait_all_tasks_blocked(0.01)
assert t == _core.current_time()
# set up a situation where the autojump task is blocked for a long long
# time, to make sure that cancel-and-adjust-threshold logic is working
mock_clock.autojump_threshold = 10000
await wait_all_tasks_blocked()
mock_clock.autojump_threshold = 0
# if the above line didn't take affect immediately, then this would be
# bad:
await sleep(100000)
async def test_mock_clock_autojump_interference(mock_clock):
mock_clock.autojump_threshold = 0.02
mock_clock2 = MockClock()
# messing with the autojump threshold of a clock that isn't actually
# installed in the run loop shouldn't do anything.
mock_clock2.autojump_threshold = 0.01
# if the autojump_threshold of 0.01 were in effect, then the next line
# would block forever, as the autojump task kept waking up to try to
# jump the clock.
await wait_all_tasks_blocked(0.015)
# but the 0.02 limit does apply
await sleep(100000)
def test_mock_clock_autojump_preset():
# Check that we can set the autojump_threshold before the clock is
# actually in use, and it gets picked up
mock_clock = MockClock(autojump_threshold=0.1)
mock_clock.autojump_threshold = 0.01
real_start = time.perf_counter()
_core.run(sleep, 10000, clock=mock_clock)
assert time.perf_counter() - real_start < 1
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock):
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with the default cushion=0.
mock_clock.autojump_threshold = 0
record = []
async def sleeper():
await sleep(100)
record.append("yawn")
async def waiter():
await wait_all_tasks_blocked()
record.append("waiter woke")
await sleep(1000)
record.append("waiter done")
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper)
nursery.start_soon(waiter)
assert record == ["waiter woke", "yawn", "waiter done"]
@slow
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clock):
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with a non-zero cushion.
mock_clock.autojump_threshold = 0
record = []
async def sleeper():
await sleep(100)
record.append("yawn")
async def waiter():
await wait_all_tasks_blocked(1)
record.append("waiter done")
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper)
nursery.start_soon(waiter)
assert record == ["waiter done", "yawn"]

View File

@@ -0,0 +1,772 @@
import gc
import logging
import pytest
from traceback import (
extract_tb,
print_exception,
format_exception,
)
from traceback import _cause_message # type: ignore
import sys
import os
import re
from pathlib import Path
import subprocess
from .tutil import slow
from .._multierror import MultiError, concat_tb
from ..._core import open_nursery
class NotHashableException(Exception):
code = None
def __init__(self, code):
super().__init__()
self.code = code
def __eq__(self, other):
if not isinstance(other, NotHashableException):
return False
return self.code == other.code
async def raise_nothashable(code):
raise NotHashableException(code)
def raiser1():
raiser1_2()
def raiser1_2():
raiser1_3()
def raiser1_3():
raise ValueError("raiser1_string")
def raiser2():
raiser2_2()
def raiser2_2():
raise KeyError("raiser2_string")
def raiser3():
raise NameError
def get_exc(raiser):
try:
raiser()
except Exception as exc:
return exc
def get_tb(raiser):
return get_exc(raiser).__traceback__
def einfo(exc):
return (type(exc), exc, exc.__traceback__)
def test_concat_tb():
tb1 = get_tb(raiser1)
tb2 = get_tb(raiser2)
# These return a list of (filename, lineno, fn name, text) tuples
# https://docs.python.org/3/library/traceback.html#traceback.extract_tb
entries1 = extract_tb(tb1)
entries2 = extract_tb(tb2)
tb12 = concat_tb(tb1, tb2)
assert extract_tb(tb12) == entries1 + entries2
tb21 = concat_tb(tb2, tb1)
assert extract_tb(tb21) == entries2 + entries1
# Check degenerate cases
assert extract_tb(concat_tb(None, tb1)) == entries1
assert extract_tb(concat_tb(tb1, None)) == entries1
assert concat_tb(None, None) is None
# Make sure the original tracebacks didn't get mutated by mistake
assert extract_tb(get_tb(raiser1)) == entries1
assert extract_tb(get_tb(raiser2)) == entries2
def test_MultiError():
exc1 = get_exc(raiser1)
exc2 = get_exc(raiser2)
assert MultiError([exc1]) is exc1
m = MultiError([exc1, exc2])
assert m.exceptions == [exc1, exc2]
assert "ValueError" in str(m)
assert "ValueError" in repr(m)
with pytest.raises(TypeError):
MultiError(object())
with pytest.raises(TypeError):
MultiError([KeyError(), ValueError])
def test_MultiErrorOfSingleMultiError():
# For MultiError([MultiError]), ensure there is no bad recursion by the
# constructor where __init__ is called if __new__ returns a bare MultiError.
exceptions = [KeyError(), ValueError()]
a = MultiError(exceptions)
b = MultiError([a])
assert b == a
assert b.exceptions == exceptions
async def test_MultiErrorNotHashable():
exc1 = NotHashableException(42)
exc2 = NotHashableException(4242)
exc3 = ValueError()
assert exc1 != exc2
assert exc1 != exc3
with pytest.raises(MultiError):
async with open_nursery() as nursery:
nursery.start_soon(raise_nothashable, 42)
nursery.start_soon(raise_nothashable, 4242)
def test_MultiError_filter_NotHashable():
excs = MultiError([NotHashableException(42), ValueError()])
def handle_ValueError(exc):
if isinstance(exc, ValueError):
return None
else:
return exc
filtered_excs = MultiError.filter(handle_ValueError, excs)
assert isinstance(filtered_excs, NotHashableException)
def test_traceback_recursion():
exc1 = RuntimeError()
exc2 = KeyError()
exc3 = NotHashableException(42)
# Note how this creates a loop, where exc1 refers to exc1
# This could trigger an infinite recursion; the 'seen' set is supposed to prevent
# this.
exc1.__cause__ = MultiError([exc1, exc2, exc3])
format_exception(*einfo(exc1))
def make_tree():
# Returns an object like:
# MultiError([
# MultiError([
# ValueError,
# KeyError,
# ]),
# NameError,
# ])
# where all exceptions except the root have a non-trivial traceback.
exc1 = get_exc(raiser1)
exc2 = get_exc(raiser2)
exc3 = get_exc(raiser3)
# Give m12 a non-trivial traceback
try:
raise MultiError([exc1, exc2])
except BaseException as m12:
return MultiError([m12, exc3])
def assert_tree_eq(m1, m2):
if m1 is None or m2 is None:
assert m1 is m2
return
assert type(m1) is type(m2)
assert extract_tb(m1.__traceback__) == extract_tb(m2.__traceback__)
assert_tree_eq(m1.__cause__, m2.__cause__)
assert_tree_eq(m1.__context__, m2.__context__)
if isinstance(m1, MultiError):
assert len(m1.exceptions) == len(m2.exceptions)
for e1, e2 in zip(m1.exceptions, m2.exceptions):
assert_tree_eq(e1, e2)
def test_MultiError_filter():
def null_handler(exc):
return exc
m = make_tree()
assert_tree_eq(m, m)
assert MultiError.filter(null_handler, m) is m
assert_tree_eq(m, make_tree())
# Make sure we don't pick up any detritus if run in a context where
# implicit exception chaining would like to kick in
m = make_tree()
try:
raise ValueError
except ValueError:
assert MultiError.filter(null_handler, m) is m
assert_tree_eq(m, make_tree())
def simple_filter(exc):
if isinstance(exc, ValueError):
return None
if isinstance(exc, KeyError):
return RuntimeError()
return exc
new_m = MultiError.filter(simple_filter, make_tree())
assert isinstance(new_m, MultiError)
assert len(new_m.exceptions) == 2
# was: [[ValueError, KeyError], NameError]
# ValueError disappeared & KeyError became RuntimeError, so now:
assert isinstance(new_m.exceptions[0], RuntimeError)
assert isinstance(new_m.exceptions[1], NameError)
# implicit chaining:
assert isinstance(new_m.exceptions[0].__context__, KeyError)
# also, the traceback on the KeyError incorporates what used to be the
# traceback on its parent MultiError
orig = make_tree()
# make sure we have the right path
assert isinstance(orig.exceptions[0].exceptions[1], KeyError)
# get original traceback summary
orig_extracted = (
extract_tb(orig.__traceback__)
+ extract_tb(orig.exceptions[0].__traceback__)
+ extract_tb(orig.exceptions[0].exceptions[1].__traceback__)
)
def p(exc):
print_exception(type(exc), exc, exc.__traceback__)
p(orig)
p(orig.exceptions[0])
p(orig.exceptions[0].exceptions[1])
p(new_m.exceptions[0].__context__)
# compare to the new path
assert new_m.__traceback__ is None
new_extracted = extract_tb(new_m.exceptions[0].__context__.__traceback__)
assert orig_extracted == new_extracted
# check preserving partial tree
def filter_NameError(exc):
if isinstance(exc, NameError):
return None
return exc
m = make_tree()
new_m = MultiError.filter(filter_NameError, m)
# with the NameError gone, the other branch gets promoted
assert new_m is m.exceptions[0]
# check fully handling everything
def filter_all(exc):
return None
assert MultiError.filter(filter_all, make_tree()) is None
def test_MultiError_catch():
# No exception to catch
def noop(_):
pass # pragma: no cover
with MultiError.catch(noop):
pass
# Simple pass-through of all exceptions
m = make_tree()
with pytest.raises(MultiError) as excinfo:
with MultiError.catch(lambda exc: exc):
raise m
assert excinfo.value is m
# Should be unchanged, except that we added a traceback frame by raising
# it here
assert m.__traceback__ is not None
assert m.__traceback__.tb_frame.f_code.co_name == "test_MultiError_catch"
assert m.__traceback__.tb_next is None
m.__traceback__ = None
assert_tree_eq(m, make_tree())
# Swallows everything
with MultiError.catch(lambda _: None):
raise make_tree()
def simple_filter(exc):
if isinstance(exc, ValueError):
return None
if isinstance(exc, KeyError):
return RuntimeError()
return exc
with pytest.raises(MultiError) as excinfo:
with MultiError.catch(simple_filter):
raise make_tree()
new_m = excinfo.value
assert isinstance(new_m, MultiError)
assert len(new_m.exceptions) == 2
# was: [[ValueError, KeyError], NameError]
# ValueError disappeared & KeyError became RuntimeError, so now:
assert isinstance(new_m.exceptions[0], RuntimeError)
assert isinstance(new_m.exceptions[1], NameError)
# Make sure that Python did not successfully attach the old MultiError to
# our new MultiError's __context__
assert not new_m.__suppress_context__
assert new_m.__context__ is None
# check preservation of __cause__ and __context__
v = ValueError()
v.__cause__ = KeyError()
with pytest.raises(ValueError) as excinfo:
with MultiError.catch(lambda exc: exc):
raise v
assert isinstance(excinfo.value.__cause__, KeyError)
v = ValueError()
context = KeyError()
v.__context__ = context
with pytest.raises(ValueError) as excinfo:
with MultiError.catch(lambda exc: exc):
raise v
assert excinfo.value.__context__ is context
assert not excinfo.value.__suppress_context__
for suppress_context in [True, False]:
v = ValueError()
context = KeyError()
v.__context__ = context
v.__suppress_context__ = suppress_context
distractor = RuntimeError()
with pytest.raises(ValueError) as excinfo:
def catch_RuntimeError(exc):
if isinstance(exc, RuntimeError):
return None
else:
return exc
with MultiError.catch(catch_RuntimeError):
raise MultiError([v, distractor])
assert excinfo.value.__context__ is context
assert excinfo.value.__suppress_context__ == suppress_context
@pytest.mark.skipif(
sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC"
)
def test_MultiError_catch_doesnt_create_cyclic_garbage():
# https://github.com/python-trio/trio/pull/2063
gc.collect()
old_flags = gc.get_debug()
def make_multi():
# make_tree creates cycles itself, so a simple
raise MultiError([get_exc(raiser1), get_exc(raiser2)])
def simple_filter(exc):
if isinstance(exc, ValueError):
return Exception()
if isinstance(exc, KeyError):
return RuntimeError()
assert False, "only ValueError and KeyError should exist" # pragma: no cover
try:
gc.set_debug(gc.DEBUG_SAVEALL)
with pytest.raises(MultiError):
# covers MultiErrorCatcher.__exit__ and _multierror.copy_tb
with MultiError.catch(simple_filter):
raise make_multi()
gc.collect()
assert not gc.garbage
finally:
gc.set_debug(old_flags)
gc.garbage.clear()
def assert_match_in_seq(pattern_list, string):
offset = 0
print("looking for pattern matches...")
for pattern in pattern_list:
print("checking pattern:", pattern)
reobj = re.compile(pattern)
match = reobj.search(string, offset)
assert match is not None
offset = match.end()
def test_assert_match_in_seq():
assert_match_in_seq(["a", "b"], "xx a xx b xx")
assert_match_in_seq(["b", "a"], "xx b xx a xx")
with pytest.raises(AssertionError):
assert_match_in_seq(["a", "b"], "xx b xx a xx")
def test_format_exception():
exc = get_exc(raiser1)
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" not in formatted
assert "in raiser2_2" not in formatted
assert "direct cause" not in formatted
assert "During handling" not in formatted
exc = get_exc(raiser1)
exc.__cause__ = get_exc(raiser2)
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" in formatted
assert "in raiser2_2" in formatted
assert "direct cause" in formatted
assert "During handling" not in formatted
# ensure cause included
assert _cause_message in formatted
exc = get_exc(raiser1)
exc.__context__ = get_exc(raiser2)
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" in formatted
assert "in raiser2_2" in formatted
assert "direct cause" not in formatted
assert "During handling" in formatted
exc.__suppress_context__ = True
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" not in formatted
assert "in raiser2_2" not in formatted
assert "direct cause" not in formatted
assert "During handling" not in formatted
# chain=False
exc = get_exc(raiser1)
exc.__context__ = get_exc(raiser2)
formatted = "".join(format_exception(*einfo(exc), chain=False))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" not in formatted
assert "in raiser2_2" not in formatted
assert "direct cause" not in formatted
assert "During handling" not in formatted
# limit
exc = get_exc(raiser1)
exc.__context__ = get_exc(raiser2)
# get_exc adds a frame that counts against the limit, so limit=2 means we
# get 1 deep into the raiser stack
formatted = "".join(format_exception(*einfo(exc), limit=2))
print(formatted)
assert "raiser1_string" in formatted
assert "in raiser1" in formatted
assert "in raiser1_2" not in formatted
assert "raiser2_string" in formatted
assert "in raiser2" in formatted
assert "in raiser2_2" not in formatted
exc = get_exc(raiser1)
exc.__context__ = get_exc(raiser2)
formatted = "".join(format_exception(*einfo(exc), limit=1))
print(formatted)
assert "raiser1_string" in formatted
assert "in raiser1" not in formatted
assert "raiser2_string" in formatted
assert "in raiser2" not in formatted
# handles loops
exc = get_exc(raiser1)
exc.__cause__ = exc
formatted = "".join(format_exception(*einfo(exc)))
assert "raiser1_string" in formatted
assert "in raiser1_3" in formatted
assert "raiser2_string" not in formatted
assert "in raiser2_2" not in formatted
# ensure duplicate exception is not included as cause
assert _cause_message not in formatted
# MultiError
formatted = "".join(format_exception(*einfo(make_tree())))
print(formatted)
assert_match_in_seq(
[
# Outer exception is MultiError
r"MultiError:",
# First embedded exception is the embedded MultiError
r"\nDetails of embedded exception 1",
# Which has a single stack frame from make_tree raising it
r"in make_tree",
# Then it has two embedded exceptions
r" Details of embedded exception 1",
r"in raiser1_2",
# for some reason ValueError has no quotes
r"ValueError: raiser1_string",
r" Details of embedded exception 2",
r"in raiser2_2",
# But KeyError does have quotes
r"KeyError: 'raiser2_string'",
# And finally the NameError, which is a sibling of the embedded
# MultiError
r"\nDetails of embedded exception 2:",
r"in raiser3",
r"NameError",
],
formatted,
)
# Prints duplicate exceptions in sub-exceptions
exc1 = get_exc(raiser1)
def raise1_raiser1():
try:
raise exc1
except:
raise ValueError("foo")
def raise2_raiser1():
try:
raise exc1
except:
raise KeyError("bar")
exc2 = get_exc(raise1_raiser1)
exc3 = get_exc(raise2_raiser1)
try:
raise MultiError([exc2, exc3])
except MultiError as e:
exc = e
formatted = "".join(format_exception(*einfo(exc)))
print(formatted)
assert_match_in_seq(
[
r"Traceback",
# Outer exception is MultiError
r"MultiError:",
# First embedded exception is the embedded ValueError with cause of raiser1
r"\nDetails of embedded exception 1",
# Print details of exc1
r" Traceback",
r"in get_exc",
r"in raiser1",
r"ValueError: raiser1_string",
# Print details of exc2
r"\n During handling of the above exception, another exception occurred:",
r" Traceback",
r"in get_exc",
r"in raise1_raiser1",
r" ValueError: foo",
# Second embedded exception is the embedded KeyError with cause of raiser1
r"\nDetails of embedded exception 2",
# Print details of exc1 again
r" Traceback",
r"in get_exc",
r"in raiser1",
r"ValueError: raiser1_string",
# Print details of exc3
r"\n During handling of the above exception, another exception occurred:",
r" Traceback",
r"in get_exc",
r"in raise2_raiser1",
r" KeyError: 'bar'",
],
formatted,
)
def test_logging(caplog):
exc1 = get_exc(raiser1)
exc2 = get_exc(raiser2)
m = MultiError([exc1, exc2])
message = "test test test"
try:
raise m
except MultiError as exc:
logging.getLogger().exception(message)
# Join lines together
formatted = "".join(format_exception(type(exc), exc, exc.__traceback__))
assert message in caplog.text
assert formatted in caplog.text
def run_script(name, use_ipython=False):
import trio
trio_path = Path(trio.__file__).parent.parent
script_path = Path(__file__).parent / "test_multierror_scripts" / name
env = dict(os.environ)
print("parent PYTHONPATH:", env.get("PYTHONPATH"))
if "PYTHONPATH" in env: # pragma: no cover
pp = env["PYTHONPATH"].split(os.pathsep)
else:
pp = []
pp.insert(0, str(trio_path))
pp.insert(0, str(script_path.parent))
env["PYTHONPATH"] = os.pathsep.join(pp)
print("subprocess PYTHONPATH:", env.get("PYTHONPATH"))
if use_ipython:
lines = [script_path.read_text(), "exit()"]
cmd = [
sys.executable,
"-u",
"-m",
"IPython",
# no startup files
"--quick",
"--TerminalIPythonApp.code_to_run=" + "\n".join(lines),
]
else:
cmd = [sys.executable, "-u", str(script_path)]
print("running:", cmd)
completed = subprocess.run(
cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
print("process output:")
print(completed.stdout.decode("utf-8"))
return completed
def check_simple_excepthook(completed):
assert_match_in_seq(
[
"in <module>",
"MultiError",
"Details of embedded exception 1",
"in exc1_fn",
"ValueError",
"Details of embedded exception 2",
"in exc2_fn",
"KeyError",
],
completed.stdout.decode("utf-8"),
)
def test_simple_excepthook():
completed = run_script("simple_excepthook.py")
check_simple_excepthook(completed)
def test_custom_excepthook():
# Check that user-defined excepthooks aren't overridden
completed = run_script("custom_excepthook.py")
assert_match_in_seq(
[
# The warning
"RuntimeWarning",
"already have a custom",
# The message printed by the custom hook, proving we didn't
# override it
"custom running!",
# The MultiError
"MultiError:",
],
completed.stdout.decode("utf-8"),
)
# This warning is triggered by ipython 7.5.0 on python 3.8
import warnings
warnings.filterwarnings(
"ignore",
message='.*"@coroutine" decorator is deprecated',
category=DeprecationWarning,
module="IPython.*",
)
try:
import IPython
except ImportError: # pragma: no cover
have_ipython = False
else:
have_ipython = True
need_ipython = pytest.mark.skipif(not have_ipython, reason="need IPython")
@slow
@need_ipython
def test_ipython_exc_handler():
completed = run_script("simple_excepthook.py", use_ipython=True)
check_simple_excepthook(completed)
@slow
@need_ipython
def test_ipython_imported_but_unused():
completed = run_script("simple_excepthook_IPython.py")
check_simple_excepthook(completed)
@slow
def test_partial_imported_but_unused():
# Check that a functools.partial as sys.excepthook doesn't cause an exception when
# importing trio. This was a problem due to the lack of a .__name__ attribute and
# happens when inside a pytest-qt test case for example.
completed = run_script("simple_excepthook_partial.py")
completed.check_returncode()
@slow
@need_ipython
def test_ipython_custom_exc_handler():
# Check we get a nice warning (but only one!) if the user is using IPython
# and already has some other set_custom_exc handler installed.
completed = run_script("ipython_custom_exc.py", use_ipython=True)
assert_match_in_seq(
[
# The warning
"RuntimeWarning",
"IPython detected",
"skip installing Trio",
# The MultiError
"MultiError",
"ValueError",
"KeyError",
],
completed.stdout.decode("utf-8"),
)
# Make sure our other warning doesn't show up
assert "custom sys.excepthook" not in completed.stdout.decode("utf-8")
@slow
@pytest.mark.skipif(
not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(),
reason="need Ubuntu with python3-apport installed",
)
def test_apport_excepthook_monkeypatch_interaction():
completed = run_script("apport_excepthook.py")
stdout = completed.stdout.decode("utf-8")
# No warning
assert "custom sys.excepthook" not in stdout
# Proper traceback
assert_match_in_seq(
["Details of embedded", "KeyError", "Details of embedded", "ValueError"],
stdout,
)

View File

@@ -0,0 +1,2 @@
# This isn't really a package, everything in here is a standalone script. This
# __init__.py is just to fool setup.py into actually installing the things.

View File

@@ -0,0 +1,7 @@
# https://coverage.readthedocs.io/en/latest/subprocess.html
try:
import coverage
except ImportError: # pragma: no cover
pass
else:
coverage.process_startup()

View File

@@ -0,0 +1,13 @@
# The apport_python_hook package is only installed as part of Ubuntu's system
# python, and not available in venvs. So before we can import it we have to
# make sure it's on sys.path.
import sys
sys.path.append("/usr/lib/python3/dist-packages")
import apport_python_hook
apport_python_hook.install()
import trio
raise trio.MultiError([KeyError("key_error"), ValueError("value_error")])

View File

@@ -0,0 +1,18 @@
import _common
import sys
def custom_excepthook(*args):
print("custom running!")
return sys.__excepthook__(*args)
sys.excepthook = custom_excepthook
# Should warn that we'll get kinda-broken tracebacks
import trio
# The custom excepthook should run, because Trio was polite and didn't
# override it
raise trio.MultiError([ValueError(), KeyError()])

View File

@@ -0,0 +1,36 @@
import _common
# Override the regular excepthook too -- it doesn't change anything either way
# because ipython doesn't use it, but we want to make sure Trio doesn't warn
# about it.
import sys
def custom_excepthook(*args):
print("custom running!")
return sys.__excepthook__(*args)
sys.excepthook = custom_excepthook
import IPython
ip = IPython.get_ipython()
# Set this to some random nonsense
class SomeError(Exception):
pass
def custom_exc_hook(etype, value, tb, tb_offset=None):
ip.showtraceback()
ip.set_custom_exc((SomeError,), custom_exc_hook)
import trio
# The custom excepthook should run, because Trio was polite and didn't
# override it
raise trio.MultiError([ValueError(), KeyError()])

View File

@@ -0,0 +1,21 @@
import _common
import trio
def exc1_fn():
try:
raise ValueError
except Exception as exc:
return exc
def exc2_fn():
try:
raise KeyError
except Exception as exc:
return exc
# This should be printed nicely, because Trio overrode sys.excepthook
raise trio.MultiError([exc1_fn(), exc2_fn()])

View File

@@ -0,0 +1,7 @@
import _common
# To tickle the "is IPython loaded?" logic, make sure that Trio tolerates
# IPython loaded but not actually in use
import IPython
import simple_excepthook

View File

@@ -0,0 +1,13 @@
import functools
import sys
import _common
# just making sure importing Trio doesn't fail if sys.excepthook doesn't have a
# .__name__ attribute
sys.excepthook = functools.partial(sys.excepthook)
assert not hasattr(sys.excepthook, "__name__")
import trio

View File

@@ -0,0 +1,197 @@
import pytest
from ... import _core
from ...testing import wait_all_tasks_blocked
from .._parking_lot import ParkingLot
from .tutil import check_sequence_matches
async def test_parking_lot_basic():
record = []
async def waiter(i, lot):
record.append("sleep {}".format(i))
await lot.park()
record.append("wake {}".format(i))
async with _core.open_nursery() as nursery:
lot = ParkingLot()
assert not lot
assert len(lot) == 0
assert lot.statistics().tasks_waiting == 0
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
assert len(record) == 3
assert bool(lot)
assert len(lot) == 3
assert lot.statistics().tasks_waiting == 3
lot.unpark_all()
assert lot.statistics().tasks_waiting == 0
await wait_all_tasks_blocked()
assert len(record) == 6
check_sequence_matches(
record, [{"sleep 0", "sleep 1", "sleep 2"}, {"wake 0", "wake 1", "wake 2"}]
)
async with _core.open_nursery() as nursery:
record = []
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
assert len(record) == 3
for i in range(3):
lot.unpark()
await wait_all_tasks_blocked()
# 1-by-1 wakeups are strict FIFO
assert record == [
"sleep 0",
"sleep 1",
"sleep 2",
"wake 0",
"wake 1",
"wake 2",
]
# It's legal (but a no-op) to try and unpark while there's nothing parked
lot.unpark()
lot.unpark(count=1)
lot.unpark(count=100)
# Check unpark with count
async with _core.open_nursery() as nursery:
record = []
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
lot.unpark(count=2)
await wait_all_tasks_blocked()
check_sequence_matches(
record, ["sleep 0", "sleep 1", "sleep 2", {"wake 0", "wake 1"}]
)
lot.unpark_all()
async def cancellable_waiter(name, lot, scopes, record):
with _core.CancelScope() as scope:
scopes[name] = scope
record.append("sleep {}".format(name))
try:
await lot.park()
except _core.Cancelled:
record.append("cancelled {}".format(name))
else:
record.append("wake {}".format(name))
async def test_parking_lot_cancel():
record = []
scopes = {}
async with _core.open_nursery() as nursery:
lot = ParkingLot()
nursery.start_soon(cancellable_waiter, 1, lot, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
scopes[2].cancel()
await wait_all_tasks_blocked()
assert len(record) == 4
lot.unpark_all()
await wait_all_tasks_blocked()
assert len(record) == 6
check_sequence_matches(
record, ["sleep 1", "sleep 2", "sleep 3", "cancelled 2", {"wake 1", "wake 3"}]
)
async def test_parking_lot_repark():
record = []
scopes = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
with pytest.raises(TypeError):
lot1.repark([])
async with _core.open_nursery() as nursery:
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
assert len(lot1) == 3
lot1.repark(lot2)
assert len(lot1) == 2
assert len(lot2) == 1
lot2.unpark_all()
await wait_all_tasks_blocked()
assert len(record) == 4
assert record == ["sleep 1", "sleep 2", "sleep 3", "wake 1"]
lot1.repark_all(lot2)
assert len(lot1) == 0
assert len(lot2) == 2
scopes[2].cancel()
await wait_all_tasks_blocked()
assert len(lot2) == 1
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"cancelled 2",
]
lot2.unpark_all()
await wait_all_tasks_blocked()
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"cancelled 2",
"wake 3",
]
async def test_parking_lot_repark_with_count():
record = []
scopes = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
async with _core.open_nursery() as nursery:
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
assert len(lot1) == 3
assert len(lot2) == 0
lot1.repark(lot2, count=2)
assert len(lot1) == 1
assert len(lot2) == 2
while lot2:
lot2.unpark()
await wait_all_tasks_blocked()
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"wake 2",
]
lot1.unpark_all()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,164 @@
import pytest
import threading
from queue import Queue
import time
import sys
from contextlib import contextmanager
from .tutil import slow, gc_collect_harder, disable_threading_excepthook
from .. import _thread_cache
from .._thread_cache import start_thread_soon, ThreadCache
def test_thread_cache_basics():
q = Queue()
def fn():
raise RuntimeError("hi")
def deliver(outcome):
q.put(outcome)
start_thread_soon(fn, deliver)
outcome = q.get()
with pytest.raises(RuntimeError, match="hi"):
outcome.unwrap()
def test_thread_cache_deref():
res = [False]
class del_me:
def __call__(self):
return 42
def __del__(self):
res[0] = True
q = Queue()
def deliver(outcome):
q.put(outcome)
start_thread_soon(del_me(), deliver)
outcome = q.get()
assert outcome.unwrap() == 42
gc_collect_harder()
assert res[0]
@slow
def test_spawning_new_thread_from_deliver_reuses_starting_thread():
# We know that no-one else is using the thread cache, so if we keep
# submitting new jobs the instant the previous one is finished, we should
# keep getting the same thread over and over. This tests both that the
# thread cache is LIFO, and that threads can be assigned new work *before*
# deliver exits.
# Make sure there are a few threads running, so if we weren't LIFO then we
# could grab the wrong one.
q = Queue()
COUNT = 5
for _ in range(COUNT):
start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result))
for _ in range(COUNT):
q.get().unwrap()
seen_threads = set()
done = threading.Event()
def deliver(n, _):
print(n)
seen_threads.add(threading.current_thread())
if n == 0:
done.set()
else:
start_thread_soon(lambda: None, lambda _: deliver(n - 1, _))
start_thread_soon(lambda: None, lambda _: deliver(5, _))
done.wait()
assert len(seen_threads) == 1
@slow
def test_idle_threads_exit(monkeypatch):
# Temporarily set the idle timeout to something tiny, to speed up the
# test. (But non-zero, so that the worker loop will at least yield the
# CPU.)
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
q = Queue()
start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread()))
seen_thread = q.get()
# Since the idle timeout is 0, after sleeping for 1 second, the thread
# should have exited
time.sleep(1)
assert not seen_thread.is_alive()
@contextmanager
def _join_started_threads():
before = frozenset(threading.enumerate())
try:
yield
finally:
for thread in threading.enumerate():
if thread not in before:
thread.join()
def test_race_between_idle_exit_and_job_assignment(monkeypatch):
# This is a lock where the first few times you try to acquire it with a
# timeout, it waits until the lock is available and then pretends to time
# out. Using this in our thread cache implementation causes the following
# sequence:
#
# 1. start_thread_soon grabs the worker thread, assigns it a job, and
# releases its lock.
# 2. The worker thread wakes up (because the lock has been released), but
# the JankyLock lies to it and tells it that the lock timed out. So the
# worker thread tries to exit.
# 3. The worker thread checks for the race between exiting and being
# assigned a job, and discovers that it *is* in the process of being
# assigned a job, so it loops around and tries to acquire the lock
# again.
# 4. Eventually the JankyLock admits that the lock is available, and
# everything proceeds as normal.
class JankyLock:
def __init__(self):
self._lock = threading.Lock()
self._counter = 3
def acquire(self, timeout=None):
self._lock.acquire()
if timeout is None:
return True
else:
if self._counter > 0:
self._counter -= 1
self._lock.release()
return False
return True
def release(self):
self._lock.release()
monkeypatch.setattr(_thread_cache, "Lock", JankyLock)
with disable_threading_excepthook(), _join_started_threads():
tc = ThreadCache()
done = threading.Event()
tc.start_thread_soon(lambda: None, lambda _: done.set())
done.wait()
# Let's kill the thread we started, so it doesn't hang around until the
# test suite finishes. Doesn't really do any harm, but it can be confusing
# to see it in debug output. This is hacky, and leaves our ThreadCache
# object in an inconsistent state... but it doesn't matter, because we're
# not going to use it again anyway.
tc.start_thread_soon(lambda: None, lambda _: sys.exit())

View File

@@ -0,0 +1,13 @@
import pytest
from .tutil import check_sequence_matches
def test_check_sequence_matches():
check_sequence_matches([1, 2, 3], [1, 2, 3])
with pytest.raises(AssertionError):
check_sequence_matches([1, 3, 2], [1, 2, 3])
check_sequence_matches([1, 2, 3, 4], [1, {2, 3}, 4])
check_sequence_matches([1, 3, 2, 4], [1, {2, 3}, 4])
with pytest.raises(AssertionError):
check_sequence_matches([1, 2, 4, 3], [1, {2, 3}, 4])

View File

@@ -0,0 +1,152 @@
import itertools
import pytest
from ... import _core
from ...testing import assert_checkpoints, wait_all_tasks_blocked
pytestmark = pytest.mark.filterwarnings(
"ignore:.*UnboundedQueue:trio.TrioDeprecationWarning"
)
async def test_UnboundedQueue_basic():
q = _core.UnboundedQueue()
q.put_nowait("hi")
assert await q.get_batch() == ["hi"]
with pytest.raises(_core.WouldBlock):
q.get_batch_nowait()
q.put_nowait(1)
q.put_nowait(2)
q.put_nowait(3)
assert q.get_batch_nowait() == [1, 2, 3]
assert q.empty()
assert q.qsize() == 0
q.put_nowait(None)
assert not q.empty()
assert q.qsize() == 1
stats = q.statistics()
assert stats.qsize == 1
assert stats.tasks_waiting == 0
# smoke test
repr(q)
async def test_UnboundedQueue_blocking():
record = []
q = _core.UnboundedQueue()
async def get_batch_consumer():
while True:
batch = await q.get_batch()
assert batch
record.append(batch)
async def aiter_consumer():
async for batch in q:
assert batch
record.append(batch)
for consumer in (get_batch_consumer, aiter_consumer):
record.clear()
async with _core.open_nursery() as nursery:
nursery.start_soon(consumer)
await _core.wait_all_tasks_blocked()
stats = q.statistics()
assert stats.qsize == 0
assert stats.tasks_waiting == 1
q.put_nowait(10)
q.put_nowait(11)
await _core.wait_all_tasks_blocked()
q.put_nowait(12)
await _core.wait_all_tasks_blocked()
assert record == [[10, 11], [12]]
nursery.cancel_scope.cancel()
async def test_UnboundedQueue_fairness():
q = _core.UnboundedQueue()
# If there's no-one else around, we can put stuff in and take it out
# again, no problem
q.put_nowait(1)
q.put_nowait(2)
assert q.get_batch_nowait() == [1, 2]
result = None
async def get_batch(q):
nonlocal result
result = await q.get_batch()
# But if someone else is waiting to read, then they get dibs
async with _core.open_nursery() as nursery:
nursery.start_soon(get_batch, q)
await _core.wait_all_tasks_blocked()
q.put_nowait(3)
q.put_nowait(4)
with pytest.raises(_core.WouldBlock):
q.get_batch_nowait()
assert result == [3, 4]
# If two tasks are trying to read, they alternate
record = []
async def reader(name):
while True:
record.append((name, await q.get_batch()))
async with _core.open_nursery() as nursery:
nursery.start_soon(reader, "a")
await _core.wait_all_tasks_blocked()
nursery.start_soon(reader, "b")
await _core.wait_all_tasks_blocked()
for i in range(20):
q.put_nowait(i)
await _core.wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == list(zip(itertools.cycle("ab"), [[i] for i in range(20)]))
async def test_UnboundedQueue_trivial_yields():
q = _core.UnboundedQueue()
q.put_nowait(None)
with assert_checkpoints():
await q.get_batch()
q.put_nowait(None)
with assert_checkpoints():
async for _ in q: # noqa # pragma: no branch
break
async def test_UnboundedQueue_no_spurious_wakeups():
# If we have two tasks waiting, and put two items into the queue... then
# only one task wakes up
record = []
async def getter(q, i):
got = await q.get_batch()
record.append((i, got))
async with _core.open_nursery() as nursery:
q = _core.UnboundedQueue()
nursery.start_soon(getter, q, 1)
await wait_all_tasks_blocked()
nursery.start_soon(getter, q, 2)
await wait_all_tasks_blocked()
for i in range(10):
q.put_nowait(i)
await wait_all_tasks_blocked()
assert record == [(1, list(range(10)))]
nursery.cancel_scope.cancel()

View File

@@ -0,0 +1 @@
import pytest

View File

@@ -0,0 +1,218 @@
import os
import tempfile
from contextlib import contextmanager
import pytest
on_windows = os.name == "nt"
# Mark all the tests in this file as being windows-only
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
from .tutil import slow, gc_collect_harder, restore_unraisablehook
from ... import _core, sleep, move_on_after
from ...testing import wait_all_tasks_blocked
if on_windows:
from .._windows_cffi import (
ffi,
kernel32,
INVALID_HANDLE_VALUE,
raise_winerror,
FileFlags,
)
# The undocumented API that this is testing should be changed to stop using
# UnboundedQueue (or just removed until we have time to redo it), but until
# then we filter out the warning.
@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning")
async def test_completion_key_listen():
async def post(key):
iocp = ffi.cast("HANDLE", _core.current_iocp())
for i in range(10):
print("post", i)
if i % 3 == 0:
await _core.checkpoint()
success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL)
assert success
with _core.monitor_completion_key() as (key, queue):
async with _core.open_nursery() as nursery:
nursery.start_soon(post, key)
i = 0
print("loop")
async for batch in queue: # pragma: no branch
print("got some", batch)
for info in batch:
assert info.lpOverlapped == 0
assert info.dwNumberOfBytesTransferred == i
i += 1
if i == 10:
break
print("end loop")
async def test_readinto_overlapped():
data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024
buffer = bytearray(len(data))
with tempfile.TemporaryDirectory() as tdir:
tfile = os.path.join(tdir, "numbers.txt")
with open(tfile, "wb") as fp:
fp.write(data)
fp.flush()
rawname = tfile.encode("utf-16le") + b"\0\0"
rawname_buf = ffi.from_buffer(rawname)
handle = kernel32.CreateFileW(
ffi.cast("LPCWSTR", rawname_buf),
FileFlags.GENERIC_READ,
FileFlags.FILE_SHARE_READ,
ffi.NULL, # no security attributes
FileFlags.OPEN_EXISTING,
FileFlags.FILE_FLAG_OVERLAPPED,
ffi.NULL, # no template file
)
if handle == INVALID_HANDLE_VALUE: # pragma: no cover
raise_winerror()
try:
with memoryview(buffer) as buffer_view:
async def read_region(start, end):
await _core.readinto_overlapped(
handle, buffer_view[start:end], start
)
_core.register_with_iocp(handle)
async with _core.open_nursery() as nursery:
for start in range(0, 4096, 512):
nursery.start_soon(read_region, start, start + 512)
assert buffer == data
with pytest.raises(BufferError):
await _core.readinto_overlapped(handle, b"immutable")
finally:
kernel32.CloseHandle(handle)
@contextmanager
def pipe_with_overlapped_read():
from asyncio.windows_utils import pipe
import msvcrt
read_handle, write_handle = pipe(overlapped=(True, False))
try:
write_fd = msvcrt.open_osfhandle(write_handle, 0)
yield os.fdopen(write_fd, "wb", closefd=False), read_handle
finally:
kernel32.CloseHandle(ffi.cast("HANDLE", read_handle))
kernel32.CloseHandle(ffi.cast("HANDLE", write_handle))
@restore_unraisablehook()
def test_forgot_to_register_with_iocp():
with pipe_with_overlapped_read() as (write_fp, read_handle):
with write_fp:
write_fp.write(b"test\n")
left_run_yet = False
async def main():
target = bytearray(1)
try:
async with _core.open_nursery() as nursery:
nursery.start_soon(
_core.readinto_overlapped, read_handle, target, name="xyz"
)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
finally:
# Run loop is exited without unwinding running tasks, so
# we don't get here until the main() coroutine is GC'ed
assert left_run_yet
with pytest.raises(_core.TrioInternalError) as exc_info:
_core.run(main)
left_run_yet = True
assert "Failed to cancel overlapped I/O in xyz " in str(exc_info.value)
assert "forget to call register_with_iocp()?" in str(exc_info.value)
# Make sure the Nursery.__del__ assertion about dangling children
# gets put with the correct test
del exc_info
gc_collect_harder()
@slow
async def test_too_late_to_cancel():
import time
with pipe_with_overlapped_read() as (write_fp, read_handle):
_core.register_with_iocp(read_handle)
target = bytearray(6)
async with _core.open_nursery() as nursery:
# Start an async read in the background
nursery.start_soon(_core.readinto_overlapped, read_handle, target)
await wait_all_tasks_blocked()
# Synchronous write to the other end of the pipe
with write_fp:
write_fp.write(b"test1\ntest2\n")
# Note: not trio.sleep! We're making sure the OS level
# ReadFile completes, before Trio has a chance to execute
# another checkpoint and notice it completed.
time.sleep(1)
nursery.cancel_scope.cancel()
assert target[:6] == b"test1\n"
# Do another I/O to make sure we've actually processed the
# fallback completion that was posted when CancelIoEx failed.
assert await _core.readinto_overlapped(read_handle, target) == 6
assert target[:6] == b"test2\n"
def test_lsp_that_hooks_select_gives_good_error(monkeypatch):
from .._windows_cffi import WSAIoctls, _handle
from .. import _io_windows
def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE):
if hasattr(sock, "fileno"): # pragma: no branch
sock = sock.fileno()
if which == WSAIoctls.SIO_BSP_HANDLE_SELECT:
return _handle(sock + 1)
else:
return _handle(sock)
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
with pytest.raises(
RuntimeError, match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ"
):
_core.run(sleep, 0)
def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch):
# This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns
# self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns
# self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to
# make sure we get an error rather than an infinite loop.
from .._windows_cffi import WSAIoctls, _handle
from .. import _io_windows
def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE):
if hasattr(sock, "fileno"): # pragma: no branch
sock = sock.fileno()
if which == WSAIoctls.SIO_BASE_HANDLE:
raise OSError("nope")
else:
return _handle(sock)
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
with pytest.raises(
RuntimeError,
match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff",
):
_core.run(sleep, 0)

View File

@@ -0,0 +1,146 @@
# Utilities for testing
import asyncio
import socket as stdlib_socket
import threading
import os
import sys
from typing import TYPE_CHECKING
import pytest
import warnings
from contextlib import contextmanager, closing
import gc
# See trio/tests/conftest.py for the other half of this
from trio.tests.conftest import RUN_SLOW
slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests")
# PyPy 7.2 was released with a bug that just never called the async
# generator 'firstiter' hook at all. This impacts tests of end-of-run
# finalization (nothing gets added to runner.asyncgens) and tests of
# "foreign" async generator behavior (since the firstiter hook is what
# marks the asyncgen as foreign), but most tests of GC-mediated
# finalization still work.
buggy_pypy_asyncgens = (
not TYPE_CHECKING
and sys.implementation.name == "pypy"
and sys.pypy_version_info < (7, 3)
)
try:
s = stdlib_socket.socket(stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0)
except OSError: # pragma: no cover
# Some systems don't even support creating an IPv6 socket, let alone
# binding it. (ex: Linux with 'ipv6.disable=1' in the kernel command line)
# We don't have any of those in our CI, and there's nothing that gets
# tested _only_ if can_create_ipv6 = False, so we'll just no-cover this.
can_create_ipv6 = False
can_bind_ipv6 = False
else:
can_create_ipv6 = True
with s:
try:
s.bind(("::1", 0))
except OSError:
can_bind_ipv6 = False
else:
can_bind_ipv6 = True
creates_ipv6 = pytest.mark.skipif(not can_create_ipv6, reason="need IPv6")
binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6")
def gc_collect_harder():
# In the test suite we sometimes want to call gc.collect() to make sure
# that any objects with noisy __del__ methods (e.g. unawaited coroutines)
# get collected before we continue, so their noise doesn't leak into
# unrelated tests.
#
# On PyPy, coroutine objects (for example) can survive at least 1 round of
# garbage collection, because executing their __del__ method to print the
# warning can cause them to be resurrected. So we call collect a few times
# to make sure.
for _ in range(4):
gc.collect()
# Some of our tests need to leak coroutines, and thus trigger the
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
# manager should be used anywhere this happens to hide those messages, because
# when expected they're clutter.
@contextmanager
def ignore_coroutine_never_awaited_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited")
try:
yield
finally:
# Make sure to trigger any coroutine __del__ methods now, before
# we leave the context manager.
gc_collect_harder()
def _noop(*args, **kwargs):
pass
if sys.version_info >= (3, 8):
@contextmanager
def restore_unraisablehook():
sys.unraisablehook, prev = sys.__unraisablehook__, sys.unraisablehook
try:
yield
finally:
sys.unraisablehook = prev
@contextmanager
def disable_threading_excepthook():
if sys.version_info >= (3, 10):
threading.excepthook, prev = threading.__excepthook__, threading.excepthook
else:
threading.excepthook, prev = _noop, threading.excepthook
try:
yield
finally:
threading.excepthook = prev
else:
@contextmanager
def restore_unraisablehook(): # pragma: no cover
yield
@contextmanager
def disable_threading_excepthook(): # pragma: no cover
yield
# template is like:
# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3]
def check_sequence_matches(seq, template):
i = 0
for pattern in template:
if not isinstance(pattern, set):
pattern = {pattern}
got = set(seq[i : i + len(pattern)])
assert got == pattern
i += len(got)
# https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350
skip_if_fbsd_pipes_broken = pytest.mark.skipif(
sys.platform != "win32" # prevent mypy from complaining about missing uname
and hasattr(os, "uname")
and os.uname().sysname == "FreeBSD"
and os.uname().release[:4] < "12.2",
reason="hangs on FreeBSD 12.1 and earlier, due to FreeBSD bug #246350",
)
def create_asyncio_future_in_new_loop():
with closing(asyncio.new_event_loop()) as loop:
return loop.create_future()

View File

@@ -0,0 +1,133 @@
import sys
from functools import wraps
from types import ModuleType
import warnings
import attr
# We want our warnings to be visible by default (at least for now), but we
# also want it to be possible to override that using the -W switch. AFAICT
# this means we cannot inherit from DeprecationWarning, because the only way
# to make it visible by default then would be to add our own filter at import
# time, but that would override -W switches...
class TrioDeprecationWarning(FutureWarning):
"""Warning emitted if you use deprecated Trio functionality.
As a young project, Trio is currently quite aggressive about deprecating
and/or removing functionality that we realize was a bad idea. If you use
Trio, you should subscribe to `issue #1
<https://github.com/python-trio/trio/issues/1>`__ to get information about
upcoming deprecations and other backwards compatibility breaking changes.
Despite the name, this class currently inherits from
:class:`FutureWarning`, not :class:`DeprecationWarning`, because while
we're in young-and-aggressive mode we want these warnings to be visible by
default. You can hide them by installing a filter or with the ``-W``
switch: see the :mod:`warnings` documentation for details.
"""
def _url_for_issue(issue):
return "https://github.com/python-trio/trio/issues/{}".format(issue)
def _stringify(thing):
if hasattr(thing, "__module__") and hasattr(thing, "__qualname__"):
return "{}.{}".format(thing.__module__, thing.__qualname__)
return str(thing)
def warn_deprecated(thing, version, *, issue, instead, stacklevel=2):
stacklevel += 1
msg = "{} is deprecated since Trio {}".format(_stringify(thing), version)
if instead is None:
msg += " with no replacement"
else:
msg += "; use {} instead".format(_stringify(instead))
if issue is not None:
msg += " ({})".format(_url_for_issue(issue))
warnings.warn(TrioDeprecationWarning(msg), stacklevel=stacklevel)
# @deprecated("0.2.0", issue=..., instead=...)
# def ...
def deprecated(version, *, thing=None, issue, instead):
def do_wrap(fn):
nonlocal thing
@wraps(fn)
def wrapper(*args, **kwargs):
warn_deprecated(thing, version, instead=instead, issue=issue)
return fn(*args, **kwargs)
# If our __module__ or __qualname__ get modified, we want to pick up
# on that, so we read them off the wrapper object instead of the (now
# hidden) fn object
if thing is None:
thing = wrapper
if wrapper.__doc__ is not None:
doc = wrapper.__doc__
doc = doc.rstrip()
doc += "\n\n"
doc += ".. deprecated:: {}\n".format(version)
if instead is not None:
doc += " Use {} instead.\n".format(_stringify(instead))
if issue is not None:
doc += " For details, see `issue #{} <{}>`__.\n".format(
issue, _url_for_issue(issue)
)
doc += "\n"
wrapper.__doc__ = doc
return wrapper
return do_wrap
def deprecated_alias(old_qualname, new_fn, version, *, issue):
@deprecated(version, issue=issue, instead=new_fn)
@wraps(new_fn, assigned=("__module__", "__annotations__"))
def wrapper(*args, **kwargs):
"Deprecated alias."
return new_fn(*args, **kwargs)
wrapper.__qualname__ = old_qualname
wrapper.__name__ = old_qualname.rpartition(".")[-1]
return wrapper
@attr.s(frozen=True)
class DeprecatedAttribute:
_not_set = object()
value = attr.ib()
version = attr.ib()
issue = attr.ib()
instead = attr.ib(default=_not_set)
class _ModuleWithDeprecations(ModuleType):
def __getattr__(self, name):
if name in self.__deprecated_attributes__:
info = self.__deprecated_attributes__[name]
instead = info.instead
if instead is DeprecatedAttribute._not_set:
instead = info.value
thing = "{}.{}".format(self.__name__, name)
warn_deprecated(thing, info.version, issue=info.issue, instead=instead)
return info.value
msg = "module '{}' has no attribute '{}'"
raise AttributeError(msg.format(self.__name__, name))
def enable_attribute_deprecations(module_name):
module = sys.modules[module_name]
module.__class__ = _ModuleWithDeprecations
# Make sure that this is always defined so that
# _ModuleWithDeprecations.__getattr__ can access it without jumping
# through hoops or risking infinite recursion.
module.__deprecated_attributes__ = {}

View File

@@ -0,0 +1,191 @@
from functools import partial
import io
from .abc import AsyncResource
from ._util import async_wraps
import trio
# This list is also in the docs, make sure to keep them in sync
_FILE_SYNC_ATTRS = {
"closed",
"encoding",
"errors",
"fileno",
"isatty",
"newlines",
"readable",
"seekable",
"writable",
# not defined in *IOBase:
"buffer",
"raw",
"line_buffering",
"closefd",
"name",
"mode",
"getvalue",
"getbuffer",
}
# This list is also in the docs, make sure to keep them in sync
_FILE_ASYNC_METHODS = {
"flush",
"read",
"read1",
"readall",
"readinto",
"readline",
"readlines",
"seek",
"tell",
"truncate",
"write",
"writelines",
# not defined in *IOBase:
"readinto1",
"peek",
}
class AsyncIOWrapper(AsyncResource):
"""A generic :class:`~io.IOBase` wrapper that implements the :term:`asynchronous
file object` interface. Wrapped methods that could block are executed in
:meth:`trio.to_thread.run_sync`.
All properties and methods defined in in :mod:`~io` are exposed by this
wrapper, if they exist in the wrapped file object.
"""
def __init__(self, file):
self._wrapped = file
@property
def wrapped(self):
"""object: A reference to the wrapped file object"""
return self._wrapped
def __getattr__(self, name):
if name in _FILE_SYNC_ATTRS:
return getattr(self._wrapped, name)
if name in _FILE_ASYNC_METHODS:
meth = getattr(self._wrapped, name)
@async_wraps(self.__class__, self._wrapped.__class__, name)
async def wrapper(*args, **kwargs):
func = partial(meth, *args, **kwargs)
return await trio.to_thread.run_sync(func)
# cache the generated method
setattr(self, name, wrapper)
return wrapper
raise AttributeError(name)
def __dir__(self):
attrs = set(super().__dir__())
attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a))
attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a))
return attrs
def __aiter__(self):
return self
async def __anext__(self):
line = await self.readline()
if line:
return line
else:
raise StopAsyncIteration
async def detach(self):
"""Like :meth:`io.BufferedIOBase.detach`, but async.
This also re-wraps the result in a new :term:`asynchronous file object`
wrapper.
"""
raw = await trio.to_thread.run_sync(self._wrapped.detach)
return wrap_file(raw)
async def aclose(self):
"""Like :meth:`io.IOBase.close`, but async.
This is also shielded from cancellation; if a cancellation scope is
cancelled, the wrapped file object will still be safely closed.
"""
# ensure the underling file is closed during cancellation
with trio.CancelScope(shield=True):
await trio.to_thread.run_sync(self._wrapped.close)
await trio.lowlevel.checkpoint_if_cancelled()
async def open_file(
file,
mode="r",
buffering=-1,
encoding=None,
errors=None,
newline=None,
closefd=True,
opener=None,
):
"""Asynchronous version of :func:`io.open`.
Returns:
An :term:`asynchronous file object`
Example::
async with await trio.open_file(filename) as f:
async for line in f:
pass
assert f.closed
See also:
:func:`trio.Path.open`
"""
_file = wrap_file(
await trio.to_thread.run_sync(
io.open, file, mode, buffering, encoding, errors, newline, closefd, opener
)
)
return _file
def wrap_file(file):
"""This wraps any file object in a wrapper that provides an asynchronous
file object interface.
Args:
file: a :term:`file object`
Returns:
An :term:`asynchronous file object` that wraps ``file``
Example::
async_file = trio.wrap_file(StringIO('asdf'))
assert await async_file.read() == 'asdf'
"""
def has(attr):
return hasattr(file, attr) and callable(getattr(file, attr))
if not (has("close") and (has("read") or has("write"))):
raise TypeError(
"{} does not implement required duck-file methods: "
"close and (read or write)".format(file)
)
return AsyncIOWrapper(file)

View File

@@ -0,0 +1,107 @@
import attr
import trio
from .abc import HalfCloseableStream
from trio._util import Final
async def aclose_forcefully(resource):
"""Close an async resource or async generator immediately, without
blocking to do any graceful cleanup.
:class:`~trio.abc.AsyncResource` objects guarantee that if their
:meth:`~trio.abc.AsyncResource.aclose` method is cancelled, then they will
still close the resource (albeit in a potentially ungraceful
fashion). :func:`aclose_forcefully` is a convenience function that
exploits this behavior to let you force a resource to be closed without
blocking: it works by calling ``await resource.aclose()`` and then
cancelling it immediately.
Most users won't need this, but it may be useful on cleanup paths where
you can't afford to block, or if you want to close a resource and don't
care about handling it gracefully. For example, if
:class:`~trio.SSLStream` encounters an error and cannot perform its
own graceful close, then there's no point in waiting to gracefully shut
down the underlying transport either, so it calls ``await
aclose_forcefully(self.transport_stream)``.
Note that this function is async, and that it acts as a checkpoint, but
unlike most async functions it cannot block indefinitely (at least,
assuming the underlying resource object is correctly implemented).
"""
with trio.CancelScope() as cs:
cs.cancel()
await resource.aclose()
@attr.s(eq=False, hash=False)
class StapledStream(HalfCloseableStream, metaclass=Final):
"""This class `staples <https://en.wikipedia.org/wiki/Staple_(fastener)>`__
together two unidirectional streams to make single bidirectional stream.
Args:
send_stream (~trio.abc.SendStream): The stream to use for sending.
receive_stream (~trio.abc.ReceiveStream): The stream to use for
receiving.
Example:
A silly way to make a stream that echoes back whatever you write to
it::
left, right = trio.testing.memory_stream_pair()
echo_stream = StapledStream(SocketStream(left), SocketStream(right))
await echo_stream.send_all(b"x")
assert await echo_stream.receive_some() == b"x"
:class:`StapledStream` objects implement the methods in the
:class:`~trio.abc.HalfCloseableStream` interface. They also have two
additional public attributes:
.. attribute:: send_stream
The underlying :class:`~trio.abc.SendStream`. :meth:`send_all` and
:meth:`wait_send_all_might_not_block` are delegated to this object.
.. attribute:: receive_stream
The underlying :class:`~trio.abc.ReceiveStream`. :meth:`receive_some`
is delegated to this object.
"""
send_stream = attr.ib()
receive_stream = attr.ib()
async def send_all(self, data):
"""Calls ``self.send_stream.send_all``."""
return await self.send_stream.send_all(data)
async def wait_send_all_might_not_block(self):
"""Calls ``self.send_stream.wait_send_all_might_not_block``."""
return await self.send_stream.wait_send_all_might_not_block()
async def send_eof(self):
"""Shuts down the send side of the stream.
If ``self.send_stream.send_eof`` exists, then calls it. Otherwise,
calls ``self.send_stream.aclose()``.
"""
if hasattr(self.send_stream, "send_eof"):
return await self.send_stream.send_eof()
else:
return await self.send_stream.aclose()
async def receive_some(self, max_bytes=None):
"""Calls ``self.receive_stream.receive_some``."""
return await self.receive_stream.receive_some(max_bytes)
async def aclose(self):
"""Calls ``aclose`` on both underlying streams."""
try:
await self.send_stream.aclose()
finally:
await self.receive_stream.aclose()

View File

@@ -0,0 +1,221 @@
import errno
import sys
from math import inf
import trio
from . import socket as tsocket
# Default backlog size:
#
# Having the backlog too low can cause practical problems (a perfectly healthy
# service that starts failing to accept connections if they arrive in a
# burst).
#
# Having it too high doesn't really cause any problems. Like any buffer, you
# want backlog queue to be zero usually, and it won't save you if you're
# getting connection attempts faster than you can call accept() on an ongoing
# basis. But unlike other buffers, this one doesn't really provide any
# backpressure. If a connection gets stuck waiting in the backlog queue, then
# from the peer's point of view the connection succeeded but then their
# send/recv will stall until we get to it, possibly for a long time. OTOH if
# there isn't room in the backlog queue... then their connect stalls, possibly
# for a long time, which is pretty much the same thing.
#
# A large backlog can also use a bit more kernel memory, but this seems fairly
# negligible these days.
#
# So this suggests we should make the backlog as large as possible. This also
# matches what Golang does. However, they do it in a weird way, where they
# have a bunch of code to sniff out the configured upper limit for backlog on
# different operating systems. But on every system, passing in a too-large
# backlog just causes it to be silently truncated to the configured maximum,
# so this is unnecessary -- we can just pass in "infinity" and get the maximum
# that way. (Verified on Windows, Linux, macOS using
# notes-to-self/measure-listen-backlog.py)
def _compute_backlog(backlog):
if backlog is None:
backlog = inf
# Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are
# missing overflow protection, so we apply our own overflow protection.
# https://github.com/golang/go/issues/5030
return min(backlog, 0xFFFF)
async def open_tcp_listeners(port, *, host=None, backlog=None):
"""Create :class:`SocketListener` objects to listen for TCP connections.
Args:
port (int): The port to listen on.
If you use 0 as your port, then the kernel will automatically pick
an arbitrary open port. But be careful: if you use this feature when
binding to multiple IP addresses, then each IP address will get its
own random port, and the returned listeners will probably be
listening on different ports. In particular, this will happen if you
use ``host=None`` which is the default because in this case
:func:`open_tcp_listeners` will bind to both the IPv4 wildcard
address (``0.0.0.0``) and also the IPv6 wildcard address (``::``).
host (str, bytes-like, or None): The local interface to bind to. This is
passed to :func:`~socket.getaddrinfo` with the ``AI_PASSIVE`` flag
set.
If you want to bind to the wildcard address on both IPv4 and IPv6,
in order to accept connections on all available interfaces, then
pass ``None``. This is the default.
If you have a specific interface you want to bind to, pass its IP
address or hostname here. If a hostname resolves to multiple IP
addresses, this function will open one listener on each of them.
If you want to use only IPv4, or only IPv6, but want to accept on
all interfaces, pass the family-specific wildcard address:
``"0.0.0.0"`` for IPv4-only and ``"::"`` for IPv6-only.
backlog (int or None): The listen backlog to use. If you leave this as
``None`` then Trio will pick a good default. (Currently: whatever
your system has configured as the maximum backlog.)
Returns:
list of :class:`SocketListener`
"""
# getaddrinfo sometimes allows port=None, sometimes not (depending on
# whether host=None). And on some systems it treats "" as 0, others it
# doesn't:
# http://klickverbot.at/blog/2012/01/getaddrinfo-edge-case-behavior-on-windows-linux-and-osx/
if not isinstance(port, int):
raise TypeError("port must be an int not {!r}".format(port))
backlog = _compute_backlog(backlog)
addresses = await tsocket.getaddrinfo(
host, port, type=tsocket.SOCK_STREAM, flags=tsocket.AI_PASSIVE
)
listeners = []
unsupported_address_families = []
try:
for family, type, proto, _, sockaddr in addresses:
try:
sock = tsocket.socket(family, type, proto)
except OSError as ex:
if ex.errno == errno.EAFNOSUPPORT:
# If a system only supports IPv4, or only IPv6, it
# is still likely that getaddrinfo will return
# both an IPv4 and an IPv6 address. As long as at
# least one of the returned addresses can be
# turned into a socket, we won't complain about a
# failure to create the other.
unsupported_address_families.append(ex)
continue
else:
raise
try:
# See https://github.com/python-trio/trio/issues/39
if sys.platform != "win32":
sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_REUSEADDR, 1)
if family == tsocket.AF_INET6:
sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, 1)
await sock.bind(sockaddr)
sock.listen(backlog)
listeners.append(trio.SocketListener(sock))
except:
sock.close()
raise
except:
for listener in listeners:
listener.socket.close()
raise
if unsupported_address_families and not listeners:
raise OSError(
errno.EAFNOSUPPORT,
"This system doesn't support any of the kinds of "
"socket that that address could use",
) from trio.MultiError(unsupported_address_families)
return listeners
async def serve_tcp(
handler,
port,
*,
host=None,
backlog=None,
handler_nursery=None,
task_status=trio.TASK_STATUS_IGNORED,
):
"""Listen for incoming TCP connections, and for each one start a task
running ``handler(stream)``.
This is a thin convenience wrapper around :func:`open_tcp_listeners` and
:func:`serve_listeners` see them for full details.
.. warning::
If ``handler`` raises an exception, then this function doesn't do
anything special to catch it so by default the exception will
propagate out and crash your server. If you don't want this, then catch
exceptions inside your ``handler``, or use a ``handler_nursery`` object
that responds to exceptions in some other way.
When used with ``nursery.start`` you get back the newly opened listeners.
So, for example, if you want to start a server in your test suite and then
connect to it to check that it's working properly, you can use something
like::
from trio.testing import open_stream_to_socket_listener
async with trio.open_nursery() as nursery:
listeners = await nursery.start(serve_tcp, handler, 0)
client_stream = await open_stream_to_socket_listener(listeners[0])
# Then send and receive data on 'client_stream', for example:
await client_stream.send_all(b"GET / HTTP/1.0\\r\\n\\r\\n")
This avoids several common pitfalls:
1. It lets the kernel pick a random open port, so your test suite doesn't
depend on any particular port being open.
2. It waits for the server to be accepting connections on that port before
``start`` returns, so there's no race condition where the incoming
connection arrives before the server is ready.
3. It uses the Listener object to find out which port was picked, so it
can connect to the right place.
Args:
handler: The handler to start for each incoming connection. Passed to
:func:`serve_listeners`.
port: The port to listen on. Use 0 to let the kernel pick an open port.
Passed to :func:`open_tcp_listeners`.
host (str, bytes, or None): The host interface to listen on; use
``None`` to bind to the wildcard address. Passed to
:func:`open_tcp_listeners`.
backlog: The listen backlog, or None to have a good default picked.
Passed to :func:`open_tcp_listeners`.
handler_nursery: The nursery to start handlers in, or None to use an
internal nursery. Passed to :func:`serve_listeners`.
task_status: This function can be used with ``nursery.start``.
Returns:
This function only returns when cancelled.
"""
listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog)
await trio.serve_listeners(
handler, listeners, handler_nursery=handler_nursery, task_status=task_status
)

View File

@@ -0,0 +1,371 @@
from contextlib import contextmanager
import trio
from trio.socket import getaddrinfo, SOCK_STREAM, socket
# Implementation of RFC 6555 "Happy eyeballs"
# https://tools.ietf.org/html/rfc6555
#
# Basically, the problem here is that if we want to connect to some host, and
# DNS returns multiple IP addresses, then we don't know which of them will
# actually work -- it can happen that some of them are reachable, and some of
# them are not. One particularly common situation where this happens is on a
# host that thinks it has ipv6 connectivity, but really doesn't. But in
# principle this could happen for any kind of multi-home situation (e.g. the
# route to one mirror is down but another is up).
#
# The naive algorithm (e.g. the stdlib's socket.create_connection) would be to
# pick one of the IP addresses and try to connect; if that fails, try the
# next; etc. The problem with this is that TCP is stubborn, and if the first
# address is a blackhole then it might take a very long time (tens of seconds)
# before that connection attempt fails.
#
# That's where RFC 6555 comes in. It tells us that what we do is:
# - get the list of IPs from getaddrinfo, trusting the order it gives us (with
# one exception noted in section 5.4)
# - start a connection attempt to the first IP
# - when this fails OR if it's still going after DELAY seconds, then start a
# connection attempt to the second IP
# - when this fails OR if it's still going after another DELAY seconds, then
# start a connection attempt to the third IP
# - ... repeat until we run out of IPs.
#
# Our implementation is similarly straightforward: we spawn a chain of tasks,
# where each one (a) waits until the previous connection has failed or DELAY
# seconds have passed, (b) spawns the next task, (c) attempts to connect. As
# soon as any task crashes or succeeds, we cancel all the tasks and return.
#
# Note: this currently doesn't attempt to cache any results, so if you make
# multiple connections to the same host it'll re-run the happy-eyeballs
# algorithm each time. RFC 6555 is pretty confusing about whether this is
# allowed. Section 4 describes an algorithm that attempts ipv4 and ipv6
# simultaneously, and then says "The client MUST cache information regarding
# the outcome of each connection attempt, and it uses that information to
# avoid thrashing the network with subsequent attempts." Then section 4.2 says
# "implementations MUST prefer the first IP address family returned by the
# host's address preference policy, unless implementing a stateful
# algorithm". Here "stateful" means "one that caches information about
# previous attempts". So my reading of this is that IF you're starting ipv4
# and ipv6 at the same time then you MUST cache the result for ~ten minutes,
# but IF you're "preferring" one protocol by trying it first (like we are),
# then you don't need to cache.
#
# Caching is quite tricky: to get it right you need to do things like detect
# when the network interfaces are reconfigured, and if you get it wrong then
# connection attempts basically just don't work. So we don't even try.
# "Firefox and Chrome use 300 ms"
# https://tools.ietf.org/html/rfc6555#section-6
# Though
# https://www.researchgate.net/profile/Vaibhav_Bajpai3/publication/304568993_Measuring_the_Effects_of_Happy_Eyeballs/links/5773848e08ae6f328f6c284c/Measuring-the-Effects-of-Happy-Eyeballs.pdf
# claims that Firefox actually uses 0 ms, unless an about:config option is
# toggled and then it uses 250 ms.
DEFAULT_DELAY = 0.250
# How should we call getaddrinfo? In particular, should we use AI_ADDRCONFIG?
#
# The idea of AI_ADDRCONFIG is that it only returns addresses that might
# work. E.g., if getaddrinfo knows that you don't have any IPv6 connectivity,
# then it doesn't return any IPv6 addresses. And this is kinda nice, because
# it means maybe you can skip sending AAAA requests entirely. But in practice,
# it doesn't really work right.
#
# - on Linux/glibc, empirically, the default is to return all addresses, and
# with AI_ADDRCONFIG then it only returns IPv6 addresses if there is at least
# one non-loopback IPv6 address configured... but this can be a link-local
# address, so in practice I guess this is basically always configured if IPv6
# is enabled at all. OTOH if you pass in "::1" as the target address with
# AI_ADDRCONFIG and there's no *external* IPv6 address configured, you get an
# error. So AI_ADDRCONFIG mostly doesn't do anything, even when you would want
# it to, and when it does do something it might break things that would have
# worked.
#
# - on Windows 10, empirically, if no IPv6 address is configured then by
# default they are also suppressed from getaddrinfo (flags=0 and
# flags=AI_ADDRCONFIG seem to do the same thing). If you pass AI_ALL, then you
# get the full list.
# ...except for localhost! getaddrinfo("localhost", "80") gives me ::1, even
# though there's no ipv6 and other queries only return ipv4.
# If you pass in and IPv6 IP address as the target address, then that's always
# returned OK, even with AI_ADDRCONFIG set and no IPv6 configured.
#
# But I guess other versions of windows messed this up, judging from these bug
# reports:
# https://bugs.chromium.org/p/chromium/issues/detail?id=5234
# https://bugs.chromium.org/p/chromium/issues/detail?id=32522#c50
#
# So basically the options are either to use AI_ADDRCONFIG and then add some
# complicated special cases to work around its brokenness, or else don't use
# AI_ADDRCONFIG and accept that sometimes on legacy/misconfigured networks
# we'll waste 300 ms trying to connect to a blackholed destination.
#
# Twisted and Tornado always uses default flags. I think we'll do the same.
@contextmanager
def close_all():
sockets_to_close = set()
try:
yield sockets_to_close
finally:
errs = []
for sock in sockets_to_close:
try:
sock.close()
except BaseException as exc:
errs.append(exc)
if errs:
raise trio.MultiError(errs)
def reorder_for_rfc_6555_section_5_4(targets):
# RFC 6555 section 5.4 says that if getaddrinfo returns multiple address
# families (e.g. IPv4 and IPv6), then you should make sure that your first
# and second attempts use different families:
#
# https://tools.ietf.org/html/rfc6555#section-5.4
#
# This function post-processes the results from getaddrinfo, in-place, to
# satisfy this requirement.
for i in range(1, len(targets)):
if targets[i][0] != targets[0][0]:
# Found the first entry with a different address family; move it
# so that it becomes the second item on the list.
if i != 1:
targets.insert(1, targets.pop(i))
break
def format_host_port(host, port):
host = host.decode("ascii") if isinstance(host, bytes) else host
if ":" in host:
return "[{}]:{}".format(host, port)
else:
return "{}:{}".format(host, port)
# Twisted's HostnameEndpoint has a good set of configurables:
# https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.HostnameEndpoint.html
#
# - per-connection timeout
# this doesn't seem useful -- we let you set a timeout on the whole thing
# using Trio's normal mechanisms, and that seems like enough
# - delay between attempts
# - bind address (but not port!)
# they *don't* support multiple address bindings, like giving the ipv4 and
# ipv6 addresses of the host.
# I think maybe our semantics should be: we accept a list of bind addresses,
# and we bind to the first one that is compatible with the
# connection attempt we want to make, and if none are compatible then we
# don't try to connect to that target.
#
# XX TODO: implement bind address support
#
# Actually, the best option is probably to be explicit: {AF_INET: "...",
# AF_INET6: "..."}
# this might be simpler after
async def open_tcp_stream(
host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None
):
"""Connect to the given host and port over TCP.
If the given ``host`` has multiple IP addresses associated with it, then
we have a problem: which one do we use?
One approach would be to attempt to connect to the first one, and then if
that fails, attempt to connect to the second one ... until we've tried all
of them. But the problem with this is that if the first IP address is
unreachable (for example, because it's an IPv6 address and our network
discards IPv6 packets), then we might end up waiting tens of seconds for
the first connection attempt to timeout before we try the second address.
Another approach would be to attempt to connect to all of the addresses at
the same time, in parallel, and then use whichever connection succeeds
first, abandoning the others. This would be fast, but create a lot of
unnecessary load on the network and the remote server.
This function strikes a balance between these two extremes: it works its
way through the available addresses one at a time, like the first
approach; but, if ``happy_eyeballs_delay`` seconds have passed and it's
still waiting for an attempt to succeed or fail, then it gets impatient
and starts the next connection attempt in parallel. As soon as any one
connection attempt succeeds, all the other attempts are cancelled. This
avoids unnecessary load because most connections will succeed after just
one or two attempts, but if one of the addresses is unreachable then it
doesn't slow us down too much.
This is known as a "happy eyeballs" algorithm, and our particular variant
is modelled after how Chrome connects to webservers; see `RFC 6555
<https://tools.ietf.org/html/rfc6555>`__ for more details.
Args:
host (str or bytes): The host to connect to. Can be an IPv4 address,
IPv6 address, or a hostname.
port (int): The port to connect to.
happy_eyeballs_delay (float): How many seconds to wait for each
connection attempt to succeed or fail before getting impatient and
starting another one in parallel. Set to `math.inf` if you want
to limit to only one connection attempt at a time (like
:func:`socket.create_connection`). Default: 0.25 (250 ms).
local_address (None or str): The local IP address or hostname to use as
the source for outgoing connections. If ``None``, we let the OS pick
the source IP.
This is useful in some exotic networking configurations where your
host has multiple IP addresses, and you want to force the use of a
specific one.
Note that if you pass an IPv4 ``local_address``, then you won't be
able to connect to IPv6 hosts, and vice-versa. If you want to take
advantage of this to force the use of IPv4 or IPv6 without
specifying an exact source address, you can use the IPv4 wildcard
address ``local_address="0.0.0.0"``, or the IPv6 wildcard address
``local_address="::"``.
Returns:
SocketStream: a :class:`~trio.abc.Stream` connected to the given server.
Raises:
OSError: if the connection fails.
See also:
open_ssl_over_tcp_stream
"""
# To keep our public API surface smaller, rule out some cases that
# getaddrinfo will accept in some circumstances, but that act weird or
# have non-portable behavior or are just plain not useful.
# No type check on host though b/c we want to allow bytes-likes.
if host is None:
raise ValueError("host cannot be None")
if not isinstance(port, int):
raise TypeError("port must be int, not {!r}".format(port))
if happy_eyeballs_delay is None:
happy_eyeballs_delay = DEFAULT_DELAY
targets = await getaddrinfo(host, port, type=SOCK_STREAM)
# I don't think this can actually happen -- if there are no results,
# getaddrinfo should have raised OSError instead of returning an empty
# list. But let's be paranoid and handle it anyway:
if not targets:
msg = "no results found for hostname lookup: {}".format(
format_host_port(host, port)
)
raise OSError(msg)
reorder_for_rfc_6555_section_5_4(targets)
# This list records all the connection failures that we ignored.
oserrors = []
# Keeps track of the socket that we're going to complete with,
# need to make sure this isn't automatically closed
winning_socket = None
# Try connecting to the specified address. Possible outcomes:
# - success: record connected socket in winning_socket and cancel
# concurrent attempts
# - failure: record exception in oserrors, set attempt_failed allowing
# the next connection attempt to start early
# code needs to ensure sockets can be closed appropriately in the
# face of crash or cancellation
async def attempt_connect(socket_args, sockaddr, attempt_failed):
nonlocal winning_socket
try:
sock = socket(*socket_args)
open_sockets.add(sock)
if local_address is not None:
# TCP connections are identified by a 4-tuple:
#
# (local IP, local port, remote IP, remote port)
#
# So if a single local IP wants to make multiple connections
# to the same (remote IP, remote port) pair, then those
# connections have to use different local ports, or else TCP
# won't be able to tell them apart. OTOH, if you have multiple
# connections to different remote IP/ports, then those
# connections can share a local port.
#
# Normally, when you call bind(), the kernel will immediately
# assign a specific local port to your socket. At this point
# the kernel doesn't know which (remote IP, remote port)
# you're going to use, so it has to pick a local port that
# *no* other connection is using. That's the only way to
# guarantee that this local port will be usable later when we
# call connect(). (Alternatively, you can set SO_REUSEADDR to
# allow multiple nascent connections to share the same port,
# but then connect() might fail with EADDRNOTAVAIL if we get
# unlucky and our TCP 4-tuple ends up colliding with another
# unrelated connection.)
#
# So calling bind() before connect() works, but it disables
# sharing of local ports. This is inefficient: it makes you
# more likely to run out of local ports.
#
# But on some versions of Linux, we can re-enable sharing of
# local ports by setting a special flag. This flag tells
# bind() to only bind the IP, and not the port. That way,
# connect() is allowed to pick the the port, and it can do a
# better job of it because it knows the remote IP/port.
try:
sock.setsockopt(
trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT, 1
)
except (OSError, AttributeError):
pass
try:
await sock.bind((local_address, 0))
except OSError:
raise OSError(
f"local_address={local_address!r} is incompatible "
f"with remote address {sockaddr}"
)
await sock.connect(sockaddr)
# Success! Save the winning socket and cancel all outstanding
# connection attempts.
winning_socket = sock
nursery.cancel_scope.cancel()
except OSError as exc:
# This connection attempt failed, but the next one might
# succeed. Save the error for later so we can report it if
# everything fails, and tell the next attempt that it should go
# ahead (if it hasn't already).
oserrors.append(exc)
attempt_failed.set()
with close_all() as open_sockets:
# nursery spawns a task for each connection attempt, will be
# cancelled by the task that gets a successful connection
async with trio.open_nursery() as nursery:
for *sa, _, addr in targets:
# create an event to indicate connection failure,
# allowing the next target to be tried early
attempt_failed = trio.Event()
nursery.start_soon(attempt_connect, sa, addr, attempt_failed)
# give this attempt at most this time before moving on
with trio.move_on_after(happy_eyeballs_delay):
await attempt_failed.wait()
# nothing succeeded
if winning_socket is None:
assert len(oserrors) == len(targets)
msg = "all attempts to connect to {} failed".format(
format_host_port(host, port)
)
raise OSError(msg) from trio.MultiError(oserrors)
else:
stream = trio.SocketStream(winning_socket)
open_sockets.remove(winning_socket)
return stream

View File

@@ -0,0 +1,49 @@
import os
from contextlib import contextmanager
import trio
from trio.socket import socket, SOCK_STREAM
try:
from trio.socket import AF_UNIX
has_unix = True
except ImportError:
has_unix = False
@contextmanager
def close_on_error(obj):
try:
yield obj
except:
obj.close()
raise
async def open_unix_socket(filename):
"""Opens a connection to the specified
`Unix domain socket <https://en.wikipedia.org/wiki/Unix_domain_socket>`__.
You must have read/write permission on the specified file to connect.
Args:
filename (str or bytes): The filename to open the connection to.
Returns:
SocketStream: a :class:`~trio.abc.Stream` connected to the given file.
Raises:
OSError: If the socket file could not be connected to.
RuntimeError: If AF_UNIX sockets are not supported.
"""
if not has_unix:
raise RuntimeError("Unix sockets are not supported on this platform")
# much more simplified logic vs tcp sockets - one socket type and only one
# possible location to connect to
sock = socket(AF_UNIX, SOCK_STREAM)
with close_on_error(sock):
await sock.connect(os.fspath(filename))
return trio.SocketStream(sock)

View File

@@ -0,0 +1,121 @@
import errno
import logging
import os
import trio
# Errors that accept(2) can return, and which indicate that the system is
# overloaded
ACCEPT_CAPACITY_ERRNOS = {
errno.EMFILE,
errno.ENFILE,
errno.ENOMEM,
errno.ENOBUFS,
}
# How long to sleep when we get one of those errors
SLEEP_TIME = 0.100
# The logger we use to complain when this happens
LOGGER = logging.getLogger("trio.serve_listeners")
async def _run_handler(stream, handler):
try:
await handler(stream)
finally:
await trio.aclose_forcefully(stream)
async def _serve_one_listener(listener, handler_nursery, handler):
async with listener:
while True:
try:
stream = await listener.accept()
except OSError as exc:
if exc.errno in ACCEPT_CAPACITY_ERRNOS:
LOGGER.error(
"accept returned %s (%s); retrying in %s seconds",
errno.errorcode[exc.errno],
os.strerror(exc.errno),
SLEEP_TIME,
exc_info=True,
)
await trio.sleep(SLEEP_TIME)
else:
raise
else:
handler_nursery.start_soon(_run_handler, stream, handler)
async def serve_listeners(
handler, listeners, *, handler_nursery=None, task_status=trio.TASK_STATUS_IGNORED
):
r"""Listen for incoming connections on ``listeners``, and for each one
start a task running ``handler(stream)``.
.. warning::
If ``handler`` raises an exception, then this function doesn't do
anything special to catch it so by default the exception will
propagate out and crash your server. If you don't want this, then catch
exceptions inside your ``handler``, or use a ``handler_nursery`` object
that responds to exceptions in some other way.
Args:
handler: An async callable, that will be invoked like
``handler_nursery.start_soon(handler, stream)`` for each incoming
connection.
listeners: A list of :class:`~trio.abc.Listener` objects.
:func:`serve_listeners` takes responsibility for closing them.
handler_nursery: The nursery used to start handlers, or any object with
a ``start_soon`` method. If ``None`` (the default), then
:func:`serve_listeners` will create a new nursery internally and use
that.
task_status: This function can be used with ``nursery.start``, which
will return ``listeners``.
Returns:
This function never returns unless cancelled.
Resource handling:
If ``handler`` neglects to close the ``stream``, then it will be closed
using :func:`trio.aclose_forcefully`.
Error handling:
Most errors coming from :meth:`~trio.abc.Listener.accept` are allowed to
propagate out (crashing the server in the process). However, some errors
those which indicate that the server is temporarily overloaded are
handled specially. These are :class:`OSError`\s with one of the following
errnos:
* ``EMFILE``: process is out of file descriptors
* ``ENFILE``: system is out of file descriptors
* ``ENOBUFS``, ``ENOMEM``: the kernel hit some sort of memory limitation
when trying to create a socket object
When :func:`serve_listeners` gets one of these errors, then it:
* Logs the error to the standard library logger ``trio.serve_listeners``
(level = ERROR, with exception information included). By default this
causes it to be printed to stderr.
* Waits 100 ms before calling ``accept`` again, in hopes that the
system will recover.
"""
async with trio.open_nursery() as nursery:
if handler_nursery is None:
handler_nursery = nursery
for listener in listeners:
nursery.start_soon(_serve_one_listener, listener, handler_nursery, handler)
# The listeners are already queueing connections when we're called,
# but we wait until the end to call started() just in case we get an
# error or whatever.
task_status.started(listeners)

View File

@@ -0,0 +1,382 @@
# "High-level" networking interface
import errno
from contextlib import contextmanager
import trio
from . import socket as tsocket
from ._util import ConflictDetector, Final
from .abc import HalfCloseableStream, Listener
# XX TODO: this number was picked arbitrarily. We should do experiments to
# tune it. (Or make it dynamic -- one idea is to start small and increase it
# if we observe single reads filling up the whole buffer, at least within some
# limits.)
DEFAULT_RECEIVE_SIZE = 65536
_closed_stream_errnos = {
# Unix
errno.EBADF,
# Windows
errno.ENOTSOCK,
}
@contextmanager
def _translate_socket_errors_to_stream_errors():
try:
yield
except OSError as exc:
if exc.errno in _closed_stream_errnos:
raise trio.ClosedResourceError("this socket was already closed") from None
else:
raise trio.BrokenResourceError(
"socket connection broken: {}".format(exc)
) from exc
class SocketStream(HalfCloseableStream, metaclass=Final):
"""An implementation of the :class:`trio.abc.HalfCloseableStream`
interface based on a raw network socket.
Args:
socket: The Trio socket object to wrap. Must have type ``SOCK_STREAM``,
and be connected.
By default for TCP sockets, :class:`SocketStream` enables ``TCP_NODELAY``,
and (on platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with
a reasonable buffer size (currently 16 KiB) see `issue #72
<https://github.com/python-trio/trio/issues/72>`__ for discussion. You can
of course override these defaults by calling :meth:`setsockopt`.
Once a :class:`SocketStream` object is constructed, it implements the full
:class:`trio.abc.HalfCloseableStream` interface. In addition, it provides
a few extra features:
.. attribute:: socket
The Trio socket object that this stream wraps.
"""
def __init__(self, socket):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketStream requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
raise ValueError("SocketStream requires a SOCK_STREAM socket")
self.socket = socket
self._send_conflict_detector = ConflictDetector(
"another task is currently sending data on this SocketStream"
)
# Socket defaults:
# Not supported on e.g. unix domain sockets
try:
self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, True)
except OSError:
pass
if hasattr(tsocket, "TCP_NOTSENT_LOWAT"):
try:
# 16 KiB is pretty arbitrary and could probably do with some
# tuning. (Apple is also setting this by default in CFNetwork
# apparently -- I'm curious what value they're using, though I
# couldn't find it online trivially. CFNetwork-129.20 source
# has no mentions of TCP_NOTSENT_LOWAT. This presentation says
# "typically 8 kilobytes":
# http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1
# ). The theory is that you want it to be bandwidth *
# rescheduling interval.
self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2**14)
except OSError:
pass
async def send_all(self, data):
if self.socket.did_shutdown_SHUT_WR:
raise trio.ClosedResourceError("can't send data after sending EOF")
with self._send_conflict_detector:
with _translate_socket_errors_to_stream_errors():
with memoryview(data) as data:
if not data:
if self.socket.fileno() == -1:
raise trio.ClosedResourceError("socket was already closed")
await trio.lowlevel.checkpoint()
return
total_sent = 0
while total_sent < len(data):
with data[total_sent:] as remaining:
sent = await self.socket.send(remaining)
total_sent += sent
async def wait_send_all_might_not_block(self):
with self._send_conflict_detector:
if self.socket.fileno() == -1:
raise trio.ClosedResourceError
with _translate_socket_errors_to_stream_errors():
await self.socket.wait_writable()
async def send_eof(self):
with self._send_conflict_detector:
await trio.lowlevel.checkpoint()
# On macOS, calling shutdown a second time raises ENOTCONN, but
# send_eof needs to be idempotent.
if self.socket.did_shutdown_SHUT_WR:
return
with _translate_socket_errors_to_stream_errors():
self.socket.shutdown(tsocket.SHUT_WR)
async def receive_some(self, max_bytes=None):
if max_bytes is None:
max_bytes = DEFAULT_RECEIVE_SIZE
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
with _translate_socket_errors_to_stream_errors():
return await self.socket.recv(max_bytes)
async def aclose(self):
self.socket.close()
await trio.lowlevel.checkpoint()
# __aenter__, __aexit__ inherited from HalfCloseableStream are OK
def setsockopt(self, level, option, value):
"""Set an option on the underlying socket.
See :meth:`socket.socket.setsockopt` for details.
"""
return self.socket.setsockopt(level, option, value)
def getsockopt(self, level, option, buffersize=0):
"""Check the current value of an option on the underlying socket.
See :meth:`socket.socket.getsockopt` for details.
"""
# This is to work around
# https://bitbucket.org/pypy/pypy/issues/2561
# We should be able to drop it when the next PyPy3 beta is released.
if buffersize == 0:
return self.socket.getsockopt(level, option)
else:
return self.socket.getsockopt(level, option, buffersize)
################################################################
# SocketListener
################################################################
# Accept error handling
# =====================
#
# Literature review
# -----------------
#
# Here's a list of all the possible errors that accept() can return, according
# to the POSIX spec or the Linux, FreeBSD, macOS, and Windows docs:
#
# Can't happen with a Trio socket:
# - EAGAIN/(WSA)EWOULDBLOCK
# - EINTR
# - WSANOTINITIALISED
# - WSAEINPROGRESS: a blocking call is already in progress
# - WSAEINTR: someone called WSACancelBlockingCall, but we don't make blocking
# calls in the first place
#
# Something is wrong with our call:
# - EBADF: not a file descriptor
# - (WSA)EINVAL: socket isn't listening, or (Linux, BSD) bad flags
# - (WSA)ENOTSOCK: not a socket
# - (WSA)EOPNOTSUPP: this kind of socket doesn't support accept
# - (Linux, FreeBSD, Windows) EFAULT: the sockaddr pointer points to readonly
# memory
#
# Something is wrong with the environment:
# - (WSA)EMFILE: this process hit its fd limit
# - ENFILE: the system hit its fd limit
# - (WSA)ENOBUFS, ENOMEM: unspecified memory problems
#
# Something is wrong with the connection we were going to accept. There's a
# ton of variability between systems here:
# - ECONNABORTED: documented everywhere, but apparently only the BSDs do this
# (signals a connection was closed/reset before being accepted)
# - EPROTO: unspecified protocol error
# - (Linux) EPERM: firewall rule prevented connection
# - (Linux) ENETDOWN, EPROTO, ENOPROTOOPT, EHOSTDOWN, ENONET, EHOSTUNREACH,
# EOPNOTSUPP, ENETUNREACH, ENOSR, ESOCKTNOSUPPORT, EPROTONOSUPPORT,
# ETIMEDOUT, ... or any other error that the socket could give, because
# apparently if an error happens on a connection before it's accept()ed,
# Linux will report that error from accept().
# - (Windows) WSAECONNRESET, WSAENETDOWN
#
#
# Code review
# -----------
#
# What do other libraries do?
#
# Twisted on Unix or when using nonblocking I/O on Windows:
# - ignores EPERM, with comment about Linux firewalls
# - logs and ignores EMFILE, ENOBUFS, ENFILE, ENOMEM, ECONNABORTED
# Comment notes that ECONNABORTED is a BSDism and that Linux returns the
# socket before having it fail, and macOS just silently discards it.
# - other errors are raised, which is logged + kills the socket
# ref: src/twisted/internet/tcp.py, Port.doRead
#
# Twisted using IOCP on Windows:
# - logs and ignores all errors
# ref: src/twisted/internet/iocpreactor/tcp.py, Port.handleAccept
#
# Tornado:
# - ignore ECONNABORTED (comments notes that it was observed on FreeBSD)
# - everything else raised, but all this does (by default) is cause it to be
# logged and then ignored
# (ref: tornado/netutil.py, tornado/ioloop.py)
#
# libuv on Unix:
# - ignores ECONNABORTED
# - does a "trick" for EMFILE or ENFILE
# - all other errors passed to the connection_cb to be handled
# (ref: src/unix/stream.c:uv__server_io, uv__emfile_trick)
#
# libuv on Windows:
# src/win/tcp.c:uv_tcp_queue_accept
# this calls AcceptEx, and then arranges to call:
# src/win/tcp.c:uv_process_tcp_accept_req
# this gets the result from AcceptEx. If the original AcceptEx call failed,
# then "we stop accepting connections and report this error to the
# connection callback". I think this is for things like ENOTSOCK. If
# AcceptEx successfully queues an overlapped operation, and then that
# reports an error, it's just discarded.
#
# asyncio, selector mode:
# - ignores EWOULDBLOCK, EINTR, ECONNABORTED
# - on EMFILE, ENFILE, ENOBUFS, ENOMEM, logs an error and then disables the
# listening loop for 1 second
# - everything else raises, but then the event loop just logs and ignores it
# (selector_events.py: BaseSelectorEventLoop._accept_connection)
#
#
# What should we do?
# ------------------
#
# When accept() returns an error, we can either ignore it or raise it.
#
# We have a long list of errors that should be ignored, and a long list of
# errors that should be raised. The big question is what to do with an error
# that isn't on either list. On Linux apparently you can get nearly arbitrary
# errors from accept() and they should be ignored, because it just indicates a
# socket that crashed before it began, and there isn't really anything to be
# done about this, plus on other platforms you may not get any indication at
# all, so programs have to tolerate not getting any indication too. OTOH if we
# get an unexpected error then it could indicate something arbitrarily bad --
# after all, it's unexpected.
#
# Given that we know that other libraries seem to be getting along fine with a
# fairly minimal list of errors to ignore, I think we'll be OK if we write
# down that list and then raise on everything else.
#
# The other question is what to do about the capacity problem errors: EMFILE,
# ENFILE, ENOBUFS, ENOMEM. Just flat out ignoring these is clearly not optimal
# -- at the very least you want to log them, and probably you want to take
# some remedial action. And if we ignore them then it prevents higher levels
# from doing anything clever with them. So we raise them.
_ignorable_accept_errno_names = [
# Linux can do this when the a connection is denied by the firewall
"EPERM",
# BSDs with an early close/reset
"ECONNABORTED",
# All the other miscellany noted above -- may not happen in practice, but
# whatever.
"EPROTO",
"ENETDOWN",
"ENOPROTOOPT",
"EHOSTDOWN",
"ENONET",
"EHOSTUNREACH",
"EOPNOTSUPP",
"ENETUNREACH",
"ENOSR",
"ESOCKTNOSUPPORT",
"EPROTONOSUPPORT",
"ETIMEDOUT",
"ECONNRESET",
]
# Not all errnos are defined on all platforms
_ignorable_accept_errnos = set()
for name in _ignorable_accept_errno_names:
try:
_ignorable_accept_errnos.add(getattr(errno, name))
except AttributeError:
pass
class SocketListener(Listener[SocketStream], metaclass=Final):
"""A :class:`~trio.abc.Listener` that uses a listening socket to accept
incoming connections as :class:`SocketStream` objects.
Args:
socket: The Trio socket object to wrap. Must have type ``SOCK_STREAM``,
and be listening.
Note that the :class:`SocketListener` "takes ownership" of the given
socket; closing the :class:`SocketListener` will also close the socket.
.. attribute:: socket
The Trio socket object that this stream wraps.
"""
def __init__(self, socket):
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketListener requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
raise ValueError("SocketListener requires a SOCK_STREAM socket")
try:
listening = socket.getsockopt(tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN)
except OSError:
# SO_ACCEPTCONN fails on macOS; we just have to trust the user.
pass
else:
if not listening:
raise ValueError("SocketListener requires a listening socket")
self.socket = socket
async def accept(self):
"""Accept an incoming connection.
Returns:
:class:`SocketStream`
Raises:
OSError: if the underlying call to ``accept`` raises an unexpected
error.
ClosedResourceError: if you already closed the socket.
This method handles routine errors like ``ECONNABORTED``, but passes
other errors on to its caller. In particular, it does *not* make any
special effort to handle resource exhaustion errors like ``EMFILE``,
``ENFILE``, ``ENOBUFS``, ``ENOMEM``.
"""
while True:
try:
sock, _ = await self.socket.accept()
except OSError as exc:
if exc.errno in _closed_stream_errnos:
raise trio.ClosedResourceError
if exc.errno not in _ignorable_accept_errnos:
raise
else:
return SocketStream(sock)
async def aclose(self):
"""Close this listener and its underlying socket."""
self.socket.close()
await trio.lowlevel.checkpoint()

View File

@@ -0,0 +1,154 @@
import trio
import ssl
from ._highlevel_open_tcp_stream import DEFAULT_DELAY
# It might have been nice to take a ssl_protocols= argument here to set up
# NPN/ALPN, but to do this we have to mutate the context object, which is OK
# if it's one we created, but not OK if it's one that was passed in... and
# the one major protocol using NPN/ALPN is HTTP/2, which mandates that you use
# a specially configured SSLContext anyway! I also thought maybe we could copy
# the given SSLContext and then mutate the copy, but it's no good as SSLContext
# objects can't be copied: https://bugs.python.org/issue33023.
# So... let's punt on that for now. Hopefully we'll be getting a new Python
# TLS API soon and can revisit this then.
async def open_ssl_over_tcp_stream(
host,
port,
*,
https_compatible=False,
ssl_context=None,
happy_eyeballs_delay=DEFAULT_DELAY,
):
"""Make a TLS-encrypted Connection to the given host and port over TCP.
This is a convenience wrapper that calls :func:`open_tcp_stream` and
wraps the result in an :class:`~trio.SSLStream`.
This function does not perform the TLS handshake; you can do it
manually by calling :meth:`~trio.SSLStream.do_handshake`, or else
it will be performed automatically the first time you send or receive
data.
Args:
host (bytes or str): The host to connect to. We require the server
to have a TLS certificate valid for this hostname.
port (int): The port to connect to.
https_compatible (bool): Set this to True if you're connecting to a web
server. See :class:`~trio.SSLStream` for details. Default:
False.
ssl_context (:class:`~ssl.SSLContext` or None): The SSL context to
use. If None (the default), :func:`ssl.create_default_context`
will be called to create a context.
happy_eyeballs_delay (float): See :func:`open_tcp_stream`.
Returns:
trio.SSLStream: the encrypted connection to the server.
"""
tcp_stream = await trio.open_tcp_stream(
host, port, happy_eyeballs_delay=happy_eyeballs_delay
)
if ssl_context is None:
ssl_context = ssl.create_default_context()
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
return trio.SSLStream(
tcp_stream, ssl_context, server_hostname=host, https_compatible=https_compatible
)
async def open_ssl_over_tcp_listeners(
port, ssl_context, *, host=None, https_compatible=False, backlog=None
):
"""Start listening for SSL/TLS-encrypted TCP connections to the given port.
Args:
port (int): The port to listen on. See :func:`open_tcp_listeners`.
ssl_context (~ssl.SSLContext): The SSL context to use for all incoming
connections.
host (str, bytes, or None): The address to bind to; use ``None`` to bind
to the wildcard address. See :func:`open_tcp_listeners`.
https_compatible (bool): See :class:`~trio.SSLStream` for details.
backlog (int or None): See :func:`open_tcp_listeners` for details.
"""
tcp_listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog)
ssl_listeners = [
trio.SSLListener(tcp_listener, ssl_context, https_compatible=https_compatible)
for tcp_listener in tcp_listeners
]
return ssl_listeners
async def serve_ssl_over_tcp(
handler,
port,
ssl_context,
*,
host=None,
https_compatible=False,
backlog=None,
handler_nursery=None,
task_status=trio.TASK_STATUS_IGNORED,
):
"""Listen for incoming TCP connections, and for each one start a task
running ``handler(stream)``.
This is a thin convenience wrapper around
:func:`open_ssl_over_tcp_listeners` and :func:`serve_listeners` see them
for full details.
.. warning::
If ``handler`` raises an exception, then this function doesn't do
anything special to catch it so by default the exception will
propagate out and crash your server. If you don't want this, then catch
exceptions inside your ``handler``, or use a ``handler_nursery`` object
that responds to exceptions in some other way.
When used with ``nursery.start`` you get back the newly opened listeners.
See the documentation for :func:`serve_tcp` for an example where this is
useful.
Args:
handler: The handler to start for each incoming connection. Passed to
:func:`serve_listeners`.
port (int): The port to listen on. Use 0 to let the kernel pick
an open port. Ultimately passed to :func:`open_tcp_listeners`.
ssl_context (~ssl.SSLContext): The SSL context to use for all incoming
connections. Passed to :func:`open_ssl_over_tcp_listeners`.
host (str, bytes, or None): The address to bind to; use ``None`` to bind
to the wildcard address. Ultimately passed to
:func:`open_tcp_listeners`.
https_compatible (bool): Set this to True if you want to use
"HTTPS-style" TLS. See :class:`~trio.SSLStream` for details.
backlog (int or None): See :class:`~trio.SSLStream` for details.
handler_nursery: The nursery to start handlers in, or None to use an
internal nursery. Passed to :func:`serve_listeners`.
task_status: This function can be used with ``nursery.start``.
Returns:
This function only returns when cancelled.
"""
listeners = await trio.open_ssl_over_tcp_listeners(
port,
ssl_context,
host=host,
https_compatible=https_compatible,
backlog=backlog,
)
await trio.serve_listeners(
handler, listeners, handler_nursery=handler_nursery, task_status=task_status
)

View File

@@ -0,0 +1,206 @@
# type: ignore
from functools import wraps, partial
import os
import types
import pathlib
import trio
from trio._util import async_wraps, Final
# re-wrap return value from methods that return new instances of pathlib.Path
def rewrap_path(value):
if isinstance(value, pathlib.Path):
value = Path(value)
return value
def _forward_factory(cls, attr_name, attr):
@wraps(attr)
def wrapper(self, *args, **kwargs):
attr = getattr(self._wrapped, attr_name)
value = attr(*args, **kwargs)
return rewrap_path(value)
return wrapper
def _forward_magic(cls, attr):
sentinel = object()
@wraps(attr)
def wrapper(self, other=sentinel):
if other is sentinel:
return attr(self._wrapped)
if isinstance(other, cls):
other = other._wrapped
value = attr(self._wrapped, other)
return rewrap_path(value)
return wrapper
def iter_wrapper_factory(cls, meth_name):
@async_wraps(cls, cls._wraps, meth_name)
async def wrapper(self, *args, **kwargs):
meth = getattr(self._wrapped, meth_name)
func = partial(meth, *args, **kwargs)
# Make sure that the full iteration is performed in the thread
# by converting the generator produced by pathlib into a list
items = await trio.to_thread.run_sync(lambda: list(func()))
return (rewrap_path(item) for item in items)
return wrapper
def thread_wrapper_factory(cls, meth_name):
@async_wraps(cls, cls._wraps, meth_name)
async def wrapper(self, *args, **kwargs):
meth = getattr(self._wrapped, meth_name)
func = partial(meth, *args, **kwargs)
value = await trio.to_thread.run_sync(func)
return rewrap_path(value)
return wrapper
def classmethod_wrapper_factory(cls, meth_name):
@classmethod
@async_wraps(cls, cls._wraps, meth_name)
async def wrapper(cls, *args, **kwargs):
meth = getattr(cls._wraps, meth_name)
func = partial(meth, *args, **kwargs)
value = await trio.to_thread.run_sync(func)
return rewrap_path(value)
return wrapper
class AsyncAutoWrapperType(Final):
def __init__(cls, name, bases, attrs):
super().__init__(name, bases, attrs)
cls._forward = []
type(cls).generate_forwards(cls, attrs)
type(cls).generate_wraps(cls, attrs)
type(cls).generate_magic(cls, attrs)
type(cls).generate_iter(cls, attrs)
def generate_forwards(cls, attrs):
# forward functions of _forwards
for attr_name, attr in cls._forwards.__dict__.items():
if attr_name.startswith("_") or attr_name in attrs:
continue
if isinstance(attr, property):
cls._forward.append(attr_name)
elif isinstance(attr, types.FunctionType):
wrapper = _forward_factory(cls, attr_name, attr)
setattr(cls, attr_name, wrapper)
else:
raise TypeError(attr_name, type(attr))
def generate_wraps(cls, attrs):
# generate wrappers for functions of _wraps
for attr_name, attr in cls._wraps.__dict__.items():
# .z. exclude cls._wrap_iter
if attr_name.startswith("_") or attr_name in attrs:
continue
if isinstance(attr, classmethod):
wrapper = classmethod_wrapper_factory(cls, attr_name)
setattr(cls, attr_name, wrapper)
elif isinstance(attr, types.FunctionType):
wrapper = thread_wrapper_factory(cls, attr_name)
setattr(cls, attr_name, wrapper)
else:
raise TypeError(attr_name, type(attr))
def generate_magic(cls, attrs):
# generate wrappers for magic
for attr_name in cls._forward_magic:
attr = getattr(cls._forwards, attr_name)
wrapper = _forward_magic(cls, attr)
setattr(cls, attr_name, wrapper)
def generate_iter(cls, attrs):
# generate wrappers for methods that return iterators
for attr_name, attr in cls._wraps.__dict__.items():
if attr_name in cls._wrap_iter:
wrapper = iter_wrapper_factory(cls, attr_name)
setattr(cls, attr_name, wrapper)
class Path(metaclass=AsyncAutoWrapperType):
"""A :class:`pathlib.Path` wrapper that executes blocking methods in
:meth:`trio.to_thread.run_sync`.
"""
_wraps = pathlib.Path
_forwards = pathlib.PurePath
_forward_magic = [
"__str__",
"__bytes__",
"__truediv__",
"__rtruediv__",
"__eq__",
"__lt__",
"__le__",
"__gt__",
"__ge__",
"__hash__",
]
_wrap_iter = ["glob", "rglob", "iterdir"]
def __init__(self, *args):
self._wrapped = pathlib.Path(*args)
def __getattr__(self, name):
if name in self._forward:
value = getattr(self._wrapped, name)
return rewrap_path(value)
raise AttributeError(name)
def __dir__(self):
return super().__dir__() + self._forward
def __repr__(self):
return "trio.Path({})".format(repr(str(self)))
def __fspath__(self):
return os.fspath(self._wrapped)
@wraps(pathlib.Path.open)
async def open(self, *args, **kwargs):
"""Open the file pointed to by the path, like the :func:`trio.open_file`
function does.
"""
func = partial(self._wrapped.open, *args, **kwargs)
value = await trio.to_thread.run_sync(func)
return trio.wrap_file(value)
Path.iterdir.__doc__ = """
Like :meth:`pathlib.Path.iterdir`, but async.
This is an async method that returns a synchronous iterator, so you
use it like::
for subpath in await mypath.iterdir():
...
Note that it actually loads the whole directory list into memory
immediately, during the initial call. (See `issue #501
<https://github.com/python-trio/trio/issues/501>`__ for discussion.)
"""
# The value of Path.absolute.__doc__ makes a reference to
# :meth:~pathlib.Path.absolute, which does not exist. Removing this makes more
# sense than inventing our own special docstring for this.
del Path.absolute.__doc__
os.PathLike.register(Path)

View File

@@ -0,0 +1,167 @@
import signal
from contextlib import contextmanager
from collections import OrderedDict
import trio
from ._util import signal_raise, is_main_thread, ConflictDetector
# Discussion of signal handling strategies:
#
# - On Windows signals barely exist. There are no options; signal handlers are
# the only available API.
#
# - On Linux signalfd is arguably the natural way. Semantics: signalfd acts as
# an *alternative* signal delivery mechanism. The way you use it is to mask
# out the relevant signals process-wide (so that they don't get delivered
# the normal way), and then when you read from signalfd that actually counts
# as delivering it (despite the mask). The problem with this is that we
# don't have any reliable way to mask out signals process-wide -- the only
# way to do that in Python is to call pthread_sigmask from the main thread
# *before starting any other threads*, and as a library we can't really
# impose that, and the failure mode is annoying (signals get delivered via
# signal handlers whether we want them to or not).
#
# - on macOS/*BSD, kqueue is the natural way. Semantics: kqueue acts as an
# *extra* signal delivery mechanism. Signals are delivered the normal
# way, *and* are delivered to kqueue. So you want to set them to SIG_IGN so
# that they don't end up pending forever (I guess?). I can't find any actual
# docs on how masking and EVFILT_SIGNAL interact. I did see someone note
# that if a signal is pending when the kqueue filter is added then you
# *don't* get notified of that, which makes sense. But still, we have to
# manipulate signal state (e.g. setting SIG_IGN) which as far as Python is
# concerned means we have to do this from the main thread.
#
# So in summary, there don't seem to be any compelling advantages to using the
# platform-native signal notification systems; they're kinda nice, but it's
# simpler to implement the naive signal-handler-based system once and be
# done. (The big advantage would be if there were a reliable way to monitor
# for SIGCHLD from outside the main thread and without interfering with other
# libraries that also want to monitor for SIGCHLD. But there isn't. I guess
# kqueue might give us that, but in kqueue we don't need it, because kqueue
# can directly monitor for child process state changes.)
@contextmanager
def _signal_handler(signals, handler):
original_handlers = {}
try:
for signum in set(signals):
original_handlers[signum] = signal.signal(signum, handler)
yield
finally:
for signum, original_handler in original_handlers.items():
signal.signal(signum, original_handler)
class SignalReceiver:
def __init__(self):
# {signal num: None}
self._pending = OrderedDict()
self._lot = trio.lowlevel.ParkingLot()
self._conflict_detector = ConflictDetector(
"only one task can iterate on a signal receiver at a time"
)
self._closed = False
def _add(self, signum):
if self._closed:
signal_raise(signum)
else:
self._pending[signum] = None
self._lot.unpark()
def _redeliver_remaining(self):
# First make sure that any signals still in the delivery pipeline will
# get redelivered
self._closed = True
# And then redeliver any that are sitting in pending. This is done
# using a weird recursive construct to make sure we process everything
# even if some of the handlers raise exceptions.
def deliver_next():
if self._pending:
signum, _ = self._pending.popitem(last=False)
try:
signal_raise(signum)
finally:
deliver_next()
deliver_next()
# Helper for tests, not public or otherwise used
def _pending_signal_count(self):
return len(self._pending)
def __aiter__(self):
return self
async def __anext__(self):
if self._closed:
raise RuntimeError("open_signal_receiver block already exited")
# In principle it would be possible to support multiple concurrent
# calls to __anext__, but doing it without race conditions is quite
# tricky, and there doesn't seem to be any point in trying.
with self._conflict_detector:
if not self._pending:
await self._lot.park()
else:
await trio.lowlevel.checkpoint()
signum, _ = self._pending.popitem(last=False)
return signum
@contextmanager
def open_signal_receiver(*signals):
"""A context manager for catching signals.
Entering this context manager starts listening for the given signals and
returns an async iterator; exiting the context manager stops listening.
The async iterator blocks until a signal arrives, and then yields it.
Note that if you leave the ``with`` block while the iterator has
unextracted signals still pending inside it, then they will be
re-delivered using Python's regular signal handling logic. This avoids a
race condition when signals arrives just before we exit the ``with``
block.
Args:
signals: the signals to listen for.
Raises:
TypeError: if no signals were provided.
RuntimeError: if you try to use this anywhere except Python's main
thread. (This is a Python limitation.)
Example:
A common convention for Unix daemons is that they should reload their
configuration when they receive a ``SIGHUP``. Here's a sketch of what
that might look like using :func:`open_signal_receiver`::
with trio.open_signal_receiver(signal.SIGHUP) as signal_aiter:
async for signum in signal_aiter:
assert signum == signal.SIGHUP
reload_configuration()
"""
if not signals:
raise TypeError("No signals were provided")
if not is_main_thread():
raise RuntimeError(
"Sorry, open_signal_receiver is only possible when running in "
"Python interpreter's main thread"
)
token = trio.lowlevel.current_trio_token()
queue = SignalReceiver()
def handler(signum, _):
token.run_sync_soon(queue._add, signum, idempotent=True)
try:
with _signal_handler(signals, handler):
yield queue
finally:
queue._redeliver_remaining()

View File

@@ -0,0 +1,787 @@
import os
import sys
import select
import socket as _stdlib_socket
from functools import wraps as _wraps
from typing import TYPE_CHECKING
import idna as _idna
import trio
from . import _core
# Usage:
#
# async with _try_sync():
# return sync_call_that_might_fail_with_exception()
# # we only get here if the sync call in fact did fail with a
# # BlockingIOError
# return await do_it_properly_with_a_check_point()
#
class _try_sync:
def __init__(self, blocking_exc_override=None):
self._blocking_exc_override = blocking_exc_override
def _is_blocking_io_error(self, exc):
if self._blocking_exc_override is None:
return isinstance(exc, BlockingIOError)
else:
return self._blocking_exc_override(exc)
async def __aenter__(self):
await trio.lowlevel.checkpoint_if_cancelled()
async def __aexit__(self, etype, value, tb):
if value is not None and self._is_blocking_io_error(value):
# Discard the exception and fall through to the code below the
# block
return True
else:
await trio.lowlevel.cancel_shielded_checkpoint()
# Let the return or exception propagate
return False
################################################################
# CONSTANTS
################################################################
try:
from socket import IPPROTO_IPV6
except ImportError:
# Before Python 3.8, Windows is missing IPPROTO_IPV6
# https://bugs.python.org/issue29515
if sys.platform == "win32": # pragma: no branch
IPPROTO_IPV6 = 41
################################################################
# Overrides
################################################################
_resolver = _core.RunVar("hostname_resolver")
_socket_factory = _core.RunVar("socket_factory")
def set_custom_hostname_resolver(hostname_resolver):
"""Set a custom hostname resolver.
By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions
use the standard system resolver functions. This function allows you to
customize that behavior. The main intended use case is for testing, but it
might also be useful for using third-party resolvers like `c-ares
<https://c-ares.haxx.se/>`__ (though be warned that these rarely make
perfect drop-in replacements for the system resolver). See
:class:`trio.abc.HostnameResolver` for more details.
Setting a custom hostname resolver affects all future calls to
:func:`getaddrinfo` and :func:`getnameinfo` within the enclosing call to
:func:`trio.run`. All other hostname resolution in Trio is implemented in
terms of these functions.
Generally you should call this function just once, right at the beginning
of your program.
Args:
hostname_resolver (trio.abc.HostnameResolver or None): The new custom
hostname resolver, or None to restore the default behavior.
Returns:
The previous hostname resolver (which may be None).
"""
old = _resolver.get(None)
_resolver.set(hostname_resolver)
return old
def set_custom_socket_factory(socket_factory):
"""Set a custom socket object factory.
This function allows you to replace Trio's normal socket class with a
custom class. This is very useful for testing, and probably a bad idea in
any other circumstance. See :class:`trio.abc.HostnameResolver` for more
details.
Setting a custom socket factory affects all future calls to :func:`socket`
within the enclosing call to :func:`trio.run`.
Generally you should call this function just once, right at the beginning
of your program.
Args:
socket_factory (trio.abc.SocketFactory or None): The new custom
socket factory, or None to restore the default behavior.
Returns:
The previous socket factory (which may be None).
"""
old = _socket_factory.get(None)
_socket_factory.set(socket_factory)
return old
################################################################
# getaddrinfo and friends
################################################################
_NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV
async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
"""Look up a numeric address given a name.
Arguments and return values are identical to :func:`socket.getaddrinfo`,
except that this version is async.
Also, :func:`trio.socket.getaddrinfo` correctly uses IDNA 2008 to process
non-ASCII domain names. (:func:`socket.getaddrinfo` uses IDNA 2003, which
can give the wrong result in some cases and cause you to connect to a
different host than the one you intended; see `bpo-17305
<https://bugs.python.org/issue17305>`__.)
This function's behavior can be customized using
:func:`set_custom_hostname_resolver`.
"""
# If host and port are numeric, then getaddrinfo doesn't block and we can
# skip the whole thread thing, which seems worthwhile. So we try first
# with the _NUMERIC_ONLY flags set, and then only spawn a thread if that
# fails with EAI_NONAME:
def numeric_only_failure(exc):
return (
isinstance(exc, _stdlib_socket.gaierror)
and exc.errno == _stdlib_socket.EAI_NONAME
)
async with _try_sync(numeric_only_failure):
return _stdlib_socket.getaddrinfo(
host, port, family, type, proto, flags | _NUMERIC_ONLY
)
# That failed; it's a real hostname. We better use a thread.
#
# Also, it might be a unicode hostname, in which case we want to do our
# own encoding using the idna module, rather than letting Python do
# it. (Python will use the old IDNA 2003 standard, and possibly get the
# wrong answer - see bpo-17305). However, the idna module is picky, and
# will refuse to process some valid hostname strings, like "::1". So if
# it's already ascii, we pass it through; otherwise, we encode it to.
if isinstance(host, str):
try:
host = host.encode("ascii")
except UnicodeEncodeError:
# UTS-46 defines various normalizations; in particular, by default
# idna.encode will error out if the hostname has Capital Letters
# in it; with uts46=True it will lowercase them instead.
host = _idna.encode(host, uts46=True)
hr = _resolver.get(None)
if hr is not None:
return await hr.getaddrinfo(host, port, family, type, proto, flags)
else:
return await trio.to_thread.run_sync(
_stdlib_socket.getaddrinfo,
host,
port,
family,
type,
proto,
flags,
cancellable=True,
)
async def getnameinfo(sockaddr, flags):
"""Look up a name given a numeric address.
Arguments and return values are identical to :func:`socket.getnameinfo`,
except that this version is async.
This function's behavior can be customized using
:func:`set_custom_hostname_resolver`.
"""
hr = _resolver.get(None)
if hr is not None:
return await hr.getnameinfo(sockaddr, flags)
else:
return await trio.to_thread.run_sync(
_stdlib_socket.getnameinfo, sockaddr, flags, cancellable=True
)
async def getprotobyname(name):
"""Look up a protocol number by name. (Rarely used.)
Like :func:`socket.getprotobyname`, but async.
"""
return await trio.to_thread.run_sync(
_stdlib_socket.getprotobyname, name, cancellable=True
)
# obsolete gethostbyname etc. intentionally omitted
# likewise for create_connection (use open_tcp_stream instead)
################################################################
# Socket "constructors"
################################################################
def from_stdlib_socket(sock):
"""Convert a standard library :class:`socket.socket` object into a Trio
socket object.
"""
return _SocketType(sock)
@_wraps(_stdlib_socket.fromfd, assigned=(), updated=())
def fromfd(fd, family, type, proto=0):
"""Like :func:`socket.fromfd`, but returns a Trio socket object."""
family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd)
return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto))
if sys.platform == "win32" or (
not TYPE_CHECKING and hasattr(_stdlib_socket, "fromshare")
):
@_wraps(_stdlib_socket.fromshare, assigned=(), updated=())
def fromshare(*args, **kwargs):
return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs))
@_wraps(_stdlib_socket.socketpair, assigned=(), updated=())
def socketpair(*args, **kwargs):
"""Like :func:`socket.socketpair`, but returns a pair of Trio socket
objects.
"""
left, right = _stdlib_socket.socketpair(*args, **kwargs)
return (from_stdlib_socket(left), from_stdlib_socket(right))
@_wraps(_stdlib_socket.socket, assigned=(), updated=())
def socket(
family=_stdlib_socket.AF_INET,
type=_stdlib_socket.SOCK_STREAM,
proto=0,
fileno=None,
):
"""Create a new Trio socket, like :class:`socket.socket`.
This function's behavior can be customized using
:func:`set_custom_socket_factory`.
"""
if fileno is None:
sf = _socket_factory.get(None)
if sf is not None:
return sf.socket(family, type, proto)
else:
family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fileno)
stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno)
return from_stdlib_socket(stdlib_socket)
def _sniff_sockopts_for_fileno(family, type, proto, fileno):
"""Correct SOCKOPTS for given fileno, falling back to provided values."""
# Wrap the raw fileno into a Python socket object
# This object might have the wrong metadata, but it lets us easily call getsockopt
# and then we'll throw it away and construct a new one with the correct metadata.
if sys.platform != "linux":
return family, type, proto
from socket import SO_DOMAIN, SO_PROTOCOL, SOL_SOCKET, SO_TYPE
sockobj = _stdlib_socket.socket(family, type, proto, fileno=fileno)
try:
family = sockobj.getsockopt(SOL_SOCKET, SO_DOMAIN)
proto = sockobj.getsockopt(SOL_SOCKET, SO_PROTOCOL)
type = sockobj.getsockopt(SOL_SOCKET, SO_TYPE)
finally:
# Unwrap it again, so that sockobj.__del__ doesn't try to close our socket
sockobj.detach()
return family, type, proto
################################################################
# _SocketType
################################################################
# sock.type gets weird stuff set in it, in particular on Linux:
#
# https://bugs.python.org/issue21327
#
# But on other platforms (e.g. Windows) SOCK_NONBLOCK and SOCK_CLOEXEC aren't
# even defined. To recover the actual socket type (e.g. SOCK_STREAM) from a
# socket.type attribute, mask with this:
_SOCK_TYPE_MASK = ~(
getattr(_stdlib_socket, "SOCK_NONBLOCK", 0)
| getattr(_stdlib_socket, "SOCK_CLOEXEC", 0)
)
# This function will modify the given socket to match the behavior in python
# 3.7. This will become unnecessary and can be removed when support for versions
# older than 3.7 is dropped.
def real_socket_type(type_num):
return type_num & _SOCK_TYPE_MASK
def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False):
fn = getattr(_stdlib_socket.socket, methname)
@_wraps(fn, assigned=("__name__",), updated=())
async def wrapper(self, *args, **kwargs):
return await self._nonblocking_helper(fn, args, kwargs, wait_fn)
wrapper.__doc__ = f"""Like :meth:`socket.socket.{methname}`, but async.
"""
if maybe_avail:
wrapper.__doc__ += (
f"Only available on platforms where :meth:`socket.socket.{methname}` is "
"available."
)
return wrapper
class SocketType:
def __init__(self):
raise TypeError(
"SocketType is an abstract class; use trio.socket.socket if you "
"want to construct a socket object"
)
class _SocketType(SocketType):
def __init__(self, sock):
if type(sock) is not _stdlib_socket.socket:
# For example, ssl.SSLSocket subclasses socket.socket, but we
# certainly don't want to blindly wrap one of those.
raise TypeError(
"expected object of type 'socket.socket', not '{}".format(
type(sock).__name__
)
)
self._sock = sock
self._sock.setblocking(False)
self._did_shutdown_SHUT_WR = False
################################################################
# Simple + portable methods and attributes
################################################################
# NB this doesn't work because for loops don't create a scope
# for _name in [
# ]:
# _meth = getattr(_stdlib_socket.socket, _name)
# @_wraps(_meth, assigned=("__name__", "__doc__"), updated=())
# def _wrapped(self, *args, **kwargs):
# return getattr(self._sock, _meth)(*args, **kwargs)
# locals()[_meth] = _wrapped
# del _name, _meth, _wrapped
_forward = {
"detach",
"get_inheritable",
"set_inheritable",
"fileno",
"getpeername",
"getsockname",
"getsockopt",
"setsockopt",
"listen",
"share",
}
def __getattr__(self, name):
if name in self._forward:
return getattr(self._sock, name)
raise AttributeError(name)
def __dir__(self):
return super().__dir__() + list(self._forward)
def __enter__(self):
return self
def __exit__(self, *exc_info):
return self._sock.__exit__(*exc_info)
@property
def family(self):
return self._sock.family
@property
def type(self):
# Modify the socket type do match what is done on python 3.7. When
# support for versions older than 3.7 is dropped, this can be updated
# to just return self._sock.type
return real_socket_type(self._sock.type)
@property
def proto(self):
return self._sock.proto
@property
def did_shutdown_SHUT_WR(self):
return self._did_shutdown_SHUT_WR
def __repr__(self):
return repr(self._sock).replace("socket.socket", "trio.socket.socket")
def dup(self):
"""Same as :meth:`socket.socket.dup`."""
return _SocketType(self._sock.dup())
def close(self):
if self._sock.fileno() != -1:
trio.lowlevel.notify_closing(self._sock)
self._sock.close()
async def bind(self, address):
address = await self._resolve_local_address_nocp(address)
if (
hasattr(_stdlib_socket, "AF_UNIX")
and self.family == _stdlib_socket.AF_UNIX
and address[0]
):
# Use a thread for the filesystem traversal (unless it's an
# abstract domain socket)
return await trio.to_thread.run_sync(self._sock.bind, address)
else:
# POSIX actually says that bind can return EWOULDBLOCK and
# complete asynchronously, like connect. But in practice AFAICT
# there aren't yet any real systems that do this, so we'll worry
# about it when it happens.
await trio.lowlevel.checkpoint()
return self._sock.bind(address)
def shutdown(self, flag):
# no need to worry about return value b/c always returns None:
self._sock.shutdown(flag)
# only do this if the call succeeded:
if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]:
self._did_shutdown_SHUT_WR = True
def is_readable(self):
# use select.select on Windows, and select.poll everywhere else
if sys.platform == "win32":
rready, _, _ = select.select([self._sock], [], [], 0)
return bool(rready)
p = select.poll()
p.register(self._sock, select.POLLIN)
return bool(p.poll(0))
async def wait_writable(self):
await _core.wait_writable(self._sock)
################################################################
# Address handling
################################################################
# Take an address in Python's representation, and returns a new address in
# the same representation, but with names resolved to numbers,
# etc.
#
# NOTE: this function does not always checkpoint
async def _resolve_address_nocp(self, address, flags):
# Do some pre-checking (or exit early for non-IP sockets)
if self._sock.family == _stdlib_socket.AF_INET:
if not isinstance(address, tuple) or not len(address) == 2:
raise ValueError("address should be a (host, port) tuple")
elif self._sock.family == _stdlib_socket.AF_INET6:
if not isinstance(address, tuple) or not 2 <= len(address) <= 4:
raise ValueError(
"address should be a (host, port, [flowinfo, [scopeid]]) tuple"
)
elif self._sock.family == _stdlib_socket.AF_UNIX:
# unwrap path-likes
return os.fspath(address)
else:
return address
# -- From here on we know we have IPv4 or IPV6 --
host, port, *_ = address
# Fast path for the simple case: already-resolved IP address,
# already-resolved port. This is particularly important for UDP, since
# every sendto call goes through here.
if isinstance(port, int):
try:
_stdlib_socket.inet_pton(self._sock.family, address[0])
except (OSError, TypeError):
pass
else:
return address
# Special cases to match the stdlib, see gh-277
if host == "":
host = None
if host == "<broadcast>":
host = "255.255.255.255"
# Since we always pass in an explicit family here, AI_ADDRCONFIG
# doesn't add any value -- if we have no ipv6 connectivity and are
# working with an ipv6 socket, then things will break soon enough! And
# if we do enable it, then it makes it impossible to even run tests
# for ipv6 address resolution on travis-ci, which as of 2017-03-07 has
# no ipv6.
# flags |= AI_ADDRCONFIG
if self._sock.family == _stdlib_socket.AF_INET6:
if not self._sock.getsockopt(IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY):
flags |= _stdlib_socket.AI_V4MAPPED
gai_res = await getaddrinfo(
host, port, self._sock.family, self.type, self._sock.proto, flags
)
# AFAICT from the spec it's not possible for getaddrinfo to return an
# empty list.
assert len(gai_res) >= 1
# Address is the last item in the first entry
(*_, normed), *_ = gai_res
# The above ignored any flowid and scopeid in the passed-in address,
# so restore them if present:
if self._sock.family == _stdlib_socket.AF_INET6:
normed = list(normed)
assert len(normed) == 4
if len(address) >= 3:
normed[2] = address[2]
if len(address) >= 4:
normed[3] = address[3]
normed = tuple(normed)
return normed
# Returns something appropriate to pass to bind()
#
# NOTE: this function does not always checkpoint
async def _resolve_local_address_nocp(self, address):
return await self._resolve_address_nocp(address, _stdlib_socket.AI_PASSIVE)
# Returns something appropriate to pass to connect()/sendto()/sendmsg()
#
# NOTE: this function does not always checkpoint
async def _resolve_remote_address_nocp(self, address):
return await self._resolve_address_nocp(address, 0)
async def _nonblocking_helper(self, fn, args, kwargs, wait_fn):
# We have to reconcile two conflicting goals:
# - We want to make it look like we always blocked in doing these
# operations. The obvious way is to always do an IO wait before
# calling the function.
# - But, we also want to provide the correct semantics, and part
# of that means giving correct errors. So, for example, if you
# haven't called .listen(), then .accept() raises an error
# immediately. But in this same circumstance, then on macOS, the
# socket does not register as readable. So if we block waiting
# for read *before* we call accept, then we'll be waiting
# forever instead of properly raising an error. (On Linux,
# interestingly, AFAICT a socket that can't possible read/write
# *does* count as readable/writable for select() purposes. But
# not on macOS.)
#
# So, we have to call the function once, with the appropriate
# cancellation/yielding sandwich if it succeeds, and if it gives
# BlockingIOError *then* we fall back to IO wait.
#
# XX think if this can be combined with the similar logic for IOCP
# submission...
async with _try_sync():
return fn(self._sock, *args, **kwargs)
# First attempt raised BlockingIOError:
while True:
await wait_fn(self._sock)
try:
return fn(self._sock, *args, **kwargs)
except BlockingIOError:
pass
################################################################
# accept
################################################################
_accept = _make_simple_sock_method_wrapper("accept", _core.wait_readable)
async def accept(self):
"""Like :meth:`socket.socket.accept`, but async."""
sock, addr = await self._accept()
return from_stdlib_socket(sock), addr
################################################################
# connect
################################################################
async def connect(self, address):
# nonblocking connect is weird -- you call it to start things
# off, then the socket becomes writable as a completion
# notification. This means it isn't really cancellable... we close the
# socket if cancelled, to avoid confusion.
try:
address = await self._resolve_remote_address_nocp(address)
async with _try_sync():
# An interesting puzzle: can a non-blocking connect() return EINTR
# (= raise InterruptedError)? PEP 475 specifically left this as
# the one place where it lets an InterruptedError escape instead
# of automatically retrying. This is based on the idea that EINTR
# from connect means that the connection was already started, and
# will continue in the background. For a blocking connect, this
# sort of makes sense: if it returns EINTR then the connection
# attempt is continuing in the background, and on many system you
# can't then call connect() again because there is already a
# connect happening. See:
#
# http://www.madore.org/~david/computers/connect-intr.html
#
# For a non-blocking connect, it doesn't make as much sense --
# surely the interrupt didn't happen after we successfully
# initiated the connect and are just waiting for it to complete,
# because a non-blocking connect does not wait! And the spec
# describes the interaction between EINTR/blocking connect, but
# doesn't have anything useful to say about non-blocking connect:
#
# http://pubs.opengroup.org/onlinepubs/007904975/functions/connect.html
#
# So we have a conundrum: if EINTR means that the connect() hasn't
# happened (like it does for essentially every other syscall),
# then InterruptedError should be caught and retried. If EINTR
# means that the connect() has successfully started, then
# InterruptedError should be caught and ignored. Which should we
# do?
#
# In practice, the resolution is probably that non-blocking
# connect simply never returns EINTR, so the question of how to
# handle it is moot. Someone spelunked macOS/FreeBSD and
# confirmed this is true there:
#
# https://stackoverflow.com/questions/14134440/eintr-and-non-blocking-calls
#
# and exarkun seems to think it's true in general of non-blocking
# calls:
#
# https://twistedmatrix.com/pipermail/twisted-python/2010-September/022864.html
# (and indeed, AFAICT twisted doesn't try to handle
# InterruptedError).
#
# So we don't try to catch InterruptedError. This way if it
# happens, someone will hopefully tell us, and then hopefully we
# can investigate their system to figure out what its semantics
# are.
return self._sock.connect(address)
# It raised BlockingIOError, meaning that it's started the
# connection attempt. We wait for it to complete:
await _core.wait_writable(self._sock)
except trio.Cancelled:
# We can't really cancel a connect, and the socket is in an
# indeterminate state. Better to close it so we don't get
# confused.
self._sock.close()
raise
# Okay, the connect finished, but it might have failed:
err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR)
if err != 0:
raise OSError(err, "Error in connect: " + os.strerror(err))
################################################################
# recv
################################################################
recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable)
################################################################
# recv_into
################################################################
recv_into = _make_simple_sock_method_wrapper("recv_into", _core.wait_readable)
################################################################
# recvfrom
################################################################
recvfrom = _make_simple_sock_method_wrapper("recvfrom", _core.wait_readable)
################################################################
# recvfrom_into
################################################################
recvfrom_into = _make_simple_sock_method_wrapper(
"recvfrom_into", _core.wait_readable
)
################################################################
# recvmsg
################################################################
if hasattr(_stdlib_socket.socket, "recvmsg"):
recvmsg = _make_simple_sock_method_wrapper(
"recvmsg", _core.wait_readable, maybe_avail=True
)
################################################################
# recvmsg_into
################################################################
if hasattr(_stdlib_socket.socket, "recvmsg_into"):
recvmsg_into = _make_simple_sock_method_wrapper(
"recvmsg_into", _core.wait_readable, maybe_avail=True
)
################################################################
# send
################################################################
send = _make_simple_sock_method_wrapper("send", _core.wait_writable)
################################################################
# sendto
################################################################
@_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=())
async def sendto(self, *args):
"""Similar to :meth:`socket.socket.sendto`, but async."""
# args is: data[, flags], address)
# and kwargs are not accepted
args = list(args)
args[-1] = await self._resolve_remote_address_nocp(args[-1])
return await self._nonblocking_helper(
_stdlib_socket.socket.sendto, args, {}, _core.wait_writable
)
################################################################
# sendmsg
################################################################
if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg")
):
@_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=())
async def sendmsg(self, *args):
"""Similar to :meth:`socket.socket.sendmsg`, but async.
Only available on platforms where :meth:`socket.socket.sendmsg` is
available.
"""
# args is: buffers[, ancdata[, flags[, address]]]
# and kwargs are not accepted
if len(args) == 4 and args[-1] is not None:
args = list(args)
args[-1] = await self._resolve_remote_address_nocp(args[-1])
return await self._nonblocking_helper(
_stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable
)
################################################################
# sendfile
################################################################
# Not implemented yet:
# async def sendfile(self, file, offset=0, count=None):
# XX
# Intentionally omitted:
# sendall
# makefile
# setblocking/getblocking
# settimeout/gettimeout
# timeout

View File

@@ -0,0 +1,927 @@
# General theory of operation:
#
# We implement an API that closely mirrors the stdlib ssl module's blocking
# API, and we do it using the stdlib ssl module's non-blocking in-memory API.
# The stdlib non-blocking in-memory API is barely documented, and acts as a
# thin wrapper around openssl, whose documentation also leaves something to be
# desired. So here's the main things you need to know to understand the code
# in this file:
#
# We use an ssl.SSLObject, which exposes the four main I/O operations:
#
# - do_handshake: performs the initial handshake. Must be called once at the
# beginning of each connection; is a no-op once it's completed once.
#
# - write: takes some unencrypted data and attempts to send it to the remote
# peer.
# - read: attempts to decrypt and return some data from the remote peer.
#
# - unwrap: this is weirdly named; maybe it helps to realize that the thing it
# wraps is called SSL_shutdown. It sends a cryptographically signed message
# saying "I'm closing this connection now", and then waits to receive the
# same from the remote peer (unless we already received one, in which case
# it returns immediately).
#
# All of these operations read and write from some in-memory buffers called
# "BIOs", which are an opaque OpenSSL-specific object that's basically
# semantically equivalent to a Python bytearray. When they want to send some
# bytes to the remote peer, they append them to the outgoing BIO, and when
# they want to receive some bytes from the remote peer, they try to pull them
# out of the incoming BIO. "Sending" always succeeds, because the outgoing BIO
# can always be extended to hold more data. "Receiving" acts sort of like a
# non-blocking socket: it might manage to get some data immediately, or it
# might fail and need to be tried again later. We can also directly add or
# remove data from the BIOs whenever we want.
#
# Now the problem is that while these I/O operations are opaque atomic
# operations from the point of view of us calling them, under the hood they
# might require some arbitrary sequence of sends and receives from the remote
# peer. This is particularly true for do_handshake, which generally requires a
# few round trips, but it's also true for write and read, due to an evil thing
# called "renegotiation".
#
# Renegotiation is the process by which one of the peers might arbitrarily
# decide to redo the handshake at any time. Did I mention it's evil? It's
# pretty evil, and almost universally hated. The HTTP/2 spec forbids the use
# of TLS renegotiation for HTTP/2 connections. TLS 1.3 removes it from the
# protocol entirely. It's impossible to trigger a renegotiation if using
# Python's ssl module. OpenSSL's renegotiation support is pretty buggy [1].
# Nonetheless, it does get used in real life, mostly in two cases:
#
# 1) Normally in TLS 1.2 and below, when the client side of a connection wants
# to present a certificate to prove their identity, that certificate gets sent
# in plaintext. This is bad, because it means that anyone eavesdropping can
# see who's connecting it's like sending your username in plain text. Not as
# bad as sending your password in plain text, but still, pretty bad. However,
# renegotiations *are* encrypted. So as a workaround, it's not uncommon for
# systems that want to use client certificates to first do an anonymous
# handshake, and then to turn around and do a second handshake (=
# renegotiation) and this time ask for a client cert. Or sometimes this is
# done on a case-by-case basis, e.g. a web server might accept a connection,
# read the request, and then once it sees the page you're asking for it might
# stop and ask you for a certificate.
#
# 2) In principle the same TLS connection can be used for an arbitrarily long
# time, and might transmit arbitrarily large amounts of data. But this creates
# a cryptographic problem: an attacker who has access to arbitrarily large
# amounts of data that's all encrypted using the same key may eventually be
# able to use this to figure out the key. Is this a real practical problem? I
# have no idea, I'm not a cryptographer. In any case, some people worry that
# it's a problem, so their TLS libraries are designed to automatically trigger
# a renegotiation every once in a while on some sort of timer.
#
# The end result is that you might be going along, minding your own business,
# and then *bam*! a wild renegotiation appears! And you just have to cope.
#
# The reason that coping with renegotiations is difficult is that some
# unassuming "read" or "write" call might find itself unable to progress until
# it does a handshake, which remember is a process with multiple round
# trips. So read might have to send data, and write might have to receive
# data, and this might happen multiple times. And some of those attempts might
# fail because there isn't any data yet, and need to be retried. Managing all
# this is pretty complicated.
#
# Here's how openssl (and thus the stdlib ssl module) handle this. All of the
# I/O operations above follow the same rules. When you call one of them:
#
# - it might write some data to the outgoing BIO
# - it might read some data from the incoming BIO
# - it might raise SSLWantReadError if it can't complete without reading more
# data from the incoming BIO. This is important: the "read" in ReadError
# refers to reading from the *underlying* stream.
# - (and in principle it might raise SSLWantWriteError too, but that never
# happens when using memory BIOs, so never mind)
#
# If it doesn't raise an error, then the operation completed successfully
# (though we still need to take any outgoing data out of the memory buffer and
# put it onto the wire). If it *does* raise an error, then we need to retry
# *exactly that method call* later in particular, if a 'write' failed, we
# need to try again later *with the same data*, because openssl might have
# already committed some of the initial parts of our data to its output even
# though it didn't tell us that, and has remembered that the next time we call
# write it needs to skip the first 1024 bytes or whatever it is. (Well,
# technically, we're actually allowed to call 'write' again with a data buffer
# which is the same as our old one PLUS some extra stuff added onto the end,
# but in Trio that never comes up so never mind.)
#
# There are some people online who claim that once you've gotten a Want*Error
# then the *very next call* you make to openssl *must* be the same as the
# previous one. I'm pretty sure those people are wrong. In particular, it's
# okay to call write, get a WantReadError, and then call read a few times;
# it's just that *the next time you call write*, it has to be with the same
# data.
#
# One final wrinkle: we want our SSLStream to support full-duplex operation,
# i.e. it should be possible for one task to be calling send_all while another
# task is calling receive_some. But renegotiation makes this a big hassle, because
# even if SSLStream's restricts themselves to one task calling send_all and one
# task calling receive_some, those two tasks might end up both wanting to call
# send_all, or both to call receive_some at the same time *on the underlying
# stream*. So we have to do some careful locking to hide this problem from our
# users.
#
# (Renegotiation is evil.)
#
# So our basic strategy is to define a single helper method called "_retry",
# which has generic logic for dealing with SSLWantReadError, pushing data from
# the outgoing BIO to the wire, reading data from the wire to the incoming
# BIO, retrying an I/O call until it works, and synchronizing with other tasks
# that might be calling _retry concurrently. Basically it takes an SSLObject
# non-blocking in-memory method and converts it into a Trio async blocking
# method. _retry is only about 30 lines of code, but all these cases
# multiplied by concurrent calls make it extremely tricky, so there are lots
# of comments down below on the details, and a really extensive test suite in
# test_ssl.py. And now you know *why* it's so tricky, and can probably
# understand how it works.
#
# [1] https://rt.openssl.org/Ticket/Display.html?id=3712
# XX how closely should we match the stdlib API?
# - maybe suppress_ragged_eofs=False is a better default?
# - maybe check crypto folks for advice?
# - this is also interesting: https://bugs.python.org/issue8108#msg102867
# Definitely keep an eye on Cory's TLS API ideas on security-sig etc.
# XX document behavior on cancellation/error (i.e.: all is lost abandon
# stream)
# docs will need to make very clear that this is different from all the other
# cancellations in core Trio
import operator as _operator
import ssl as _stdlib_ssl
from enum import Enum as _Enum
import trio
from .abc import Stream, Listener
from ._highlevel_generic import aclose_forcefully
from . import _sync
from ._util import ConflictDetector, Final
################################################################
# SSLStream
################################################################
# Ideally, when the user calls SSLStream.receive_some() with no argument, then
# we should do exactly one call to self.transport_stream.receive_some(),
# decrypt everything we got, and return it. Unfortunately, the way openssl's
# API works, we have to pick how much data we want to allow when we call
# read(), and then it (potentially) triggers a call to
# transport_stream.receive_some(). So at the time we pick the amount of data
# to decrypt, we don't know how much data we've read. As a simple heuristic,
# we record the max amount of data returned by previous calls to
# transport_stream.receive_some(), and we use that for future calls to read().
# But what do we use for the very first call? That's what this constant sets.
#
# Note that the value passed to read() is a limit on the amount of
# *decrypted* data, but we can only see the size of the *encrypted* data
# returned by transport_stream.receive_some(). TLS adds a small amount of
# framing overhead, and TLS compression is rarely used these days because it's
# insecure. So the size of the encrypted data should be a slight over-estimate
# of the size of the decrypted data, which is exactly what we want.
#
# The specific value is not really based on anything; it might be worth tuning
# at some point. But, if you have an TCP connection with the typical 1500 byte
# MTU and an initial window of 10 (see RFC 6928), then the initial burst of
# data will be limited to ~15000 bytes (or a bit less due to IP-level framing
# overhead), so this is chosen to be larger than that.
STARTING_RECEIVE_SIZE = 16384
def _is_eof(exc):
# There appears to be a bug on Python 3.10, where SSLErrors
# aren't properly translated into SSLEOFErrors.
# This stringly-typed error check is borrowed from the AnyIO
# project.
return isinstance(exc, _stdlib_ssl.SSLEOFError) or (
hasattr(exc, "strerror") and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
)
class NeedHandshakeError(Exception):
"""Some :class:`SSLStream` methods can't return any meaningful data until
after the handshake. If you call them before the handshake, they raise
this error.
"""
class _Once:
def __init__(self, afn, *args):
self._afn = afn
self._args = args
self.started = False
self._done = _sync.Event()
async def ensure(self, *, checkpoint):
if not self.started:
self.started = True
await self._afn(*self._args)
self._done.set()
elif not checkpoint and self._done.is_set():
return
else:
await self._done.wait()
@property
def done(self):
return self._done.is_set()
_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])
class SSLStream(Stream, metaclass=Final):
r"""Encrypted communication using SSL/TLS.
:class:`SSLStream` wraps an arbitrary :class:`~trio.abc.Stream`, and
allows you to perform encrypted communication over it using the usual
:class:`~trio.abc.Stream` interface. You pass regular data to
:meth:`send_all`, then it encrypts it and sends the encrypted data on the
underlying :class:`~trio.abc.Stream`; :meth:`receive_some` takes encrypted
data out of the underlying :class:`~trio.abc.Stream` and decrypts it
before returning it.
You should read the standard library's :mod:`ssl` documentation carefully
before attempting to use this class, and probably other general
documentation on SSL/TLS as well. SSL/TLS is subtle and quick to
anger. Really. I'm not kidding.
Args:
transport_stream (~trio.abc.Stream): The stream used to transport
encrypted data. Required.
ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` used for
this connection. Required. Usually created by calling
:func:`ssl.create_default_context`.
server_hostname (str or None): The name of the server being connected
to. Used for `SNI
<https://en.wikipedia.org/wiki/Server_Name_Indication>`__ and for
validating the server's certificate (if hostname checking is
enabled). This is effectively mandatory for clients, and actually
mandatory if ``ssl_context.check_hostname`` is ``True``.
server_side (bool): Whether this stream is acting as a client or
server. Defaults to False, i.e. client mode.
https_compatible (bool): There are two versions of SSL/TLS commonly
encountered in the wild: the standard version, and the version used
for HTTPS (HTTP-over-SSL/TLS).
Standard-compliant SSL/TLS implementations always send a
cryptographically signed ``close_notify`` message before closing the
connection. This is important because if the underlying transport
were simply closed, then there wouldn't be any way for the other
side to know whether the connection was intentionally closed by the
peer that they negotiated a cryptographic connection to, or by some
`man-in-the-middle
<https://en.wikipedia.org/wiki/Man-in-the-middle_attack>`__ attacker
who can't manipulate the cryptographic stream, but can manipulate
the transport layer (a so-called "truncation attack").
However, this part of the standard is widely ignored by real-world
HTTPS implementations, which means that if you want to interoperate
with them, then you NEED to ignore it too.
Fortunately this isn't as bad as it sounds, because the HTTP
protocol already includes its own equivalent of ``close_notify``, so
doing this again at the SSL/TLS level is redundant. But not all
protocols do! Therefore, by default Trio implements the safer
standard-compliant version (``https_compatible=False``). But if
you're speaking HTTPS or some other protocol where
``close_notify``\s are commonly skipped, then you should set
``https_compatible=True``; with this setting, Trio will neither
expect nor send ``close_notify`` messages.
If you have code that was written to use :class:`ssl.SSLSocket` and
now you're porting it to Trio, then it may be useful to know that a
difference between :class:`SSLStream` and :class:`ssl.SSLSocket` is
that :class:`~ssl.SSLSocket` implements the
``https_compatible=True`` behavior by default.
Attributes:
transport_stream (trio.abc.Stream): The underlying transport stream
that was passed to ``__init__``. An example of when this would be
useful is if you're using :class:`SSLStream` over a
:class:`~trio.SocketStream` and want to call the
:class:`~trio.SocketStream`'s :meth:`~trio.SocketStream.setsockopt`
method.
Internally, this class is implemented using an instance of
:class:`ssl.SSLObject`, and all of :class:`~ssl.SSLObject`'s methods and
attributes are re-exported as methods and attributes on this class.
However, there is one difference: :class:`~ssl.SSLObject` has several
methods that return information about the encrypted connection, like
:meth:`~ssl.SSLSocket.cipher` or
:meth:`~ssl.SSLSocket.selected_alpn_protocol`. If you call them before the
handshake, when they can't possibly return useful data, then
:class:`ssl.SSLObject` returns None, but :class:`trio.SSLStream`
raises :exc:`NeedHandshakeError`.
This also means that if you register a SNI callback using
`~ssl.SSLContext.sni_callback`, then the first argument your callback
receives will be a :class:`ssl.SSLObject`.
"""
# Note: any new arguments here should likely also be added to
# SSLListener.__init__, and maybe the open_ssl_over_tcp_* helpers.
def __init__(
self,
transport_stream,
ssl_context,
*,
server_hostname=None,
server_side=False,
https_compatible=False,
):
self.transport_stream = transport_stream
self._state = _State.OK
self._https_compatible = https_compatible
self._outgoing = _stdlib_ssl.MemoryBIO()
self._delayed_outgoing = None
self._incoming = _stdlib_ssl.MemoryBIO()
self._ssl_object = ssl_context.wrap_bio(
self._incoming,
self._outgoing,
server_side=server_side,
server_hostname=server_hostname,
)
# Tracks whether we've already done the initial handshake
self._handshook = _Once(self._do_handshake)
# These are used to synchronize access to self.transport_stream
self._inner_send_lock = _sync.StrictFIFOLock()
self._inner_recv_count = 0
self._inner_recv_lock = _sync.Lock()
# These are used to make sure that our caller doesn't attempt to make
# multiple concurrent calls to send_all/wait_send_all_might_not_block
# or to receive_some.
self._outer_send_conflict_detector = ConflictDetector(
"another task is currently sending data on this SSLStream"
)
self._outer_recv_conflict_detector = ConflictDetector(
"another task is currently receiving data on this SSLStream"
)
self._estimated_receive_size = STARTING_RECEIVE_SIZE
_forwarded = {
"context",
"server_side",
"server_hostname",
"session",
"session_reused",
"getpeercert",
"selected_npn_protocol",
"cipher",
"shared_ciphers",
"compression",
"pending",
"get_channel_binding",
"selected_alpn_protocol",
"version",
}
_after_handshake = {
"session_reused",
"getpeercert",
"selected_npn_protocol",
"cipher",
"shared_ciphers",
"compression",
"get_channel_binding",
"selected_alpn_protocol",
"version",
}
def __getattr__(self, name):
if name in self._forwarded:
if name in self._after_handshake and not self._handshook.done:
raise NeedHandshakeError(
"call do_handshake() before calling {!r}".format(name)
)
return getattr(self._ssl_object, name)
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if name in self._forwarded:
setattr(self._ssl_object, name, value)
else:
super().__setattr__(name, value)
def __dir__(self):
return super().__dir__() + list(self._forwarded)
def _check_status(self):
if self._state is _State.OK:
return
elif self._state is _State.BROKEN:
raise trio.BrokenResourceError
elif self._state is _State.CLOSED:
raise trio.ClosedResourceError
else: # pragma: no cover
assert False
# This is probably the single trickiest function in Trio. It has lots of
# comments, though, just make sure to think carefully if you ever have to
# touch it. The big comment at the top of this file will help explain
# too.
async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False):
await trio.lowlevel.checkpoint_if_cancelled()
yielded = False
finished = False
while not finished:
# WARNING: this code needs to be very careful with when it
# calls 'await'! There might be multiple tasks calling this
# function at the same time trying to do different operations,
# so we need to be careful to:
#
# 1) interact with the SSLObject, then
# 2) await on exactly one thing that lets us make forward
# progress, then
# 3) loop or exit
#
# In particular we don't want to yield while interacting with
# the SSLObject (because it's shared state, so someone else
# might come in and mess with it while we're suspended), and
# we don't want to yield *before* starting the operation that
# will help us make progress, because then someone else might
# come in and leapfrog us.
# Call the SSLObject method, and get its result.
#
# NB: despite what the docs say, SSLWantWriteError can't
# happen "Writes to memory BIOs will always succeed if
# memory is available: that is their size can grow
# indefinitely."
# https://wiki.openssl.org/index.php/Manual:BIO_s_mem(3)
want_read = False
ret = None
try:
ret = fn(*args)
except _stdlib_ssl.SSLWantReadError:
want_read = True
except (_stdlib_ssl.SSLError, _stdlib_ssl.CertificateError) as exc:
self._state = _State.BROKEN
raise trio.BrokenResourceError from exc
else:
finished = True
if ignore_want_read:
want_read = False
finished = True
to_send = self._outgoing.read()
# Some versions of SSL_do_handshake have a bug in how they handle
# the TLS 1.3 handshake on the server side: after the handshake
# finishes, they automatically send session tickets, even though
# the client may not be expecting data to arrive at this point and
# sending it could cause a deadlock or lost data. This applies at
# least to OpenSSL 1.1.1c and earlier, and the OpenSSL devs
# currently have no plans to fix it:
#
# https://github.com/openssl/openssl/issues/7948
# https://github.com/openssl/openssl/issues/7967
#
# The correct behavior is to wait to send session tickets on the
# first call to SSL_write. (This is what BoringSSL does.) So, we
# use a heuristic to detect when OpenSSL has tried to send session
# tickets, and we manually delay sending them until the
# appropriate moment. For more discussion see:
#
# https://github.com/python-trio/trio/issues/819#issuecomment-517529763
if (
is_handshake
and not want_read
and self._ssl_object.server_side
and self._ssl_object.version() == "TLSv1.3"
):
assert self._delayed_outgoing is None
self._delayed_outgoing = to_send
to_send = b""
# Outputs from the above code block are:
#
# - to_send: bytestring; if non-empty then we need to send
# this data to make forward progress
#
# - want_read: True if we need to receive_some some data to make
# forward progress
#
# - finished: False means that we need to retry the call to
# fn(*args) again, after having pushed things forward. True
# means we still need to do whatever was said (in particular
# send any data in to_send), but once we do then we're
# done.
#
# - ret: the operation's return value. (Meaningless unless
# finished is True.)
#
# Invariant: want_read and finished can't both be True at the
# same time.
#
# Now we need to move things forward. There are two things we
# might have to do, and any given operation might require
# either, both, or neither to proceed:
#
# - send the data in to_send
#
# - receive_some some data and put it into the incoming BIO
#
# Our strategy is: if there's data to send, send it;
# *otherwise* if there's data to receive_some, receive_some it.
#
# If both need to happen, then we only send. Why? Well, we
# know that *right now* we have to both send and receive_some
# before the operation can complete. But as soon as we yield,
# that information becomes potentially stale e.g. while
# we're sending, some other task might go and receive_some the
# data we need and put it into the incoming BIO. And if it
# does, then we *definitely don't* want to do a receive_some
# there might not be any more data coming, and we'd deadlock!
# We could do something tricky to keep track of whether a
# receive_some happens while we're sending, but the case where
# we have to do both is very unusual (only during a
# renegotiation), so it's better to keep things simple. So we
# do just one potentially-blocking operation, then check again
# for fresh information.
#
# And we prioritize sending over receiving because, if there
# are multiple tasks that want to receive_some, then it
# doesn't matter what order they go in. But if there are
# multiple tasks that want to send, then they each have
# different data, and the data needs to get put onto the wire
# in the same order that it was retrieved from the outgoing
# BIO. So if we have data to send, that *needs* to be the
# *very* *next* *thing* we do, to make sure no-one else sneaks
# in before us. Or if we can't send immediately because
# someone else is, then we at least need to get in line
# immediately.
if to_send:
# NOTE: This relies on the lock being strict FIFO fair!
async with self._inner_send_lock:
yielded = True
try:
if self._delayed_outgoing is not None:
to_send = self._delayed_outgoing + to_send
self._delayed_outgoing = None
await self.transport_stream.send_all(to_send)
except:
# Some unknown amount of our data got sent, and we
# don't know how much. This stream is doomed.
self._state = _State.BROKEN
raise
elif want_read:
# It's possible that someone else is already blocked in
# transport_stream.receive_some. If so then we want to
# wait for them to finish, but we don't want to call
# transport_stream.receive_some again ourselves; we just
# want to loop around and check if their contribution
# helped anything. So we make a note of how many times
# some task has been through here before taking the lock,
# and if it's changed by the time we get the lock, then we
# skip calling transport_stream.receive_some and loop
# around immediately.
recv_count = self._inner_recv_count
async with self._inner_recv_lock:
yielded = True
if recv_count == self._inner_recv_count:
data = await self.transport_stream.receive_some()
if not data:
self._incoming.write_eof()
else:
self._estimated_receive_size = max(
self._estimated_receive_size, len(data)
)
self._incoming.write(data)
self._inner_recv_count += 1
if not yielded:
await trio.lowlevel.cancel_shielded_checkpoint()
return ret
async def _do_handshake(self):
try:
await self._retry(self._ssl_object.do_handshake, is_handshake=True)
except:
self._state = _State.BROKEN
raise
async def do_handshake(self):
"""Ensure that the initial handshake has completed.
The SSL protocol requires an initial handshake to exchange
certificates, select cryptographic keys, and so forth, before any
actual data can be sent or received. You don't have to call this
method; if you don't, then :class:`SSLStream` will automatically
perform the handshake as needed, the first time you try to send or
receive data. But if you want to trigger it manually for example,
because you want to look at the peer's certificate before you start
talking to them then you can call this method.
If the initial handshake is already in progress in another task, this
waits for it to complete and then returns.
If the initial handshake has already completed, this returns
immediately without doing anything (except executing a checkpoint).
.. warning:: If this method is cancelled, then it may leave the
:class:`SSLStream` in an unusable state. If this happens then any
future attempt to use the object will raise
:exc:`trio.BrokenResourceError`.
"""
self._check_status()
await self._handshook.ensure(checkpoint=True)
# Most things work if we don't explicitly force do_handshake to be called
# before calling receive_some or send_all, because openssl will
# automatically perform the handshake on the first SSL_{read,write}
# call. BUT, allowing openssl to do this will disable Python's hostname
# checking!!! See:
# https://bugs.python.org/issue30141
# So we *definitely* have to make sure that do_handshake is called
# before doing anything else.
async def receive_some(self, max_bytes=None):
"""Read some data from the underlying transport, decrypt it, and
return it.
See :meth:`trio.abc.ReceiveStream.receive_some` for details.
.. warning:: If this method is cancelled while the initial handshake
or a renegotiation are in progress, then it may leave the
:class:`SSLStream` in an unusable state. If this happens then any
future attempt to use the object will raise
:exc:`trio.BrokenResourceError`.
"""
with self._outer_recv_conflict_detector:
self._check_status()
try:
await self._handshook.ensure(checkpoint=False)
except trio.BrokenResourceError as exc:
# For some reason, EOF before handshake sometimes raises
# SSLSyscallError instead of SSLEOFError (e.g. on my linux
# laptop, but not on appveyor). Thanks openssl.
if self._https_compatible and (
isinstance(exc.__cause__, _stdlib_ssl.SSLSyscallError)
or _is_eof(exc.__cause__)
):
await trio.lowlevel.checkpoint()
return b""
else:
raise
if max_bytes is None:
# If we somehow have more data already in our pending buffer
# than the estimate receive size, bump up our size a bit for
# this read only.
max_bytes = max(self._estimated_receive_size, self._incoming.pending)
else:
max_bytes = _operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
try:
return await self._retry(self._ssl_object.read, max_bytes)
except trio.BrokenResourceError as exc:
# This isn't quite equivalent to just returning b"" in the
# first place, because we still end up with self._state set to
# BROKEN. But that's actually fine, because after getting an
# EOF on TLS then the only thing you can do is close the
# stream, and closing doesn't care about the state.
if self._https_compatible and _is_eof(exc.__cause__):
await trio.lowlevel.checkpoint()
return b""
else:
raise
async def send_all(self, data):
"""Encrypt some data and then send it on the underlying transport.
See :meth:`trio.abc.SendStream.send_all` for details.
.. warning:: If this method is cancelled, then it may leave the
:class:`SSLStream` in an unusable state. If this happens then any
attempt to use the object will raise
:exc:`trio.BrokenResourceError`.
"""
with self._outer_send_conflict_detector:
self._check_status()
await self._handshook.ensure(checkpoint=False)
# SSLObject interprets write(b"") as an EOF for some reason, which
# is not what we want.
if not data:
await trio.lowlevel.checkpoint()
return
await self._retry(self._ssl_object.write, data)
async def unwrap(self):
"""Cleanly close down the SSL/TLS encryption layer, allowing the
underlying stream to be used for unencrypted communication.
You almost certainly don't need this.
Returns:
A pair ``(transport_stream, trailing_bytes)``, where
``transport_stream`` is the underlying transport stream, and
``trailing_bytes`` is a byte string. Since :class:`SSLStream`
doesn't necessarily know where the end of the encrypted data will
be, it can happen that it accidentally reads too much from the
underlying stream. ``trailing_bytes`` contains this extra data; you
should process it as if it was returned from a call to
``transport_stream.receive_some(...)``.
"""
with self._outer_recv_conflict_detector, self._outer_send_conflict_detector:
self._check_status()
await self._handshook.ensure(checkpoint=False)
await self._retry(self._ssl_object.unwrap)
transport_stream = self.transport_stream
self.transport_stream = None
self._state = _State.CLOSED
return (transport_stream, self._incoming.read())
async def aclose(self):
"""Gracefully shut down this connection, and close the underlying
transport.
If ``https_compatible`` is False (the default), then this attempts to
first send a ``close_notify`` and then close the underlying stream by
calling its :meth:`~trio.abc.AsyncResource.aclose` method.
If ``https_compatible`` is set to True, then this simply closes the
underlying stream and marks this stream as closed.
"""
if self._state is _State.CLOSED:
await trio.lowlevel.checkpoint()
return
if self._state is _State.BROKEN or self._https_compatible:
self._state = _State.CLOSED
await self.transport_stream.aclose()
return
try:
# https_compatible=False, so we're in spec-compliant mode and have
# to send close_notify so that the other side gets a cryptographic
# assurance that we've called aclose. Of course, we can't do
# anything cryptographic until after we've completed the
# handshake:
await self._handshook.ensure(checkpoint=False)
# Then, we call SSL_shutdown *once*, because we want to send a
# close_notify but *not* wait for the other side to send back a
# response. In principle it would be more polite to wait for the
# other side to reply with their own close_notify. However, if
# they aren't paying attention (e.g., if they're just sending
# data and not receiving) then we will never notice our
# close_notify and we'll be waiting forever. Eventually we'll time
# out (hopefully), but it's still kind of nasty. And we can't
# require the other side to always be receiving, because (a)
# backpressure is kind of important, and (b) I bet there are
# broken TLS implementations out there that don't receive all the
# time. (Like e.g. anyone using Python ssl in synchronous mode.)
#
# The send-then-immediately-close behavior is explicitly allowed
# by the TLS specs, so we're ok on that.
#
# Subtlety: SSLObject.unwrap will immediately call it a second
# time, and the second time will raise SSLWantReadError because
# there hasn't been time for the other side to respond
# yet. (Unless they spontaneously sent a close_notify before we
# called this, and it's either already been processed or gets
# pulled out of the buffer by Python's second call.) So the way to
# do what we want is to ignore SSLWantReadError on this call.
#
# Also, because the other side might have already sent
# close_notify and closed their connection then it's possible that
# our attempt to send close_notify will raise
# BrokenResourceError. This is totally legal, and in fact can happen
# with two well-behaved Trio programs talking to each other, so we
# don't want to raise an error. So we suppress BrokenResourceError
# here. (This is safe, because literally the only thing this call
# to _retry will do is send the close_notify alert, so that's
# surely where the error comes from.)
#
# FYI in some cases this could also raise SSLSyscallError which I
# think is because SSL_shutdown is terrible. (Check out that note
# at the bottom of the man page saying that it sometimes gets
# raised spuriously.) I haven't seen this since we switched to
# immediately closing the socket, and I don't know exactly what
# conditions cause it and how to respond, so for now we're just
# letting that happen. But if you start seeing it, then hopefully
# this will give you a little head start on tracking it down,
# because whoa did this puzzle us at the 2017 PyCon sprints.
#
# Also, if someone else is blocked in send/receive, then we aren't
# going to be able to do a clean shutdown. If that happens, we'll
# just do an unclean shutdown.
try:
await self._retry(self._ssl_object.unwrap, ignore_want_read=True)
except (trio.BrokenResourceError, trio.BusyResourceError):
pass
except:
# Failure! Kill the stream and move on.
await aclose_forcefully(self.transport_stream)
raise
else:
# Success! Gracefully close the underlying stream.
await self.transport_stream.aclose()
finally:
self._state = _State.CLOSED
async def wait_send_all_might_not_block(self):
"""See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`."""
# This method's implementation is deceptively simple.
#
# First, we take the outer send lock, because of Trio's standard
# semantics that wait_send_all_might_not_block and send_all
# conflict.
with self._outer_send_conflict_detector:
self._check_status()
# Then we take the inner send lock. We know that no other tasks
# are calling self.send_all or self.wait_send_all_might_not_block,
# because we have the outer_send_lock. But! There might be another
# task calling self.receive_some -> transport_stream.send_all, in
# which case if we were to call
# transport_stream.wait_send_all_might_not_block directly we'd
# have two tasks doing write-related operations on
# transport_stream simultaneously, which is not allowed. We
# *don't* want to raise this conflict to our caller, because it's
# purely an internal affair all they did was call
# wait_send_all_might_not_block and receive_some at the same time,
# which is totally valid. And waiting for the lock is OK, because
# a call to send_all certainly wouldn't complete while the other
# task holds the lock.
async with self._inner_send_lock:
# Now we have the lock, which creates another potential
# problem: what if a call to self.receive_some attempts to do
# transport_stream.send_all now? It'll have to wait for us to
# finish! But that's OK, because we release the lock as soon
# as the underlying stream becomes writable, and the
# self.receive_some call wasn't going to make any progress
# until then anyway.
#
# Of course, this does mean we might return *before* the
# stream is logically writable, because immediately after we
# return self.receive_some might write some data and make it
# non-writable again. But that's OK too,
# wait_send_all_might_not_block only guarantees that it
# doesn't return late.
await self.transport_stream.wait_send_all_might_not_block()
class SSLListener(Listener[SSLStream], metaclass=Final):
"""A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers.
:class:`SSLListener` wraps around another Listener, and converts
all incoming connections to encrypted connections by wrapping them
in a :class:`SSLStream`.
Args:
transport_listener (~trio.abc.Listener): The listener whose incoming
connections will be wrapped in :class:`SSLStream`.
ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` that will be
used for incoming connections.
https_compatible (bool): Passed on to :class:`SSLStream`.
Attributes:
transport_listener (trio.abc.Listener): The underlying listener that was
passed to ``__init__``.
"""
def __init__(
self,
transport_listener,
ssl_context,
*,
https_compatible=False,
):
self.transport_listener = transport_listener
self._ssl_context = ssl_context
self._https_compatible = https_compatible
async def accept(self):
"""Accept the next connection and wrap it in an :class:`SSLStream`.
See :meth:`trio.abc.Listener.accept` for details.
"""
transport_stream = await self.transport_listener.accept()
return SSLStream(
transport_stream,
self._ssl_context,
server_side=True,
https_compatible=self._https_compatible,
)
async def aclose(self):
"""Close the transport listener."""
await self.transport_listener.aclose()

View File

@@ -0,0 +1,744 @@
# coding: utf-8
import os
import subprocess
import sys
from contextlib import ExitStack
from typing import Optional
from functools import partial
import warnings
from typing import TYPE_CHECKING
from ._abc import AsyncResource, SendStream, ReceiveStream
from ._core import ClosedResourceError
from ._highlevel_generic import StapledStream
from ._sync import Lock
from ._subprocess_platform import (
wait_child_exiting,
create_pipe_to_child_stdin,
create_pipe_from_child_output,
)
from ._deprecate import deprecated
from ._util import NoPublicConstructor
import trio
# Linux-specific, but has complex lifetime management stuff so we hard-code it
# here instead of hiding it behind the _subprocess_platform abstraction
can_try_pidfd_open: bool
if TYPE_CHECKING:
def pidfd_open(fd: int, flags: int) -> int:
...
from ._subprocess_platform import ClosableReceiveStream, ClosableSendStream
else:
can_try_pidfd_open = True
try:
from os import pidfd_open
except ImportError:
if sys.platform == "linux":
import ctypes
_cdll_for_pidfd_open = ctypes.CDLL(None, use_errno=True)
_cdll_for_pidfd_open.syscall.restype = ctypes.c_long
# pid and flags are actually int-sized, but the syscall() function
# always takes longs. (Except on x32 where long is 32-bits and syscall
# takes 64-bit arguments. But in the unlikely case that anyone is
# using x32, this will still work, b/c we only need to pass in 32 bits
# of data, and the C ABI doesn't distinguish between passing 32-bit vs
# 64-bit integers; our 32-bit values will get loaded into 64-bit
# registers where syscall() will find them.)
_cdll_for_pidfd_open.syscall.argtypes = [
ctypes.c_long, # syscall number
ctypes.c_long, # pid
ctypes.c_long, # flags
]
__NR_pidfd_open = 434
def pidfd_open(fd: int, flags: int) -> int:
result = _cdll_for_pidfd_open.syscall(__NR_pidfd_open, fd, flags)
if result < 0:
err = ctypes.get_errno()
raise OSError(err, os.strerror(err))
return result
else:
can_try_pidfd_open = False
class Process(AsyncResource, metaclass=NoPublicConstructor):
r"""A child process. Like :class:`subprocess.Popen`, but async.
This class has no public constructor. The most common way to get a
`Process` object is to combine `Nursery.start` with `run_process`::
process_object = await nursery.start(run_process, ...)
This way, `run_process` supervises the process and makes sure that it is
cleaned up properly, while optionally checking the return value, feeding
it input, and so on.
If you need more control for example, because you want to spawn a child
process that outlives your program then another option is to use
`trio.lowlevel.open_process`::
process_object = await trio.lowlevel.open_process(...)
Attributes:
args (str or list): The ``command`` passed at construction time,
specifying the process to execute and its arguments.
pid (int): The process ID of the child process managed by this object.
stdin (trio.abc.SendStream or None): A stream connected to the child's
standard input stream: when you write bytes here, they become available
for the child to read. Only available if the :class:`Process`
was constructed using ``stdin=PIPE``; otherwise this will be None.
stdout (trio.abc.ReceiveStream or None): A stream connected to
the child's standard output stream: when the child writes to
standard output, the written bytes become available for you
to read here. Only available if the :class:`Process` was
constructed using ``stdout=PIPE``; otherwise this will be None.
stderr (trio.abc.ReceiveStream or None): A stream connected to
the child's standard error stream: when the child writes to
standard error, the written bytes become available for you
to read here. Only available if the :class:`Process` was
constructed using ``stderr=PIPE``; otherwise this will be None.
stdio (trio.StapledStream or None): A stream that sends data to
the child's standard input and receives from the child's standard
output. Only available if both :attr:`stdin` and :attr:`stdout` are
available; otherwise this will be None.
"""
universal_newlines = False
encoding = None
errors = None
# Available for the per-platform wait_child_exiting() implementations
# to stash some state; waitid platforms use this to avoid spawning
# arbitrarily many threads if wait() keeps getting cancelled.
_wait_for_exit_data = None
def __init__(self, popen, stdin, stdout, stderr):
self._proc = popen
self.stdin = stdin # type: Optional[SendStream]
self.stdout = stdout # type: Optional[ReceiveStream]
self.stderr = stderr # type: Optional[ReceiveStream]
self.stdio = None # type: Optional[StapledStream]
if self.stdin is not None and self.stdout is not None:
self.stdio = StapledStream(self.stdin, self.stdout)
self._wait_lock = Lock()
self._pidfd = None
if can_try_pidfd_open:
try:
fd = pidfd_open(self._proc.pid, 0)
except OSError:
# Well, we tried, but it didn't work (probably because we're
# running on an older kernel, or in an older sandbox, that
# hasn't been updated to support pidfd_open). We'll fall back
# on waitid instead.
pass
else:
# It worked! Wrap the raw fd up in a Python file object to
# make sure it'll get closed.
self._pidfd = open(fd)
self.args = self._proc.args
self.pid = self._proc.pid
def __repr__(self):
returncode = self.returncode
if returncode is None:
status = "running with PID {}".format(self.pid)
else:
if returncode < 0:
status = "exited with signal {}".format(-returncode)
else:
status = "exited with status {}".format(returncode)
return "<trio.Process {!r}: {}>".format(self.args, status)
@property
def returncode(self):
"""The exit status of the process (an integer), or ``None`` if it's
still running.
By convention, a return code of zero indicates success. On
UNIX, negative values indicate termination due to a signal,
e.g., -11 if terminated by signal 11 (``SIGSEGV``). On
Windows, a process that exits due to a call to
:meth:`Process.terminate` will have an exit status of 1.
Unlike the standard library `subprocess.Popen.returncode`, you don't
have to call `poll` or `wait` to update this attribute; it's
automatically updated as needed, and will always give you the latest
information.
"""
result = self._proc.poll()
if result is not None:
self._close_pidfd()
return result
@deprecated(
"0.20.0",
thing="using trio.Process as an async context manager",
issue=1104,
instead="run_process or nursery.start(run_process, ...)",
)
async def __aenter__(self):
return self
@deprecated(
"0.20.0", issue=1104, instead="run_process or nursery.start(run_process, ...)"
)
async def aclose(self):
"""Close any pipes we have to the process (both input and output)
and wait for it to exit.
If cancelled, kills the process and waits for it to finish
exiting before propagating the cancellation.
"""
with trio.CancelScope(shield=True):
if self.stdin is not None:
await self.stdin.aclose()
if self.stdout is not None:
await self.stdout.aclose()
if self.stderr is not None:
await self.stderr.aclose()
try:
await self.wait()
finally:
if self._proc.returncode is None:
self.kill()
with trio.CancelScope(shield=True):
await self.wait()
def _close_pidfd(self):
if self._pidfd is not None:
trio.lowlevel.notify_closing(self._pidfd.fileno())
self._pidfd.close()
self._pidfd = None
async def wait(self):
"""Block until the process exits.
Returns:
The exit status of the process; see :attr:`returncode`.
"""
async with self._wait_lock:
if self.poll() is None:
if self._pidfd is not None:
try:
await trio.lowlevel.wait_readable(self._pidfd)
except ClosedResourceError:
# something else (probably a call to poll) already closed the
# pidfd
pass
else:
await wait_child_exiting(self)
# We have to use .wait() here, not .poll(), because on macOS
# (and maybe other systems, who knows), there's a race
# condition inside the kernel that creates a tiny window where
# kqueue reports that the process has exited, but
# waitpid(WNOHANG) can't yet reap it. So this .wait() may
# actually block for a tiny fraction of a second.
self._proc.wait()
self._close_pidfd()
assert self._proc.returncode is not None
return self._proc.returncode
def poll(self):
"""Returns the exit status of the process (an integer), or ``None`` if
it's still running.
Note that on Trio (unlike the standard library `subprocess.Popen`),
``process.poll()`` and ``process.returncode`` always give the same
result. See `returncode` for more details. This method is only
included to make it easier to port code from `subprocess`.
"""
return self.returncode
def send_signal(self, sig):
"""Send signal ``sig`` to the process.
On UNIX, ``sig`` may be any signal defined in the
:mod:`signal` module, such as ``signal.SIGINT`` or
``signal.SIGTERM``. On Windows, it may be anything accepted by
the standard library :meth:`subprocess.Popen.send_signal`.
"""
self._proc.send_signal(sig)
def terminate(self):
"""Terminate the process, politely if possible.
On UNIX, this is equivalent to
``send_signal(signal.SIGTERM)``; by convention this requests
graceful termination, but a misbehaving or buggy process might
ignore it. On Windows, :meth:`terminate` forcibly terminates the
process in the same manner as :meth:`kill`.
"""
self._proc.terminate()
def kill(self):
"""Immediately terminate the process.
On UNIX, this is equivalent to
``send_signal(signal.SIGKILL)``. On Windows, it calls
``TerminateProcess``. In both cases, the process cannot
prevent itself from being killed, but the termination will be
delivered asynchronously; use :meth:`wait` if you want to
ensure the process is actually dead before proceeding.
"""
self._proc.kill()
async def open_process(
command, *, stdin=None, stdout=None, stderr=None, **options
) -> Process:
r"""Execute a child program in a new process.
After construction, you can interact with the child process by writing data to its
`~trio.Process.stdin` stream (a `~trio.abc.SendStream`), reading data from its
`~trio.Process.stdout` and/or `~trio.Process.stderr` streams (both
`~trio.abc.ReceiveStream`\s), sending it signals using `~trio.Process.terminate`,
`~trio.Process.kill`, or `~trio.Process.send_signal`, and waiting for it to exit
using `~trio.Process.wait`. See `trio.Process` for details.
Each standard stream is only available if you specify that a pipe should be created
for it. For example, if you pass ``stdin=subprocess.PIPE``, you can write to the
`~trio.Process.stdin` stream, else `~trio.Process.stdin` will be ``None``.
Unlike `trio.run_process`, this function doesn't do any kind of automatic
management of the child process. It's up to you to implement whatever semantics you
want.
Args:
command (list or str): The command to run. Typically this is a
sequence of strings such as ``['ls', '-l', 'directory with spaces']``,
where the first element names the executable to invoke and the other
elements specify its arguments. With ``shell=True`` in the
``**options``, or on Windows, ``command`` may alternatively
be a string, which will be parsed following platform-dependent
:ref:`quoting rules <subprocess-quoting>`.
stdin: Specifies what the child process's standard input
stream should connect to: output written by the parent
(``subprocess.PIPE``), nothing (``subprocess.DEVNULL``),
or an open file (pass a file descriptor or something whose
``fileno`` method returns one). If ``stdin`` is unspecified,
the child process will have the same standard input stream
as its parent.
stdout: Like ``stdin``, but for the child process's standard output
stream.
stderr: Like ``stdin``, but for the child process's standard error
stream. An additional value ``subprocess.STDOUT`` is supported,
which causes the child's standard output and standard error
messages to be intermixed on a single standard output stream,
attached to whatever the ``stdout`` option says to attach it to.
**options: Other :ref:`general subprocess options <subprocess-options>`
are also accepted.
Returns:
A new `trio.Process` object.
Raises:
OSError: if the process spawning fails, for example because the
specified command could not be found.
"""
for key in ("universal_newlines", "text", "encoding", "errors", "bufsize"):
if options.get(key):
raise TypeError(
"trio.Process only supports communicating over "
"unbuffered byte streams; the '{}' option is not supported".format(key)
)
if os.name == "posix":
if isinstance(command, str) and not options.get("shell"):
raise TypeError(
"command must be a sequence (not a string) if shell=False "
"on UNIX systems"
)
if not isinstance(command, str) and options.get("shell"):
raise TypeError(
"command must be a string (not a sequence) if shell=True "
"on UNIX systems"
)
trio_stdin = None # type: Optional[ClosableSendStream]
trio_stdout = None # type: Optional[ClosableReceiveStream]
trio_stderr = None # type: Optional[ClosableReceiveStream]
# Close the parent's handle for each child side of a pipe; we want the child to
# have the only copy, so that when it exits we can read EOF on our side. The
# trio ends of pipes will be transferred to the Process object, which will be
# responsible for their lifetime. If process spawning fails, though, we still
# want to close them before letting the failure bubble out
with ExitStack() as always_cleanup, ExitStack() as cleanup_on_fail:
if stdin == subprocess.PIPE:
trio_stdin, stdin = create_pipe_to_child_stdin()
always_cleanup.callback(os.close, stdin)
cleanup_on_fail.callback(trio_stdin.close)
if stdout == subprocess.PIPE:
trio_stdout, stdout = create_pipe_from_child_output()
always_cleanup.callback(os.close, stdout)
cleanup_on_fail.callback(trio_stdout.close)
if stderr == subprocess.STDOUT:
# If we created a pipe for stdout, pass the same pipe for
# stderr. If stdout was some non-pipe thing (DEVNULL or a
# given FD), pass the same thing. If stdout was passed as
# None, keep stderr as STDOUT to allow subprocess to dup
# our stdout. Regardless of which of these is applicable,
# don't create a new Trio stream for stderr -- if stdout
# is piped, stderr will be intermixed on the stdout stream.
if stdout is not None:
stderr = stdout
elif stderr == subprocess.PIPE:
trio_stderr, stderr = create_pipe_from_child_output()
always_cleanup.callback(os.close, stderr)
cleanup_on_fail.callback(trio_stderr.close)
popen = await trio.to_thread.run_sync(
partial(
subprocess.Popen,
command,
stdin=stdin,
stdout=stdout,
stderr=stderr,
**options,
)
)
# We did not fail, so dismiss the stack for the trio ends
cleanup_on_fail.pop_all()
return Process._create(popen, trio_stdin, trio_stdout, trio_stderr)
async def _windows_deliver_cancel(p):
try:
p.terminate()
except OSError as exc:
warnings.warn(RuntimeWarning(f"TerminateProcess on {p!r} failed with: {exc!r}"))
async def _posix_deliver_cancel(p):
try:
p.terminate()
await trio.sleep(5)
warnings.warn(
RuntimeWarning(
f"process {p!r} ignored SIGTERM for 5 seconds. "
f"(Maybe you should pass a custom deliver_cancel?) "
f"Trying SIGKILL."
)
)
p.kill()
except OSError as exc:
warnings.warn(
RuntimeWarning(f"tried to kill process {p!r}, but failed with: {exc!r}")
)
async def run_process(
command,
*,
stdin=b"",
capture_stdout=False,
capture_stderr=False,
check=True,
deliver_cancel=None,
task_status=trio.TASK_STATUS_IGNORED,
**options,
):
"""Run ``command`` in a subprocess and wait for it to complete.
This function can be called in two different ways.
One option is a direct call, like::
completed_process_info = await trio.run_process(...)
In this case, it returns a :class:`subprocess.CompletedProcess` instance
describing the results. Use this if you want to treat a process like a
function call.
The other option is to run it as a task using `Nursery.start` the enhanced version
of `~Nursery.start_soon` that lets a task pass back a value during startup::
process = await nursery.start(trio.run_process, ...)
In this case, `~Nursery.start` returns a `Process` object that you can use
to interact with the process while it's running. Use this if you want to
treat a process like a background task.
Either way, `run_process` makes sure that the process has exited before
returning, handles cancellation, optionally checks for errors, and
provides some convenient shorthands for dealing with the child's
input/output.
**Input:** `run_process` supports all the same ``stdin=`` arguments as
`subprocess.Popen`. In addition, if you simply want to pass in some fixed
data, you can pass a plain `bytes` object, and `run_process` will take
care of setting up a pipe, feeding in the data you gave, and then sending
end-of-file. The default is ``b""``, which means that the child will receive
an empty stdin. If you want the child to instead read from the parent's
stdin, use ``stdin=None``.
**Output:** By default, any output produced by the subprocess is
passed through to the standard output and error streams of the
parent Trio process.
When calling `run_process` directly, you can capture the subprocess's output by
passing ``capture_stdout=True`` to capture the subprocess's standard output, and/or
``capture_stderr=True`` to capture its standard error. Captured data is collected up
by Trio into an in-memory buffer, and then provided as the
:attr:`~subprocess.CompletedProcess.stdout` and/or
:attr:`~subprocess.CompletedProcess.stderr` attributes of the returned
:class:`~subprocess.CompletedProcess` object. The value for any stream that was not
captured will be ``None``.
If you want to capture both stdout and stderr while keeping them
separate, pass ``capture_stdout=True, capture_stderr=True``.
If you want to capture both stdout and stderr but mixed together
in the order they were printed, use: ``capture_stdout=True, stderr=subprocess.STDOUT``.
This directs the child's stderr into its stdout, so the combined
output will be available in the `~subprocess.CompletedProcess.stdout`
attribute.
If you're using ``await nursery.start(trio.run_process, ...)`` and want to capture
the subprocess's output for further processing, then use ``stdout=subprocess.PIPE``
and then make sure to read the data out of the `Process.stdout` stream. If you want
to capture stderr separately, use ``stderr=subprocess.PIPE``. If you want to capture
both, but mixed together in the correct order, use ``stdout=subproces.PIPE,
stderr=subprocess.STDOUT``.
**Error checking:** If the subprocess exits with a nonzero status
code, indicating failure, :func:`run_process` raises a
:exc:`subprocess.CalledProcessError` exception rather than
returning normally. The captured outputs are still available as
the :attr:`~subprocess.CalledProcessError.stdout` and
:attr:`~subprocess.CalledProcessError.stderr` attributes of that
exception. To disable this behavior, so that :func:`run_process`
returns normally even if the subprocess exits abnormally, pass ``check=False``.
Note that this can make the ``capture_stdout`` and ``capture_stderr``
arguments useful even when starting `run_process` as a task: if you only
care about the output if the process fails, then you can enable capturing
and then read the output off of the `~subprocess.CalledProcessError`.
**Cancellation:** If cancelled, `run_process` sends a termination
request to the subprocess, then waits for it to fully exit. The
``deliver_cancel`` argument lets you control how the process is terminated.
.. note:: `run_process` is intentionally similar to the standard library
`subprocess.run`, but some of the defaults are different. Specifically, we
default to:
- ``check=True``, because `"errors should never pass silently / unless
explicitly silenced" <https://www.python.org/dev/peps/pep-0020/>`__.
- ``stdin=b""``, because it produces less-confusing results if a subprocess
unexpectedly tries to read from stdin.
To get the `subprocess.run` semantics, use ``check=False, stdin=None``.
Args:
command (list or str): The command to run. Typically this is a
sequence of strings such as ``['ls', '-l', 'directory with spaces']``,
where the first element names the executable to invoke and the other
elements specify its arguments. With ``shell=True`` in the
``**options``, or on Windows, ``command`` may alternatively
be a string, which will be parsed following platform-dependent
:ref:`quoting rules <subprocess-quoting>`.
stdin (:obj:`bytes`, subprocess.PIPE, file descriptor, or None): The
bytes to provide to the subprocess on its standard input stream, or
``None`` if the subprocess's standard input should come from the
same place as the parent Trio process's standard input. As is the
case with the :mod:`subprocess` module, you can also pass a file
descriptor or an object with a ``fileno()`` method, in which case
the subprocess's standard input will come from that file.
When starting `run_process` as a background task, you can also use
``stdin=subprocess.PIPE``, in which case `Process.stdin` will be a
`~trio.abc.SendStream` that you can use to send data to the child.
capture_stdout (bool): If true, capture the bytes that the subprocess
writes to its standard output stream and return them in the
`~subprocess.CompletedProcess.stdout` attribute of the returned
`subprocess.CompletedProcess` or `subprocess.CalledProcessError`.
capture_stderr (bool): If true, capture the bytes that the subprocess
writes to its standard error stream and return them in the
`~subprocess.CompletedProcess.stderr` attribute of the returned
`~subprocess.CompletedProcess` or `subprocess.CalledProcessError`.
check (bool): If false, don't validate that the subprocess exits
successfully. You should be sure to check the
``returncode`` attribute of the returned object if you pass
``check=False``, so that errors don't pass silently.
deliver_cancel (async function or None): If `run_process` is cancelled,
then it needs to kill the child process. There are multiple ways to
do this, so we let you customize it.
If you pass None (the default), then the behavior depends on the
platform:
- On Windows, Trio calls ``TerminateProcess``, which should kill the
process immediately.
- On Unix-likes, the default behavior is to send a ``SIGTERM``, wait
5 seconds, and send a ``SIGKILL``.
Alternatively, you can customize this behavior by passing in an
arbitrary async function, which will be called with the `Process`
object as an argument. For example, the default Unix behavior could
be implemented like this::
async def my_deliver_cancel(process):
process.send_signal(signal.SIGTERM)
await trio.sleep(5)
process.send_signal(signal.SIGKILL)
When the process actually exits, the ``deliver_cancel`` function
will automatically be cancelled so if the process exits after
``SIGTERM``, then we'll never reach the ``SIGKILL``.
In any case, `run_process` will always wait for the child process to
exit before raising `Cancelled`.
**options: :func:`run_process` also accepts any :ref:`general subprocess
options <subprocess-options>` and passes them on to the
:class:`~trio.Process` constructor. This includes the
``stdout`` and ``stderr`` options, which provide additional
redirection possibilities such as ``stderr=subprocess.STDOUT``,
``stdout=subprocess.DEVNULL``, or file descriptors.
Returns:
When called normally a `subprocess.CompletedProcess` instance
describing the return code and outputs.
When called via `Nursery.start` a `trio.Process` instance.
Raises:
UnicodeError: if ``stdin`` is specified as a Unicode string, rather
than bytes
ValueError: if multiple redirections are specified for the same
stream, e.g., both ``capture_stdout=True`` and
``stdout=subprocess.DEVNULL``
subprocess.CalledProcessError: if ``check=False`` is not passed
and the process exits with a nonzero exit status
OSError: if an error is encountered starting or communicating with
the process
.. note:: The child process runs in the same process group as the parent
Trio process, so a Ctrl+C will be delivered simultaneously to both
parent and child. If you don't want this behavior, consult your
platform's documentation for starting child processes in a different
process group.
"""
if isinstance(stdin, str):
raise UnicodeError("process stdin must be bytes, not str")
if task_status is trio.TASK_STATUS_IGNORED:
if stdin is subprocess.PIPE:
raise ValueError(
"stdout=subprocess.PIPE is only valid with nursery.start, "
"since that's the only way to access the pipe; use nursery.start "
"or pass the data you want to write directly"
)
if options.get("stdout") is subprocess.PIPE:
raise ValueError(
"stdout=subprocess.PIPE is only valid with nursery.start, "
"since that's the only way to access the pipe"
)
if options.get("stderr") is subprocess.PIPE:
raise ValueError(
"stderr=subprocess.PIPE is only valid with nursery.start, "
"since that's the only way to access the pipe"
)
if isinstance(stdin, (bytes, bytearray, memoryview)):
input = stdin
options["stdin"] = subprocess.PIPE
else:
# stdin should be something acceptable to Process
# (None, DEVNULL, a file descriptor, etc) and Process
# will raise if it's not
input = None
options["stdin"] = stdin
if capture_stdout:
if "stdout" in options:
raise ValueError("can't specify both stdout and capture_stdout")
options["stdout"] = subprocess.PIPE
if capture_stderr:
if "stderr" in options:
raise ValueError("can't specify both stderr and capture_stderr")
options["stderr"] = subprocess.PIPE
if deliver_cancel is None:
if os.name == "nt":
deliver_cancel = _windows_deliver_cancel
else:
assert os.name == "posix"
deliver_cancel = _posix_deliver_cancel
stdout_chunks = []
stderr_chunks = []
async def feed_input(stream):
async with stream:
try:
await stream.send_all(input)
except trio.BrokenResourceError:
pass
async def read_output(stream, chunks):
async with stream:
async for chunk in stream:
chunks.append(chunk)
async with trio.open_nursery() as nursery:
proc = await open_process(command, **options)
try:
if input is not None:
nursery.start_soon(feed_input, proc.stdin)
proc.stdin = None
proc.stdio = None
if capture_stdout:
nursery.start_soon(read_output, proc.stdout, stdout_chunks)
proc.stdout = None
proc.stdio = None
if capture_stderr:
nursery.start_soon(read_output, proc.stderr, stderr_chunks)
proc.stderr = None
task_status.started(proc)
await proc.wait()
except BaseException:
with trio.CancelScope(shield=True):
killer_cscope = trio.CancelScope(shield=True)
async def killer():
with killer_cscope:
await deliver_cancel(proc)
nursery.start_soon(killer)
await proc.wait()
killer_cscope.cancel()
raise
stdout = b"".join(stdout_chunks) if capture_stdout else None
stderr = b"".join(stderr_chunks) if capture_stderr else None
if proc.returncode and check:
raise subprocess.CalledProcessError(
proc.returncode, proc.args, output=stdout, stderr=stderr
)
else:
return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr)

View File

@@ -0,0 +1,122 @@
# Platform-specific subprocess bits'n'pieces.
import os
import sys
from typing import Optional, Tuple, TYPE_CHECKING
import trio
from .. import _core, _subprocess
from .._abc import SendStream, ReceiveStream
_wait_child_exiting_error: Optional[ImportError] = None
_create_child_pipe_error: Optional[ImportError] = None
if TYPE_CHECKING:
# internal types for the pipe representations used in type checking only
class ClosableSendStream(SendStream):
def close(self) -> None:
...
class ClosableReceiveStream(ReceiveStream):
def close(self) -> None:
...
# Fallback versions of the functions provided -- implementations
# per OS are imported atop these at the bottom of the module.
async def wait_child_exiting(process: "_subprocess.Process") -> None:
"""Block until the child process managed by ``process`` is exiting.
It is invalid to call this function if the process has already
been waited on; that is, ``process.returncode`` must be None.
When this function returns, it indicates that a call to
:meth:`subprocess.Popen.wait` will immediately be able to
return the process's exit status. The actual exit status is not
consumed by this call, since :class:`~subprocess.Popen` wants
to be able to do that itself.
"""
raise NotImplementedError from _wait_child_exiting_error # pragma: no cover
def create_pipe_to_child_stdin() -> Tuple["ClosableSendStream", int]:
"""Create a new pipe suitable for sending data from this
process to the standard input of a child we're about to spawn.
Returns:
A pair ``(trio_end, subprocess_end)`` where ``trio_end`` is a
:class:`~trio.abc.SendStream` and ``subprocess_end`` is
something suitable for passing as the ``stdin`` argument of
:class:`subprocess.Popen`.
"""
raise NotImplementedError from _create_child_pipe_error # pragma: no cover
def create_pipe_from_child_output() -> Tuple["ClosableReceiveStream", int]:
"""Create a new pipe suitable for receiving data into this
process from the standard output or error stream of a child
we're about to spawn.
Returns:
A pair ``(trio_end, subprocess_end)`` where ``trio_end`` is a
:class:`~trio.abc.ReceiveStream` and ``subprocess_end`` is
something suitable for passing as the ``stdin`` argument of
:class:`subprocess.Popen`.
"""
raise NotImplementedError from _create_child_pipe_error # pragma: no cover
try:
if sys.platform == "win32":
from .windows import wait_child_exiting # noqa: F811
elif sys.platform != "linux" and (TYPE_CHECKING or hasattr(_core, "wait_kevent")):
from .kqueue import wait_child_exiting # noqa: F811
else:
from .waitid import wait_child_exiting # noqa: F811
except ImportError as ex: # pragma: no cover
_wait_child_exiting_error = ex
try:
if TYPE_CHECKING:
# Not worth type checking these definitions
pass
elif os.name == "posix":
def create_pipe_to_child_stdin(): # noqa: F811
rfd, wfd = os.pipe()
return trio.lowlevel.FdStream(wfd), rfd
def create_pipe_from_child_output(): # noqa: F811
rfd, wfd = os.pipe()
return trio.lowlevel.FdStream(rfd), wfd
elif os.name == "nt":
from .._windows_pipes import PipeSendStream, PipeReceiveStream
# This isn't exported or documented, but it's also not
# underscore-prefixed, and seems kosher to use. The asyncio docs
# for 3.5 included an example that imported socketpair from
# windows_utils (before socket.socketpair existed on Windows), and
# when asyncio.windows_utils.socketpair was removed in 3.7, the
# removal was mentioned in the release notes.
from asyncio.windows_utils import pipe as windows_pipe
import msvcrt
def create_pipe_to_child_stdin(): # noqa: F811
# for stdin, we want the write end (our end) to use overlapped I/O
rh, wh = windows_pipe(overlapped=(False, True))
return PipeSendStream(wh), msvcrt.open_osfhandle(rh, os.O_RDONLY)
def create_pipe_from_child_output(): # noqa: F811
# for stdout/err, it's the read end that's overlapped
rh, wh = windows_pipe(overlapped=(True, False))
return PipeReceiveStream(rh), msvcrt.open_osfhandle(wh, 0)
else: # pragma: no cover
raise ImportError("pipes not implemented on this platform")
except ImportError as ex: # pragma: no cover
_create_child_pipe_error = ex

View File

@@ -0,0 +1,41 @@
import sys
import select
from typing import TYPE_CHECKING
from .. import _core, _subprocess
assert (sys.platform != "win32" and sys.platform != "linux") or not TYPE_CHECKING
async def wait_child_exiting(process: "_subprocess.Process") -> None:
kqueue = _core.current_kqueue()
try:
from select import KQ_NOTE_EXIT
except ImportError: # pragma: no cover
# pypy doesn't define KQ_NOTE_EXIT:
# https://bitbucket.org/pypy/pypy/issues/2921/
# I verified this value against both Darwin and FreeBSD
KQ_NOTE_EXIT = 0x80000000
make_event = lambda flags: select.kevent(
process.pid, filter=select.KQ_FILTER_PROC, flags=flags, fflags=KQ_NOTE_EXIT
)
try:
kqueue.control([make_event(select.KQ_EV_ADD | select.KQ_EV_ONESHOT)], 0)
except ProcessLookupError: # pragma: no cover
# This can supposedly happen if the process is in the process
# of exiting, and it can even be the case that kqueue says the
# process doesn't exist before waitpid(WNOHANG) says it hasn't
# exited yet. See the discussion in https://chromium.googlesource.com/
# chromium/src/base/+/master/process/kill_mac.cc .
# We haven't actually seen this error occur since we added
# locking to prevent multiple calls to wait_child_exiting()
# for the same process simultaneously, but given the explanation
# in Chromium it seems we should still keep the check.
return
def abort(_):
kqueue.control([make_event(select.KQ_EV_DELETE)], 0)
return _core.Abort.SUCCEEDED
await _core.wait_kevent(process.pid, select.KQ_FILTER_PROC, abort)

View File

@@ -0,0 +1,107 @@
import errno
import math
import os
import sys
from .. import _core, _subprocess
from .._sync import CapacityLimiter, Event
from .._threads import to_thread_run_sync
try:
from os import waitid
def sync_wait_reapable(pid):
waitid(os.P_PID, pid, os.WEXITED | os.WNOWAIT)
except ImportError:
# pypy doesn't define os.waitid so we need to pull it out ourselves
# using cffi: https://bitbucket.org/pypy/pypy/issues/2922/
import cffi
waitid_ffi = cffi.FFI()
# Believe it or not, siginfo_t starts with fields in the
# same layout on both Linux and Darwin. The Linux structure
# is bigger so that's what we use to size `pad`; while
# there are a few extra fields in there, most of it is
# true padding which would not be written by the syscall.
waitid_ffi.cdef(
"""
typedef struct siginfo_s {
int si_signo;
int si_errno;
int si_code;
int si_pid;
int si_uid;
int si_status;
int pad[26];
} siginfo_t;
int waitid(int idtype, int id, siginfo_t* result, int options);
"""
)
waitid = waitid_ffi.dlopen(None).waitid
def sync_wait_reapable(pid):
P_PID = 1
WEXITED = 0x00000004
if sys.platform == "darwin": # pragma: no cover
# waitid() is not exposed on Python on Darwin but does
# work through CFFI; note that we typically won't get
# here since Darwin also defines kqueue
WNOWAIT = 0x00000020
else:
WNOWAIT = 0x01000000
result = waitid_ffi.new("siginfo_t *")
while waitid(P_PID, pid, result, WEXITED | WNOWAIT) < 0:
got_errno = waitid_ffi.errno
if got_errno == errno.EINTR:
continue
raise OSError(got_errno, os.strerror(got_errno))
# adapted from
# https://github.com/python-trio/trio/issues/4#issuecomment-398967572
waitid_limiter = CapacityLimiter(math.inf)
async def _waitid_system_task(pid: int, event: Event) -> None:
"""Spawn a thread that waits for ``pid`` to exit, then wake any tasks
that were waiting on it.
"""
# cancellable=True: if this task is cancelled, then we abandon the
# thread to keep running waitpid in the background. Since this is
# always run as a system task, this will only happen if the whole
# call to trio.run is shutting down.
try:
await to_thread_run_sync(
sync_wait_reapable, pid, cancellable=True, limiter=waitid_limiter
)
except OSError:
# If waitid fails, waitpid will fail too, so it still makes
# sense to wake up the callers of wait_process_exiting(). The
# most likely reason for this error in practice is a child
# exiting when wait() is not possible because SIGCHLD is
# ignored.
pass
finally:
event.set()
async def wait_child_exiting(process: "_subprocess.Process") -> None:
# Logic of this function:
# - The first time we get called, we create an Event and start
# an instance of _waitid_system_task that will set the Event
# when waitid() completes. If that Event is set before
# we get cancelled, we're good.
# - Otherwise, a following call after the cancellation must
# reuse the Event created during the first call, lest we
# create an arbitrary number of threads waiting on the same
# process.
if process._wait_for_exit_data is None:
process._wait_for_exit_data = event = Event() # type: ignore
_core.spawn_system_task(_waitid_system_task, process.pid, event)
assert isinstance(process._wait_for_exit_data, Event)
await process._wait_for_exit_data.wait()

View File

@@ -0,0 +1,6 @@
from .. import _subprocess
from .._wait_for_object import WaitForSingleObject
async def wait_child_exiting(process: "_subprocess.Process") -> None:
await WaitForSingleObject(int(process._proc._handle))

View File

@@ -0,0 +1,786 @@
import math
import attr
import trio
from . import _core
from ._core import enable_ki_protection, ParkingLot
from ._util import Final
@attr.s(frozen=True)
class _EventStatistics:
tasks_waiting = attr.ib()
@attr.s(repr=False, eq=False, hash=False, slots=True)
class Event(metaclass=Final):
"""A waitable boolean value useful for inter-task synchronization,
inspired by :class:`threading.Event`.
An event object has an internal boolean flag, representing whether
the event has happened yet. The flag is initially False, and the
:meth:`wait` method waits until the flag is True. If the flag is
already True, then :meth:`wait` returns immediately. (If the event has
already happened, there's nothing to wait for.) The :meth:`set` method
sets the flag to True, and wakes up any waiters.
This behavior is useful because it helps avoid race conditions and
lost wakeups: it doesn't matter whether :meth:`set` gets called just
before or after :meth:`wait`. If you want a lower-level wakeup
primitive that doesn't have this protection, consider :class:`Condition`
or :class:`trio.lowlevel.ParkingLot`.
.. note:: Unlike `threading.Event`, `trio.Event` has no
`~threading.Event.clear` method. In Trio, once an `Event` has happened,
it cannot un-happen. If you need to represent a series of events,
consider creating a new `Event` object for each one (they're cheap!),
or other synchronization methods like :ref:`channels <channels>` or
`trio.lowlevel.ParkingLot`.
"""
_tasks = attr.ib(factory=set, init=False)
_flag = attr.ib(default=False, init=False)
def is_set(self):
"""Return the current value of the internal flag."""
return self._flag
@enable_ki_protection
def set(self):
"""Set the internal flag value to True, and wake any waiting tasks."""
if not self._flag:
self._flag = True
for task in self._tasks:
_core.reschedule(task)
self._tasks.clear()
async def wait(self):
"""Block until the internal flag value becomes True.
If it's already True, then this method returns immediately.
"""
if self._flag:
await trio.lowlevel.checkpoint()
else:
task = _core.current_task()
self._tasks.add(task)
def abort_fn(_):
self._tasks.remove(task)
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort_fn)
def statistics(self):
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this event's
:meth:`wait` method.
"""
return _EventStatistics(tasks_waiting=len(self._tasks))
def async_cm(cls):
@enable_ki_protection
async def __aenter__(self):
await self.acquire()
__aenter__.__qualname__ = cls.__qualname__ + ".__aenter__"
cls.__aenter__ = __aenter__
@enable_ki_protection
async def __aexit__(self, *args):
self.release()
__aexit__.__qualname__ = cls.__qualname__ + ".__aexit__"
cls.__aexit__ = __aexit__
return cls
@attr.s(frozen=True)
class _CapacityLimiterStatistics:
borrowed_tokens = attr.ib()
total_tokens = attr.ib()
borrowers = attr.ib()
tasks_waiting = attr.ib()
@async_cm
class CapacityLimiter(metaclass=Final):
"""An object for controlling access to a resource with limited capacity.
Sometimes you need to put a limit on how many tasks can do something at
the same time. For example, you might want to use some threads to run
multiple blocking I/O operations in parallel... but if you use too many
threads at once, then your system can become overloaded and it'll actually
make things slower. One popular solution is to impose a policy like "run
up to 40 threads at the same time, but no more". But how do you implement
a policy like this?
That's what :class:`CapacityLimiter` is for. You can think of a
:class:`CapacityLimiter` object as a sack that starts out holding some fixed
number of tokens::
limit = trio.CapacityLimiter(40)
Then tasks can come along and borrow a token out of the sack::
# Borrow a token:
async with limit:
# We are holding a token!
await perform_expensive_operation()
# Exiting the 'async with' block puts the token back into the sack
And crucially, if you try to borrow a token but the sack is empty, then
you have to wait for another task to finish what it's doing and put its
token back first before you can take it and continue.
Another way to think of it: a :class:`CapacityLimiter` is like a sofa with a
fixed number of seats, and if they're all taken then you have to wait for
someone to get up before you can sit down.
By default, :func:`trio.to_thread.run_sync` uses a
:class:`CapacityLimiter` to limit the number of threads running at once;
see `trio.to_thread.current_default_thread_limiter` for details.
If you're familiar with semaphores, then you can think of this as a
restricted semaphore that's specialized for one common use case, with
additional error checking. For a more traditional semaphore, see
:class:`Semaphore`.
.. note::
Don't confuse this with the `"leaky bucket"
<https://en.wikipedia.org/wiki/Leaky_bucket>`__ or `"token bucket"
<https://en.wikipedia.org/wiki/Token_bucket>`__ algorithms used to
limit bandwidth usage on networks. The basic idea of using tokens to
track a resource limit is similar, but this is a very simple sack where
tokens aren't automatically created or destroyed over time; they're
just borrowed and then put back.
"""
def __init__(self, total_tokens):
self._lot = ParkingLot()
self._borrowers = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
self._pending_borrowers = {}
# invoke the property setter for validation
self.total_tokens = total_tokens
assert self._total_tokens == total_tokens
def __repr__(self):
return "<trio.CapacityLimiter at {:#x}, {}/{} with {} waiting>".format(
id(self), len(self._borrowers), self._total_tokens, len(self._lot)
)
@property
def total_tokens(self):
"""The total capacity available.
You can change :attr:`total_tokens` by assigning to this attribute. If
you make it larger, then the appropriate number of waiting tasks will
be woken immediately to take the new tokens. If you decrease
total_tokens below the number of tasks that are currently using the
resource, then all current tasks will be allowed to finish as normal,
but no new tasks will be allowed in until the total number of tasks
drops below the new total_tokens.
"""
return self._total_tokens
@total_tokens.setter
def total_tokens(self, new_total_tokens):
if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf:
raise TypeError("total_tokens must be an int or math.inf")
if new_total_tokens < 1:
raise ValueError("total_tokens must be >= 1")
self._total_tokens = new_total_tokens
self._wake_waiters()
def _wake_waiters(self):
available = self._total_tokens - len(self._borrowers)
for woken in self._lot.unpark(count=available):
self._borrowers.add(self._pending_borrowers.pop(woken))
@property
def borrowed_tokens(self):
"""The amount of capacity that's currently in use."""
return len(self._borrowers)
@property
def available_tokens(self):
"""The amount of capacity that's available to use."""
return self.total_tokens - self.borrowed_tokens
@enable_ki_protection
def acquire_nowait(self):
"""Borrow a token from the sack, without blocking.
Raises:
WouldBlock: if no tokens are available.
RuntimeError: if the current task already holds one of this sack's
tokens.
"""
self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task())
@enable_ki_protection
def acquire_on_behalf_of_nowait(self, borrower):
"""Borrow a token from the sack on behalf of ``borrower``, without
blocking.
Args:
borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object
used to record who is borrowing this token. This is used by
:func:`trio.to_thread.run_sync` to allow threads to "hold
tokens", with the intention in the future of using it to `allow
deadlock detection and other useful things
<https://github.com/python-trio/trio/issues/182>`__
Raises:
WouldBlock: if no tokens are available.
RuntimeError: if ``borrower`` already holds one of this sack's
tokens.
"""
if borrower in self._borrowers:
raise RuntimeError(
"this borrower is already holding one of this "
"CapacityLimiter's tokens"
)
if len(self._borrowers) < self._total_tokens and not self._lot:
self._borrowers.add(borrower)
else:
raise trio.WouldBlock
@enable_ki_protection
async def acquire(self):
"""Borrow a token from the sack, blocking if necessary.
Raises:
RuntimeError: if the current task already holds one of this sack's
tokens.
"""
await self.acquire_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
async def acquire_on_behalf_of(self, borrower):
"""Borrow a token from the sack on behalf of ``borrower``, blocking if
necessary.
Args:
borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object
used to record who is borrowing this token; see
:meth:`acquire_on_behalf_of_nowait` for details.
Raises:
RuntimeError: if ``borrower`` task already holds one of this sack's
tokens.
"""
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.acquire_on_behalf_of_nowait(borrower)
except trio.WouldBlock:
task = trio.lowlevel.current_task()
self._pending_borrowers[task] = borrower
try:
await self._lot.park()
except trio.Cancelled:
self._pending_borrowers.pop(task)
raise
else:
await trio.lowlevel.cancel_shielded_checkpoint()
@enable_ki_protection
def release(self):
"""Put a token back into the sack.
Raises:
RuntimeError: if the current task has not acquired one of this
sack's tokens.
"""
self.release_on_behalf_of(trio.lowlevel.current_task())
@enable_ki_protection
def release_on_behalf_of(self, borrower):
"""Put a token back into the sack on behalf of ``borrower``.
Raises:
RuntimeError: if the given borrower has not acquired one of this
sack's tokens.
"""
if borrower not in self._borrowers:
raise RuntimeError(
"this borrower isn't holding any of this CapacityLimiter's tokens"
)
self._borrowers.remove(borrower)
self._wake_waiters()
def statistics(self):
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``borrowed_tokens``: The number of tokens currently borrowed from
the sack.
* ``total_tokens``: The total number of tokens in the sack. Usually
this will be larger than ``borrowed_tokens``, but it's possibly for
it to be smaller if :attr:`total_tokens` was recently decreased.
* ``borrowers``: A list of all tasks or other entities that currently
hold a token.
* ``tasks_waiting``: The number of tasks blocked on this
:class:`CapacityLimiter`\'s :meth:`acquire` or
:meth:`acquire_on_behalf_of` methods.
"""
return _CapacityLimiterStatistics(
borrowed_tokens=len(self._borrowers),
total_tokens=self._total_tokens,
# Use a list instead of a frozenset just in case we start to allow
# one borrower to hold multiple tokens in the future
borrowers=list(self._borrowers),
tasks_waiting=len(self._lot),
)
@async_cm
class Semaphore(metaclass=Final):
"""A `semaphore <https://en.wikipedia.org/wiki/Semaphore_(programming)>`__.
A semaphore holds an integer value, which can be incremented by
calling :meth:`release` and decremented by calling :meth:`acquire` but
the value is never allowed to drop below zero. If the value is zero, then
:meth:`acquire` will block until someone calls :meth:`release`.
If you're looking for a :class:`Semaphore` to limit the number of tasks
that can access some resource simultaneously, then consider using a
:class:`CapacityLimiter` instead.
This object's interface is similar to, but different from, that of
:class:`threading.Semaphore`.
A :class:`Semaphore` object can be used as an async context manager; it
blocks on entry but not on exit.
Args:
initial_value (int): A non-negative integer giving semaphore's initial
value.
max_value (int or None): If given, makes this a "bounded" semaphore that
raises an error if the value is about to exceed the given
``max_value``.
"""
def __init__(self, initial_value, *, max_value=None):
if not isinstance(initial_value, int):
raise TypeError("initial_value must be an int")
if initial_value < 0:
raise ValueError("initial value must be >= 0")
if max_value is not None:
if not isinstance(max_value, int):
raise TypeError("max_value must be None or an int")
if max_value < initial_value:
raise ValueError("max_values must be >= initial_value")
# Invariants:
# bool(self._lot) implies self._value == 0
# (or equivalently: self._value > 0 implies not self._lot)
self._lot = trio.lowlevel.ParkingLot()
self._value = initial_value
self._max_value = max_value
def __repr__(self):
if self._max_value is None:
max_value_str = ""
else:
max_value_str = ", max_value={}".format(self._max_value)
return "<trio.Semaphore({}{}) at {:#x}>".format(
self._value, max_value_str, id(self)
)
@property
def value(self):
"""The current value of the semaphore."""
return self._value
@property
def max_value(self):
"""The maximum allowed value. May be None to indicate no limit."""
return self._max_value
@enable_ki_protection
def acquire_nowait(self):
"""Attempt to decrement the semaphore value, without blocking.
Raises:
WouldBlock: if the value is zero.
"""
if self._value > 0:
assert not self._lot
self._value -= 1
else:
raise trio.WouldBlock
@enable_ki_protection
async def acquire(self):
"""Decrement the semaphore value, blocking if necessary to avoid
letting it drop below zero.
"""
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.acquire_nowait()
except trio.WouldBlock:
await self._lot.park()
else:
await trio.lowlevel.cancel_shielded_checkpoint()
@enable_ki_protection
def release(self):
"""Increment the semaphore value, possibly waking a task blocked in
:meth:`acquire`.
Raises:
ValueError: if incrementing the value would cause it to exceed
:attr:`max_value`.
"""
if self._lot:
assert self._value == 0
self._lot.unpark(count=1)
else:
if self._max_value is not None and self._value == self._max_value:
raise ValueError("semaphore released too many times")
self._value += 1
def statistics(self):
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this semaphore's
:meth:`acquire` method.
"""
return self._lot.statistics()
@attr.s(frozen=True)
class _LockStatistics:
locked = attr.ib()
owner = attr.ib()
tasks_waiting = attr.ib()
@async_cm
@attr.s(eq=False, hash=False, repr=False)
class _LockImpl:
_lot = attr.ib(factory=ParkingLot, init=False)
_owner = attr.ib(default=None, init=False)
def __repr__(self):
if self.locked():
s1 = "locked"
s2 = " with {} waiters".format(len(self._lot))
else:
s1 = "unlocked"
s2 = ""
return "<{} {} object at {:#x}{}>".format(
s1, self.__class__.__name__, id(self), s2
)
def locked(self):
"""Check whether the lock is currently held.
Returns:
bool: True if the lock is held, False otherwise.
"""
return self._owner is not None
@enable_ki_protection
def acquire_nowait(self):
"""Attempt to acquire the lock, without blocking.
Raises:
WouldBlock: if the lock is held.
"""
task = trio.lowlevel.current_task()
if self._owner is task:
raise RuntimeError("attempt to re-acquire an already held Lock")
elif self._owner is None and not self._lot:
# No-one owns it
self._owner = task
else:
raise trio.WouldBlock
@enable_ki_protection
async def acquire(self):
"""Acquire the lock, blocking if necessary."""
await trio.lowlevel.checkpoint_if_cancelled()
try:
self.acquire_nowait()
except trio.WouldBlock:
# NOTE: it's important that the contended acquire path is just
# "_lot.park()", because that's how Condition.wait() acquires the
# lock as well.
await self._lot.park()
else:
await trio.lowlevel.cancel_shielded_checkpoint()
@enable_ki_protection
def release(self):
"""Release the lock.
Raises:
RuntimeError: if the calling task does not hold the lock.
"""
task = trio.lowlevel.current_task()
if task is not self._owner:
raise RuntimeError("can't release a Lock you don't own")
if self._lot:
(self._owner,) = self._lot.unpark(count=1)
else:
self._owner = None
def statistics(self):
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``locked``: boolean indicating whether the lock is held.
* ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock,
or None if the lock is not held.
* ``tasks_waiting``: The number of tasks blocked on this lock's
:meth:`acquire` method.
"""
return _LockStatistics(
locked=self.locked(), owner=self._owner, tasks_waiting=len(self._lot)
)
class Lock(_LockImpl, metaclass=Final):
"""A classic `mutex
<https://en.wikipedia.org/wiki/Lock_(computer_science)>`__.
This is a non-reentrant, single-owner lock. Unlike
:class:`threading.Lock`, only the owner of the lock is allowed to release
it.
A :class:`Lock` object can be used as an async context manager; it
blocks on entry but not on exit.
"""
class StrictFIFOLock(_LockImpl, metaclass=Final):
r"""A variant of :class:`Lock` where tasks are guaranteed to acquire the
lock in strict first-come-first-served order.
An example of when this is useful is if you're implementing something like
:class:`trio.SSLStream` or an HTTP/2 server using `h2
<https://hyper-h2.readthedocs.io/>`__, where you have multiple concurrent
tasks that are interacting with a shared state machine, and at
unpredictable moments the state machine requests that a chunk of data be
sent over the network. (For example, when using h2 simply reading incoming
data can occasionally `create outgoing data to send
<https://http2.github.io/http2-spec/#PING>`__.) The challenge is to make
sure that these chunks are sent in the correct order, without being
garbled.
One option would be to use a regular :class:`Lock`, and wrap it around
every interaction with the state machine::
# This approach is sometimes workable but often sub-optimal; see below
async with lock:
state_machine.do_something()
if state_machine.has_data_to_send():
await conn.sendall(state_machine.get_data_to_send())
But this can be problematic. If you're using h2 then *usually* reading
incoming data doesn't create the need to send any data, so we don't want
to force every task that tries to read from the network to sit and wait
a potentially long time for ``sendall`` to finish. And in some situations
this could even potentially cause a deadlock, if the remote peer is
waiting for you to read some data before it accepts the data you're
sending.
:class:`StrictFIFOLock` provides an alternative. We can rewrite our
example like::
# Note: no awaits between when we start using the state machine and
# when we block to take the lock!
state_machine.do_something()
if state_machine.has_data_to_send():
# Notice that we fetch the data to send out of the state machine
# *before* sleeping, so that other tasks won't see it.
chunk = state_machine.get_data_to_send()
async with strict_fifo_lock:
await conn.sendall(chunk)
First we do all our interaction with the state machine in a single
scheduling quantum (notice there are no ``await``\s in there), so it's
automatically atomic with respect to other tasks. And then if and only if
we have data to send, we get in line to send it and
:class:`StrictFIFOLock` guarantees that each task will send its data in
the same order that the state machine generated it.
Currently, :class:`StrictFIFOLock` is identical to :class:`Lock`,
but (a) this may not always be true in the future, especially if Trio ever
implements `more sophisticated scheduling policies
<https://github.com/python-trio/trio/issues/32>`__, and (b) the above code
is relying on a pretty subtle property of its lock. Using a
:class:`StrictFIFOLock` acts as an executable reminder that you're relying
on this property.
"""
@attr.s(frozen=True)
class _ConditionStatistics:
tasks_waiting = attr.ib()
lock_statistics = attr.ib()
@async_cm
class Condition(metaclass=Final):
"""A classic `condition variable
<https://en.wikipedia.org/wiki/Monitor_(synchronization)>`__, similar to
:class:`threading.Condition`.
A :class:`Condition` object can be used as an async context manager to
acquire the underlying lock; it blocks on entry but not on exit.
Args:
lock (Lock): the lock object to use. If given, must be a
:class:`trio.Lock`. If None, a new :class:`Lock` will be allocated
and used.
"""
def __init__(self, lock=None):
if lock is None:
lock = Lock()
if not type(lock) is Lock:
raise TypeError("lock must be a trio.Lock")
self._lock = lock
self._lot = trio.lowlevel.ParkingLot()
def locked(self):
"""Check whether the underlying lock is currently held.
Returns:
bool: True if the lock is held, False otherwise.
"""
return self._lock.locked()
def acquire_nowait(self):
"""Attempt to acquire the underlying lock, without blocking.
Raises:
WouldBlock: if the lock is currently held.
"""
return self._lock.acquire_nowait()
async def acquire(self):
"""Acquire the underlying lock, blocking if necessary."""
await self._lock.acquire()
def release(self):
"""Release the underlying lock."""
self._lock.release()
@enable_ki_protection
async def wait(self):
"""Wait for another task to call :meth:`notify` or
:meth:`notify_all`.
When calling this method, you must hold the lock. It releases the lock
while waiting, and then re-acquires it before waking up.
There is a subtlety with how this method interacts with cancellation:
when cancelled it will block to re-acquire the lock before raising
:exc:`Cancelled`. This may cause cancellation to be less prompt than
expected. The advantage is that it makes code like this work::
async with condition:
await condition.wait()
If we didn't re-acquire the lock before waking up, and :meth:`wait`
were cancelled here, then we'd crash in ``condition.__aexit__`` when
we tried to release the lock we no longer held.
Raises:
RuntimeError: if the calling task does not hold the lock.
"""
if trio.lowlevel.current_task() is not self._lock._owner:
raise RuntimeError("must hold the lock to wait")
self.release()
# NOTE: we go to sleep on self._lot, but we'll wake up on
# self._lock._lot. That's all that's required to acquire a Lock.
try:
await self._lot.park()
except:
with trio.CancelScope(shield=True):
await self.acquire()
raise
def notify(self, n=1):
"""Wake one or more tasks that are blocked in :meth:`wait`.
Args:
n (int): The number of tasks to wake.
Raises:
RuntimeError: if the calling task does not hold the lock.
"""
if trio.lowlevel.current_task() is not self._lock._owner:
raise RuntimeError("must hold the lock to notify")
self._lot.repark(self._lock._lot, count=n)
def notify_all(self):
"""Wake all tasks that are currently blocked in :meth:`wait`.
Raises:
RuntimeError: if the calling task does not hold the lock.
"""
if trio.lowlevel.current_task() is not self._lock._owner:
raise RuntimeError("must hold the lock to notify")
self._lot.repark_all(self._lock._lot)
def statistics(self):
r"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this condition's
:meth:`wait` method.
* ``lock_statistics``: The result of calling the underlying
:class:`Lock`\s :meth:`~Lock.statistics` method.
"""
return _ConditionStatistics(
tasks_waiting=len(self._lot), lock_statistics=self._lock.statistics()
)

View File

@@ -0,0 +1,372 @@
# coding: utf-8
import contextvars
import threading
import queue as stdlib_queue
import functools
from itertools import count
import attr
import inspect
import outcome
from sniffio import current_async_library_cvar
import trio
from ._sync import CapacityLimiter
from ._core import (
enable_ki_protection,
disable_ki_protection,
RunVar,
TrioToken,
start_thread_soon,
)
from ._util import coroutine_or_error
# Global due to Threading API, thread local storage for trio token
TOKEN_LOCAL = threading.local()
_limiter_local = RunVar("limiter")
# I pulled this number out of the air; it isn't based on anything. Probably we
# should make some kind of measurements to pick a good value.
DEFAULT_LIMIT = 40
_thread_counter = count()
def current_default_thread_limiter():
"""Get the default `~trio.CapacityLimiter` used by
`trio.to_thread.run_sync`.
The most common reason to call this would be if you want to modify its
:attr:`~trio.CapacityLimiter.total_tokens` attribute.
"""
try:
limiter = _limiter_local.get()
except LookupError:
limiter = CapacityLimiter(DEFAULT_LIMIT)
_limiter_local.set(limiter)
return limiter
# Eventually we might build this into a full-fledged deadlock-detection
# system; see https://github.com/python-trio/trio/issues/182
# But for now we just need an object to stand in for the thread, so we can
# keep track of who's holding the CapacityLimiter's token.
@attr.s(frozen=True, eq=False, hash=False)
class ThreadPlaceholder:
name = attr.ib()
@enable_ki_protection
async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None):
"""Convert a blocking operation into an async operation using a thread.
These two lines are equivalent::
sync_fn(*args)
await trio.to_thread.run_sync(sync_fn, *args)
except that if ``sync_fn`` takes a long time, then the first line will
block the Trio loop while it runs, while the second line allows other Trio
tasks to continue working while ``sync_fn`` runs. This is accomplished by
pushing the call to ``sync_fn(*args)`` off into a worker thread.
From inside the worker thread, you can get back into Trio using the
functions in `trio.from_thread`.
Args:
sync_fn: An arbitrary synchronous callable.
*args: Positional arguments to pass to sync_fn. If you need keyword
arguments, use :func:`functools.partial`.
cancellable (bool): Whether to allow cancellation of this operation. See
discussion below.
limiter (None, or CapacityLimiter-like object):
An object used to limit the number of simultaneous threads. Most
commonly this will be a `~trio.CapacityLimiter`, but it could be
anything providing compatible
:meth:`~trio.CapacityLimiter.acquire_on_behalf_of` and
:meth:`~trio.CapacityLimiter.release_on_behalf_of` methods. This
function will call ``acquire_on_behalf_of`` before starting the
thread, and ``release_on_behalf_of`` after the thread has finished.
If None (the default), uses the default `~trio.CapacityLimiter`, as
returned by :func:`current_default_thread_limiter`.
**Cancellation handling**: Cancellation is a tricky issue here, because
neither Python nor the operating systems it runs on provide any general
mechanism for cancelling an arbitrary synchronous function running in a
thread. This function will always check for cancellation on entry, before
starting the thread. But once the thread is running, there are two ways it
can handle being cancelled:
* If ``cancellable=False``, the function ignores the cancellation and
keeps going, just like if we had called ``sync_fn`` synchronously. This
is the default behavior.
* If ``cancellable=True``, then this function immediately raises
`~trio.Cancelled`. In this case **the thread keeps running in
background** we just abandon it to do whatever it's going to do, and
silently discard any return value or errors that it raises. Only use
this if you know that the operation is safe and side-effect free. (For
example: :func:`trio.socket.getaddrinfo` uses a thread with
``cancellable=True``, because it doesn't really affect anything if a
stray hostname lookup keeps running in the background.)
The ``limiter`` is only released after the thread has *actually*
finished which in the case of cancellation may be some time after this
function has returned. If :func:`trio.run` finishes before the thread
does, then the limiter release method will never be called at all.
.. warning::
You should not use this function to call long-running CPU-bound
functions! In addition to the usual GIL-related reasons why using
threads for CPU-bound work is not very effective in Python, there is an
additional problem: on CPython, `CPU-bound threads tend to "starve out"
IO-bound threads <https://bugs.python.org/issue7946>`__, so using
threads for CPU-bound work is likely to adversely affect the main
thread running Trio. If you need to do this, you're better off using a
worker process, or perhaps PyPy (which still has a GIL, but may do a
better job of fairly allocating CPU time between threads).
Returns:
Whatever ``sync_fn(*args)`` returns.
Raises:
Exception: Whatever ``sync_fn(*args)`` raises.
"""
await trio.lowlevel.checkpoint_if_cancelled()
if limiter is None:
limiter = current_default_thread_limiter()
# Holds a reference to the task that's blocked in this function waiting
# for the result or None if this function was cancelled and we should
# discard the result.
task_register = [trio.lowlevel.current_task()]
name = f"trio.to_thread.run_sync-{next(_thread_counter)}"
placeholder = ThreadPlaceholder(name)
# This function gets scheduled into the Trio run loop to deliver the
# thread's result.
def report_back_in_trio_thread_fn(result):
def do_release_then_return_result():
# release_on_behalf_of is an arbitrary user-defined method, so it
# might raise an error. If it does, we want that error to
# replace the regular return value, and if the regular return was
# already an exception then we want them to chain.
try:
return result.unwrap()
finally:
limiter.release_on_behalf_of(placeholder)
result = outcome.capture(do_release_then_return_result)
if task_register[0] is not None:
trio.lowlevel.reschedule(task_register[0], result)
current_trio_token = trio.lowlevel.current_trio_token()
def worker_fn():
current_async_library_cvar.set(None)
TOKEN_LOCAL.token = current_trio_token
try:
ret = sync_fn(*args)
if inspect.iscoroutine(ret):
# Manually close coroutine to avoid RuntimeWarnings
ret.close()
raise TypeError(
"Trio expected a sync function, but {!r} appears to be "
"asynchronous".format(getattr(sync_fn, "__qualname__", sync_fn))
)
return ret
finally:
del TOKEN_LOCAL.token
context = contextvars.copy_context()
contextvars_aware_worker_fn = functools.partial(context.run, worker_fn)
def deliver_worker_fn_result(result):
try:
current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result)
except trio.RunFinishedError:
# The entire run finished, so the task we're trying to contact is
# certainly long gone -- it must have been cancelled and abandoned
# us.
pass
await limiter.acquire_on_behalf_of(placeholder)
try:
start_thread_soon(contextvars_aware_worker_fn, deliver_worker_fn_result)
except:
limiter.release_on_behalf_of(placeholder)
raise
def abort(_):
if cancellable:
task_register[0] = None
return trio.lowlevel.Abort.SUCCEEDED
else:
return trio.lowlevel.Abort.FAILED
return await trio.lowlevel.wait_task_rescheduled(abort)
def _run_fn_as_system_task(cb, fn, *args, context, trio_token=None):
"""Helper function for from_thread.run and from_thread.run_sync.
Since this internally uses TrioToken.run_sync_soon, all warnings about
raised exceptions canceling all tasks should be noted.
"""
if trio_token and not isinstance(trio_token, TrioToken):
raise RuntimeError("Passed kwarg trio_token is not of type TrioToken")
if not trio_token:
try:
trio_token = TOKEN_LOCAL.token
except AttributeError:
raise RuntimeError(
"this thread wasn't created by Trio, pass kwarg trio_token=..."
)
# Avoid deadlock by making sure we're not called from Trio thread
try:
trio.lowlevel.current_task()
except RuntimeError:
pass
else:
raise RuntimeError("this is a blocking function; call it from a thread")
q = stdlib_queue.Queue()
trio_token.run_sync_soon(context.run, cb, q, fn, args)
return q.get().unwrap()
def from_thread_run(afn, *args, trio_token=None):
"""Run the given async function in the parent Trio thread, blocking until it
is complete.
Returns:
Whatever ``afn(*args)`` returns.
Returns or raises whatever the given function returns or raises. It
can also raise exceptions of its own:
Raises:
RunFinishedError: if the corresponding call to :func:`trio.run` has
already completed, or if the run has started its final cleanup phase
and can no longer spawn new system tasks.
Cancelled: if the corresponding call to :func:`trio.run` completes
while ``afn(*args)`` is running, then ``afn`` is likely to raise
:exc:`trio.Cancelled`, and this will propagate out into
RuntimeError: if you try calling this from inside the Trio thread,
which would otherwise cause a deadlock.
AttributeError: if no ``trio_token`` was provided, and we can't infer
one from context.
TypeError: if ``afn`` is not an asynchronous function.
**Locating a Trio Token**: There are two ways to specify which
`trio.run` loop to reenter:
- Spawn this thread from `trio.to_thread.run_sync`. Trio will
automatically capture the relevant Trio token and use it when you
want to re-enter Trio.
- Pass a keyword argument, ``trio_token`` specifying a specific
`trio.run` loop to re-enter. This is useful in case you have a
"foreign" thread, spawned using some other framework, and still want
to enter Trio.
"""
def callback(q, afn, args):
@disable_ki_protection
async def unprotected_afn():
coro = coroutine_or_error(afn, *args)
return await coro
async def await_in_trio_thread_task():
q.put_nowait(await outcome.acapture(unprotected_afn))
context = contextvars.copy_context()
try:
trio.lowlevel.spawn_system_task(
await_in_trio_thread_task, name=afn, context=context
)
except RuntimeError: # system nursery is closed
q.put_nowait(
outcome.Error(trio.RunFinishedError("system nursery is closed"))
)
context = contextvars.copy_context()
context.run(current_async_library_cvar.set, "trio")
return _run_fn_as_system_task(
callback,
afn,
*args,
context=context,
trio_token=trio_token,
)
def from_thread_run_sync(fn, *args, trio_token=None):
"""Run the given sync function in the parent Trio thread, blocking until it
is complete.
Returns:
Whatever ``fn(*args)`` returns.
Returns or raises whatever the given function returns or raises. It
can also raise exceptions of its own:
Raises:
RunFinishedError: if the corresponding call to `trio.run` has
already completed.
RuntimeError: if you try calling this from inside the Trio thread,
which would otherwise cause a deadlock.
AttributeError: if no ``trio_token`` was provided, and we can't infer
one from context.
TypeError: if ``fn`` is an async function.
**Locating a Trio Token**: There are two ways to specify which
`trio.run` loop to reenter:
- Spawn this thread from `trio.to_thread.run_sync`. Trio will
automatically capture the relevant Trio token and use it when you
want to re-enter Trio.
- Pass a keyword argument, ``trio_token`` specifying a specific
`trio.run` loop to re-enter. This is useful in case you have a
"foreign" thread, spawned using some other framework, and still want
to enter Trio.
"""
def callback(q, fn, args):
current_async_library_cvar.set("trio")
@disable_ki_protection
def unprotected_fn():
ret = fn(*args)
if inspect.iscoroutine(ret):
# Manually close coroutine to avoid RuntimeWarnings
ret.close()
raise TypeError(
"Trio expected a sync function, but {!r} appears to be "
"asynchronous".format(getattr(fn, "__qualname__", fn))
)
return ret
res = outcome.capture(unprotected_fn)
q.put_nowait(res)
context = contextvars.copy_context()
return _run_fn_as_system_task(
callback,
fn,
*args,
context=context,
trio_token=trio_token,
)

View File

@@ -0,0 +1,130 @@
from contextlib import contextmanager
import trio
def move_on_at(deadline):
"""Use as a context manager to create a cancel scope with the given
absolute deadline.
Args:
deadline (float): The deadline.
"""
return trio.CancelScope(deadline=deadline)
def move_on_after(seconds):
"""Use as a context manager to create a cancel scope whose deadline is
set to now + *seconds*.
Args:
seconds (float): The timeout.
Raises:
ValueError: if timeout is less than zero.
"""
if seconds < 0:
raise ValueError("timeout must be non-negative")
return move_on_at(trio.current_time() + seconds)
async def sleep_forever():
"""Pause execution of the current task forever (or until cancelled).
Equivalent to calling ``await sleep(math.inf)``.
"""
await trio.lowlevel.wait_task_rescheduled(lambda _: trio.lowlevel.Abort.SUCCEEDED)
async def sleep_until(deadline):
"""Pause execution of the current task until the given time.
The difference between :func:`sleep` and :func:`sleep_until` is that the
former takes a relative time and the latter takes an absolute time
according to Trio's internal clock (as returned by :func:`current_time`).
Args:
deadline (float): The time at which we should wake up again. May be in
the past, in which case this function executes a checkpoint but
does not block.
"""
with move_on_at(deadline):
await sleep_forever()
async def sleep(seconds):
"""Pause execution of the current task for the given number of seconds.
Args:
seconds (float): The number of seconds to sleep. May be zero to
insert a checkpoint without actually blocking.
Raises:
ValueError: if *seconds* is negative.
"""
if seconds < 0:
raise ValueError("duration must be non-negative")
if seconds == 0:
await trio.lowlevel.checkpoint()
else:
await sleep_until(trio.current_time() + seconds)
class TooSlowError(Exception):
"""Raised by :func:`fail_after` and :func:`fail_at` if the timeout
expires.
"""
@contextmanager
def fail_at(deadline):
"""Creates a cancel scope with the given deadline, and raises an error if it
is actually cancelled.
This function and :func:`move_on_at` are similar in that both create a
cancel scope with a given absolute deadline, and if the deadline expires
then both will cause :exc:`Cancelled` to be raised within the scope. The
difference is that when the :exc:`Cancelled` exception reaches
:func:`move_on_at`, it's caught and discarded. When it reaches
:func:`fail_at`, then it's caught and :exc:`TooSlowError` is raised in its
place.
Raises:
TooSlowError: if a :exc:`Cancelled` exception is raised in this scope
and caught by the context manager.
"""
with move_on_at(deadline) as scope:
yield scope
if scope.cancelled_caught:
raise TooSlowError
def fail_after(seconds):
"""Creates a cancel scope with the given timeout, and raises an error if
it is actually cancelled.
This function and :func:`move_on_after` are similar in that both create a
cancel scope with a given timeout, and if the timeout expires then both
will cause :exc:`Cancelled` to be raised within the scope. The difference
is that when the :exc:`Cancelled` exception reaches :func:`move_on_after`,
it's caught and discarded. When it reaches :func:`fail_after`, then it's
caught and :exc:`TooSlowError` is raised in its place.
Raises:
TooSlowError: if a :exc:`Cancelled` exception is raised in this scope
and caught by the context manager.
ValueError: if *seconds* is less than zero.
"""
if seconds < 0:
raise ValueError("timeout must be non-negative")
return fail_at(trio.current_time() + seconds)

View File

@@ -0,0 +1,192 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -`-
"""
Code generation script for class methods
to be exported as public API
"""
import argparse
import ast
import astor
import os
from pathlib import Path
import sys
from textwrap import indent
PREFIX = "_generated"
HEADER = """# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
"""
FOOTER = """# fmt: on
"""
TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return{}GLOBAL_RUN_CONTEXT.{}.{}
except AttributeError:
raise RuntimeError("must be called from async context")
"""
def is_function(node):
"""Check if the AST node is either a function
or an async function
"""
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
return True
return False
def is_public(node):
"""Check if the AST node has a _public decorator"""
if not is_function(node):
return False
for decorator in node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == "_public":
return True
return False
def get_public_methods(tree):
"""Return a list of methods marked as public.
The function walks the given tree and extracts
all objects that are functions which are marked
public.
"""
for node in ast.walk(tree):
if is_public(node):
yield node
def create_passthrough_args(funcdef):
"""Given a function definition, create a string that represents taking all
the arguments from the function, and passing them through to another
invocation of the same function.
Example input: ast.parse("def f(a, *, b): ...")
Example output: "(a, b=b)"
"""
call_args = []
for arg in funcdef.args.args:
call_args.append(arg.arg)
if funcdef.args.vararg:
call_args.append("*" + funcdef.args.vararg.arg)
for arg in funcdef.args.kwonlyargs:
call_args.append(arg.arg + "=" + arg.arg)
if funcdef.args.kwarg:
call_args.append("**" + funcdef.args.kwarg.arg)
return "({})".format(", ".join(call_args))
def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
"""Scan the given .py file for @_public decorators, and generate wrapper
functions.
"""
generated = [HEADER]
source = astor.code_to_ast.parse_file(source_path)
for method in get_public_methods(source):
# Remove self from arguments
assert method.args.args[0].arg == "self"
del method.args.args[0]
# Remove decorators
method.decorator_list = []
# Create pass through arguments
new_args = create_passthrough_args(method)
# Remove method body without the docstring
if ast.get_docstring(method) is None:
del method.body[:]
else:
# The first entry is always the docstring
del method.body[1:]
# Create the function definition including the body
func = astor.to_source(method, indent_with=" " * 4)
# Create export function body
template = TEMPLATE.format(
" await " if isinstance(method, ast.AsyncFunctionDef) else " ",
lookup_path,
method.name + new_args,
)
# Assemble function definition arguments and body
snippet = func + indent(template, " " * 4)
# Append the snippet to the corresponding module
generated.append(snippet)
generated.append(FOOTER)
return "\n\n".join(generated)
def matches_disk_files(new_files):
for new_path, new_source in new_files.items():
if not os.path.exists(new_path):
return False
with open(new_path, "r", encoding="utf-8") as old_file:
old_source = old_file.read()
if old_source != new_source:
return False
return True
def process(sources_and_lookups, *, do_test):
new_files = {}
for source_path, lookup_path in sources_and_lookups:
print("Scanning:", source_path)
new_source = gen_public_wrappers_source(source_path, lookup_path)
dirname, basename = os.path.split(source_path)
new_path = os.path.join(dirname, PREFIX + basename)
new_files[new_path] = new_source
if do_test:
if not matches_disk_files(new_files):
print("Generated sources are outdated. Please regenerate.")
sys.exit(1)
else:
print("Generated sources are up to date.")
else:
for new_path, new_source in new_files.items():
with open(new_path, "w", encoding="utf-8") as f:
f.write(new_source)
print("Regenerated sources successfully.")
# This is in fact run in CI, but only in the formatting check job, which
# doesn't collect coverage.
def main(): # pragma: no cover
parser = argparse.ArgumentParser(
description="Generate python code for public api wrappers"
)
parser.add_argument(
"--test", "-t", action="store_true", help="test if code is still up to date"
)
parsed_args = parser.parse_args()
source_root = Path.cwd()
# Double-check we found the right directory
assert (source_root / "LICENSE").exists()
core = source_root / "trio/_core"
to_wrap = [
(core / "_run.py", "runner"),
(core / "_instrumentation.py", "runner.instruments"),
(core / "_io_windows.py", "runner.io_manager"),
(core / "_io_epoll.py", "runner.io_manager"),
(core / "_io_kqueue.py", "runner.io_manager"),
]
process(to_wrap, do_test=parsed_args.test)
if __name__ == "__main__": # pragma: no cover
main()

View File

@@ -0,0 +1,190 @@
import os
import errno
from ._abc import Stream
from ._util import ConflictDetector, Final
import trio
if os.name != "posix":
# We raise an error here rather than gating the import in lowlevel.py
# in order to keep jedi static analysis happy.
raise ImportError
# XX TODO: is this a good number? who knows... it does match the default Linux
# pipe capacity though.
DEFAULT_RECEIVE_SIZE = 65536
class _FdHolder:
# This class holds onto a raw file descriptor, in non-blocking mode, and
# is responsible for managing its lifecycle. In particular, it's
# responsible for making sure it gets closed, and also for tracking
# whether it's been closed.
#
# The way we track closure is to set the .fd field to -1, discarding the
# original value. You might think that this is a strange idea, since it
# overloads the same field to do two different things. Wouldn't it be more
# natural to have a dedicated .closed field? But that would be more
# error-prone. Fds are represented by small integers, and once an fd is
# closed, its integer value may be reused immediately. If we accidentally
# used the old fd after being closed, we might end up doing something to
# another unrelated fd that happened to get assigned the same integer
# value. By throwing away the integer value immediately, it becomes
# impossible to make this mistake we'll just get an EBADF.
#
# (This trick was copied from the stdlib socket module.)
def __init__(self, fd: int):
# make sure self.fd is always initialized to *something*, because even
# if we error out here then __del__ will run and access it.
self.fd = -1
if not isinstance(fd, int):
raise TypeError("file descriptor must be an int")
self.fd = fd
# Store original state, and ensure non-blocking mode is enabled
self._original_is_blocking = os.get_blocking(fd)
os.set_blocking(fd, False)
@property
def closed(self):
return self.fd == -1
def _raw_close(self):
# This doesn't assume it's in a Trio context, so it can be called from
# __del__. You should never call it from Trio context, because it
# skips calling notify_fd_close. But from __del__, skipping that is
# OK, because notify_fd_close just wakes up other tasks that are
# waiting on this fd, and those tasks hold a reference to this object.
# So if __del__ is being called, we know there aren't any tasks that
# need to be woken.
if self.closed:
return
fd = self.fd
self.fd = -1
os.set_blocking(fd, self._original_is_blocking)
os.close(fd)
def __del__(self):
self._raw_close()
def close(self):
if not self.closed:
trio.lowlevel.notify_closing(self.fd)
self._raw_close()
class FdStream(Stream, metaclass=Final):
"""
Represents a stream given the file descriptor to a pipe, TTY, etc.
*fd* must refer to a file that is open for reading and/or writing and
supports non-blocking I/O (pipes and TTYs will work, on-disk files probably
not). The returned stream takes ownership of the fd, so closing the stream
will close the fd too. As with `os.fdopen`, you should not directly use
an fd after you have wrapped it in a stream using this function.
To be used as a Trio stream, an open file must be placed in non-blocking
mode. Unfortunately, this impacts all I/O that goes through the
underlying open file, including I/O that uses a different
file descriptor than the one that was passed to Trio. If other threads
or processes are using file descriptors that are related through `os.dup`
or inheritance across `os.fork` to the one that Trio is using, they are
unlikely to be prepared to have non-blocking I/O semantics suddenly
thrust upon them. For example, you can use
``FdStream(os.dup(sys.stdin.fileno()))`` to obtain a stream for reading
from standard input, but it is only safe to do so with heavy caveats: your
stdin must not be shared by any other processes and you must not make any
calls to synchronous methods of `sys.stdin` until the stream returned by
`FdStream` is closed. See `issue #174
<https://github.com/python-trio/trio/issues/174>`__ for a discussion of the
challenges involved in relaxing this restriction.
Args:
fd (int): The fd to be wrapped.
Returns:
A new `FdStream` object.
"""
def __init__(self, fd: int):
self._fd_holder = _FdHolder(fd)
self._send_conflict_detector = ConflictDetector(
"another task is using this stream for send"
)
self._receive_conflict_detector = ConflictDetector(
"another task is using this stream for receive"
)
async def send_all(self, data: bytes):
with self._send_conflict_detector:
# have to check up front, because send_all(b"") on a closed pipe
# should raise
if self._fd_holder.closed:
raise trio.ClosedResourceError("file was already closed")
await trio.lowlevel.checkpoint()
length = len(data)
# adapted from the SocketStream code
with memoryview(data) as view:
sent = 0
while sent < length:
with view[sent:] as remaining:
try:
sent += os.write(self._fd_holder.fd, remaining)
except BlockingIOError:
await trio.lowlevel.wait_writable(self._fd_holder.fd)
except OSError as e:
if e.errno == errno.EBADF:
raise trio.ClosedResourceError(
"file was already closed"
) from None
else:
raise trio.BrokenResourceError from e
async def wait_send_all_might_not_block(self) -> None:
with self._send_conflict_detector:
if self._fd_holder.closed:
raise trio.ClosedResourceError("file was already closed")
try:
await trio.lowlevel.wait_writable(self._fd_holder.fd)
except BrokenPipeError as e:
# kqueue: raises EPIPE on wait_writable instead
# of sending, which is annoying
raise trio.BrokenResourceError from e
async def receive_some(self, max_bytes=None) -> bytes:
with self._receive_conflict_detector:
if max_bytes is None:
max_bytes = DEFAULT_RECEIVE_SIZE
else:
if not isinstance(max_bytes, int):
raise TypeError("max_bytes must be integer >= 1")
if max_bytes < 1:
raise ValueError("max_bytes must be integer >= 1")
await trio.lowlevel.checkpoint()
while True:
try:
data = os.read(self._fd_holder.fd, max_bytes)
except BlockingIOError:
await trio.lowlevel.wait_readable(self._fd_holder.fd)
except OSError as e:
if e.errno == errno.EBADF:
raise trio.ClosedResourceError(
"file was already closed"
) from None
else:
raise trio.BrokenResourceError from e
else:
break
return data
def close(self):
self._fd_holder.close()
async def aclose(self):
self.close()
await trio.lowlevel.checkpoint()
def fileno(self):
return self._fd_holder.fd

View File

@@ -0,0 +1,341 @@
# coding: utf-8
# Little utilities we use internally
from abc import ABCMeta
import os
import signal
import sys
import pathlib
from functools import wraps, update_wrapper
import typing as t
import threading
import collections
from async_generator import isasyncgen
import trio
# Equivalent to the C function raise(), which Python doesn't wrap
if os.name == "nt":
# On windows, os.kill exists but is really weird.
#
# If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver
# those using GenerateConsoleCtrlEvent. But I found that when I tried
# to run my test normally, it would freeze waiting... unless I added
# print statements, in which case the test suddenly worked. So I guess
# these signals are only delivered if/when you access the console? I
# don't really know what was going on there. From reading the
# GenerateConsoleCtrlEvent docs I don't know how it worked at all.
#
# I later spent a bunch of time trying to make GenerateConsoleCtrlEvent
# work for creating synthetic control-C events, and... failed
# utterly. There are lots of details in the code and comments
# removed/added at this commit:
# https://github.com/python-trio/trio/commit/95843654173e3e826c34d70a90b369ba6edf2c23
#
# OTOH, if you pass os.kill any *other* signal number... then CPython
# just calls TerminateProcess (wtf).
#
# So, anyway, os.kill is not so useful for testing purposes. Instead
# we use raise():
#
# https://msdn.microsoft.com/en-us/library/dwwzkt4c.aspx
#
# Have to import cffi inside the 'if os.name' block because we don't
# depend on cffi on non-Windows platforms. (It would be easy to switch
# this to ctypes though if we ever remove the cffi dependency.)
#
# Some more information:
# https://bugs.python.org/issue26350
#
# Anyway, we use this for two things:
# - redelivering unhandled signals
# - generating synthetic signals for tests
# and for both of those purposes, 'raise' works fine.
import cffi
_ffi = cffi.FFI()
_ffi.cdef("int raise(int);")
_lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll")
signal_raise = getattr(_lib, "raise")
else:
def signal_raise(signum):
signal.pthread_kill(threading.get_ident(), signum)
# See: #461 as to why this is needed.
# The gist is that threading.main_thread() has the capability to lie to us
# if somebody else edits the threading ident cache to replace the main
# thread; causing threading.current_thread() to return a _DummyThread,
# causing the C-c check to fail, and so on.
# Trying to use signal out of the main thread will fail, so we can then
# reliably check if this is the main thread without relying on a
# potentially modified threading.
def is_main_thread():
"""Attempt to reliably check if we are in the main thread."""
try:
signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT))
return True
except ValueError:
return False
######
# Call the function and get the coroutine object, while giving helpful
# errors for common mistakes. Returns coroutine object.
######
def coroutine_or_error(async_fn, *args):
def _return_value_looks_like_wrong_library(value):
# Returned by legacy @asyncio.coroutine functions, which includes
# a surprising proportion of asyncio builtins.
if isinstance(value, collections.abc.Generator):
return True
# The protocol for detecting an asyncio Future-like object
if getattr(value, "_asyncio_future_blocking", None) is not None:
return True
# This janky check catches tornado Futures and twisted Deferreds.
# By the time we're calling this function, we already know
# something has gone wrong, so a heuristic is pretty safe.
if value.__class__.__name__ in ("Future", "Deferred"):
return True
return False
try:
coro = async_fn(*args)
except TypeError:
# Give good error for: nursery.start_soon(trio.sleep(1))
if isinstance(async_fn, collections.abc.Coroutine):
# explicitly close coroutine to avoid RuntimeWarning
async_fn.close()
raise TypeError(
"Trio was expecting an async function, but instead it got "
"a coroutine object {async_fn!r}\n"
"\n"
"Probably you did something like:\n"
"\n"
" trio.run({async_fn.__name__}(...)) # incorrect!\n"
" nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n"
"\n"
"Instead, you want (notice the parentheses!):\n"
"\n"
" trio.run({async_fn.__name__}, ...) # correct!\n"
" nursery.start_soon({async_fn.__name__}, ...) # correct!".format(
async_fn=async_fn
)
) from None
# Give good error for: nursery.start_soon(future)
if _return_value_looks_like_wrong_library(async_fn):
raise TypeError(
"Trio was expecting an async function, but instead it got "
"{!r} are you trying to use a library written for "
"asyncio/twisted/tornado or similar? That won't work "
"without some sort of compatibility shim.".format(async_fn)
) from None
raise
# We can't check iscoroutinefunction(async_fn), because that will fail
# for things like functools.partial objects wrapping an async
# function. So we have to just call it and then check whether the
# return value is a coroutine object.
if not isinstance(coro, collections.abc.Coroutine):
# Give good error for: nursery.start_soon(func_returning_future)
if _return_value_looks_like_wrong_library(coro):
raise TypeError(
"Trio got unexpected {!r} are you trying to use a "
"library written for asyncio/twisted/tornado or similar? "
"That won't work without some sort of compatibility shim.".format(coro)
)
if isasyncgen(coro):
raise TypeError(
"start_soon expected an async function but got an async "
"generator {!r}".format(coro)
)
# Give good error for: nursery.start_soon(some_sync_fn)
raise TypeError(
"Trio expected an async function, but {!r} appears to be "
"synchronous".format(getattr(async_fn, "__qualname__", async_fn))
)
return coro
class ConflictDetector:
"""Detect when two tasks are about to perform operations that would
conflict.
Use as a synchronous context manager; if two tasks enter it at the same
time then the second one raises an error. You can use it when there are
two pieces of code that *would* collide and need a lock if they ever were
called at the same time, but that should never happen.
We use this in particular for things like, making sure that two different
tasks don't call sendall simultaneously on the same stream.
"""
def __init__(self, msg):
self._msg = msg
self._held = False
def __enter__(self):
if self._held:
raise trio.BusyResourceError(self._msg)
else:
self._held = True
def __exit__(self, *args):
self._held = False
def async_wraps(cls, wrapped_cls, attr_name):
"""Similar to wraps, but for async wrappers of non-async functions."""
def decorator(func):
func.__name__ = attr_name
func.__qualname__ = ".".join((cls.__qualname__, attr_name))
func.__doc__ = """Like :meth:`~{}.{}.{}`, but async.
""".format(
wrapped_cls.__module__, wrapped_cls.__qualname__, attr_name
)
return func
return decorator
def fixup_module_metadata(module_name, namespace):
seen_ids = set()
def fix_one(qualname, name, obj):
# avoid infinite recursion (relevant when using
# typing.Generic, for example)
if id(obj) in seen_ids:
return
seen_ids.add(id(obj))
mod = getattr(obj, "__module__", None)
if mod is not None and mod.startswith("trio."):
obj.__module__ = module_name
# Modules, unlike everything else in Python, put fully-qualitied
# names into their __name__ attribute. We check for "." to avoid
# rewriting these.
if hasattr(obj, "__name__") and "." not in obj.__name__:
obj.__name__ = name
obj.__qualname__ = qualname
if isinstance(obj, type):
for attr_name, attr_value in obj.__dict__.items():
fix_one(objname + "." + attr_name, attr_name, attr_value)
for objname, obj in namespace.items():
if not objname.startswith("_"): # ignore private attributes
fix_one(objname, objname, obj)
class generic_function:
"""Decorator that makes a function indexable, to communicate
non-inferrable generic type parameters to a static type checker.
If you write::
@generic_function
def open_memory_channel(max_buffer_size: int) -> Tuple[
SendChannel[T], ReceiveChannel[T]
]: ...
it is valid at runtime to say ``open_memory_channel[bytes](5)``.
This behaves identically to ``open_memory_channel(5)`` at runtime,
and currently won't type-check without a mypy plugin or clever stubs,
but at least it becomes possible to write those.
"""
def __init__(self, fn):
update_wrapper(self, fn)
self._fn = fn
def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)
def __getitem__(self, _):
return self
class Final(ABCMeta):
"""Metaclass that enforces a class to be final (i.e., subclass not allowed).
If a class uses this metaclass like this::
class SomeClass(metaclass=Final):
pass
The metaclass will ensure that no sub class can be created.
Raises
------
- TypeError if a sub class is created
"""
def __new__(cls, name, bases, cls_namespace):
for base in bases:
if isinstance(base, Final):
raise TypeError(
f"{base.__module__}.{base.__qualname__} does not support subclassing"
)
return super().__new__(cls, name, bases, cls_namespace)
T = t.TypeVar("T")
class NoPublicConstructor(Final):
"""Metaclass that enforces a class to be final (i.e., subclass not allowed)
and ensures a private constructor.
If a class uses this metaclass like this::
class SomeClass(metaclass=NoPublicConstructor):
pass
The metaclass will ensure that no sub class can be created, and that no instance
can be initialized.
If you try to instantiate your class (SomeClass()), a TypeError will be thrown.
Raises
------
- TypeError if a sub class or an instance is created.
"""
def __call__(cls, *args, **kwargs):
raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
)
def _create(cls: t.Type[T], *args: t.Any, **kwargs: t.Any) -> T:
return super().__call__(*args, **kwargs) # type: ignore
def name_asyncgen(agen):
"""Return the fully-qualified name of the async generator function
that produced the async generator iterator *agen*.
"""
if not hasattr(agen, "ag_code"): # pragma: no cover
return repr(agen)
try:
module = agen.ag_frame.f_globals["__name__"]
except (AttributeError, KeyError):
module = "<{}>".format(agen.ag_code.co_filename)
try:
qualname = agen.__qualname__
except AttributeError:
qualname = agen.ag_code.co_name
return f"{module}.{qualname}"

View File

@@ -0,0 +1,3 @@
# This file is imported from __init__.py and exec'd from setup.py
__version__ = "0.20.0"

View File

@@ -0,0 +1,62 @@
import math
from . import _timeouts
import trio
from ._core._windows_cffi import (
ffi,
kernel32,
ErrorCodes,
raise_winerror,
_handle,
)
async def WaitForSingleObject(obj):
"""Async and cancellable variant of WaitForSingleObject. Windows only.
Args:
handle: A Win32 handle, as a Python integer.
Raises:
OSError: If the handle is invalid, e.g. when it is already closed.
"""
# Allow ints or whatever we can convert to a win handle
handle = _handle(obj)
# Quick check; we might not even need to spawn a thread. The zero
# means a zero timeout; this call never blocks. We also exit here
# if the handle is already closed for some reason.
retcode = kernel32.WaitForSingleObject(handle, 0)
if retcode == ErrorCodes.WAIT_FAILED:
raise_winerror()
elif retcode != ErrorCodes.WAIT_TIMEOUT:
return
# Wait for a thread that waits for two handles: the handle plus a handle
# that we can use to cancel the thread.
cancel_handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
try:
await trio.to_thread.run_sync(
WaitForMultipleObjects_sync,
handle,
cancel_handle,
cancellable=True,
limiter=trio.CapacityLimiter(math.inf),
)
finally:
# Clean up our cancel handle. In case we get here because this task was
# cancelled, we also want to set the cancel_handle to stop the thread.
kernel32.SetEvent(cancel_handle)
kernel32.CloseHandle(cancel_handle)
def WaitForMultipleObjects_sync(*handles):
"""Wait for any of the given Windows handles to be signaled."""
n = len(handles)
handle_arr = ffi.new("HANDLE[{}]".format(n))
for i in range(n):
handle_arr[i] = handles[i]
timeout = 0xFFFFFFFF # INFINITE
retcode = kernel32.WaitForMultipleObjects(n, handle_arr, False, timeout) # blocking
if retcode == ErrorCodes.WAIT_FAILED:
raise_winerror()

View File

@@ -0,0 +1,138 @@
import sys
from typing import TYPE_CHECKING
from . import _core
from ._abc import SendStream, ReceiveStream
from ._util import ConflictDetector, Final
from ._core._windows_cffi import _handle, raise_winerror, kernel32, ffi
assert sys.platform == "win32" or not TYPE_CHECKING
# XX TODO: don't just make this up based on nothing.
DEFAULT_RECEIVE_SIZE = 65536
# See the comments on _unix_pipes._FdHolder for discussion of why we set the
# handle to -1 when it's closed.
class _HandleHolder:
def __init__(self, handle: int) -> None:
self.handle = -1
if not isinstance(handle, int):
raise TypeError("handle must be an int")
self.handle = handle
_core.register_with_iocp(self.handle)
@property
def closed(self):
return self.handle == -1
def close(self):
if self.closed:
return
handle = self.handle
self.handle = -1
if not kernel32.CloseHandle(_handle(handle)):
raise_winerror()
def __del__(self):
self.close()
class PipeSendStream(SendStream, metaclass=Final):
"""Represents a send stream over a Windows named pipe that has been
opened in OVERLAPPED mode.
"""
def __init__(self, handle: int) -> None:
self._handle_holder = _HandleHolder(handle)
self._conflict_detector = ConflictDetector(
"another task is currently using this pipe"
)
async def send_all(self, data: bytes):
with self._conflict_detector:
if self._handle_holder.closed:
raise _core.ClosedResourceError("this pipe is already closed")
if not data:
await _core.checkpoint()
return
try:
written = await _core.write_overlapped(self._handle_holder.handle, data)
except BrokenPipeError as ex:
raise _core.BrokenResourceError from ex
# By my reading of MSDN, this assert is guaranteed to pass so long
# as the pipe isn't in nonblocking mode, but... let's just
# double-check.
assert written == len(data)
async def wait_send_all_might_not_block(self) -> None:
with self._conflict_detector:
if self._handle_holder.closed:
raise _core.ClosedResourceError("This pipe is already closed")
# not implemented yet, and probably not needed
await _core.checkpoint()
def close(self):
self._handle_holder.close()
async def aclose(self):
self.close()
await _core.checkpoint()
class PipeReceiveStream(ReceiveStream, metaclass=Final):
"""Represents a receive stream over an os.pipe object."""
def __init__(self, handle: int) -> None:
self._handle_holder = _HandleHolder(handle)
self._conflict_detector = ConflictDetector(
"another task is currently using this pipe"
)
async def receive_some(self, max_bytes=None) -> bytes:
with self._conflict_detector:
if self._handle_holder.closed:
raise _core.ClosedResourceError("this pipe is already closed")
if max_bytes is None:
max_bytes = DEFAULT_RECEIVE_SIZE
else:
if not isinstance(max_bytes, int):
raise TypeError("max_bytes must be integer >= 1")
if max_bytes < 1:
raise ValueError("max_bytes must be integer >= 1")
buffer = bytearray(max_bytes)
try:
size = await _core.readinto_overlapped(
self._handle_holder.handle, buffer
)
except BrokenPipeError:
if self._handle_holder.closed:
raise _core.ClosedResourceError(
"another task closed this pipe"
) from None
# Windows raises BrokenPipeError on one end of a pipe
# whenever the other end closes, regardless of direction.
# Convert this to the Unix behavior of returning EOF to the
# reader when the writer closes.
#
# And since we're not raising an exception, we have to
# checkpoint. But readinto_overlapped did raise an exception,
# so it might not have checkpointed for us. So we have to
# checkpoint manually.
await _core.checkpoint()
return b""
else:
del buffer[size:]
return buffer
def close(self):
self._handle_holder.close()
async def aclose(self):
self.close()
await _core.checkpoint()

View File

@@ -0,0 +1,21 @@
# This is a public namespace, so we don't want to expose any non-underscored
# attributes that aren't actually part of our public API. But it's very
# annoying to carefully always use underscored names for module-level
# temporaries, imports, etc. when implementing the module. So we put the
# implementation in an underscored module, and then re-export the public parts
# here.
from ._abc import (
Clock,
Instrument,
AsyncResource,
SendStream,
ReceiveStream,
Stream,
HalfCloseableStream,
SocketFactory,
HostnameResolver,
Listener,
SendChannel,
ReceiveChannel,
Channel,
)

View File

@@ -0,0 +1,7 @@
"""
This namespace represents special functions that can call back into Trio from
an external thread by means of a Trio Token present in Thread Local Storage
"""
from ._threads import from_thread_run as run
from ._threads import from_thread_run_sync as run_sync

View File

@@ -0,0 +1,74 @@
"""
This namespace represents low-level functionality not intended for daily use,
but useful for extending Trio's functionality.
"""
import select as _select
import sys
import typing as _t
# This is the union of a subset of trio/_core/ and some things from trio/*.py.
# See comments in trio/__init__.py for details. To make static analysis easier,
# this lists all possible symbols from trio._core, and then we prune those that
# aren't available on this system. After that we add some symbols from trio/*.py.
# Generally available symbols
from ._core import (
cancel_shielded_checkpoint,
Abort,
wait_task_rescheduled,
enable_ki_protection,
disable_ki_protection,
currently_ki_protected,
Task,
checkpoint,
current_task,
ParkingLot,
UnboundedQueue,
RunVar,
TrioToken,
current_trio_token,
temporarily_detach_coroutine_object,
permanently_detach_coroutine_object,
reattach_detached_coroutine_object,
current_statistics,
reschedule,
remove_instrument,
add_instrument,
current_clock,
current_root_task,
checkpoint_if_cancelled,
spawn_system_task,
wait_readable,
wait_writable,
notify_closing,
start_thread_soon,
start_guest_run,
)
from ._subprocess import open_process
if sys.platform == "win32":
# Windows symbols
from ._core import (
current_iocp,
register_with_iocp,
wait_overlapped,
monitor_completion_key,
readinto_overlapped,
write_overlapped,
)
from ._wait_for_object import WaitForSingleObject
else:
# Unix symbols
from ._unix_pipes import FdStream
# Kqueue-specific symbols
if sys.platform != "linux" and (_t.TYPE_CHECKING or not hasattr(_select, "epoll")):
from ._core import (
current_kqueue,
monitor_kevent,
wait_kevent,
)
del sys

View File

@@ -0,0 +1,200 @@
# This is a public namespace, so we don't want to expose any non-underscored
# attributes that aren't actually part of our public API. But it's very
# annoying to carefully always use underscored names for module-level
# temporaries, imports, etc. when implementing the module. So we put the
# implementation in an underscored module, and then re-export the public parts
# here.
# We still have some underscore names though but only a few.
from . import _socket
import sys
import typing as _t
# The socket module exports a bunch of platform-specific constants. We want to
# re-export them. Since the exact set of constants varies depending on Python
# version, platform, the libc installed on the system where Python was built,
# etc., we figure out which constants to re-export dynamically at runtime (see
# below). But that confuses static analysis tools like jedi and mypy. So this
# import statement statically lists every constant that *could* be
# exported. It always fails at runtime, since no single Python build exports
# all these constants, but it lets static analysis tools understand what's
# going on. There's a test in test_exports.py to make sure that the list is
# kept up to date.
try:
# fmt: off
from socket import ( # type: ignore
CMSG_LEN, CMSG_SPACE, CAPI, AF_UNSPEC, AF_INET, AF_UNIX, AF_IPX,
AF_APPLETALK, AF_INET6, AF_ROUTE, AF_LINK, AF_SNA, PF_SYSTEM,
AF_SYSTEM, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, SOCK_SEQPACKET, SOCK_RDM,
SO_DEBUG, SO_ACCEPTCONN, SO_REUSEADDR, SO_KEEPALIVE, SO_DONTROUTE,
SO_BROADCAST, SO_USELOOPBACK, SO_LINGER, SO_OOBINLINE, SO_REUSEPORT,
SO_SNDBUF, SO_RCVBUF, SO_SNDLOWAT, SO_RCVLOWAT, SO_SNDTIMEO,
SO_RCVTIMEO, SO_ERROR, SO_TYPE, LOCAL_PEERCRED, SOMAXCONN, SCM_RIGHTS,
SCM_CREDS, MSG_OOB, MSG_PEEK, MSG_DONTROUTE, MSG_DONTWAIT, MSG_EOR,
MSG_TRUNC, MSG_CTRUNC, MSG_WAITALL, MSG_EOF, SOL_SOCKET, SOL_IP,
SOL_TCP, SOL_UDP, IPPROTO_IP, IPPROTO_HOPOPTS, IPPROTO_ICMP,
IPPROTO_IGMP, IPPROTO_GGP, IPPROTO_IPV4, IPPROTO_IPIP, IPPROTO_TCP,
IPPROTO_EGP, IPPROTO_PUP, IPPROTO_UDP, IPPROTO_IDP, IPPROTO_HELLO,
IPPROTO_ND, IPPROTO_TP, IPPROTO_ROUTING, IPPROTO_FRAGMENT,
IPPROTO_RSVP, IPPROTO_GRE, IPPROTO_ESP, IPPROTO_AH, IPPROTO_ICMPV6,
IPPROTO_NONE, IPPROTO_DSTOPTS, IPPROTO_XTP, IPPROTO_EON, IPPROTO_PIM,
IPPROTO_IPCOMP, IPPROTO_SCTP, IPPROTO_RAW, IPPROTO_MAX,
SYSPROTO_CONTROL, IPPORT_RESERVED, IPPORT_USERRESERVED, INADDR_ANY,
INADDR_BROADCAST, INADDR_LOOPBACK, INADDR_UNSPEC_GROUP,
INADDR_ALLHOSTS_GROUP, INADDR_MAX_LOCAL_GROUP, INADDR_NONE, IP_OPTIONS,
IP_HDRINCL, IP_TOS, IP_TTL, IP_RECVOPTS, IP_RECVRETOPTS,
IP_RECVDSTADDR, IP_RETOPTS, IP_MULTICAST_IF, IP_MULTICAST_TTL,
IP_MULTICAST_LOOP, IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP,
IP_DEFAULT_MULTICAST_TTL, IP_DEFAULT_MULTICAST_LOOP,
IP_MAX_MEMBERSHIPS, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP,
IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP,
IPV6_UNICAST_HOPS, IPV6_V6ONLY, IPV6_CHECKSUM, IPV6_RECVTCLASS,
IPV6_RTHDR_TYPE_0, IPV6_TCLASS, TCP_NODELAY, TCP_MAXSEG, TCP_KEEPINTVL,
TCP_KEEPCNT, TCP_FASTOPEN, TCP_NOTSENT_LOWAT, EAI_ADDRFAMILY,
EAI_AGAIN, EAI_BADFLAGS, EAI_FAIL, EAI_FAMILY, EAI_MEMORY, EAI_NODATA,
EAI_NONAME, EAI_OVERFLOW, EAI_SERVICE, EAI_SOCKTYPE, EAI_SYSTEM,
EAI_BADHINTS, EAI_PROTOCOL, EAI_MAX, AI_PASSIVE, AI_CANONNAME,
AI_NUMERICHOST, AI_NUMERICSERV, AI_MASK, AI_ALL, AI_V4MAPPED_CFG,
AI_ADDRCONFIG, AI_V4MAPPED, AI_DEFAULT, NI_MAXHOST, NI_MAXSERV,
NI_NOFQDN, NI_NUMERICHOST, NI_NAMEREQD, NI_NUMERICSERV, NI_DGRAM,
SHUT_RD, SHUT_WR, SHUT_RDWR, EBADF, EAGAIN, EWOULDBLOCK, AF_ASH,
AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_ECONET,
AF_IRDA, AF_KEY, AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET,
AF_PPPOX, AF_ROSE, AF_SECURITY, AF_WANPIPE, AF_X25, BDADDR_ANY,
BDADDR_LOCAL, FD_SETSIZE, IPV6_DSTOPTS, IPV6_HOPLIMIT, IPV6_HOPOPTS,
IPV6_NEXTHOP, IPV6_PKTINFO, IPV6_RECVDSTOPTS, IPV6_RECVHOPLIMIT,
IPV6_RECVHOPOPTS, IPV6_RECVPKTINFO, IPV6_RECVRTHDR, IPV6_RTHDR,
IPV6_RTHDRDSTOPTS, MSG_ERRQUEUE, NETLINK_DNRTMSG, NETLINK_FIREWALL,
NETLINK_IP6_FW, NETLINK_NFLOG, NETLINK_ROUTE, NETLINK_USERSOCK,
NETLINK_XFRM, PACKET_BROADCAST, PACKET_FASTROUTE, PACKET_HOST,
PACKET_LOOPBACK, PACKET_MULTICAST, PACKET_OTHERHOST, PACKET_OUTGOING,
POLLERR, POLLHUP, POLLIN, POLLMSG, POLLNVAL, POLLOUT, POLLPRI,
POLLRDBAND, POLLRDNORM, POLLWRNORM, SIOCGIFINDEX, SIOCGIFNAME,
SOCK_CLOEXEC, TCP_CORK, TCP_DEFER_ACCEPT, TCP_INFO, TCP_KEEPIDLE,
TCP_LINGER2, TCP_QUICKACK, TCP_SYNCNT, TCP_WINDOW_CLAMP, AF_ALG,
AF_CAN, AF_RDS, AF_TIPC, AF_VSOCK, ALG_OP_DECRYPT, ALG_OP_ENCRYPT,
ALG_OP_SIGN, ALG_OP_VERIFY, ALG_SET_AEAD_ASSOCLEN,
ALG_SET_AEAD_AUTHSIZE, ALG_SET_IV, ALG_SET_KEY, ALG_SET_OP,
ALG_SET_PUBKEY, CAN_BCM, CAN_BCM_RX_CHANGED, CAN_BCM_RX_DELETE,
CAN_BCM_RX_READ, CAN_BCM_RX_SETUP, CAN_BCM_RX_STATUS,
CAN_BCM_RX_TIMEOUT, CAN_BCM_TX_DELETE, CAN_BCM_TX_EXPIRED,
CAN_BCM_TX_READ, CAN_BCM_TX_SEND, CAN_BCM_TX_SETUP, CAN_BCM_TX_STATUS,
CAN_EFF_FLAG, CAN_EFF_MASK, CAN_ERR_FLAG, CAN_ERR_MASK, CAN_ISOTP,
CAN_RAW, CAN_RAW_ERR_FILTER, CAN_RAW_FD_FRAMES, CAN_RAW_FILTER,
CAN_RAW_LOOPBACK, CAN_RAW_RECV_OWN_MSGS, CAN_RTR_FLAG, CAN_SFF_MASK,
IOCTL_VM_SOCKETS_GET_LOCAL_CID, IPV6_DONTFRAG, IPV6_PATHMTU,
IPV6_RECVPATHMTU, IP_TRANSPARENT, MSG_CMSG_CLOEXEC, MSG_CONFIRM,
MSG_FASTOPEN, MSG_MORE, MSG_NOSIGNAL, NETLINK_CRYPTO, PF_CAN,
PF_PACKET, PF_RDS, SCM_CREDENTIALS, SOCK_NONBLOCK, SOL_ALG,
SOL_CAN_BASE, SOL_CAN_RAW, SOL_TIPC, SO_BINDTODEVICE, SO_DOMAIN,
SO_MARK, SO_PASSCRED, SO_PASSSEC, SO_PEERCRED, SO_PEERSEC, SO_PRIORITY,
SO_PROTOCOL, SO_VM_SOCKETS_BUFFER_MAX_SIZE,
SO_VM_SOCKETS_BUFFER_MIN_SIZE, SO_VM_SOCKETS_BUFFER_SIZE,
TCP_CONGESTION, TCP_USER_TIMEOUT, TIPC_ADDR_ID, TIPC_ADDR_NAME,
TIPC_ADDR_NAMESEQ, TIPC_CFG_SRV, TIPC_CLUSTER_SCOPE, TIPC_CONN_TIMEOUT,
TIPC_CRITICAL_IMPORTANCE, TIPC_DEST_DROPPABLE, TIPC_HIGH_IMPORTANCE,
TIPC_IMPORTANCE, TIPC_LOW_IMPORTANCE, TIPC_MEDIUM_IMPORTANCE,
TIPC_NODE_SCOPE, TIPC_PUBLISHED, TIPC_SRC_DROPPABLE,
TIPC_SUBSCR_TIMEOUT, TIPC_SUB_CANCEL, TIPC_SUB_PORTS, TIPC_SUB_SERVICE,
TIPC_TOP_SRV, TIPC_WAIT_FOREVER, TIPC_WITHDRAWN, TIPC_ZONE_SCOPE,
VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_PORT_ANY,
VM_SOCKETS_INVALID_VERSION, MSG_BCAST, MSG_MCAST, RCVALL_MAX,
RCVALL_OFF, RCVALL_ON, RCVALL_SOCKETLEVELONLY, SIO_KEEPALIVE_VALS,
SIO_LOOPBACK_FAST_PATH, SIO_RCVALL, SO_EXCLUSIVEADDRUSE, HCI_FILTER,
BTPROTO_SCO, BTPROTO_HCI, HCI_TIME_STAMP, SOL_RDS, BTPROTO_L2CAP,
BTPROTO_RFCOMM, HCI_DATA_DIR, SOL_HCI, CAN_BCM_RX_ANNOUNCE_RESUME,
CAN_BCM_RX_CHECK_DLC, CAN_BCM_RX_FILTER_ID, CAN_BCM_RX_NO_AUTOTIMER,
CAN_BCM_RX_RTR_FRAME, CAN_BCM_SETTIMER, CAN_BCM_STARTTIMER,
CAN_BCM_TX_ANNOUNCE, CAN_BCM_TX_COUNTEVT, CAN_BCM_TX_CP_CAN_ID,
CAN_BCM_TX_RESET_MULTI_IDX, IPPROTO_CBT, IPPROTO_ICLFXBM, IPPROTO_IGP,
IPPROTO_L2TP, IPPROTO_PGM, IPPROTO_RDP, IPPROTO_ST, AF_QIPCRTR,
CAN_BCM_CAN_FD_FRAME, IPPROTO_MOBILE, IPV6_USE_MIN_MTU,
MSG_NOTIFICATION, SO_SETFIB, CAN_J1939, CAN_RAW_JOIN_FILTERS,
IPPROTO_UDPLITE, J1939_EE_INFO_NONE, J1939_EE_INFO_TX_ABORT,
J1939_FILTER_MAX, J1939_IDLE_ADDR, J1939_MAX_UNICAST_ADDR,
J1939_NLA_BYTES_ACKED, J1939_NLA_PAD, J1939_NO_ADDR, J1939_NO_NAME,
J1939_NO_PGN, J1939_PGN_ADDRESS_CLAIMED, J1939_PGN_ADDRESS_COMMANDED,
J1939_PGN_MAX, J1939_PGN_PDU1_MAX, J1939_PGN_REQUEST,
SCM_J1939_DEST_ADDR, SCM_J1939_DEST_NAME, SCM_J1939_ERRQUEUE,
SCM_J1939_PRIO, SO_J1939_ERRQUEUE, SO_J1939_FILTER, SO_J1939_PROMISC,
SO_J1939_SEND_PRIO, UDPLITE_RECV_CSCOV, UDPLITE_SEND_CSCOV, IP_RECVTOS,
TCP_KEEPALIVE
)
# fmt: on
except ImportError:
pass
# Dynamically re-export whatever constants this particular Python happens to
# have:
import socket as _stdlib_socket
_bad_symbols: _t.Set[str] = set()
if sys.platform == "win32":
# See https://github.com/python-trio/trio/issues/39
# Do not import for windows platform
# (you can still get it from stdlib socket, of course, if you want it)
_bad_symbols.add("SO_REUSEADDR")
globals().update(
{
_name: getattr(_stdlib_socket, _name)
for _name in _stdlib_socket.__all__ # type: ignore
if _name.isupper() and _name not in _bad_symbols
}
)
# import the overwrites
from ._socket import (
fromfd,
from_stdlib_socket,
getprotobyname,
socketpair,
getnameinfo,
socket,
getaddrinfo,
set_custom_hostname_resolver,
set_custom_socket_factory,
SocketType,
)
# not always available so expose only if
if sys.platform == "win32" or not _t.TYPE_CHECKING:
try:
from ._socket import fromshare
except ImportError:
pass
# expose these functions to trio.socket
from socket import (
gaierror,
herror,
gethostname,
ntohs,
htonl,
htons,
inet_aton,
inet_ntoa,
inet_pton,
inet_ntop,
)
# not always available so expose only if
if sys.platform != "win32" or not _t.TYPE_CHECKING:
try:
from socket import sethostname, if_nameindex, if_nametoindex, if_indextoname
except ImportError:
pass
# get names used by Trio that we define on our own
from ._socket import IPPROTO_IPV6
if _t.TYPE_CHECKING:
IP_BIND_ADDRESS_NO_PORT: int
else:
try:
IP_BIND_ADDRESS_NO_PORT
except NameError:
if sys.platform == "linux":
IP_BIND_ADDRESS_NO_PORT = 24
del sys

View File

@@ -0,0 +1,32 @@
from .._core import wait_all_tasks_blocked, MockClock
from ._trio_test import trio_test
from ._checkpoints import assert_checkpoints, assert_no_checkpoints
from ._sequencer import Sequencer
from ._check_streams import (
check_one_way_stream,
check_two_way_stream,
check_half_closeable_stream,
)
from ._memory_streams import (
MemorySendStream,
MemoryReceiveStream,
memory_stream_pump,
memory_stream_one_way_pair,
memory_stream_pair,
lockstep_stream_one_way_pair,
lockstep_stream_pair,
)
from ._network import open_stream_to_socket_listener
################################################################
from .._util import fixup_module_metadata
fixup_module_metadata(__name__, globals())
del fixup_module_metadata

View File

@@ -0,0 +1,512 @@
# Generic stream tests
from contextlib import contextmanager
import random
from .. import _core
from .._highlevel_generic import aclose_forcefully
from .._abc import SendStream, ReceiveStream, Stream, HalfCloseableStream
from ._checkpoints import assert_checkpoints
class _ForceCloseBoth:
def __init__(self, both):
self._both = list(both)
async def __aenter__(self):
return self._both
async def __aexit__(self, *args):
try:
await aclose_forcefully(self._both[0])
finally:
await aclose_forcefully(self._both[1])
@contextmanager
def _assert_raises(exc):
__tracebackhide__ = True
try:
yield
except exc:
pass
else:
raise AssertionError("expected exception: {}".format(exc))
async def check_one_way_stream(stream_maker, clogged_stream_maker):
"""Perform a number of generic tests on a custom one-way stream
implementation.
Args:
stream_maker: An async (!) function which returns a connected
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`)
pair.
clogged_stream_maker: Either None, or an async function similar to
stream_maker, but with the extra property that the returned stream
is in a state where ``send_all`` and
``wait_send_all_might_not_block`` will block until ``receive_some``
has been called. This allows for more thorough testing of some edge
cases, especially around ``wait_send_all_might_not_block``.
Raises:
AssertionError: if a test fails.
"""
async with _ForceCloseBoth(await stream_maker()) as (s, r):
assert isinstance(s, SendStream)
assert isinstance(r, ReceiveStream)
async def do_send_all(data):
with assert_checkpoints():
assert await s.send_all(data) is None
async def do_receive_some(*args):
with assert_checkpoints():
return await r.receive_some(*args)
async def checked_receive_1(expected):
assert await do_receive_some(1) == expected
async def do_aclose(resource):
with assert_checkpoints():
await resource.aclose()
# Simple sending/receiving
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, b"x")
nursery.start_soon(checked_receive_1, b"x")
async def send_empty_then_y():
# Streams should tolerate sending b"" without giving it any
# special meaning.
await do_send_all(b"")
await do_send_all(b"y")
async with _core.open_nursery() as nursery:
nursery.start_soon(send_empty_then_y)
nursery.start_soon(checked_receive_1, b"y")
# ---- Checking various argument types ----
# send_all accepts bytearray and memoryview
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, bytearray(b"1"))
nursery.start_soon(checked_receive_1, b"1")
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, memoryview(b"2"))
nursery.start_soon(checked_receive_1, b"2")
# max_bytes must be a positive integer
with _assert_raises(ValueError):
await r.receive_some(-1)
with _assert_raises(ValueError):
await r.receive_some(0)
with _assert_raises(TypeError):
await r.receive_some(1.5)
# it can also be missing or None
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, b"x")
assert await do_receive_some() == b"x"
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, b"x")
assert await do_receive_some(None) == b"x"
with _assert_raises(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(do_receive_some, 1)
nursery.start_soon(do_receive_some, 1)
# Method always has to exist, and an empty stream with a blocked
# receive_some should *always* allow send_all. (Technically it's legal
# for send_all to wait until receive_some is called to run, though; a
# stream doesn't *have* to have any internal buffering. That's why we
# start a concurrent receive_some call, then cancel it.)
async def simple_check_wait_send_all_might_not_block(scope):
with assert_checkpoints():
await s.wait_send_all_might_not_block()
scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(
simple_check_wait_send_all_might_not_block, nursery.cancel_scope
)
nursery.start_soon(do_receive_some, 1)
# closing the r side leads to BrokenResourceError on the s side
# (eventually)
async def expect_broken_stream_on_send():
with _assert_raises(_core.BrokenResourceError):
while True:
await do_send_all(b"x" * 100)
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_broken_stream_on_send)
nursery.start_soon(do_aclose, r)
# once detected, the stream stays broken
with _assert_raises(_core.BrokenResourceError):
await do_send_all(b"x" * 100)
# r closed -> ClosedResourceError on the receive side
with _assert_raises(_core.ClosedResourceError):
await do_receive_some(4096)
# we can close the same stream repeatedly, it's fine
await do_aclose(r)
await do_aclose(r)
# closing the sender side
await do_aclose(s)
# now trying to send raises ClosedResourceError
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"x" * 100)
# even if it's an empty send
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"")
# ditto for wait_send_all_might_not_block
with _assert_raises(_core.ClosedResourceError):
with assert_checkpoints():
await s.wait_send_all_might_not_block()
# and again, repeated closing is fine
await do_aclose(s)
await do_aclose(s)
async with _ForceCloseBoth(await stream_maker()) as (s, r):
# if send-then-graceful-close, receiver gets data then b""
async def send_then_close():
await do_send_all(b"y")
await do_aclose(s)
async def receive_send_then_close():
# We want to make sure that if the sender closes the stream before
# we read anything, then we still get all the data. But some
# streams might block on the do_send_all call. So we let the
# sender get as far as it can, then we receive.
await _core.wait_all_tasks_blocked()
await checked_receive_1(b"y")
await checked_receive_1(b"")
await do_aclose(r)
async with _core.open_nursery() as nursery:
nursery.start_soon(send_then_close)
nursery.start_soon(receive_send_then_close)
async with _ForceCloseBoth(await stream_maker()) as (s, r):
await aclose_forcefully(r)
with _assert_raises(_core.BrokenResourceError):
while True:
await do_send_all(b"x" * 100)
with _assert_raises(_core.ClosedResourceError):
await do_receive_some(4096)
async with _ForceCloseBoth(await stream_maker()) as (s, r):
await aclose_forcefully(s)
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"123")
# after the sender does a forceful close, the receiver might either
# get BrokenResourceError or a clean b""; either is OK. Not OK would be
# if it freezes, or returns data.
try:
await checked_receive_1(b"")
except _core.BrokenResourceError:
pass
# cancelled aclose still closes
async with _ForceCloseBoth(await stream_maker()) as (s, r):
with _core.CancelScope() as scope:
scope.cancel()
await r.aclose()
with _core.CancelScope() as scope:
scope.cancel()
await s.aclose()
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"123")
with _assert_raises(_core.ClosedResourceError):
await do_receive_some(4096)
# Check that we can still gracefully close a stream after an operation has
# been cancelled. This can be challenging if cancellation can leave the
# stream internals in an inconsistent state, e.g. for
# SSLStream. Unfortunately this test isn't very thorough; the really
# challenging case for something like SSLStream is it gets cancelled
# *while* it's sending data on the underlying, not before. But testing
# that requires some special-case handling of the particular stream setup;
# we can't do it here. Maybe we could do a bit better with
# https://github.com/python-trio/trio/issues/77
async with _ForceCloseBoth(await stream_maker()) as (s, r):
async def expect_cancelled(afn, *args):
with _assert_raises(_core.Cancelled):
await afn(*args)
with _core.CancelScope() as scope:
scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_cancelled, do_send_all, b"x")
nursery.start_soon(expect_cancelled, do_receive_some, 1)
async with _core.open_nursery() as nursery:
nursery.start_soon(do_aclose, s)
nursery.start_soon(do_aclose, r)
# Check that if a task is blocked in receive_some, then closing the
# receive stream causes it to wake up.
async with _ForceCloseBoth(await stream_maker()) as (s, r):
async def receive_expecting_closed():
with _assert_raises(_core.ClosedResourceError):
await r.receive_some(10)
async with _core.open_nursery() as nursery:
nursery.start_soon(receive_expecting_closed)
await _core.wait_all_tasks_blocked()
await aclose_forcefully(r)
# check wait_send_all_might_not_block, if we can
if clogged_stream_maker is not None:
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
record = []
async def waiter(cancel_scope):
record.append("waiter sleeping")
with assert_checkpoints():
await s.wait_send_all_might_not_block()
record.append("waiter wokeup")
cancel_scope.cancel()
async def receiver():
# give wait_send_all_might_not_block a chance to block
await _core.wait_all_tasks_blocked()
record.append("receiver starting")
while True:
await r.receive_some(16834)
async with _core.open_nursery() as nursery:
nursery.start_soon(waiter, nursery.cancel_scope)
await _core.wait_all_tasks_blocked()
nursery.start_soon(receiver)
assert record == [
"waiter sleeping",
"receiver starting",
"waiter wokeup",
]
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
# simultaneous wait_send_all_might_not_block fails
with _assert_raises(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(s.wait_send_all_might_not_block)
nursery.start_soon(s.wait_send_all_might_not_block)
# and simultaneous send_all and wait_send_all_might_not_block (NB
# this test might destroy the stream b/c we end up cancelling
# send_all and e.g. SSLStream can't handle that, so we have to
# recreate afterwards)
with _assert_raises(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(s.wait_send_all_might_not_block)
nursery.start_soon(s.send_all, b"123")
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
# send_all and send_all blocked simultaneously should also raise
# (but again this might destroy the stream)
with _assert_raises(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(s.send_all, b"123")
nursery.start_soon(s.send_all, b"123")
# closing the receiver causes wait_send_all_might_not_block to return,
# with or without an exception
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
async def sender():
try:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
except _core.BrokenResourceError: # pragma: no cover
pass
async def receiver():
await _core.wait_all_tasks_blocked()
await aclose_forcefully(r)
async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)
# and again with the call starting after the close
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
await aclose_forcefully(r)
try:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
except _core.BrokenResourceError: # pragma: no cover
pass
# Check that if a task is blocked in a send-side method, then closing
# the send stream causes it to wake up.
async def close_soon(s):
await _core.wait_all_tasks_blocked()
await aclose_forcefully(s)
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
async with _core.open_nursery() as nursery:
nursery.start_soon(close_soon, s)
with _assert_raises(_core.ClosedResourceError):
await s.send_all(b"xyzzy")
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
async with _core.open_nursery() as nursery:
nursery.start_soon(close_soon, s)
with _assert_raises(_core.ClosedResourceError):
await s.wait_send_all_might_not_block()
async def check_two_way_stream(stream_maker, clogged_stream_maker):
"""Perform a number of generic tests on a custom two-way stream
implementation.
This is similar to :func:`check_one_way_stream`, except that the maker
functions are expected to return objects implementing the
:class:`~trio.abc.Stream` interface.
This function tests a *superset* of what :func:`check_one_way_stream`
checks if you call this, then you don't need to also call
:func:`check_one_way_stream`.
"""
await check_one_way_stream(stream_maker, clogged_stream_maker)
async def flipped_stream_maker():
return reversed(await stream_maker())
if clogged_stream_maker is not None:
async def flipped_clogged_stream_maker():
return reversed(await clogged_stream_maker())
else:
flipped_clogged_stream_maker = None
await check_one_way_stream(flipped_stream_maker, flipped_clogged_stream_maker)
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
assert isinstance(s1, Stream)
assert isinstance(s2, Stream)
# Duplex can be a bit tricky, might as well check it as well
DUPLEX_TEST_SIZE = 2**20
CHUNK_SIZE_MAX = 2**14
r = random.Random(0)
i = r.getrandbits(8 * DUPLEX_TEST_SIZE)
test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little")
async def sender(s, data, seed):
r = random.Random(seed)
m = memoryview(data)
while m:
chunk_size = r.randint(1, CHUNK_SIZE_MAX)
await s.send_all(m[:chunk_size])
m = m[chunk_size:]
async def receiver(s, data, seed):
r = random.Random(seed)
got = bytearray()
while len(got) < len(data):
chunk = await s.receive_some(r.randint(1, CHUNK_SIZE_MAX))
assert chunk
got += chunk
assert got == data
async with _core.open_nursery() as nursery:
nursery.start_soon(sender, s1, test_data, 0)
nursery.start_soon(sender, s2, test_data[::-1], 1)
nursery.start_soon(receiver, s1, test_data[::-1], 2)
nursery.start_soon(receiver, s2, test_data, 3)
async def expect_receive_some_empty():
assert await s2.receive_some(10) == b""
await s2.aclose()
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_receive_some_empty)
nursery.start_soon(s1.aclose)
async def check_half_closeable_stream(stream_maker, clogged_stream_maker):
"""Perform a number of generic tests on a custom half-closeable stream
implementation.
This is similar to :func:`check_two_way_stream`, except that the maker
functions are expected to return objects that implement the
:class:`~trio.abc.HalfCloseableStream` interface.
This function tests a *superset* of what :func:`check_two_way_stream`
checks if you call this, then you don't need to also call
:func:`check_two_way_stream`.
"""
await check_two_way_stream(stream_maker, clogged_stream_maker)
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
assert isinstance(s1, HalfCloseableStream)
assert isinstance(s2, HalfCloseableStream)
async def send_x_then_eof(s):
await s.send_all(b"x")
with assert_checkpoints():
await s.send_eof()
async def expect_x_then_eof(r):
await _core.wait_all_tasks_blocked()
assert await r.receive_some(10) == b"x"
assert await r.receive_some(10) == b""
async with _core.open_nursery() as nursery:
nursery.start_soon(send_x_then_eof, s1)
nursery.start_soon(expect_x_then_eof, s2)
# now sending is disallowed
with _assert_raises(_core.ClosedResourceError):
await s1.send_all(b"y")
# but we can do send_eof again
with assert_checkpoints():
await s1.send_eof()
# and we can still send stuff back the other way
async with _core.open_nursery() as nursery:
nursery.start_soon(send_x_then_eof, s2)
nursery.start_soon(expect_x_then_eof, s1)
if clogged_stream_maker is not None:
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
# send_all and send_eof simultaneously is not ok
with _assert_raises(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(s1.send_all, b"x")
await _core.wait_all_tasks_blocked()
nursery.start_soon(s1.send_eof)
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
# wait_send_all_might_not_block and send_eof simultaneously is not
# ok either
with _assert_raises(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(s1.wait_send_all_might_not_block)
await _core.wait_all_tasks_blocked()
nursery.start_soon(s1.send_eof)

View File

@@ -0,0 +1,62 @@
from contextlib import contextmanager
from .. import _core
@contextmanager
def _assert_yields_or_not(expected):
__tracebackhide__ = True
task = _core.current_task()
orig_cancel = task._cancel_points
orig_schedule = task._schedule_points
try:
yield
if expected and (
task._cancel_points == orig_cancel or task._schedule_points == orig_schedule
):
raise AssertionError("assert_checkpoints block did not yield!")
finally:
if not expected and (
task._cancel_points != orig_cancel or task._schedule_points != orig_schedule
):
raise AssertionError("assert_no_checkpoints block yielded!")
def assert_checkpoints():
"""Use as a context manager to check that the code inside the ``with``
block either exits with an exception or executes at least one
:ref:`checkpoint <checkpoints>`.
Raises:
AssertionError: if no checkpoint was executed.
Example:
Check that :func:`trio.sleep` is a checkpoint, even if it doesn't
block::
with trio.testing.assert_checkpoints():
await trio.sleep(0)
"""
__tracebackhide__ = True
return _assert_yields_or_not(True)
def assert_no_checkpoints():
"""Use as a context manager to check that the code inside the ``with``
block does not execute any :ref:`checkpoints <checkpoints>`.
Raises:
AssertionError: if a checkpoint was executed.
Example:
Synchronous code never contains any checkpoints, but we can double-check
that::
send_channel, receive_channel = trio.open_memory_channel(10)
with trio.testing.assert_no_checkpoints():
send_channel.send_nowait(None)
"""
__tracebackhide__ = True
return _assert_yields_or_not(False)

View File

@@ -0,0 +1,591 @@
import operator
from .. import _core
from .._highlevel_generic import StapledStream
from .. import _util
from ..abc import SendStream, ReceiveStream
################################################################
# In-memory streams - Unbounded buffer version
################################################################
class _UnboundedByteQueue:
def __init__(self):
self._data = bytearray()
self._closed = False
self._lot = _core.ParkingLot()
self._fetch_lock = _util.ConflictDetector(
"another task is already fetching data"
)
# This object treats "close" as being like closing the send side of a
# channel: so after close(), calling put() raises ClosedResourceError, and
# calling the get() variants drains the buffer and then returns an empty
# bytearray.
def close(self):
self._closed = True
self._lot.unpark_all()
def close_and_wipe(self):
self._data = bytearray()
self.close()
def put(self, data):
if self._closed:
raise _core.ClosedResourceError("virtual connection closed")
self._data += data
self._lot.unpark_all()
def _check_max_bytes(self, max_bytes):
if max_bytes is None:
return
max_bytes = operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
def _get_impl(self, max_bytes):
assert self._closed or self._data
if max_bytes is None:
max_bytes = len(self._data)
if self._data:
chunk = self._data[:max_bytes]
del self._data[:max_bytes]
assert chunk
return chunk
else:
return bytearray()
def get_nowait(self, max_bytes=None):
with self._fetch_lock:
self._check_max_bytes(max_bytes)
if not self._closed and not self._data:
raise _core.WouldBlock
return self._get_impl(max_bytes)
async def get(self, max_bytes=None):
with self._fetch_lock:
self._check_max_bytes(max_bytes)
if not self._closed and not self._data:
await self._lot.park()
else:
await _core.checkpoint()
return self._get_impl(max_bytes)
class MemorySendStream(SendStream, metaclass=_util.Final):
"""An in-memory :class:`~trio.abc.SendStream`.
Args:
send_all_hook: An async function, or None. Called from
:meth:`send_all`. Can do whatever you like.
wait_send_all_might_not_block_hook: An async function, or None. Called
from :meth:`wait_send_all_might_not_block`. Can do whatever you
like.
close_hook: A synchronous function, or None. Called from :meth:`close`
and :meth:`aclose`. Can do whatever you like.
.. attribute:: send_all_hook
wait_send_all_might_not_block_hook
close_hook
All of these hooks are also exposed as attributes on the object, and
you can change them at any time.
"""
def __init__(
self,
send_all_hook=None,
wait_send_all_might_not_block_hook=None,
close_hook=None,
):
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream"
)
self._outgoing = _UnboundedByteQueue()
self.send_all_hook = send_all_hook
self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook
self.close_hook = close_hook
async def send_all(self, data):
"""Places the given data into the object's internal buffer, and then
calls the :attr:`send_all_hook` (if any).
"""
# Execute two checkpoints so we have more of a chance to detect
# buggy user code that calls this twice at the same time.
with self._conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
self._outgoing.put(data)
if self.send_all_hook is not None:
await self.send_all_hook()
async def wait_send_all_might_not_block(self):
"""Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and
then returns immediately.
"""
# Execute two checkpoints so we have more of a chance to detect
# buggy user code that calls this twice at the same time.
with self._conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
# check for being closed:
self._outgoing.put(b"")
if self.wait_send_all_might_not_block_hook is not None:
await self.wait_send_all_might_not_block_hook()
def close(self):
"""Marks this stream as closed, and then calls the :attr:`close_hook`
(if any).
"""
# XXX should this cancel any pending calls to the send_all_hook and
# wait_send_all_might_not_block_hook? Those are the only places where
# send_all and wait_send_all_might_not_block can be blocked.
#
# The way we set things up, send_all_hook is memory_stream_pump, and
# wait_send_all_might_not_block_hook is unset. memory_stream_pump is
# synchronous. So normally, send_all and wait_send_all_might_not_block
# cannot block at all.
self._outgoing.close()
if self.close_hook is not None:
self.close_hook()
async def aclose(self):
"""Same as :meth:`close`, but async."""
self.close()
await _core.checkpoint()
async def get_data(self, max_bytes=None):
"""Retrieves data from the internal buffer, blocking if necessary.
Args:
max_bytes (int or None): The maximum amount of data to
retrieve. None (the default) means to retrieve all the data
that's present (but still blocks until at least one byte is
available).
Returns:
If this stream has been closed, an empty bytearray. Otherwise, the
requested data.
"""
return await self._outgoing.get(max_bytes)
def get_data_nowait(self, max_bytes=None):
"""Retrieves data from the internal buffer, but doesn't block.
See :meth:`get_data` for details.
Raises:
trio.WouldBlock: if no data is available to retrieve.
"""
return self._outgoing.get_nowait(max_bytes)
class MemoryReceiveStream(ReceiveStream, metaclass=_util.Final):
"""An in-memory :class:`~trio.abc.ReceiveStream`.
Args:
receive_some_hook: An async function, or None. Called from
:meth:`receive_some`. Can do whatever you like.
close_hook: A synchronous function, or None. Called from :meth:`close`
and :meth:`aclose`. Can do whatever you like.
.. attribute:: receive_some_hook
close_hook
Both hooks are also exposed as attributes on the object, and you can
change them at any time.
"""
def __init__(self, receive_some_hook=None, close_hook=None):
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream"
)
self._incoming = _UnboundedByteQueue()
self._closed = False
self.receive_some_hook = receive_some_hook
self.close_hook = close_hook
async def receive_some(self, max_bytes=None):
"""Calls the :attr:`receive_some_hook` (if any), and then retrieves
data from the internal buffer, blocking if necessary.
"""
# Execute two checkpoints so we have more of a chance to detect
# buggy user code that calls this twice at the same time.
with self._conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
if self._closed:
raise _core.ClosedResourceError
if self.receive_some_hook is not None:
await self.receive_some_hook()
# self._incoming's closure state tracks whether we got an EOF.
# self._closed tracks whether we, ourselves, are closed.
# self.close() sends an EOF to wake us up and sets self._closed,
# so after we wake up we have to check self._closed again.
data = await self._incoming.get(max_bytes)
if self._closed:
raise _core.ClosedResourceError
return data
def close(self):
"""Discards any pending data from the internal buffer, and marks this
stream as closed.
"""
self._closed = True
self._incoming.close_and_wipe()
if self.close_hook is not None:
self.close_hook()
async def aclose(self):
"""Same as :meth:`close`, but async."""
self.close()
await _core.checkpoint()
def put_data(self, data):
"""Appends the given data to the internal buffer."""
self._incoming.put(data)
def put_eof(self):
"""Adds an end-of-file marker to the internal buffer."""
self._incoming.close()
def memory_stream_pump(memory_send_stream, memory_receive_stream, *, max_bytes=None):
"""Take data out of the given :class:`MemorySendStream`'s internal buffer,
and put it into the given :class:`MemoryReceiveStream`'s internal buffer.
Args:
memory_send_stream (MemorySendStream): The stream to get data from.
memory_receive_stream (MemoryReceiveStream): The stream to put data into.
max_bytes (int or None): The maximum amount of data to transfer in this
call, or None to transfer all available data.
Returns:
True if it successfully transferred some data, or False if there was no
data to transfer.
This is used to implement :func:`memory_stream_one_way_pair` and
:func:`memory_stream_pair`; see the latter's docstring for an example
of how you might use it yourself.
"""
try:
data = memory_send_stream.get_data_nowait(max_bytes)
except _core.WouldBlock:
return False
try:
if not data:
memory_receive_stream.put_eof()
else:
memory_receive_stream.put_data(data)
except _core.ClosedResourceError:
raise _core.BrokenResourceError("MemoryReceiveStream was closed")
return True
def memory_stream_one_way_pair():
"""Create a connected, pure-Python, unidirectional stream with infinite
buffering and flexible configuration options.
You can think of this as being a no-operating-system-involved
Trio-streamsified version of :func:`os.pipe` (except that :func:`os.pipe`
returns the streams in the wrong order we follow the superior convention
that data flows from left to right).
Returns:
A tuple (:class:`MemorySendStream`, :class:`MemoryReceiveStream`), where
the :class:`MemorySendStream` has its hooks set up so that it calls
:func:`memory_stream_pump` from its
:attr:`~MemorySendStream.send_all_hook` and
:attr:`~MemorySendStream.close_hook`.
The end result is that data automatically flows from the
:class:`MemorySendStream` to the :class:`MemoryReceiveStream`. But you're
also free to rearrange things however you like. For example, you can
temporarily set the :attr:`~MemorySendStream.send_all_hook` to None if you
want to simulate a stall in data transmission. Or see
:func:`memory_stream_pair` for a more elaborate example.
"""
send_stream = MemorySendStream()
recv_stream = MemoryReceiveStream()
def pump_from_send_stream_to_recv_stream():
memory_stream_pump(send_stream, recv_stream)
async def async_pump_from_send_stream_to_recv_stream():
pump_from_send_stream_to_recv_stream()
send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream
send_stream.close_hook = pump_from_send_stream_to_recv_stream
return send_stream, recv_stream
def _make_stapled_pair(one_way_pair):
pipe1_send, pipe1_recv = one_way_pair()
pipe2_send, pipe2_recv = one_way_pair()
stream1 = StapledStream(pipe1_send, pipe2_recv)
stream2 = StapledStream(pipe2_send, pipe1_recv)
return stream1, stream2
def memory_stream_pair():
"""Create a connected, pure-Python, bidirectional stream with infinite
buffering and flexible configuration options.
This is a convenience function that creates two one-way streams using
:func:`memory_stream_one_way_pair`, and then uses
:class:`~trio.StapledStream` to combine them into a single bidirectional
stream.
This is like a no-operating-system-involved, Trio-streamsified version of
:func:`socket.socketpair`.
Returns:
A pair of :class:`~trio.StapledStream` objects that are connected so
that data automatically flows from one to the other in both directions.
After creating a stream pair, you can send data back and forth, which is
enough for simple tests::
left, right = memory_stream_pair()
await left.send_all(b"123")
assert await right.receive_some() == b"123"
await right.send_all(b"456")
assert await left.receive_some() == b"456"
But if you read the docs for :class:`~trio.StapledStream` and
:func:`memory_stream_one_way_pair`, you'll see that all the pieces
involved in wiring this up are public APIs, so you can adjust to suit the
requirements of your tests. For example, here's how to tweak a stream so
that data flowing from left to right trickles in one byte at a time (but
data flowing from right to left proceeds at full speed)::
left, right = memory_stream_pair()
async def trickle():
# left is a StapledStream, and left.send_stream is a MemorySendStream
# right is a StapledStream, and right.recv_stream is a MemoryReceiveStream
while memory_stream_pump(left.send_stream, right.recv_stream, max_bytes=1):
# Pause between each byte
await trio.sleep(1)
# Normally this send_all_hook calls memory_stream_pump directly without
# passing in a max_bytes. We replace it with our custom version:
left.send_stream.send_all_hook = trickle
And here's a simple test using our modified stream objects::
async def sender():
await left.send_all(b"12345")
await left.send_eof()
async def receiver():
async for data in right:
print(data)
async with trio.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)
By default, this will print ``b"12345"`` and then immediately exit; with
our trickle stream it instead sleeps 1 second, then prints ``b"1"``, then
sleeps 1 second, then prints ``b"2"``, etc.
Pro-tip: you can insert sleep calls (like in our example above) to
manipulate the flow of data across tasks... and then use
:class:`MockClock` and its :attr:`~MockClock.autojump_threshold`
functionality to keep your test suite running quickly.
If you want to stress test a protocol implementation, one nice trick is to
use the :mod:`random` module (preferably with a fixed seed) to move random
numbers of bytes at a time, and insert random sleeps in between them. You
can also set up a custom :attr:`~MemoryReceiveStream.receive_some_hook` if
you want to manipulate things on the receiving side, and not just the
sending side.
"""
return _make_stapled_pair(memory_stream_one_way_pair)
################################################################
# In-memory streams - Lockstep version
################################################################
class _LockstepByteQueue:
def __init__(self):
self._data = bytearray()
self._sender_closed = False
self._receiver_closed = False
self._receiver_waiting = False
self._waiters = _core.ParkingLot()
self._send_conflict_detector = _util.ConflictDetector(
"another task is already sending"
)
self._receive_conflict_detector = _util.ConflictDetector(
"another task is already receiving"
)
def _something_happened(self):
self._waiters.unpark_all()
# Always wakes up when one side is closed, because everyone always reacts
# to that.
async def _wait_for(self, fn):
while True:
if fn():
break
if self._sender_closed or self._receiver_closed:
break
await self._waiters.park()
await _core.checkpoint()
def close_sender(self):
self._sender_closed = True
self._something_happened()
def close_receiver(self):
self._receiver_closed = True
self._something_happened()
async def send_all(self, data):
with self._send_conflict_detector:
if self._sender_closed:
raise _core.ClosedResourceError
if self._receiver_closed:
raise _core.BrokenResourceError
assert not self._data
self._data += data
self._something_happened()
await self._wait_for(lambda: not self._data)
if self._sender_closed:
raise _core.ClosedResourceError
if self._data and self._receiver_closed:
raise _core.BrokenResourceError
async def wait_send_all_might_not_block(self):
with self._send_conflict_detector:
if self._sender_closed:
raise _core.ClosedResourceError
if self._receiver_closed:
await _core.checkpoint()
return
await self._wait_for(lambda: self._receiver_waiting)
if self._sender_closed:
raise _core.ClosedResourceError
async def receive_some(self, max_bytes=None):
with self._receive_conflict_detector:
# Argument validation
if max_bytes is not None:
max_bytes = operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
# State validation
if self._receiver_closed:
raise _core.ClosedResourceError
# Wake wait_send_all_might_not_block and wait for data
self._receiver_waiting = True
self._something_happened()
try:
await self._wait_for(lambda: self._data)
finally:
self._receiver_waiting = False
if self._receiver_closed:
raise _core.ClosedResourceError
# Get data, possibly waking send_all
if self._data:
# Neat trick: if max_bytes is None, then obj[:max_bytes] is
# the same as obj[:].
got = self._data[:max_bytes]
del self._data[:max_bytes]
self._something_happened()
return got
else:
assert self._sender_closed
return b""
class _LockstepSendStream(SendStream):
def __init__(self, lbq):
self._lbq = lbq
def close(self):
self._lbq.close_sender()
async def aclose(self):
self.close()
await _core.checkpoint()
async def send_all(self, data):
await self._lbq.send_all(data)
async def wait_send_all_might_not_block(self):
await self._lbq.wait_send_all_might_not_block()
class _LockstepReceiveStream(ReceiveStream):
def __init__(self, lbq):
self._lbq = lbq
def close(self):
self._lbq.close_receiver()
async def aclose(self):
self.close()
await _core.checkpoint()
async def receive_some(self, max_bytes=None):
return await self._lbq.receive_some(max_bytes)
def lockstep_stream_one_way_pair():
"""Create a connected, pure Python, unidirectional stream where data flows
in lockstep.
Returns:
A tuple
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`).
This stream has *absolutely no* buffering. Each call to
:meth:`~trio.abc.SendStream.send_all` will block until all the given data
has been returned by a call to
:meth:`~trio.abc.ReceiveStream.receive_some`.
This can be useful for testing flow control mechanisms in an extreme case,
or for setting up "clogged" streams to use with
:func:`check_one_way_stream` and friends.
In addition to fulfilling the :class:`~trio.abc.SendStream` and
:class:`~trio.abc.ReceiveStream` interfaces, the return objects
also have a synchronous ``close`` method.
"""
lbq = _LockstepByteQueue()
return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq)
def lockstep_stream_pair():
"""Create a connected, pure-Python, bidirectional stream where data flows
in lockstep.
Returns:
A tuple (:class:`~trio.StapledStream`, :class:`~trio.StapledStream`).
This is a convenience function that creates two one-way streams using
:func:`lockstep_stream_one_way_pair`, and then uses
:class:`~trio.StapledStream` to combine them into a single bidirectional
stream.
"""
return _make_stapled_pair(lockstep_stream_one_way_pair)

View File

@@ -0,0 +1,34 @@
from .. import socket as tsocket
from .._highlevel_socket import SocketStream
async def open_stream_to_socket_listener(socket_listener):
"""Connect to the given :class:`~trio.SocketListener`.
This is particularly useful in tests when you want to let a server pick
its own port, and then connect to it::
listeners = await trio.open_tcp_listeners(0)
client = await trio.testing.open_stream_to_socket_listener(listeners[0])
Args:
socket_listener (~trio.SocketListener): The
:class:`~trio.SocketListener` to connect to.
Returns:
SocketStream: a stream connected to the given listener.
"""
family = socket_listener.socket.family
sockaddr = socket_listener.socket.getsockname()
if family in (tsocket.AF_INET, tsocket.AF_INET6):
sockaddr = list(sockaddr)
if sockaddr[0] == "0.0.0.0":
sockaddr[0] = "127.0.0.1"
if sockaddr[0] == "::":
sockaddr[0] = "::1"
sockaddr = tuple(sockaddr)
sock = tsocket.socket(family=family)
await sock.connect(sockaddr)
return SocketStream(sock)

View File

@@ -0,0 +1,82 @@
from collections import defaultdict
import attr
from async_generator import asynccontextmanager
from .. import _core
from .. import _util
from .. import Event
if False:
from typing import DefaultDict, Set
@attr.s(eq=False, hash=False)
class Sequencer(metaclass=_util.Final):
"""A convenience class for forcing code in different tasks to run in an
explicit linear order.
Instances of this class implement a ``__call__`` method which returns an
async context manager. The idea is that you pass a sequence number to
``__call__`` to say where this block of code should go in the linear
sequence. Block 0 starts immediately, and then block N doesn't start until
block N-1 has finished.
Example:
An extremely elaborate way to print the numbers 0-5, in order::
async def worker1(seq):
async with seq(0):
print(0)
async with seq(4):
print(4)
async def worker2(seq):
async with seq(2):
print(2)
async with seq(5):
print(5)
async def worker3(seq):
async with seq(1):
print(1)
async with seq(3):
print(3)
async def main():
seq = trio.testing.Sequencer()
async with trio.open_nursery() as nursery:
nursery.start_soon(worker1, seq)
nursery.start_soon(worker2, seq)
nursery.start_soon(worker3, seq)
"""
_sequence_points = attr.ib(
factory=lambda: defaultdict(Event), init=False
) # type: DefaultDict[int, Event]
_claimed = attr.ib(factory=set, init=False) # type: Set[int]
_broken = attr.ib(default=False, init=False)
@asynccontextmanager
async def __call__(self, position: int):
if position in self._claimed:
raise RuntimeError("Attempted to re-use sequence point {}".format(position))
if self._broken:
raise RuntimeError("sequence broken!")
self._claimed.add(position)
if position != 0:
try:
await self._sequence_points[position].wait()
except _core.Cancelled:
self._broken = True
for event in self._sequence_points.values():
event.set()
raise RuntimeError("Sequencer wait cancelled -- sequence broken")
else:
if self._broken:
raise RuntimeError("sequence broken!")
try:
yield
finally:
self._sequence_points[position + 1].set()

View File

@@ -0,0 +1,29 @@
from functools import wraps, partial
from .. import _core
from ..abc import Clock, Instrument
# Use:
#
# @trio_test
# async def test_whatever():
# await ...
#
# Also: if a pytest fixture is passed in that subclasses the Clock abc, then
# that clock is passed to trio.run().
def trio_test(fn):
@wraps(fn)
def wrapper(**kwargs):
__tracebackhide__ = True
clocks = [c for c in kwargs.values() if isinstance(c, Clock)]
if not clocks:
clock = None
elif len(clocks) == 1:
clock = clocks[0]
else:
raise ValueError("too many clocks spoil the broth!")
instruments = [i for i in kwargs.values() if isinstance(i, Instrument)]
return _core.run(partial(fn, **kwargs), clock=clock, instruments=instruments)
return wrapper

View File

@@ -0,0 +1,41 @@
# XX this does not belong here -- b/c it's here, these things only apply to
# the tests in trio/_core/tests, not in trio/tests. For now there's some
# copy-paste...
#
# this stuff should become a proper pytest plugin
import pytest
import inspect
from ..testing import trio_test, MockClock
RUN_SLOW = True
def pytest_addoption(parser):
parser.addoption("--run-slow", action="store_true", help="run slow tests")
def pytest_configure(config):
global RUN_SLOW
RUN_SLOW = config.getoption("--run-slow", True)
@pytest.fixture
def mock_clock():
return MockClock()
@pytest.fixture
def autojump_clock():
return MockClock(autojump_threshold=0)
# FIXME: split off into a package (or just make part of Trio's public
# interface?), with config file to enable? and I guess a mark option too; I
# guess it's useful with the class- and file-level marking machinery (where
# the raw @trio_test decorator isn't enough).
@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem):
if inspect.iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj = trio_test(pyfuncitem.obj)

View File

@@ -0,0 +1,21 @@
regular = "hi"
from .. import _deprecate
_deprecate.enable_attribute_deprecations(__name__)
# Make sure that we don't trigger infinite recursion when accessing module
# attributes in between calling enable_attribute_deprecations and defining
# __deprecated_attributes__:
import sys
this_mod = sys.modules[__name__]
assert this_mod.regular == "hi"
assert not hasattr(this_mod, "dep1")
__deprecated_attributes__ = {
"dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1),
"dep2": _deprecate.DeprecatedAttribute(
"value2", "1.2", issue=1, instead="instead-string"
),
}

View File

@@ -0,0 +1,49 @@
import pytest
import attr
from ..testing import assert_checkpoints
from .. import abc as tabc
async def test_AsyncResource_defaults():
@attr.s
class MyAR(tabc.AsyncResource):
record = attr.ib(factory=list)
async def aclose(self):
self.record.append("ac")
async with MyAR() as myar:
assert isinstance(myar, MyAR)
assert myar.record == []
assert myar.record == ["ac"]
def test_abc_generics():
# Pythons below 3.5.2 had a typing.Generic that would throw
# errors when instantiating or subclassing a parameterized
# version of a class with any __slots__. This is why RunVar
# (which has slots) is not generic. This tests that
# the generic ABCs are fine, because while they are slotted
# they don't actually define any slots.
class SlottedChannel(tabc.SendChannel[tabc.Stream]):
__slots__ = ("x",)
def send_nowait(self, value):
raise RuntimeError
async def send(self, value):
raise RuntimeError # pragma: no cover
def clone(self):
raise RuntimeError # pragma: no cover
async def aclose(self):
pass # pragma: no cover
channel = SlottedChannel()
with pytest.raises(RuntimeError):
channel.send_nowait(None)

View File

@@ -0,0 +1,407 @@
import pytest
from ..testing import wait_all_tasks_blocked, assert_checkpoints
import trio
from trio import open_memory_channel, EndOfChannel
async def test_channel():
with pytest.raises(TypeError):
open_memory_channel(1.0)
with pytest.raises(ValueError):
open_memory_channel(-1)
s, r = open_memory_channel(2)
repr(s) # smoke test
repr(r) # smoke test
s.send_nowait(1)
with assert_checkpoints():
await s.send(2)
with pytest.raises(trio.WouldBlock):
s.send_nowait(None)
with assert_checkpoints():
assert await r.receive() == 1
assert r.receive_nowait() == 2
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
s.send_nowait("last")
await s.aclose()
with pytest.raises(trio.ClosedResourceError):
await s.send("too late")
with pytest.raises(trio.ClosedResourceError):
s.send_nowait("too late")
with pytest.raises(trio.ClosedResourceError):
s.clone()
await s.aclose()
assert r.receive_nowait() == "last"
with pytest.raises(EndOfChannel):
await r.receive()
await r.aclose()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
with pytest.raises(trio.ClosedResourceError):
await r.receive_nowait()
await r.aclose()
async def test_553(autojump_clock):
s, r = open_memory_channel(1)
with trio.move_on_after(10) as timeout_scope:
await r.receive()
assert timeout_scope.cancelled_caught
await s.send("Test for PR #553")
async def test_channel_multiple_producers():
async def producer(send_channel, i):
# We close our handle when we're done with it
async with send_channel:
for j in range(3 * i, 3 * (i + 1)):
await send_channel.send(j)
send_channel, receive_channel = open_memory_channel(0)
async with trio.open_nursery() as nursery:
# We hand out clones to all the new producers, and then close the
# original.
async with send_channel:
for i in range(10):
nursery.start_soon(producer, send_channel.clone(), i)
got = []
async for value in receive_channel:
got.append(value)
got.sort()
assert got == list(range(30))
async def test_channel_multiple_consumers():
successful_receivers = set()
received = []
async def consumer(receive_channel, i):
async for value in receive_channel:
successful_receivers.add(i)
received.append(value)
async with trio.open_nursery() as nursery:
send_channel, receive_channel = trio.open_memory_channel(1)
async with send_channel:
for i in range(5):
nursery.start_soon(consumer, receive_channel, i)
await wait_all_tasks_blocked()
for i in range(10):
await send_channel.send(i)
assert successful_receivers == set(range(5))
assert len(received) == 10
assert set(received) == set(range(10))
async def test_close_basics():
async def send_block(s, expect):
with pytest.raises(expect):
await s.send(None)
# closing send -> other send gets ClosedResourceError
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.ClosedResourceError)
await wait_all_tasks_blocked()
await s.aclose()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
s.send_nowait(None)
with pytest.raises(trio.ClosedResourceError):
await s.send(None)
# and receive gets EndOfChannel
with pytest.raises(EndOfChannel):
r.receive_nowait()
with pytest.raises(EndOfChannel):
await r.receive()
# closing receive -> send gets BrokenResourceError
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.BrokenResourceError)
await wait_all_tasks_blocked()
await r.aclose()
# and it's persistent
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
with pytest.raises(trio.BrokenResourceError):
await s.send(None)
# closing receive -> other receive gets ClosedResourceError
async def receive_block(r):
with pytest.raises(trio.ClosedResourceError):
await r.receive()
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_block, r)
await wait_all_tasks_blocked()
await r.aclose()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
r.receive_nowait()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
async def test_close_sync():
async def send_block(s, expect):
with pytest.raises(expect):
await s.send(None)
# closing send -> other send gets ClosedResourceError
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.ClosedResourceError)
await wait_all_tasks_blocked()
s.close()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
s.send_nowait(None)
with pytest.raises(trio.ClosedResourceError):
await s.send(None)
# and receive gets EndOfChannel
with pytest.raises(EndOfChannel):
r.receive_nowait()
with pytest.raises(EndOfChannel):
await r.receive()
# closing receive -> send gets BrokenResourceError
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.BrokenResourceError)
await wait_all_tasks_blocked()
r.close()
# and it's persistent
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
with pytest.raises(trio.BrokenResourceError):
await s.send(None)
# closing receive -> other receive gets ClosedResourceError
async def receive_block(r):
with pytest.raises(trio.ClosedResourceError):
await r.receive()
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_block, r)
await wait_all_tasks_blocked()
r.close()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
r.receive_nowait()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
async def test_receive_channel_clone_and_close():
s, r = open_memory_channel(10)
r2 = r.clone()
r3 = r.clone()
s.send_nowait(None)
await r.aclose()
with r2:
pass
with pytest.raises(trio.ClosedResourceError):
r.clone()
with pytest.raises(trio.ClosedResourceError):
r2.clone()
# Can still send, r3 is still open
s.send_nowait(None)
await r3.aclose()
# But now the receiver is really closed
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
async def test_close_multiple_send_handles():
# With multiple send handles, closing one handle only wakes senders on
# that handle, but others can continue just fine
s1, r = open_memory_channel(0)
s2 = s1.clone()
async def send_will_close():
with pytest.raises(trio.ClosedResourceError):
await s1.send("nope")
async def send_will_succeed():
await s2.send("ok")
async with trio.open_nursery() as nursery:
nursery.start_soon(send_will_close)
nursery.start_soon(send_will_succeed)
await wait_all_tasks_blocked()
await s1.aclose()
assert await r.receive() == "ok"
async def test_close_multiple_receive_handles():
# With multiple receive handles, closing one handle only wakes receivers on
# that handle, but others can continue just fine
s, r1 = open_memory_channel(0)
r2 = r1.clone()
async def receive_will_close():
with pytest.raises(trio.ClosedResourceError):
await r1.receive()
async def receive_will_succeed():
assert await r2.receive() == "ok"
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_will_close)
nursery.start_soon(receive_will_succeed)
await wait_all_tasks_blocked()
await r1.aclose()
await s.send("ok")
async def test_inf_capacity():
s, r = open_memory_channel(float("inf"))
# It's accepted, and we can send all day without blocking
with s:
for i in range(10):
s.send_nowait(i)
got = []
async for i in r:
got.append(i)
assert got == list(range(10))
async def test_statistics():
s, r = open_memory_channel(2)
assert s.statistics() == r.statistics()
stats = s.statistics()
assert stats.current_buffer_used == 0
assert stats.max_buffer_size == 2
assert stats.open_send_channels == 1
assert stats.open_receive_channels == 1
assert stats.tasks_waiting_send == 0
assert stats.tasks_waiting_receive == 0
s.send_nowait(None)
assert s.statistics().current_buffer_used == 1
s2 = s.clone()
assert s.statistics().open_send_channels == 2
await s.aclose()
assert s2.statistics().open_send_channels == 1
r2 = r.clone()
assert s2.statistics().open_receive_channels == 2
await r2.aclose()
assert s2.statistics().open_receive_channels == 1
async with trio.open_nursery() as nursery:
s2.send_nowait(None) # fill up the buffer
assert s.statistics().current_buffer_used == 2
nursery.start_soon(s2.send, None)
nursery.start_soon(s2.send, None)
await wait_all_tasks_blocked()
assert s.statistics().tasks_waiting_send == 2
nursery.cancel_scope.cancel()
assert s.statistics().tasks_waiting_send == 0
# empty out the buffer again
try:
while True:
r.receive_nowait()
except trio.WouldBlock:
pass
async with trio.open_nursery() as nursery:
nursery.start_soon(r.receive)
await wait_all_tasks_blocked()
assert s.statistics().tasks_waiting_receive == 1
nursery.cancel_scope.cancel()
assert s.statistics().tasks_waiting_receive == 0
async def test_channel_fairness():
# We can remove an item we just sent, and send an item back in after, if
# no-one else is waiting.
s, r = open_memory_channel(1)
s.send_nowait(1)
assert r.receive_nowait() == 1
s.send_nowait(2)
assert r.receive_nowait() == 2
# But if someone else is waiting to receive, then they "own" the item we
# send, so we can't receive it (even though we run first):
result = None
async def do_receive(r):
nonlocal result
result = await r.receive()
async with trio.open_nursery() as nursery:
nursery.start_soon(do_receive, r)
await wait_all_tasks_blocked()
s.send_nowait(2)
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
assert result == 2
# And the analogous situation for send: if we free up a space, we can't
# immediately send something in it if someone is already waiting to do
# that
s, r = open_memory_channel(1)
s.send_nowait(1)
with pytest.raises(trio.WouldBlock):
s.send_nowait(None)
async with trio.open_nursery() as nursery:
nursery.start_soon(s.send, 2)
await wait_all_tasks_blocked()
assert r.receive_nowait() == 1
with pytest.raises(trio.WouldBlock):
s.send_nowait(3)
assert (await r.receive()) == 2
async def test_unbuffered():
s, r = open_memory_channel(0)
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
with pytest.raises(trio.WouldBlock):
s.send_nowait(1)
async def do_send(s, v):
with assert_checkpoints():
await s.send(v)
async with trio.open_nursery() as nursery:
nursery.start_soon(do_send, s, 1)
with assert_checkpoints():
assert await r.receive() == 1
with pytest.raises(trio.WouldBlock):
r.receive_nowait()

View File

@@ -0,0 +1,52 @@
import contextvars
from .. import _core
trio_testing_contextvar = contextvars.ContextVar("trio_testing_contextvar")
async def test_contextvars_default():
trio_testing_contextvar.set("main")
record = []
async def child():
value = trio_testing_contextvar.get()
record.append(value)
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
assert record == ["main"]
async def test_contextvars_set():
trio_testing_contextvar.set("main")
record = []
async def child():
trio_testing_contextvar.set("child")
value = trio_testing_contextvar.get()
record.append(value)
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
value = trio_testing_contextvar.get()
assert record == ["child"]
assert value == "main"
async def test_contextvars_copy():
trio_testing_contextvar.set("main")
context = contextvars.copy_context()
trio_testing_contextvar.set("second_main")
record = []
async def child():
value = trio_testing_contextvar.get()
record.append(value)
async with _core.open_nursery() as nursery:
context.run(nursery.start_soon, child)
nursery.start_soon(child)
value = trio_testing_contextvar.get()
assert set(record) == {"main", "second_main"}
assert value == "second_main"

View File

@@ -0,0 +1,243 @@
import pytest
import inspect
import warnings
from .._deprecate import (
TrioDeprecationWarning,
warn_deprecated,
deprecated,
deprecated_alias,
)
from . import module_with_deprecations
@pytest.fixture
def recwarn_always(recwarn):
warnings.simplefilter("always")
# ResourceWarnings about unclosed sockets can occur nondeterministically
# (during GC) which throws off the tests in this file
warnings.simplefilter("ignore", ResourceWarning)
return recwarn
def _here():
info = inspect.getframeinfo(inspect.currentframe().f_back)
return (info.filename, info.lineno)
def test_warn_deprecated(recwarn_always):
def deprecated_thing():
warn_deprecated("ice", "1.2", issue=1, instead="water")
deprecated_thing()
filename, lineno = _here()
assert len(recwarn_always) == 1
got = recwarn_always.pop(TrioDeprecationWarning)
assert "ice is deprecated" in got.message.args[0]
assert "Trio 1.2" in got.message.args[0]
assert "water instead" in got.message.args[0]
assert "/issues/1" in got.message.args[0]
assert got.filename == filename
assert got.lineno == lineno - 1
def test_warn_deprecated_no_instead_or_issue(recwarn_always):
# Explicitly no instead or issue
warn_deprecated("water", "1.3", issue=None, instead=None)
assert len(recwarn_always) == 1
got = recwarn_always.pop(TrioDeprecationWarning)
assert "water is deprecated" in got.message.args[0]
assert "no replacement" in got.message.args[0]
assert "Trio 1.3" in got.message.args[0]
def test_warn_deprecated_stacklevel(recwarn_always):
def nested1():
nested2()
def nested2():
warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3)
filename, lineno = _here()
nested1()
got = recwarn_always.pop(TrioDeprecationWarning)
assert got.filename == filename
assert got.lineno == lineno + 1
def old(): # pragma: no cover
pass
def new(): # pragma: no cover
pass
def test_warn_deprecated_formatting(recwarn_always):
warn_deprecated(old, "1.0", issue=1, instead=new)
got = recwarn_always.pop(TrioDeprecationWarning)
assert "test_deprecate.old is deprecated" in got.message.args[0]
assert "test_deprecate.new instead" in got.message.args[0]
@deprecated("1.5", issue=123, instead=new)
def deprecated_old():
return 3
def test_deprecated_decorator(recwarn_always):
assert deprecated_old() == 3
got = recwarn_always.pop(TrioDeprecationWarning)
assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0]
assert "1.5" in got.message.args[0]
assert "test_deprecate.new" in got.message.args[0]
assert "issues/123" in got.message.args[0]
class Foo:
@deprecated("1.0", issue=123, instead="crying")
def method(self):
return 7
def test_deprecated_decorator_method(recwarn_always):
f = Foo()
assert f.method() == 7
got = recwarn_always.pop(TrioDeprecationWarning)
assert "test_deprecate.Foo.method is deprecated" in got.message.args[0]
@deprecated("1.2", thing="the thing", issue=None, instead=None)
def deprecated_with_thing():
return 72
def test_deprecated_decorator_with_explicit_thing(recwarn_always):
assert deprecated_with_thing() == 72
got = recwarn_always.pop(TrioDeprecationWarning)
assert "the thing is deprecated" in got.message.args[0]
def new_hotness():
return "new hotness"
old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1)
def test_deprecated_alias(recwarn_always):
assert old_hotness() == "new hotness"
got = recwarn_always.pop(TrioDeprecationWarning)
assert "test_deprecate.old_hotness is deprecated" in got.message.args[0]
assert "1.23" in got.message.args[0]
assert "test_deprecate.new_hotness instead" in got.message.args[0]
assert "issues/1" in got.message.args[0]
assert ".. deprecated:: 1.23" in old_hotness.__doc__
assert "test_deprecate.new_hotness instead" in old_hotness.__doc__
assert "issues/1>`__" in old_hotness.__doc__
class Alias:
def new_hotness_method(self):
return "new hotness method"
old_hotness_method = deprecated_alias(
"Alias.old_hotness_method", new_hotness_method, "3.21", issue=1
)
def test_deprecated_alias_method(recwarn_always):
obj = Alias()
assert obj.old_hotness_method() == "new hotness method"
got = recwarn_always.pop(TrioDeprecationWarning)
msg = got.message.args[0]
assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg
assert "test_deprecate.Alias.new_hotness_method instead" in msg
@deprecated("2.1", issue=1, instead="hi")
def docstring_test1(): # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=None, instead="hi")
def docstring_test2(): # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=1, instead=None)
def docstring_test3(): # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=None, instead=None)
def docstring_test4(): # pragma: no cover
"""Hello!"""
def test_deprecated_docstring_munging():
assert (
docstring_test1.__doc__
== """Hello!
.. deprecated:: 2.1
Use hi instead.
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
"""
)
assert (
docstring_test2.__doc__
== """Hello!
.. deprecated:: 2.1
Use hi instead.
"""
)
assert (
docstring_test3.__doc__
== """Hello!
.. deprecated:: 2.1
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
"""
)
assert (
docstring_test4.__doc__
== """Hello!
.. deprecated:: 2.1
"""
)
def test_module_with_deprecations(recwarn_always):
assert module_with_deprecations.regular == "hi"
assert len(recwarn_always) == 0
filename, lineno = _here()
assert module_with_deprecations.dep1 == "value1"
got = recwarn_always.pop(TrioDeprecationWarning)
assert got.filename == filename
assert got.lineno == lineno + 1
assert "module_with_deprecations.dep1" in got.message.args[0]
assert "Trio 1.1" in got.message.args[0]
assert "/issues/1" in got.message.args[0]
assert "value1 instead" in got.message.args[0]
assert module_with_deprecations.dep2 == "value2"
got = recwarn_always.pop(TrioDeprecationWarning)
assert "instead-string instead" in got.message.args[0]
with pytest.raises(AttributeError):
module_with_deprecations.asdf

Some files were not shown because too many files have changed in this diff Show More