second commit
This commit is contained in:
0
venv/Lib/site-packages/trio/tests/__init__.py
Normal file
0
venv/Lib/site-packages/trio/tests/__init__.py
Normal file
41
venv/Lib/site-packages/trio/tests/conftest.py
Normal file
41
venv/Lib/site-packages/trio/tests/conftest.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# XX this does not belong here -- b/c it's here, these things only apply to
|
||||
# the tests in trio/_core/tests, not in trio/tests. For now there's some
|
||||
# copy-paste...
|
||||
#
|
||||
# this stuff should become a proper pytest plugin
|
||||
|
||||
import pytest
|
||||
import inspect
|
||||
|
||||
from ..testing import trio_test, MockClock
|
||||
|
||||
RUN_SLOW = True
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--run-slow", action="store_true", help="run slow tests")
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
global RUN_SLOW
|
||||
RUN_SLOW = config.getoption("--run-slow", True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clock():
|
||||
return MockClock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def autojump_clock():
|
||||
return MockClock(autojump_threshold=0)
|
||||
|
||||
|
||||
# FIXME: split off into a package (or just make part of Trio's public
|
||||
# interface?), with config file to enable? and I guess a mark option too; I
|
||||
# guess it's useful with the class- and file-level marking machinery (where
|
||||
# the raw @trio_test decorator isn't enough).
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
def pytest_pyfunc_call(pyfuncitem):
|
||||
if inspect.iscoroutinefunction(pyfuncitem.obj):
|
||||
pyfuncitem.obj = trio_test(pyfuncitem.obj)
|
||||
@@ -0,0 +1,21 @@
|
||||
regular = "hi"
|
||||
|
||||
from .. import _deprecate
|
||||
|
||||
_deprecate.enable_attribute_deprecations(__name__)
|
||||
|
||||
# Make sure that we don't trigger infinite recursion when accessing module
|
||||
# attributes in between calling enable_attribute_deprecations and defining
|
||||
# __deprecated_attributes__:
|
||||
import sys
|
||||
|
||||
this_mod = sys.modules[__name__]
|
||||
assert this_mod.regular == "hi"
|
||||
assert not hasattr(this_mod, "dep1")
|
||||
|
||||
__deprecated_attributes__ = {
|
||||
"dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1),
|
||||
"dep2": _deprecate.DeprecatedAttribute(
|
||||
"value2", "1.2", issue=1, instead="instead-string"
|
||||
),
|
||||
}
|
||||
49
venv/Lib/site-packages/trio/tests/test_abc.py
Normal file
49
venv/Lib/site-packages/trio/tests/test_abc.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
|
||||
import attr
|
||||
|
||||
from ..testing import assert_checkpoints
|
||||
from .. import abc as tabc
|
||||
|
||||
|
||||
async def test_AsyncResource_defaults():
|
||||
@attr.s
|
||||
class MyAR(tabc.AsyncResource):
|
||||
record = attr.ib(factory=list)
|
||||
|
||||
async def aclose(self):
|
||||
self.record.append("ac")
|
||||
|
||||
async with MyAR() as myar:
|
||||
assert isinstance(myar, MyAR)
|
||||
assert myar.record == []
|
||||
|
||||
assert myar.record == ["ac"]
|
||||
|
||||
|
||||
def test_abc_generics():
|
||||
# Pythons below 3.5.2 had a typing.Generic that would throw
|
||||
# errors when instantiating or subclassing a parameterized
|
||||
# version of a class with any __slots__. This is why RunVar
|
||||
# (which has slots) is not generic. This tests that
|
||||
# the generic ABCs are fine, because while they are slotted
|
||||
# they don't actually define any slots.
|
||||
|
||||
class SlottedChannel(tabc.SendChannel[tabc.Stream]):
|
||||
__slots__ = ("x",)
|
||||
|
||||
def send_nowait(self, value):
|
||||
raise RuntimeError
|
||||
|
||||
async def send(self, value):
|
||||
raise RuntimeError # pragma: no cover
|
||||
|
||||
def clone(self):
|
||||
raise RuntimeError # pragma: no cover
|
||||
|
||||
async def aclose(self):
|
||||
pass # pragma: no cover
|
||||
|
||||
channel = SlottedChannel()
|
||||
with pytest.raises(RuntimeError):
|
||||
channel.send_nowait(None)
|
||||
407
venv/Lib/site-packages/trio/tests/test_channel.py
Normal file
407
venv/Lib/site-packages/trio/tests/test_channel.py
Normal file
@@ -0,0 +1,407 @@
|
||||
import pytest
|
||||
|
||||
from ..testing import wait_all_tasks_blocked, assert_checkpoints
|
||||
import trio
|
||||
from trio import open_memory_channel, EndOfChannel
|
||||
|
||||
|
||||
async def test_channel():
|
||||
with pytest.raises(TypeError):
|
||||
open_memory_channel(1.0)
|
||||
with pytest.raises(ValueError):
|
||||
open_memory_channel(-1)
|
||||
|
||||
s, r = open_memory_channel(2)
|
||||
repr(s) # smoke test
|
||||
repr(r) # smoke test
|
||||
|
||||
s.send_nowait(1)
|
||||
with assert_checkpoints():
|
||||
await s.send(2)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(None)
|
||||
|
||||
with assert_checkpoints():
|
||||
assert await r.receive() == 1
|
||||
assert r.receive_nowait() == 2
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
|
||||
s.send_nowait("last")
|
||||
await s.aclose()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send("too late")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait("too late")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.clone()
|
||||
await s.aclose()
|
||||
|
||||
assert r.receive_nowait() == "last"
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
await r.aclose()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive_nowait()
|
||||
await r.aclose()
|
||||
|
||||
|
||||
async def test_553(autojump_clock):
|
||||
s, r = open_memory_channel(1)
|
||||
with trio.move_on_after(10) as timeout_scope:
|
||||
await r.receive()
|
||||
assert timeout_scope.cancelled_caught
|
||||
await s.send("Test for PR #553")
|
||||
|
||||
|
||||
async def test_channel_multiple_producers():
|
||||
async def producer(send_channel, i):
|
||||
# We close our handle when we're done with it
|
||||
async with send_channel:
|
||||
for j in range(3 * i, 3 * (i + 1)):
|
||||
await send_channel.send(j)
|
||||
|
||||
send_channel, receive_channel = open_memory_channel(0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
# We hand out clones to all the new producers, and then close the
|
||||
# original.
|
||||
async with send_channel:
|
||||
for i in range(10):
|
||||
nursery.start_soon(producer, send_channel.clone(), i)
|
||||
|
||||
got = []
|
||||
async for value in receive_channel:
|
||||
got.append(value)
|
||||
|
||||
got.sort()
|
||||
assert got == list(range(30))
|
||||
|
||||
|
||||
async def test_channel_multiple_consumers():
|
||||
successful_receivers = set()
|
||||
received = []
|
||||
|
||||
async def consumer(receive_channel, i):
|
||||
async for value in receive_channel:
|
||||
successful_receivers.add(i)
|
||||
received.append(value)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
send_channel, receive_channel = trio.open_memory_channel(1)
|
||||
async with send_channel:
|
||||
for i in range(5):
|
||||
nursery.start_soon(consumer, receive_channel, i)
|
||||
await wait_all_tasks_blocked()
|
||||
for i in range(10):
|
||||
await send_channel.send(i)
|
||||
|
||||
assert successful_receivers == set(range(5))
|
||||
assert len(received) == 10
|
||||
assert set(received) == set(range(10))
|
||||
|
||||
|
||||
async def test_close_basics():
|
||||
async def send_block(s, expect):
|
||||
with pytest.raises(expect):
|
||||
await s.send(None)
|
||||
|
||||
# closing send -> other send gets ClosedResourceError
|
||||
s, r = open_memory_channel(0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.ClosedResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
await s.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# and receive gets EndOfChannel
|
||||
with pytest.raises(EndOfChannel):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
|
||||
# closing receive -> send gets BrokenResourceError
|
||||
s, r = open_memory_channel(0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.BrokenResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
await r.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# closing receive -> other receive gets ClosedResourceError
|
||||
async def receive_block(r):
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
s, r = open_memory_channel(0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_block, r)
|
||||
await wait_all_tasks_blocked()
|
||||
await r.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
|
||||
async def test_close_sync():
|
||||
async def send_block(s, expect):
|
||||
with pytest.raises(expect):
|
||||
await s.send(None)
|
||||
|
||||
# closing send -> other send gets ClosedResourceError
|
||||
s, r = open_memory_channel(0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.ClosedResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# and receive gets EndOfChannel
|
||||
with pytest.raises(EndOfChannel):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
|
||||
# closing receive -> send gets BrokenResourceError
|
||||
s, r = open_memory_channel(0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.BrokenResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
r.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# closing receive -> other receive gets ClosedResourceError
|
||||
async def receive_block(r):
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
s, r = open_memory_channel(0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_block, r)
|
||||
await wait_all_tasks_blocked()
|
||||
r.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
|
||||
async def test_receive_channel_clone_and_close():
|
||||
s, r = open_memory_channel(10)
|
||||
|
||||
r2 = r.clone()
|
||||
r3 = r.clone()
|
||||
|
||||
s.send_nowait(None)
|
||||
await r.aclose()
|
||||
with r2:
|
||||
pass
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.clone()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r2.clone()
|
||||
|
||||
# Can still send, r3 is still open
|
||||
s.send_nowait(None)
|
||||
|
||||
await r3.aclose()
|
||||
|
||||
# But now the receiver is really closed
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
|
||||
|
||||
async def test_close_multiple_send_handles():
|
||||
# With multiple send handles, closing one handle only wakes senders on
|
||||
# that handle, but others can continue just fine
|
||||
s1, r = open_memory_channel(0)
|
||||
s2 = s1.clone()
|
||||
|
||||
async def send_will_close():
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s1.send("nope")
|
||||
|
||||
async def send_will_succeed():
|
||||
await s2.send("ok")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_will_close)
|
||||
nursery.start_soon(send_will_succeed)
|
||||
await wait_all_tasks_blocked()
|
||||
await s1.aclose()
|
||||
assert await r.receive() == "ok"
|
||||
|
||||
|
||||
async def test_close_multiple_receive_handles():
|
||||
# With multiple receive handles, closing one handle only wakes receivers on
|
||||
# that handle, but others can continue just fine
|
||||
s, r1 = open_memory_channel(0)
|
||||
r2 = r1.clone()
|
||||
|
||||
async def receive_will_close():
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r1.receive()
|
||||
|
||||
async def receive_will_succeed():
|
||||
assert await r2.receive() == "ok"
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_will_close)
|
||||
nursery.start_soon(receive_will_succeed)
|
||||
await wait_all_tasks_blocked()
|
||||
await r1.aclose()
|
||||
await s.send("ok")
|
||||
|
||||
|
||||
async def test_inf_capacity():
|
||||
s, r = open_memory_channel(float("inf"))
|
||||
|
||||
# It's accepted, and we can send all day without blocking
|
||||
with s:
|
||||
for i in range(10):
|
||||
s.send_nowait(i)
|
||||
|
||||
got = []
|
||||
async for i in r:
|
||||
got.append(i)
|
||||
assert got == list(range(10))
|
||||
|
||||
|
||||
async def test_statistics():
|
||||
s, r = open_memory_channel(2)
|
||||
|
||||
assert s.statistics() == r.statistics()
|
||||
stats = s.statistics()
|
||||
assert stats.current_buffer_used == 0
|
||||
assert stats.max_buffer_size == 2
|
||||
assert stats.open_send_channels == 1
|
||||
assert stats.open_receive_channels == 1
|
||||
assert stats.tasks_waiting_send == 0
|
||||
assert stats.tasks_waiting_receive == 0
|
||||
|
||||
s.send_nowait(None)
|
||||
assert s.statistics().current_buffer_used == 1
|
||||
|
||||
s2 = s.clone()
|
||||
assert s.statistics().open_send_channels == 2
|
||||
await s.aclose()
|
||||
assert s2.statistics().open_send_channels == 1
|
||||
|
||||
r2 = r.clone()
|
||||
assert s2.statistics().open_receive_channels == 2
|
||||
await r2.aclose()
|
||||
assert s2.statistics().open_receive_channels == 1
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
s2.send_nowait(None) # fill up the buffer
|
||||
assert s.statistics().current_buffer_used == 2
|
||||
nursery.start_soon(s2.send, None)
|
||||
nursery.start_soon(s2.send, None)
|
||||
await wait_all_tasks_blocked()
|
||||
assert s.statistics().tasks_waiting_send == 2
|
||||
nursery.cancel_scope.cancel()
|
||||
assert s.statistics().tasks_waiting_send == 0
|
||||
|
||||
# empty out the buffer again
|
||||
try:
|
||||
while True:
|
||||
r.receive_nowait()
|
||||
except trio.WouldBlock:
|
||||
pass
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(r.receive)
|
||||
await wait_all_tasks_blocked()
|
||||
assert s.statistics().tasks_waiting_receive == 1
|
||||
nursery.cancel_scope.cancel()
|
||||
assert s.statistics().tasks_waiting_receive == 0
|
||||
|
||||
|
||||
async def test_channel_fairness():
|
||||
|
||||
# We can remove an item we just sent, and send an item back in after, if
|
||||
# no-one else is waiting.
|
||||
s, r = open_memory_channel(1)
|
||||
s.send_nowait(1)
|
||||
assert r.receive_nowait() == 1
|
||||
s.send_nowait(2)
|
||||
assert r.receive_nowait() == 2
|
||||
|
||||
# But if someone else is waiting to receive, then they "own" the item we
|
||||
# send, so we can't receive it (even though we run first):
|
||||
|
||||
result = None
|
||||
|
||||
async def do_receive(r):
|
||||
nonlocal result
|
||||
result = await r.receive()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive, r)
|
||||
await wait_all_tasks_blocked()
|
||||
s.send_nowait(2)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
assert result == 2
|
||||
|
||||
# And the analogous situation for send: if we free up a space, we can't
|
||||
# immediately send something in it if someone is already waiting to do
|
||||
# that
|
||||
s, r = open_memory_channel(1)
|
||||
s.send_nowait(1)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(None)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(s.send, 2)
|
||||
await wait_all_tasks_blocked()
|
||||
assert r.receive_nowait() == 1
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(3)
|
||||
assert (await r.receive()) == 2
|
||||
|
||||
|
||||
async def test_unbuffered():
|
||||
s, r = open_memory_channel(0)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(1)
|
||||
|
||||
async def do_send(s, v):
|
||||
with assert_checkpoints():
|
||||
await s.send(v)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send, s, 1)
|
||||
with assert_checkpoints():
|
||||
assert await r.receive() == 1
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
52
venv/Lib/site-packages/trio/tests/test_contextvars.py
Normal file
52
venv/Lib/site-packages/trio/tests/test_contextvars.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import contextvars
|
||||
|
||||
from .. import _core
|
||||
|
||||
trio_testing_contextvar = contextvars.ContextVar("trio_testing_contextvar")
|
||||
|
||||
|
||||
async def test_contextvars_default():
|
||||
trio_testing_contextvar.set("main")
|
||||
record = []
|
||||
|
||||
async def child():
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
assert record == ["main"]
|
||||
|
||||
|
||||
async def test_contextvars_set():
|
||||
trio_testing_contextvar.set("main")
|
||||
record = []
|
||||
|
||||
async def child():
|
||||
trio_testing_contextvar.set("child")
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
value = trio_testing_contextvar.get()
|
||||
assert record == ["child"]
|
||||
assert value == "main"
|
||||
|
||||
|
||||
async def test_contextvars_copy():
|
||||
trio_testing_contextvar.set("main")
|
||||
context = contextvars.copy_context()
|
||||
trio_testing_contextvar.set("second_main")
|
||||
record = []
|
||||
|
||||
async def child():
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
context.run(nursery.start_soon, child)
|
||||
nursery.start_soon(child)
|
||||
value = trio_testing_contextvar.get()
|
||||
assert set(record) == {"main", "second_main"}
|
||||
assert value == "second_main"
|
||||
243
venv/Lib/site-packages/trio/tests/test_deprecate.py
Normal file
243
venv/Lib/site-packages/trio/tests/test_deprecate.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import pytest
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
from .._deprecate import (
|
||||
TrioDeprecationWarning,
|
||||
warn_deprecated,
|
||||
deprecated,
|
||||
deprecated_alias,
|
||||
)
|
||||
|
||||
from . import module_with_deprecations
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recwarn_always(recwarn):
|
||||
warnings.simplefilter("always")
|
||||
# ResourceWarnings about unclosed sockets can occur nondeterministically
|
||||
# (during GC) which throws off the tests in this file
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
return recwarn
|
||||
|
||||
|
||||
def _here():
|
||||
info = inspect.getframeinfo(inspect.currentframe().f_back)
|
||||
return (info.filename, info.lineno)
|
||||
|
||||
|
||||
def test_warn_deprecated(recwarn_always):
|
||||
def deprecated_thing():
|
||||
warn_deprecated("ice", "1.2", issue=1, instead="water")
|
||||
|
||||
deprecated_thing()
|
||||
filename, lineno = _here()
|
||||
assert len(recwarn_always) == 1
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert "ice is deprecated" in got.message.args[0]
|
||||
assert "Trio 1.2" in got.message.args[0]
|
||||
assert "water instead" in got.message.args[0]
|
||||
assert "/issues/1" in got.message.args[0]
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno - 1
|
||||
|
||||
|
||||
def test_warn_deprecated_no_instead_or_issue(recwarn_always):
|
||||
# Explicitly no instead or issue
|
||||
warn_deprecated("water", "1.3", issue=None, instead=None)
|
||||
assert len(recwarn_always) == 1
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert "water is deprecated" in got.message.args[0]
|
||||
assert "no replacement" in got.message.args[0]
|
||||
assert "Trio 1.3" in got.message.args[0]
|
||||
|
||||
|
||||
def test_warn_deprecated_stacklevel(recwarn_always):
|
||||
def nested1():
|
||||
nested2()
|
||||
|
||||
def nested2():
|
||||
warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3)
|
||||
|
||||
filename, lineno = _here()
|
||||
nested1()
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno + 1
|
||||
|
||||
|
||||
def old(): # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def new(): # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def test_warn_deprecated_formatting(recwarn_always):
|
||||
warn_deprecated(old, "1.0", issue=1, instead=new)
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert "test_deprecate.old is deprecated" in got.message.args[0]
|
||||
assert "test_deprecate.new instead" in got.message.args[0]
|
||||
|
||||
|
||||
@deprecated("1.5", issue=123, instead=new)
|
||||
def deprecated_old():
|
||||
return 3
|
||||
|
||||
|
||||
def test_deprecated_decorator(recwarn_always):
|
||||
assert deprecated_old() == 3
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0]
|
||||
assert "1.5" in got.message.args[0]
|
||||
assert "test_deprecate.new" in got.message.args[0]
|
||||
assert "issues/123" in got.message.args[0]
|
||||
|
||||
|
||||
class Foo:
|
||||
@deprecated("1.0", issue=123, instead="crying")
|
||||
def method(self):
|
||||
return 7
|
||||
|
||||
|
||||
def test_deprecated_decorator_method(recwarn_always):
|
||||
f = Foo()
|
||||
assert f.method() == 7
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert "test_deprecate.Foo.method is deprecated" in got.message.args[0]
|
||||
|
||||
|
||||
@deprecated("1.2", thing="the thing", issue=None, instead=None)
|
||||
def deprecated_with_thing():
|
||||
return 72
|
||||
|
||||
|
||||
def test_deprecated_decorator_with_explicit_thing(recwarn_always):
|
||||
assert deprecated_with_thing() == 72
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert "the thing is deprecated" in got.message.args[0]
|
||||
|
||||
|
||||
def new_hotness():
|
||||
return "new hotness"
|
||||
|
||||
|
||||
old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1)
|
||||
|
||||
|
||||
def test_deprecated_alias(recwarn_always):
|
||||
assert old_hotness() == "new hotness"
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert "test_deprecate.old_hotness is deprecated" in got.message.args[0]
|
||||
assert "1.23" in got.message.args[0]
|
||||
assert "test_deprecate.new_hotness instead" in got.message.args[0]
|
||||
assert "issues/1" in got.message.args[0]
|
||||
|
||||
assert ".. deprecated:: 1.23" in old_hotness.__doc__
|
||||
assert "test_deprecate.new_hotness instead" in old_hotness.__doc__
|
||||
assert "issues/1>`__" in old_hotness.__doc__
|
||||
|
||||
|
||||
class Alias:
|
||||
def new_hotness_method(self):
|
||||
return "new hotness method"
|
||||
|
||||
old_hotness_method = deprecated_alias(
|
||||
"Alias.old_hotness_method", new_hotness_method, "3.21", issue=1
|
||||
)
|
||||
|
||||
|
||||
def test_deprecated_alias_method(recwarn_always):
|
||||
obj = Alias()
|
||||
assert obj.old_hotness_method() == "new hotness method"
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
msg = got.message.args[0]
|
||||
assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg
|
||||
assert "test_deprecate.Alias.new_hotness_method instead" in msg
|
||||
|
||||
|
||||
@deprecated("2.1", issue=1, instead="hi")
|
||||
def docstring_test1(): # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=None, instead="hi")
|
||||
def docstring_test2(): # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=1, instead=None)
|
||||
def docstring_test3(): # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=None, instead=None)
|
||||
def docstring_test4(): # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
def test_deprecated_docstring_munging():
|
||||
assert (
|
||||
docstring_test1.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
Use hi instead.
|
||||
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
assert (
|
||||
docstring_test2.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
Use hi instead.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
assert (
|
||||
docstring_test3.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
assert (
|
||||
docstring_test4.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_module_with_deprecations(recwarn_always):
|
||||
assert module_with_deprecations.regular == "hi"
|
||||
assert len(recwarn_always) == 0
|
||||
|
||||
filename, lineno = _here()
|
||||
assert module_with_deprecations.dep1 == "value1"
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno + 1
|
||||
|
||||
assert "module_with_deprecations.dep1" in got.message.args[0]
|
||||
assert "Trio 1.1" in got.message.args[0]
|
||||
assert "/issues/1" in got.message.args[0]
|
||||
assert "value1 instead" in got.message.args[0]
|
||||
|
||||
assert module_with_deprecations.dep2 == "value2"
|
||||
got = recwarn_always.pop(TrioDeprecationWarning)
|
||||
assert "instead-string instead" in got.message.args[0]
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
module_with_deprecations.asdf
|
||||
156
venv/Lib/site-packages/trio/tests/test_exports.py
Normal file
156
venv/Lib/site-packages/trio/tests/test_exports.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import re
|
||||
import sys
|
||||
import importlib
|
||||
import types
|
||||
import inspect
|
||||
import enum
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
|
||||
from .. import _core
|
||||
from .. import _util
|
||||
|
||||
|
||||
def test_core_is_properly_reexported():
|
||||
# Each export from _core should be re-exported by exactly one of these
|
||||
# three modules:
|
||||
sources = [trio, trio.lowlevel, trio.testing]
|
||||
for symbol in dir(_core):
|
||||
if symbol.startswith("_") or symbol == "tests":
|
||||
continue
|
||||
found = 0
|
||||
for source in sources:
|
||||
if symbol in dir(source) and getattr(source, symbol) is getattr(
|
||||
_core, symbol
|
||||
):
|
||||
found += 1
|
||||
print(symbol, found)
|
||||
assert found == 1
|
||||
|
||||
|
||||
def public_modules(module):
|
||||
yield module
|
||||
for name, class_ in module.__dict__.items():
|
||||
if name.startswith("_"): # pragma: no cover
|
||||
continue
|
||||
if not isinstance(class_, types.ModuleType):
|
||||
continue
|
||||
if not class_.__name__.startswith(module.__name__): # pragma: no cover
|
||||
continue
|
||||
if class_ is module:
|
||||
continue
|
||||
# We should rename the trio.tests module (#274), but until then we use
|
||||
# a special-case hack:
|
||||
if class_.__name__ == "trio.tests":
|
||||
continue
|
||||
yield from public_modules(class_)
|
||||
|
||||
|
||||
PUBLIC_MODULES = list(public_modules(trio))
|
||||
PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES]
|
||||
|
||||
|
||||
# It doesn't make sense for downstream redistributors to run this test, since
|
||||
# they might be using a newer version of Python with additional symbols which
|
||||
# won't be reflected in trio.socket, and this shouldn't cause downstream test
|
||||
# runs to start failing.
|
||||
@pytest.mark.redistributors_should_skip
|
||||
# pylint/jedi often have trouble with alpha releases, where Python's internals
|
||||
# are in flux, grammar may not have settled down, etc.
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info.releaselevel == "alpha",
|
||||
reason="skip static introspection tools on Python dev/alpha releases",
|
||||
)
|
||||
@pytest.mark.filterwarnings(
|
||||
# https://github.com/PyCQA/astroid/issues/681
|
||||
"ignore:the imp module is deprecated.*:DeprecationWarning"
|
||||
)
|
||||
@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES)
|
||||
@pytest.mark.parametrize("tool", ["pylint", "jedi"])
|
||||
@pytest.mark.filterwarnings(
|
||||
"ignore:"
|
||||
+ re.escape(
|
||||
"The distutils package is deprecated and slated for removal in Python 3.12. "
|
||||
"Use setuptools or check PEP 632 for potential alternatives"
|
||||
)
|
||||
+ ":DeprecationWarning",
|
||||
"ignore:"
|
||||
+ re.escape("The distutils.sysconfig module is deprecated, use sysconfig instead")
|
||||
+ ":DeprecationWarning",
|
||||
)
|
||||
def test_static_tool_sees_all_symbols(tool, modname):
|
||||
module = importlib.import_module(modname)
|
||||
|
||||
def no_underscores(symbols):
|
||||
return {symbol for symbol in symbols if not symbol.startswith("_")}
|
||||
|
||||
runtime_names = no_underscores(dir(module))
|
||||
|
||||
# We should rename the trio.tests module (#274), but until then we use a
|
||||
# special-case hack:
|
||||
if modname == "trio":
|
||||
runtime_names.remove("tests")
|
||||
|
||||
if tool == "pylint":
|
||||
from pylint.lint import PyLinter
|
||||
|
||||
linter = PyLinter()
|
||||
ast = linter.get_ast(module.__file__, modname)
|
||||
static_names = no_underscores(ast)
|
||||
elif tool == "jedi":
|
||||
import jedi
|
||||
|
||||
# Simulate typing "import trio; trio.<TAB>"
|
||||
script = jedi.Script("import {}; {}.".format(modname, modname))
|
||||
completions = script.complete()
|
||||
static_names = no_underscores(c.name for c in completions)
|
||||
else: # pragma: no cover
|
||||
assert False
|
||||
|
||||
# It's expected that the static set will contain more names than the
|
||||
# runtime set:
|
||||
# - static tools are sometimes sloppy and include deleted names
|
||||
# - some symbols are platform-specific at runtime, but always show up in
|
||||
# static analysis (e.g. in trio.socket or trio.lowlevel)
|
||||
# So we check that the runtime names are a subset of the static names.
|
||||
missing_names = runtime_names - static_names
|
||||
if missing_names: # pragma: no cover
|
||||
print("{} can't see the following names in {}:".format(tool, modname))
|
||||
print()
|
||||
for name in sorted(missing_names):
|
||||
print(" {}".format(name))
|
||||
assert False
|
||||
|
||||
|
||||
def test_classes_are_final():
|
||||
for module in PUBLIC_MODULES:
|
||||
for name, class_ in module.__dict__.items():
|
||||
if not isinstance(class_, type):
|
||||
continue
|
||||
# Deprecated classes are exported with a leading underscore
|
||||
if name.startswith("_"): # pragma: no cover
|
||||
continue
|
||||
|
||||
# Abstract classes can be subclassed, because that's the whole
|
||||
# point of ABCs
|
||||
if inspect.isabstract(class_):
|
||||
continue
|
||||
# Exceptions are allowed to be subclassed, because exception
|
||||
# subclassing isn't used to inherit behavior.
|
||||
if issubclass(class_, BaseException):
|
||||
continue
|
||||
# These are classes that are conceptually abstract, but
|
||||
# inspect.isabstract returns False for boring reasons.
|
||||
if class_ in {trio.abc.Instrument, trio.socket.SocketType}:
|
||||
continue
|
||||
# Enums have their own metaclass, so we can't use our metaclasses.
|
||||
# And I don't think there's a lot of risk from people subclassing
|
||||
# enums...
|
||||
if issubclass(class_, enum.Enum):
|
||||
continue
|
||||
# ... insert other special cases here ...
|
||||
|
||||
assert isinstance(class_, _util.Final)
|
||||
198
venv/Lib/site-packages/trio/tests/test_file_io.py
Normal file
198
venv/Lib/site-packages/trio/tests/test_file_io.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import io
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from unittest.mock import sentinel
|
||||
|
||||
import trio
|
||||
from trio import _core
|
||||
from trio._file_io import AsyncIOWrapper, _FILE_SYNC_ATTRS, _FILE_ASYNC_METHODS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def path(tmpdir):
|
||||
return os.fspath(tmpdir.join("test"))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wrapped():
|
||||
return mock.Mock(spec_set=io.StringIO)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_file(wrapped):
|
||||
return trio.wrap_file(wrapped)
|
||||
|
||||
|
||||
def test_wrap_invalid():
|
||||
with pytest.raises(TypeError):
|
||||
trio.wrap_file(str())
|
||||
|
||||
|
||||
def test_wrap_non_iobase():
|
||||
class FakeFile:
|
||||
def close(self): # pragma: no cover
|
||||
pass
|
||||
|
||||
def write(self): # pragma: no cover
|
||||
pass
|
||||
|
||||
wrapped = FakeFile()
|
||||
assert not isinstance(wrapped, io.IOBase)
|
||||
|
||||
async_file = trio.wrap_file(wrapped)
|
||||
assert isinstance(async_file, AsyncIOWrapper)
|
||||
|
||||
del FakeFile.write
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
trio.wrap_file(FakeFile())
|
||||
|
||||
|
||||
def test_wrapped_property(async_file, wrapped):
|
||||
assert async_file.wrapped is wrapped
|
||||
|
||||
|
||||
def test_dir_matches_wrapped(async_file, wrapped):
|
||||
attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
|
||||
|
||||
# all supported attrs in wrapped should be available in async_file
|
||||
assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped))
|
||||
# all supported attrs not in wrapped should not be available in async_file
|
||||
assert not any(
|
||||
attr in dir(async_file) for attr in attrs if attr not in dir(wrapped)
|
||||
)
|
||||
|
||||
|
||||
def test_unsupported_not_forwarded():
|
||||
class FakeFile(io.RawIOBase):
|
||||
def unsupported_attr(self): # pragma: no cover
|
||||
pass
|
||||
|
||||
async_file = trio.wrap_file(FakeFile())
|
||||
|
||||
assert hasattr(async_file.wrapped, "unsupported_attr")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(async_file, "unsupported_attr")
|
||||
|
||||
|
||||
def test_sync_attrs_forwarded(async_file, wrapped):
|
||||
for attr_name in _FILE_SYNC_ATTRS:
|
||||
if attr_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
|
||||
|
||||
|
||||
def test_sync_attrs_match_wrapper(async_file, wrapped):
|
||||
for attr_name in _FILE_SYNC_ATTRS:
|
||||
if attr_name in dir(async_file):
|
||||
continue
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(async_file, attr_name)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(wrapped, attr_name)
|
||||
|
||||
|
||||
def test_async_methods_generated_once(async_file):
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
|
||||
|
||||
|
||||
def test_async_methods_signature(async_file):
|
||||
# use read as a representative of all async methods
|
||||
assert async_file.read.__name__ == "read"
|
||||
assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
|
||||
|
||||
assert "io.StringIO.read" in async_file.read.__doc__
|
||||
|
||||
|
||||
async def test_async_methods_wrap(async_file, wrapped):
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
meth = getattr(async_file, meth_name)
|
||||
wrapped_meth = getattr(wrapped, meth_name)
|
||||
|
||||
value = await meth(sentinel.argument, keyword=sentinel.keyword)
|
||||
|
||||
wrapped_meth.assert_called_once_with(
|
||||
sentinel.argument, keyword=sentinel.keyword
|
||||
)
|
||||
assert value == wrapped_meth()
|
||||
|
||||
wrapped.reset_mock()
|
||||
|
||||
|
||||
async def test_async_methods_match_wrapper(async_file, wrapped):
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name in dir(async_file):
|
||||
continue
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(async_file, meth_name)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(wrapped, meth_name)
|
||||
|
||||
|
||||
async def test_open(path):
|
||||
f = await trio.open_file(path, "w")
|
||||
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
|
||||
await f.aclose()
|
||||
|
||||
|
||||
async def test_open_context_manager(path):
|
||||
async with await trio.open_file(path, "w") as f:
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
assert not f.closed
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_async_iter():
|
||||
async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar"))
|
||||
expected = list(async_file.wrapped)
|
||||
result = []
|
||||
async_file.wrapped.seek(0)
|
||||
|
||||
async for line in async_file:
|
||||
result.append(line)
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
async def test_aclose_cancelled(path):
|
||||
with _core.CancelScope() as cscope:
|
||||
f = await trio.open_file(path, "w")
|
||||
cscope.cancel()
|
||||
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await f.write("a")
|
||||
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await f.aclose()
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_detach_rewraps_asynciobase():
|
||||
raw = io.BytesIO()
|
||||
buffered = io.BufferedReader(raw)
|
||||
|
||||
async_file = trio.wrap_file(buffered)
|
||||
|
||||
detached = await async_file.detach()
|
||||
|
||||
assert isinstance(detached, AsyncIOWrapper)
|
||||
assert detached.wrapped is raw
|
||||
94
venv/Lib/site-packages/trio/tests/test_highlevel_generic.py
Normal file
94
venv/Lib/site-packages/trio/tests/test_highlevel_generic.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import pytest
|
||||
|
||||
import attr
|
||||
|
||||
from ..abc import SendStream, ReceiveStream
|
||||
from .._highlevel_generic import StapledStream
|
||||
|
||||
|
||||
@attr.s
|
||||
class RecordSendStream(SendStream):
|
||||
record = attr.ib(factory=list)
|
||||
|
||||
async def send_all(self, data):
|
||||
self.record.append(("send_all", data))
|
||||
|
||||
async def wait_send_all_might_not_block(self):
|
||||
self.record.append("wait_send_all_might_not_block")
|
||||
|
||||
async def aclose(self):
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
@attr.s
|
||||
class RecordReceiveStream(ReceiveStream):
|
||||
record = attr.ib(factory=list)
|
||||
|
||||
async def receive_some(self, max_bytes=None):
|
||||
self.record.append(("receive_some", max_bytes))
|
||||
|
||||
async def aclose(self):
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
async def test_StapledStream():
|
||||
send_stream = RecordSendStream()
|
||||
receive_stream = RecordReceiveStream()
|
||||
stapled = StapledStream(send_stream, receive_stream)
|
||||
|
||||
assert stapled.send_stream is send_stream
|
||||
assert stapled.receive_stream is receive_stream
|
||||
|
||||
await stapled.send_all(b"foo")
|
||||
await stapled.wait_send_all_might_not_block()
|
||||
assert send_stream.record == [
|
||||
("send_all", b"foo"),
|
||||
"wait_send_all_might_not_block",
|
||||
]
|
||||
send_stream.record.clear()
|
||||
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["aclose"]
|
||||
send_stream.record.clear()
|
||||
|
||||
async def fake_send_eof():
|
||||
send_stream.record.append("send_eof")
|
||||
|
||||
send_stream.send_eof = fake_send_eof
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["send_eof"]
|
||||
|
||||
send_stream.record.clear()
|
||||
assert receive_stream.record == []
|
||||
|
||||
await stapled.receive_some(1234)
|
||||
assert receive_stream.record == [("receive_some", 1234)]
|
||||
assert send_stream.record == []
|
||||
receive_stream.record.clear()
|
||||
|
||||
await stapled.aclose()
|
||||
assert receive_stream.record == ["aclose"]
|
||||
assert send_stream.record == ["aclose"]
|
||||
|
||||
|
||||
async def test_StapledStream_with_erroring_close():
|
||||
# Make sure that if one of the aclose methods errors out, then the other
|
||||
# one still gets called.
|
||||
class BrokenSendStream(RecordSendStream):
|
||||
async def aclose(self):
|
||||
await super().aclose()
|
||||
raise ValueError
|
||||
|
||||
class BrokenReceiveStream(RecordReceiveStream):
|
||||
async def aclose(self):
|
||||
await super().aclose()
|
||||
raise ValueError
|
||||
|
||||
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await stapled.aclose()
|
||||
assert isinstance(excinfo.value.__context__, ValueError)
|
||||
|
||||
assert stapled.send_stream.record == ["aclose"]
|
||||
assert stapled.receive_stream.record == ["aclose"]
|
||||
@@ -0,0 +1,295 @@
|
||||
import pytest
|
||||
|
||||
import socket as stdlib_socket
|
||||
import errno
|
||||
|
||||
import attr
|
||||
|
||||
import trio
|
||||
from trio import open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream
|
||||
from trio.testing import open_stream_to_socket_listener
|
||||
from .. import socket as tsocket
|
||||
from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_basic():
|
||||
listeners = await open_tcp_listeners(0)
|
||||
assert isinstance(listeners, list)
|
||||
for obj in listeners:
|
||||
assert isinstance(obj, SocketListener)
|
||||
# Binds to wildcard address by default
|
||||
assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6]
|
||||
assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"]
|
||||
|
||||
listener = listeners[0]
|
||||
# Make sure the backlog is at least 2
|
||||
c1 = await open_stream_to_socket_listener(listener)
|
||||
c2 = await open_stream_to_socket_listener(listener)
|
||||
|
||||
s1 = await listener.accept()
|
||||
s2 = await listener.accept()
|
||||
|
||||
# Note that we don't know which client stream is connected to which server
|
||||
# stream
|
||||
await s1.send_all(b"x")
|
||||
await s2.send_all(b"x")
|
||||
assert await c1.receive_some(1) == b"x"
|
||||
assert await c2.receive_some(1) == b"x"
|
||||
|
||||
for resource in [c1, c2, s1, s2] + listeners:
|
||||
await resource.aclose()
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_specific_port_specific_host():
|
||||
# Pick a port
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("127.0.0.1", 0))
|
||||
host, port = sock.getsockname()
|
||||
sock.close()
|
||||
|
||||
(listener,) = await open_tcp_listeners(port, host=host)
|
||||
async with listener:
|
||||
assert listener.socket.getsockname() == (host, port)
|
||||
|
||||
|
||||
@binds_ipv6
|
||||
async def test_open_tcp_listeners_ipv6_v6only():
|
||||
# Check IPV6_V6ONLY is working properly
|
||||
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
|
||||
async with ipv6_listener:
|
||||
_, port, *_ = ipv6_listener.socket.getsockname()
|
||||
|
||||
with pytest.raises(OSError):
|
||||
await open_tcp_stream("127.0.0.1", port)
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_rebind():
|
||||
(l1,) = await open_tcp_listeners(0, host="127.0.0.1")
|
||||
sockaddr1 = l1.socket.getsockname()
|
||||
|
||||
# Plain old rebinding while it's still there should fail, even if we have
|
||||
# SO_REUSEADDR set
|
||||
with stdlib_socket.socket() as probe:
|
||||
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
|
||||
with pytest.raises(OSError):
|
||||
probe.bind(sockaddr1)
|
||||
|
||||
# Now use the first listener to set up some connections in various states,
|
||||
# and make sure that they don't create any obstacle to rebinding a second
|
||||
# listener after the first one is closed.
|
||||
c_established = await open_stream_to_socket_listener(l1)
|
||||
s_established = await l1.accept()
|
||||
|
||||
c_time_wait = await open_stream_to_socket_listener(l1)
|
||||
s_time_wait = await l1.accept()
|
||||
# Server-initiated close leaves socket in TIME_WAIT
|
||||
await s_time_wait.aclose()
|
||||
|
||||
await l1.aclose()
|
||||
(l2,) = await open_tcp_listeners(sockaddr1[1], host="127.0.0.1")
|
||||
sockaddr2 = l2.socket.getsockname()
|
||||
|
||||
assert sockaddr1 == sockaddr2
|
||||
assert s_established.socket.getsockname() == sockaddr2
|
||||
assert c_time_wait.socket.getpeername() == sockaddr2
|
||||
|
||||
for resource in [
|
||||
l1,
|
||||
l2,
|
||||
c_established,
|
||||
s_established,
|
||||
c_time_wait,
|
||||
s_time_wait,
|
||||
]:
|
||||
await resource.aclose()
|
||||
|
||||
|
||||
class FakeOSError(OSError):
|
||||
pass
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeSocket(tsocket.SocketType):
|
||||
family = attr.ib()
|
||||
type = attr.ib()
|
||||
proto = attr.ib()
|
||||
|
||||
closed = attr.ib(default=False)
|
||||
poison_listen = attr.ib(default=False)
|
||||
backlog = attr.ib(default=None)
|
||||
|
||||
def getsockopt(self, level, option):
|
||||
if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
|
||||
return True
|
||||
assert False # pragma: no cover
|
||||
|
||||
def setsockopt(self, level, option, value):
|
||||
pass
|
||||
|
||||
async def bind(self, sockaddr):
|
||||
pass
|
||||
|
||||
def listen(self, backlog):
|
||||
assert self.backlog is None
|
||||
assert backlog is not None
|
||||
self.backlog = backlog
|
||||
if self.poison_listen:
|
||||
raise FakeOSError("whoops")
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeSocketFactory:
|
||||
poison_after = attr.ib()
|
||||
sockets = attr.ib(factory=list)
|
||||
raise_on_family = attr.ib(factory=dict) # family => errno
|
||||
|
||||
def socket(self, family, type, proto):
|
||||
if family in self.raise_on_family:
|
||||
raise OSError(self.raise_on_family[family], "nope")
|
||||
sock = FakeSocket(family, type, proto)
|
||||
self.poison_after -= 1
|
||||
if self.poison_after == 0:
|
||||
sock.poison_listen = True
|
||||
self.sockets.append(sock)
|
||||
return sock
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeHostnameResolver:
|
||||
family_addr_pairs = attr.ib()
|
||||
|
||||
async def getaddrinfo(self, host, port, family, type, proto, flags):
|
||||
return [
|
||||
(family, tsocket.SOCK_STREAM, 0, "", (addr, port))
|
||||
for family, addr in self.family_addr_pairs
|
||||
]
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_multiple_host_cleanup_on_error():
|
||||
# If we were trying to bind to multiple hosts and one of them failed, they
|
||||
# call get cleaned up before returning
|
||||
fsf = FakeSocketFactory(3)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver(
|
||||
[
|
||||
(tsocket.AF_INET, "1.1.1.1"),
|
||||
(tsocket.AF_INET, "2.2.2.2"),
|
||||
(tsocket.AF_INET, "3.3.3.3"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(FakeOSError):
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
|
||||
assert len(fsf.sockets) == 3
|
||||
for sock in fsf.sockets:
|
||||
assert sock.closed
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_port_checking():
|
||||
for host in ["127.0.0.1", None]:
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners(None, host=host)
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners(b"80", host=host)
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners("http", host=host)
|
||||
|
||||
|
||||
async def test_serve_tcp():
|
||||
async def handler(stream):
|
||||
await stream.send_all(b"x")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
listeners = await nursery.start(serve_tcp, handler, 0)
|
||||
stream = await open_stream_to_socket_listener(listeners[0])
|
||||
async with stream:
|
||||
await stream.receive_some(1) == b"x"
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"try_families",
|
||||
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"fail_families",
|
||||
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
|
||||
)
|
||||
async def test_open_tcp_listeners_some_address_families_unavailable(
|
||||
try_families, fail_families
|
||||
):
|
||||
fsf = FakeSocketFactory(
|
||||
10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families}
|
||||
)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver([(family, "foo") for family in try_families])
|
||||
)
|
||||
|
||||
should_succeed = try_families - fail_families
|
||||
|
||||
if not should_succeed:
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
|
||||
assert "This system doesn't support" in str(exc_info.value)
|
||||
if isinstance(exc_info.value.__cause__, trio.MultiError):
|
||||
for subexc in exc_info.value.__cause__.exceptions:
|
||||
assert "nope" in str(subexc)
|
||||
else:
|
||||
assert isinstance(exc_info.value.__cause__, OSError)
|
||||
assert "nope" in str(exc_info.value.__cause__)
|
||||
else:
|
||||
listeners = await open_tcp_listeners(80)
|
||||
for listener in listeners:
|
||||
should_succeed.remove(listener.socket.family)
|
||||
assert not should_succeed
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_socket_fails_not_afnosupport():
|
||||
fsf = FakeSocketFactory(
|
||||
10,
|
||||
raise_on_family={
|
||||
tsocket.AF_INET: errno.EAFNOSUPPORT,
|
||||
tsocket.AF_INET6: errno.EINVAL,
|
||||
},
|
||||
)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")])
|
||||
)
|
||||
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
assert exc_info.value.errno == errno.EINVAL
|
||||
assert exc_info.value.__cause__ is None
|
||||
assert "nope" in str(exc_info.value)
|
||||
|
||||
|
||||
# We used to have an elaborate test that opened a real TCP listening socket
|
||||
# and then tried to measure its backlog by making connections to it. And most
|
||||
# of the time, it worked. But no matter what we tried, it was always fragile,
|
||||
# because it had to do things like use timeouts to guess when the listening
|
||||
# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there
|
||||
# effectively is no backlog), sometimes the host might not be enough resources
|
||||
# to give us the full requested backlog... it was a mess. So now we just check
|
||||
# that the backlog argument is passed through correctly.
|
||||
async def test_open_tcp_listeners_backlog():
|
||||
fsf = FakeSocketFactory(99)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
for (given, expected) in [
|
||||
(None, 0xFFFF),
|
||||
(99999999, 0xFFFF),
|
||||
(10, 10),
|
||||
(1, 1),
|
||||
]:
|
||||
listeners = await open_tcp_listeners(0, backlog=given)
|
||||
assert listeners
|
||||
for listener in listeners:
|
||||
assert listener.socket.backlog == expected
|
||||
@@ -0,0 +1,571 @@
|
||||
import pytest
|
||||
import sys
|
||||
import socket
|
||||
|
||||
import attr
|
||||
|
||||
import trio
|
||||
from trio.socket import AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_TCP
|
||||
from trio._highlevel_open_tcp_stream import (
|
||||
reorder_for_rfc_6555_section_5_4,
|
||||
close_all,
|
||||
open_tcp_stream,
|
||||
format_host_port,
|
||||
)
|
||||
|
||||
|
||||
def test_close_all():
|
||||
class CloseMe:
|
||||
closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
class CloseKiller:
|
||||
def close(self):
|
||||
raise OSError
|
||||
|
||||
c = CloseMe()
|
||||
with close_all() as to_close:
|
||||
to_close.add(c)
|
||||
assert c.closed
|
||||
|
||||
c = CloseMe()
|
||||
with pytest.raises(RuntimeError):
|
||||
with close_all() as to_close:
|
||||
to_close.add(c)
|
||||
raise RuntimeError
|
||||
assert c.closed
|
||||
|
||||
c = CloseMe()
|
||||
with pytest.raises(OSError):
|
||||
with close_all() as to_close:
|
||||
to_close.add(CloseKiller())
|
||||
to_close.add(c)
|
||||
assert c.closed
|
||||
|
||||
|
||||
def test_reorder_for_rfc_6555_section_5_4():
|
||||
def fake4(i):
|
||||
return (
|
||||
AF_INET,
|
||||
SOCK_STREAM,
|
||||
IPPROTO_TCP,
|
||||
"",
|
||||
("10.0.0.{}".format(i), 80),
|
||||
)
|
||||
|
||||
def fake6(i):
|
||||
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", ("::{}".format(i), 80))
|
||||
|
||||
for fake in fake4, fake6:
|
||||
# No effect on homogeneous lists
|
||||
targets = [fake(0), fake(1), fake(2)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake(0), fake(1), fake(2)]
|
||||
|
||||
# Single item lists also OK
|
||||
targets = [fake(0)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake(0)]
|
||||
|
||||
# If the list starts out with different families in positions 0 and 1,
|
||||
# then it's left alone
|
||||
orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
|
||||
targets = list(orig)
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == orig
|
||||
|
||||
# If not, it's reordered
|
||||
targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
|
||||
|
||||
|
||||
def test_format_host_port():
|
||||
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
|
||||
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
|
||||
assert format_host_port("example.com", 443) == "example.com:443"
|
||||
assert format_host_port(b"example.com", 443) == "example.com:443"
|
||||
assert format_host_port("::1", "http") == "[::1]:http"
|
||||
assert format_host_port(b"::1", "http") == "[::1]:http"
|
||||
|
||||
|
||||
# Make sure we can connect to localhost using real kernel sockets
|
||||
async def test_open_tcp_stream_real_socket_smoketest():
|
||||
listen_sock = trio.socket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
_, listen_port = listen_sock.getsockname()
|
||||
listen_sock.listen(1)
|
||||
client_stream = await open_tcp_stream("127.0.0.1", listen_port)
|
||||
server_sock, _ = await listen_sock.accept()
|
||||
await client_stream.send_all(b"x")
|
||||
assert await server_sock.recv(1) == b"x"
|
||||
await client_stream.aclose()
|
||||
server_sock.close()
|
||||
|
||||
listen_sock.close()
|
||||
|
||||
|
||||
async def test_open_tcp_stream_input_validation():
|
||||
with pytest.raises(ValueError):
|
||||
await open_tcp_stream(None, 80)
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_stream("127.0.0.1", b"80")
|
||||
|
||||
|
||||
def can_bind_127_0_0_2():
|
||||
with socket.socket() as s:
|
||||
try:
|
||||
s.bind(("127.0.0.2", 0))
|
||||
except OSError:
|
||||
return False
|
||||
return s.getsockname()[0] == "127.0.0.2"
|
||||
|
||||
|
||||
async def test_local_address_real():
|
||||
with trio.socket.socket() as listener:
|
||||
await listener.bind(("127.0.0.1", 0))
|
||||
listener.listen()
|
||||
|
||||
# It's hard to test local_address properly, because you need multiple
|
||||
# local addresses that you can bind to. Fortunately, on most Linux
|
||||
# systems, you can bind to any 127.*.*.* address, and they all go
|
||||
# through the loopback interface. So we can use a non-standard
|
||||
# loopback address. On other systems, the only address we know for
|
||||
# certain we have is 127.0.0.1, so we can't really test local_address=
|
||||
# properly -- passing local_address=127.0.0.1 is indistinguishable
|
||||
# from not passing local_address= at all. But, we can still do a smoke
|
||||
# test to make sure the local_address= code doesn't crash.
|
||||
if can_bind_127_0_0_2():
|
||||
local_address = "127.0.0.2"
|
||||
else:
|
||||
local_address = "127.0.0.1"
|
||||
|
||||
async with await open_tcp_stream(
|
||||
*listener.getsockname(), local_address=local_address
|
||||
) as client_stream:
|
||||
assert client_stream.socket.getsockname()[0] == local_address
|
||||
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
|
||||
assert client_stream.socket.getsockopt(
|
||||
trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT
|
||||
)
|
||||
|
||||
server_sock, remote_addr = await listener.accept()
|
||||
await client_stream.aclose()
|
||||
server_sock.close()
|
||||
assert remote_addr[0] == local_address
|
||||
|
||||
# Trying to connect to an ipv4 address with the ipv6 wildcard
|
||||
# local_address should fail
|
||||
with pytest.raises(OSError):
|
||||
await open_tcp_stream(*listener.getsockname(), local_address="::")
|
||||
|
||||
# But the ipv4 wildcard address should work
|
||||
async with await open_tcp_stream(
|
||||
*listener.getsockname(), local_address="0.0.0.0"
|
||||
) as client_stream:
|
||||
server_sock, remote_addr = await listener.accept()
|
||||
server_sock.close()
|
||||
assert remote_addr == client_stream.socket.getsockname()
|
||||
|
||||
|
||||
# Now, thorough tests using fake sockets
|
||||
|
||||
|
||||
@attr.s(eq=False)
|
||||
class FakeSocket(trio.socket.SocketType):
|
||||
scenario = attr.ib()
|
||||
family = attr.ib()
|
||||
type = attr.ib()
|
||||
proto = attr.ib()
|
||||
|
||||
ip = attr.ib(default=None)
|
||||
port = attr.ib(default=None)
|
||||
succeeded = attr.ib(default=False)
|
||||
closed = attr.ib(default=False)
|
||||
failing = attr.ib(default=False)
|
||||
|
||||
async def connect(self, sockaddr):
|
||||
self.ip = sockaddr[0]
|
||||
self.port = sockaddr[1]
|
||||
assert self.ip not in self.scenario.sockets
|
||||
self.scenario.sockets[self.ip] = self
|
||||
self.scenario.connect_times[self.ip] = trio.current_time()
|
||||
delay, result = self.scenario.ip_dict[self.ip]
|
||||
await trio.sleep(delay)
|
||||
if result == "error":
|
||||
raise OSError("sorry")
|
||||
if result == "postconnect_fail":
|
||||
self.failing = True
|
||||
self.succeeded = True
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
# called when SocketStream is constructed
|
||||
def setsockopt(self, *args, **kwargs):
|
||||
if self.failing:
|
||||
# raise something that isn't OSError as SocketStream
|
||||
# ignores those
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
|
||||
def __init__(self, port, ip_list, supported_families):
|
||||
# ip_list have to be unique
|
||||
ip_order = [ip for (ip, _, _) in ip_list]
|
||||
assert len(set(ip_order)) == len(ip_list)
|
||||
ip_dict = {}
|
||||
for ip, delay, result in ip_list:
|
||||
assert 0 <= delay
|
||||
assert result in ["error", "success", "postconnect_fail"]
|
||||
ip_dict[ip] = (delay, result)
|
||||
|
||||
self.port = port
|
||||
self.ip_order = ip_order
|
||||
self.ip_dict = ip_dict
|
||||
self.supported_families = supported_families
|
||||
self.socket_count = 0
|
||||
self.sockets = {}
|
||||
self.connect_times = {}
|
||||
|
||||
def socket(self, family, type, proto):
|
||||
if family not in self.supported_families:
|
||||
raise OSError("pretending not to support this family")
|
||||
self.socket_count += 1
|
||||
return FakeSocket(self, family, type, proto)
|
||||
|
||||
def _ip_to_gai_entry(self, ip):
|
||||
if ":" in ip:
|
||||
family = trio.socket.AF_INET6
|
||||
sockaddr = (ip, self.port, 0, 0)
|
||||
else:
|
||||
family = trio.socket.AF_INET
|
||||
sockaddr = (ip, self.port)
|
||||
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
|
||||
|
||||
async def getaddrinfo(self, host, port, family, type, proto, flags):
|
||||
assert host == b"test.example.com"
|
||||
assert port == self.port
|
||||
assert family == trio.socket.AF_UNSPEC
|
||||
assert type == trio.socket.SOCK_STREAM
|
||||
assert proto == 0
|
||||
assert flags == 0
|
||||
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
|
||||
|
||||
async def getnameinfo(self, sockaddr, flags): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def check(self, succeeded):
|
||||
# sockets only go into self.sockets when connect is called; make sure
|
||||
# all the sockets that were created did in fact go in there.
|
||||
assert self.socket_count == len(self.sockets)
|
||||
|
||||
for ip, socket in self.sockets.items():
|
||||
assert ip in self.ip_dict
|
||||
if socket is not succeeded:
|
||||
assert socket.closed
|
||||
assert socket.port == self.port
|
||||
|
||||
|
||||
async def run_scenario(
|
||||
# The port to connect to
|
||||
port,
|
||||
# A list of
|
||||
# (ip, delay, result)
|
||||
# tuples, where delay is in seconds and result is "success" or "error"
|
||||
# The ip's will be returned from getaddrinfo in this order, and then
|
||||
# connect() calls to them will have the given result.
|
||||
ip_list,
|
||||
*,
|
||||
# If False, AF_INET4/6 sockets error out on creation, before connect is
|
||||
# even called.
|
||||
ipv4_supported=True,
|
||||
ipv6_supported=True,
|
||||
# Normally, we return (winning_sock, scenario object)
|
||||
# If this is True, we require there to be an exception, and return
|
||||
# (exception, scenario object)
|
||||
expect_error=(),
|
||||
**kwargs,
|
||||
):
|
||||
supported_families = set()
|
||||
if ipv4_supported:
|
||||
supported_families.add(trio.socket.AF_INET)
|
||||
if ipv6_supported:
|
||||
supported_families.add(trio.socket.AF_INET6)
|
||||
scenario = Scenario(port, ip_list, supported_families)
|
||||
trio.socket.set_custom_hostname_resolver(scenario)
|
||||
trio.socket.set_custom_socket_factory(scenario)
|
||||
|
||||
try:
|
||||
stream = await open_tcp_stream("test.example.com", port, **kwargs)
|
||||
assert expect_error == ()
|
||||
scenario.check(stream.socket)
|
||||
return (stream.socket, scenario)
|
||||
except AssertionError: # pragma: no cover
|
||||
raise
|
||||
except expect_error as exc:
|
||||
scenario.check(None)
|
||||
return (exc, scenario)
|
||||
|
||||
|
||||
async def test_one_host_quick_success(autojump_clock):
|
||||
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
|
||||
assert sock.ip == "1.2.3.4"
|
||||
assert trio.current_time() == 0.123
|
||||
|
||||
|
||||
async def test_one_host_slow_success(autojump_clock):
|
||||
sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
|
||||
assert sock.ip == "1.2.3.4"
|
||||
assert trio.current_time() == 100
|
||||
|
||||
|
||||
async def test_one_host_quick_fail(autojump_clock):
|
||||
exc, scenario = await run_scenario(
|
||||
82, [("1.2.3.4", 0.123, "error")], expect_error=OSError
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
assert trio.current_time() == 0.123
|
||||
|
||||
|
||||
async def test_one_host_slow_fail(autojump_clock):
|
||||
exc, scenario = await run_scenario(
|
||||
83, [("1.2.3.4", 100, "error")], expect_error=OSError
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
assert trio.current_time() == 100
|
||||
|
||||
|
||||
async def test_one_host_failed_after_connect(autojump_clock):
|
||||
exc, scenario = await run_scenario(
|
||||
83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt
|
||||
)
|
||||
assert isinstance(exc, KeyboardInterrupt)
|
||||
|
||||
|
||||
# With the default 0.250 second delay, the third attempt will win
|
||||
async def test_basic_fallthrough(autojump_clock):
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
)
|
||||
assert sock.ip == "3.3.3.3"
|
||||
# current time is default time + default time + connection time
|
||||
assert trio.current_time() == (0.250 + 0.250 + 0.2)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
"3.3.3.3": 0.500,
|
||||
}
|
||||
|
||||
|
||||
async def test_early_success(autojump_clock):
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 0.1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
)
|
||||
assert sock.ip == "2.2.2.2"
|
||||
assert trio.current_time() == (0.250 + 0.1)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
# 3.3.3.3 was never even started
|
||||
}
|
||||
|
||||
|
||||
# With a 0.450 second delay, the first attempt will win
|
||||
async def test_custom_delay(autojump_clock):
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=0.450,
|
||||
)
|
||||
assert sock.ip == "1.1.1.1"
|
||||
assert trio.current_time() == 1
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.450,
|
||||
"3.3.3.3": 0.900,
|
||||
}
|
||||
|
||||
|
||||
async def test_custom_errors_expedite(autojump_clock):
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.1, "error"),
|
||||
("2.2.2.2", 0.2, "error"),
|
||||
("3.3.3.3", 10, "success"),
|
||||
# .25 is the default timeout
|
||||
("4.4.4.4", 0.25, "success"),
|
||||
],
|
||||
)
|
||||
assert sock.ip == "4.4.4.4"
|
||||
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.1,
|
||||
"3.3.3.3": 0.1 + 0.2,
|
||||
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
||||
}
|
||||
|
||||
|
||||
async def test_all_fail(autojump_clock):
|
||||
exc, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.1, "error"),
|
||||
("2.2.2.2", 0.2, "error"),
|
||||
("3.3.3.3", 10, "error"),
|
||||
("4.4.4.4", 0.250, "error"),
|
||||
],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
assert isinstance(exc.__cause__, trio.MultiError)
|
||||
assert len(exc.__cause__.exceptions) == 4
|
||||
assert trio.current_time() == (0.1 + 0.2 + 10)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.1,
|
||||
"3.3.3.3": 0.1 + 0.2,
|
||||
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
||||
}
|
||||
|
||||
|
||||
async def test_multi_success(autojump_clock):
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.5, "error"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("3.3.3.3", 10 - 1, "success"),
|
||||
("4.4.4.4", 10 - 2, "success"),
|
||||
("5.5.5.5", 0.5, "error"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
assert not scenario.sockets["1.1.1.1"].succeeded
|
||||
assert (
|
||||
scenario.sockets["2.2.2.2"].succeeded
|
||||
or scenario.sockets["3.3.3.3"].succeeded
|
||||
or scenario.sockets["4.4.4.4"].succeeded
|
||||
)
|
||||
assert not scenario.sockets["5.5.5.5"].succeeded
|
||||
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
|
||||
assert trio.current_time() == (0.5 + 10)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.5,
|
||||
"3.3.3.3": 1.5,
|
||||
"4.4.4.4": 2.5,
|
||||
"5.5.5.5": 3.5,
|
||||
}
|
||||
|
||||
|
||||
async def test_does_reorder(autojump_clock):
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 10, "error"),
|
||||
# This would win if we tried it first...
|
||||
("2.2.2.2", 1, "success"),
|
||||
# But in fact we try this first, because of section 5.4
|
||||
("::3", 0.5, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
assert sock.ip == "::3"
|
||||
assert trio.current_time() == 1 + 0.5
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"::3": 1,
|
||||
}
|
||||
|
||||
|
||||
async def test_handles_no_ipv4(autojump_clock):
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
# Here the ipv6 addresses fail at socket creation time, so the connect
|
||||
# configuration doesn't matter
|
||||
[
|
||||
("::1", 10, "success"),
|
||||
("2.2.2.2", 0, "success"),
|
||||
("::3", 0.1, "success"),
|
||||
("4.4.4.4", 0, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
ipv4_supported=False,
|
||||
)
|
||||
assert sock.ip == "::3"
|
||||
assert trio.current_time() == 1 + 0.1
|
||||
assert scenario.connect_times == {
|
||||
"::1": 0,
|
||||
"::3": 1.0,
|
||||
}
|
||||
|
||||
|
||||
async def test_handles_no_ipv6(autojump_clock):
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
# Here the ipv6 addresses fail at socket creation time, so the connect
|
||||
# configuration doesn't matter
|
||||
[
|
||||
("::1", 0, "success"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("::3", 0, "success"),
|
||||
("4.4.4.4", 0.1, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
ipv6_supported=False,
|
||||
)
|
||||
assert sock.ip == "4.4.4.4"
|
||||
assert trio.current_time() == 1 + 0.1
|
||||
assert scenario.connect_times == {
|
||||
"2.2.2.2": 0,
|
||||
"4.4.4.4": 1.0,
|
||||
}
|
||||
|
||||
|
||||
async def test_no_hosts(autojump_clock):
|
||||
exc, scenario = await run_scenario(80, [], expect_error=OSError)
|
||||
assert "no results found" in str(exc)
|
||||
|
||||
|
||||
async def test_cancel(autojump_clock):
|
||||
with trio.move_on_after(5) as cancel_scope:
|
||||
exc, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 10, "success"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("3.3.3.3", 10, "success"),
|
||||
("4.4.4.4", 10, "success"),
|
||||
],
|
||||
expect_error=trio.MultiError,
|
||||
)
|
||||
# What comes out should be 1 or more Cancelled errors that all belong
|
||||
# to this cancel_scope; this is the easiest way to check that
|
||||
raise exc
|
||||
assert cancel_scope.cancelled_caught
|
||||
|
||||
assert trio.current_time() == 5
|
||||
|
||||
# This should have been called already, but just to make sure, since the
|
||||
# exception-handling logic in run_scenario is a bit complicated and the
|
||||
# main thing we care about here is that all the sockets were cleaned up.
|
||||
scenario.check(succeeded=False)
|
||||
@@ -0,0 +1,67 @@
|
||||
import os
|
||||
import socket
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from trio import open_unix_socket, Path
|
||||
from trio._highlevel_open_unix_stream import close_on_error
|
||||
|
||||
if not hasattr(socket, "AF_UNIX"):
|
||||
pytestmark = pytest.mark.skip("Needs unix socket support")
|
||||
|
||||
|
||||
def test_close_on_error():
|
||||
class CloseMe:
|
||||
closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
with close_on_error(CloseMe()) as c:
|
||||
pass
|
||||
assert not c.closed
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with close_on_error(CloseMe()) as c:
|
||||
raise RuntimeError
|
||||
assert c.closed
|
||||
|
||||
|
||||
@pytest.mark.parametrize("filename", [4, 4.5])
|
||||
async def test_open_with_bad_filename_type(filename):
|
||||
with pytest.raises(TypeError):
|
||||
await open_unix_socket(filename)
|
||||
|
||||
|
||||
async def test_open_bad_socket():
|
||||
# mktemp is marked as insecure, but that's okay, we don't want the file to
|
||||
# exist
|
||||
name = tempfile.mktemp()
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await open_unix_socket(name)
|
||||
|
||||
|
||||
async def test_open_unix_socket():
|
||||
for name_type in [Path, str]:
|
||||
name = tempfile.mktemp()
|
||||
serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
with serv_sock:
|
||||
serv_sock.bind(name)
|
||||
try:
|
||||
serv_sock.listen(1)
|
||||
|
||||
# The actual function we're testing
|
||||
unix_socket = await open_unix_socket(name_type(name))
|
||||
|
||||
async with unix_socket:
|
||||
client, _ = serv_sock.accept()
|
||||
with client:
|
||||
await unix_socket.send_all(b"test")
|
||||
assert client.recv(2048) == b"test"
|
||||
|
||||
client.sendall(b"response")
|
||||
received = await unix_socket.receive_some(2048)
|
||||
assert received == b"response"
|
||||
finally:
|
||||
os.unlink(name)
|
||||
@@ -0,0 +1,145 @@
|
||||
import pytest
|
||||
|
||||
from functools import partial
|
||||
import errno
|
||||
|
||||
import attr
|
||||
|
||||
import trio
|
||||
from trio.testing import memory_stream_pair, wait_all_tasks_blocked
|
||||
|
||||
|
||||
@attr.s(hash=False, eq=False)
|
||||
class MemoryListener(trio.abc.Listener):
|
||||
closed = attr.ib(default=False)
|
||||
accepted_streams = attr.ib(factory=list)
|
||||
queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1)))
|
||||
accept_hook = attr.ib(default=None)
|
||||
|
||||
async def connect(self):
|
||||
assert not self.closed
|
||||
client, server = memory_stream_pair()
|
||||
await self.queued_streams[0].send(server)
|
||||
return client
|
||||
|
||||
async def accept(self):
|
||||
await trio.lowlevel.checkpoint()
|
||||
assert not self.closed
|
||||
if self.accept_hook is not None:
|
||||
await self.accept_hook()
|
||||
stream = await self.queued_streams[1].receive()
|
||||
self.accepted_streams.append(stream)
|
||||
return stream
|
||||
|
||||
async def aclose(self):
|
||||
self.closed = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def test_serve_listeners_basic():
|
||||
listeners = [MemoryListener(), MemoryListener()]
|
||||
|
||||
record = []
|
||||
|
||||
def close_hook():
|
||||
# Make sure this is a forceful close
|
||||
assert trio.current_effective_deadline() == float("-inf")
|
||||
record.append("closed")
|
||||
|
||||
async def handler(stream):
|
||||
await stream.send_all(b"123")
|
||||
assert await stream.receive_some(10) == b"456"
|
||||
stream.send_stream.close_hook = close_hook
|
||||
stream.receive_stream.close_hook = close_hook
|
||||
|
||||
async def client(listener):
|
||||
s = await listener.connect()
|
||||
assert await s.receive_some(10) == b"123"
|
||||
await s.send_all(b"456")
|
||||
|
||||
async def do_tests(parent_nursery):
|
||||
async with trio.open_nursery() as nursery:
|
||||
for listener in listeners:
|
||||
for _ in range(3):
|
||||
nursery.start_soon(client, listener)
|
||||
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# verifies that all 6 streams x 2 directions each were closed ok
|
||||
assert len(record) == 12
|
||||
|
||||
parent_nursery.cancel_scope.cancel()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
l2 = await nursery.start(trio.serve_listeners, handler, listeners)
|
||||
assert l2 == listeners
|
||||
# This is just split into another function because gh-136 isn't
|
||||
# implemented yet
|
||||
nursery.start_soon(do_tests, nursery)
|
||||
|
||||
for listener in listeners:
|
||||
assert listener.closed
|
||||
|
||||
|
||||
async def test_serve_listeners_accept_unrecognized_error():
|
||||
for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def raise_error():
|
||||
raise error
|
||||
|
||||
listener.accept_hook = raise_error
|
||||
|
||||
with pytest.raises(type(error)) as excinfo:
|
||||
await trio.serve_listeners(None, [listener])
|
||||
assert excinfo.value is error
|
||||
|
||||
|
||||
async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog):
|
||||
listener = MemoryListener()
|
||||
|
||||
async def raise_EMFILE():
|
||||
raise OSError(errno.EMFILE, "out of file descriptors")
|
||||
|
||||
listener.accept_hook = raise_EMFILE
|
||||
|
||||
# It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
|
||||
# = 10 times total
|
||||
with trio.move_on_after(0.950):
|
||||
await trio.serve_listeners(None, [listener])
|
||||
|
||||
assert len(caplog.records) == 10
|
||||
for record in caplog.records:
|
||||
assert "retrying" in record.msg
|
||||
assert record.exc_info[1].errno == errno.EMFILE
|
||||
|
||||
|
||||
async def test_serve_listeners_connection_nursery(autojump_clock):
|
||||
listener = MemoryListener()
|
||||
|
||||
async def handler(stream):
|
||||
await trio.sleep(1)
|
||||
|
||||
class Done(Exception):
|
||||
pass
|
||||
|
||||
async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED):
|
||||
async with trio.open_nursery() as nursery:
|
||||
task_status.started(nursery)
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(nursery.child_tasks) == 10
|
||||
raise Done
|
||||
|
||||
with pytest.raises(Done):
|
||||
async with trio.open_nursery() as nursery:
|
||||
handler_nursery = await nursery.start(connection_watcher)
|
||||
await nursery.start(
|
||||
partial(
|
||||
trio.serve_listeners,
|
||||
handler,
|
||||
[listener],
|
||||
handler_nursery=handler_nursery,
|
||||
)
|
||||
)
|
||||
for _ in range(10):
|
||||
nursery.start_soon(listener.connect)
|
||||
267
venv/Lib/site-packages/trio/tests/test_highlevel_socket.py
Normal file
267
venv/Lib/site-packages/trio/tests/test_highlevel_socket.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import pytest
|
||||
|
||||
import sys
|
||||
import socket as stdlib_socket
|
||||
import errno
|
||||
|
||||
from .. import _core
|
||||
from ..testing import (
|
||||
check_half_closeable_stream,
|
||||
wait_all_tasks_blocked,
|
||||
assert_checkpoints,
|
||||
)
|
||||
from .._highlevel_socket import *
|
||||
from .. import socket as tsocket
|
||||
|
||||
|
||||
async def test_SocketStream_basics():
|
||||
# stdlib socket bad (even if connected)
|
||||
a, b = stdlib_socket.socketpair()
|
||||
with a, b:
|
||||
with pytest.raises(TypeError):
|
||||
SocketStream(a)
|
||||
|
||||
# DGRAM socket bad
|
||||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
|
||||
with pytest.raises(ValueError):
|
||||
SocketStream(sock)
|
||||
|
||||
a, b = tsocket.socketpair()
|
||||
with a, b:
|
||||
s = SocketStream(a)
|
||||
assert s.socket is a
|
||||
|
||||
# Use a real, connected socket to test socket options, because
|
||||
# socketpair() might give us a unix socket that doesn't support any of
|
||||
# these options
|
||||
with tsocket.socket() as listen_sock:
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(1)
|
||||
with tsocket.socket() as client_sock:
|
||||
await client_sock.connect(listen_sock.getsockname())
|
||||
|
||||
s = SocketStream(client_sock)
|
||||
|
||||
# TCP_NODELAY enabled by default
|
||||
assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||||
# We can disable it though
|
||||
s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
|
||||
assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||||
|
||||
b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
|
||||
assert isinstance(b, bytes)
|
||||
|
||||
|
||||
async def test_SocketStream_send_all():
|
||||
BIG = 10000000
|
||||
|
||||
a_sock, b_sock = tsocket.socketpair()
|
||||
with a_sock, b_sock:
|
||||
a = SocketStream(a_sock)
|
||||
b = SocketStream(b_sock)
|
||||
|
||||
# Check a send_all that has to be split into multiple parts (on most
|
||||
# platforms... on Windows every send() either succeeds or fails as a
|
||||
# whole)
|
||||
async def sender():
|
||||
data = bytearray(BIG)
|
||||
await a.send_all(data)
|
||||
# send_all uses memoryviews internally, which temporarily "lock"
|
||||
# the object they view. If it doesn't clean them up properly, then
|
||||
# some bytearray operations might raise an error afterwards, which
|
||||
# would be a pretty weird and annoying side-effect to spring on
|
||||
# users. So test that this doesn't happen, by forcing the
|
||||
# bytearray's underlying buffer to be realloc'ed:
|
||||
data += bytes(BIG)
|
||||
# (Note: the above line of code doesn't do a very good job at
|
||||
# testing anything, because:
|
||||
# - on CPython, the refcount GC generally cleans up memoryviews
|
||||
# for us even if we're sloppy.
|
||||
# - on PyPy3, at least as of 5.7.0, the memoryview code and the
|
||||
# bytearray code conspire so that resizing never fails – if
|
||||
# resizing forces the bytearray's internal buffer to move, then
|
||||
# all memoryview references are automagically updated (!!).
|
||||
# See:
|
||||
# https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
|
||||
# But I'm leaving the test here in hopes that if this ever changes
|
||||
# and we break our implementation of send_all, then we'll get some
|
||||
# early warning...)
|
||||
|
||||
async def receiver():
|
||||
# Make sure the sender fills up the kernel buffers and blocks
|
||||
await wait_all_tasks_blocked()
|
||||
nbytes = 0
|
||||
while nbytes < BIG:
|
||||
nbytes += len(await b.receive_some(BIG))
|
||||
assert nbytes == BIG
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
# We know that we received BIG bytes of NULs so far. Make sure that
|
||||
# was all the data in there.
|
||||
await a.send_all(b"e")
|
||||
assert await b.receive_some(10) == b"e"
|
||||
await a.send_eof()
|
||||
assert await b.receive_some(10) == b""
|
||||
|
||||
|
||||
async def fill_stream(s):
|
||||
async def sender():
|
||||
while True:
|
||||
await s.send_all(b"x" * 10000)
|
||||
|
||||
async def waiter(nursery):
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(waiter, nursery)
|
||||
|
||||
|
||||
async def test_SocketStream_generic():
|
||||
async def stream_maker():
|
||||
left, right = tsocket.socketpair()
|
||||
return SocketStream(left), SocketStream(right)
|
||||
|
||||
async def clogged_stream_maker():
|
||||
left, right = await stream_maker()
|
||||
await fill_stream(left)
|
||||
await fill_stream(right)
|
||||
return left, right
|
||||
|
||||
await check_half_closeable_stream(stream_maker, clogged_stream_maker)
|
||||
|
||||
|
||||
async def test_SocketListener():
|
||||
# Not a Trio socket
|
||||
with stdlib_socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
s.listen(10)
|
||||
with pytest.raises(TypeError):
|
||||
SocketListener(s)
|
||||
|
||||
# Not a SOCK_STREAM
|
||||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
SocketListener(s)
|
||||
excinfo.match(r".*SOCK_STREAM")
|
||||
|
||||
# Didn't call .listen()
|
||||
# macOS has no way to check for this, so skip testing it there.
|
||||
if sys.platform != "darwin":
|
||||
with tsocket.socket() as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
SocketListener(s)
|
||||
excinfo.match(r".*listen")
|
||||
|
||||
listen_sock = tsocket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(10)
|
||||
listener = SocketListener(listen_sock)
|
||||
|
||||
assert listener.socket is listen_sock
|
||||
|
||||
client_sock = tsocket.socket()
|
||||
await client_sock.connect(listen_sock.getsockname())
|
||||
with assert_checkpoints():
|
||||
server_stream = await listener.accept()
|
||||
assert isinstance(server_stream, SocketStream)
|
||||
assert server_stream.socket.getsockname() == listen_sock.getsockname()
|
||||
assert server_stream.socket.getpeername() == client_sock.getsockname()
|
||||
|
||||
with assert_checkpoints():
|
||||
await listener.aclose()
|
||||
|
||||
with assert_checkpoints():
|
||||
await listener.aclose()
|
||||
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await listener.accept()
|
||||
|
||||
client_sock.close()
|
||||
await server_stream.aclose()
|
||||
|
||||
|
||||
async def test_SocketListener_socket_closed_underfoot():
|
||||
listen_sock = tsocket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(10)
|
||||
listener = SocketListener(listen_sock)
|
||||
|
||||
# Close the socket, not the listener
|
||||
listen_sock.close()
|
||||
|
||||
# SocketListener gives correct error
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await listener.accept()
|
||||
|
||||
|
||||
async def test_SocketListener_accept_errors():
|
||||
class FakeSocket(tsocket.SocketType):
|
||||
def __init__(self, events):
|
||||
self._events = iter(events)
|
||||
|
||||
type = tsocket.SOCK_STREAM
|
||||
|
||||
# Fool the check for SO_ACCEPTCONN in SocketListener.__init__
|
||||
def getsockopt(self, level, opt):
|
||||
return True
|
||||
|
||||
def setsockopt(self, level, opt, value):
|
||||
pass
|
||||
|
||||
async def accept(self):
|
||||
await _core.checkpoint()
|
||||
event = next(self._events)
|
||||
if isinstance(event, BaseException):
|
||||
raise event
|
||||
else:
|
||||
return event, None
|
||||
|
||||
fake_server_sock = FakeSocket([])
|
||||
|
||||
fake_listen_sock = FakeSocket(
|
||||
[
|
||||
OSError(errno.ECONNABORTED, "Connection aborted"),
|
||||
OSError(errno.EPERM, "Permission denied"),
|
||||
OSError(errno.EPROTO, "Bad protocol"),
|
||||
fake_server_sock,
|
||||
OSError(errno.EMFILE, "Out of file descriptors"),
|
||||
OSError(errno.EFAULT, "attempt to write to read-only memory"),
|
||||
OSError(errno.ENOBUFS, "out of buffers"),
|
||||
fake_server_sock,
|
||||
]
|
||||
)
|
||||
|
||||
l = SocketListener(fake_listen_sock)
|
||||
|
||||
with assert_checkpoints():
|
||||
s = await l.accept()
|
||||
assert s.socket is fake_server_sock
|
||||
|
||||
for code in [errno.EMFILE, errno.EFAULT, errno.ENOBUFS]:
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(OSError) as excinfo:
|
||||
await l.accept()
|
||||
assert excinfo.value.errno == code
|
||||
|
||||
with assert_checkpoints():
|
||||
s = await l.accept()
|
||||
assert s.socket is fake_server_sock
|
||||
|
||||
|
||||
async def test_socket_stream_works_when_peer_has_already_closed():
|
||||
sock_a, sock_b = tsocket.socketpair()
|
||||
with sock_a, sock_b:
|
||||
await sock_b.send(b"x")
|
||||
sock_b.close()
|
||||
stream = SocketStream(sock_a)
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
assert await stream.receive_some(1) == b""
|
||||
113
venv/Lib/site-packages/trio/tests/test_highlevel_ssl_helpers.py
Normal file
113
venv/Lib/site-packages/trio/tests/test_highlevel_ssl_helpers.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import pytest
|
||||
|
||||
from functools import partial
|
||||
|
||||
import attr
|
||||
|
||||
import trio
|
||||
from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP
|
||||
import trio.testing
|
||||
from .test_ssl import client_ctx, SERVER_CTX
|
||||
|
||||
from .._highlevel_ssl_helpers import (
|
||||
open_ssl_over_tcp_stream,
|
||||
open_ssl_over_tcp_listeners,
|
||||
serve_ssl_over_tcp,
|
||||
)
|
||||
|
||||
|
||||
async def echo_handler(stream):
|
||||
async with stream:
|
||||
try:
|
||||
while True:
|
||||
data = await stream.receive_some(10000)
|
||||
if not data:
|
||||
break
|
||||
await stream.send_all(data)
|
||||
except trio.BrokenResourceError:
|
||||
pass
|
||||
|
||||
|
||||
# Resolver that always returns the given sockaddr, no matter what host/port
|
||||
# you ask for.
|
||||
@attr.s
|
||||
class FakeHostnameResolver(trio.abc.HostnameResolver):
|
||||
sockaddr = attr.ib()
|
||||
|
||||
async def getaddrinfo(self, *args):
|
||||
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
|
||||
|
||||
async def getnameinfo(self, *args): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
|
||||
# noqa is needed because flake8 doesn't understand how pytest fixtures work.
|
||||
async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811
|
||||
async with trio.open_nursery() as nursery:
|
||||
(listener,) = await nursery.start(
|
||||
partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1")
|
||||
)
|
||||
async with listener:
|
||||
sockaddr = listener.transport_listener.socket.getsockname()
|
||||
hostname_resolver = FakeHostnameResolver(sockaddr)
|
||||
trio.socket.set_custom_hostname_resolver(hostname_resolver)
|
||||
|
||||
# We don't have the right trust set up
|
||||
# (checks that ssl_context=None is doing some validation)
|
||||
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
|
||||
async with stream:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await stream.do_handshake()
|
||||
|
||||
# We have the trust but not the hostname
|
||||
# (checks custom ssl_context + hostname checking)
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"xyzzy.example.org", 80, ssl_context=client_ctx
|
||||
)
|
||||
async with stream:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await stream.do_handshake()
|
||||
|
||||
# This one should work!
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"trio-test-1.example.org", 80, ssl_context=client_ctx
|
||||
)
|
||||
async with stream:
|
||||
assert isinstance(stream, trio.SSLStream)
|
||||
assert stream.server_hostname == "trio-test-1.example.org"
|
||||
await stream.send_all(b"x")
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
|
||||
# Check https_compatible settings are being passed through
|
||||
assert not stream._https_compatible
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"trio-test-1.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
https_compatible=True,
|
||||
# also, smoke test happy_eyeballs_delay
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
async with stream:
|
||||
assert stream._https_compatible
|
||||
|
||||
# Stop the echo server
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_open_ssl_over_tcp_listeners():
|
||||
(listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
|
||||
async with listener:
|
||||
assert isinstance(listener, trio.SSLListener)
|
||||
tl = listener.transport_listener
|
||||
assert isinstance(tl, trio.SocketListener)
|
||||
assert tl.socket.getsockname()[0] == "127.0.0.1"
|
||||
|
||||
assert not listener._https_compatible
|
||||
|
||||
(listener,) = await open_ssl_over_tcp_listeners(
|
||||
0, SERVER_CTX, host="127.0.0.1", https_compatible=True
|
||||
)
|
||||
async with listener:
|
||||
assert listener._https_compatible
|
||||
262
venv/Lib/site-packages/trio/tests/test_path.py
Normal file
262
venv/Lib/site-packages/trio/tests/test_path.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio._path import AsyncAutoWrapperType as Type
|
||||
from trio._file_io import AsyncIOWrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def path(tmpdir):
|
||||
p = str(tmpdir.join("test"))
|
||||
return trio.Path(p)
|
||||
|
||||
|
||||
def method_pair(path, method_name):
|
||||
path = pathlib.Path(path)
|
||||
async_path = trio.Path(path)
|
||||
return getattr(path, method_name), getattr(async_path, method_name)
|
||||
|
||||
|
||||
async def test_open_is_async_context_manager(path):
|
||||
async with await path.open("w") as f:
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_magic():
|
||||
path = trio.Path("test")
|
||||
|
||||
assert str(path) == "test"
|
||||
assert bytes(path) == b"test"
|
||||
|
||||
|
||||
cls_pairs = [
|
||||
(trio.Path, pathlib.Path),
|
||||
(pathlib.Path, trio.Path),
|
||||
(trio.Path, trio.Path),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
|
||||
async def test_cmp_magic(cls_a, cls_b):
|
||||
a, b = cls_a(""), cls_b("")
|
||||
assert a == b
|
||||
assert not a != b
|
||||
|
||||
a, b = cls_a("a"), cls_b("b")
|
||||
assert a < b
|
||||
assert b > a
|
||||
|
||||
# this is intentionally testing equivalence with none, due to the
|
||||
# other=sentinel logic in _forward_magic
|
||||
assert not a == None # noqa
|
||||
assert not b == None # noqa
|
||||
|
||||
|
||||
# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but
|
||||
# __*div__ does not properly raise NotImplementedError like the other comparison
|
||||
# magic, so trio.Path's implementation does not get dispatched
|
||||
cls_pairs = [
|
||||
(trio.Path, pathlib.Path),
|
||||
(trio.Path, trio.Path),
|
||||
(trio.Path, str),
|
||||
(str, trio.Path),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
|
||||
async def test_div_magic(cls_a, cls_b):
|
||||
a, b = cls_a("a"), cls_b("b")
|
||||
|
||||
result = a / b
|
||||
assert isinstance(result, trio.Path)
|
||||
assert str(result) == os.path.join("a", "b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)]
|
||||
)
|
||||
@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
|
||||
async def test_hash_magic(cls_a, cls_b, path):
|
||||
a, b = cls_a(path), cls_b(path)
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
|
||||
async def test_forwarded_properties(path):
|
||||
# use `name` as a representative of forwarded properties
|
||||
|
||||
assert "name" in dir(path)
|
||||
assert path.name == "test"
|
||||
|
||||
|
||||
async def test_async_method_signature(path):
|
||||
# use `resolve` as a representative of wrapped methods
|
||||
|
||||
assert path.resolve.__name__ == "resolve"
|
||||
assert path.resolve.__qualname__ == "Path.resolve"
|
||||
|
||||
assert "pathlib.Path.resolve" in path.resolve.__doc__
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
|
||||
async def test_compare_async_stat_methods(method_name):
|
||||
|
||||
method, async_method = method_pair(".", method_name)
|
||||
|
||||
result = method()
|
||||
async_result = await async_method()
|
||||
|
||||
assert result == async_result
|
||||
|
||||
|
||||
async def test_invalid_name_not_wrapped(path):
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(path, "invalid_fake_attr")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
|
||||
async def test_async_methods_rewrap(method_name):
|
||||
|
||||
method, async_method = method_pair(".", method_name)
|
||||
|
||||
result = method()
|
||||
async_result = await async_method()
|
||||
|
||||
assert isinstance(async_result, trio.Path)
|
||||
assert str(result) == str(async_result)
|
||||
|
||||
|
||||
async def test_forward_methods_rewrap(path, tmpdir):
|
||||
with_name = path.with_name("foo")
|
||||
with_suffix = path.with_suffix(".py")
|
||||
|
||||
assert isinstance(with_name, trio.Path)
|
||||
assert with_name == tmpdir.join("foo")
|
||||
assert isinstance(with_suffix, trio.Path)
|
||||
assert with_suffix == tmpdir.join("test.py")
|
||||
|
||||
|
||||
async def test_forward_properties_rewrap(path):
|
||||
assert isinstance(path.parent, trio.Path)
|
||||
|
||||
|
||||
async def test_forward_methods_without_rewrap(path, tmpdir):
|
||||
path = await path.parent.resolve()
|
||||
|
||||
assert path.as_uri().startswith("file:///")
|
||||
|
||||
|
||||
async def test_repr():
|
||||
path = trio.Path(".")
|
||||
|
||||
assert repr(path) == "trio.Path('.')"
|
||||
|
||||
|
||||
class MockWrapped:
|
||||
unsupported = "unsupported"
|
||||
_private = "private"
|
||||
|
||||
|
||||
class MockWrapper:
|
||||
_forwards = MockWrapped
|
||||
_wraps = MockWrapped
|
||||
|
||||
|
||||
async def test_type_forwards_unsupported():
|
||||
with pytest.raises(TypeError):
|
||||
Type.generate_forwards(MockWrapper, {})
|
||||
|
||||
|
||||
async def test_type_wraps_unsupported():
|
||||
with pytest.raises(TypeError):
|
||||
Type.generate_wraps(MockWrapper, {})
|
||||
|
||||
|
||||
async def test_type_forwards_private():
|
||||
Type.generate_forwards(MockWrapper, {"unsupported": None})
|
||||
|
||||
assert not hasattr(MockWrapper, "_private")
|
||||
|
||||
|
||||
async def test_type_wraps_private():
|
||||
Type.generate_wraps(MockWrapper, {"unsupported": None})
|
||||
|
||||
assert not hasattr(MockWrapper, "_private")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
|
||||
async def test_path_wraps_path(path, meth):
|
||||
wrapped = await path.absolute()
|
||||
result = meth(path, wrapped)
|
||||
if result is None:
|
||||
result = path
|
||||
|
||||
assert wrapped == result
|
||||
|
||||
|
||||
async def test_path_nonpath():
|
||||
with pytest.raises(TypeError):
|
||||
trio.Path(1)
|
||||
|
||||
|
||||
async def test_open_file_can_open_path(path):
|
||||
async with await trio.open_file(path, "w") as f:
|
||||
assert f.name == os.fspath(path)
|
||||
|
||||
|
||||
async def test_globmethods(path):
|
||||
# Populate a directory tree
|
||||
await path.mkdir()
|
||||
await (path / "foo").mkdir()
|
||||
await (path / "foo" / "_bar.txt").write_bytes(b"")
|
||||
await (path / "bar.txt").write_bytes(b"")
|
||||
await (path / "bar.dat").write_bytes(b"")
|
||||
|
||||
# Path.glob
|
||||
for _pattern, _results in {
|
||||
"*.txt": {"bar.txt"},
|
||||
"**/*.txt": {"_bar.txt", "bar.txt"},
|
||||
}.items():
|
||||
entries = set()
|
||||
for entry in await path.glob(_pattern):
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == _results
|
||||
|
||||
# Path.rglob
|
||||
entries = set()
|
||||
for entry in await path.rglob("*.txt"):
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == {"_bar.txt", "bar.txt"}
|
||||
|
||||
|
||||
async def test_iterdir(path):
|
||||
# Populate a directory
|
||||
await path.mkdir()
|
||||
await (path / "foo").mkdir()
|
||||
await (path / "bar.txt").write_bytes(b"")
|
||||
|
||||
entries = set()
|
||||
for entry in await path.iterdir():
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == {"bar.txt", "foo"}
|
||||
|
||||
|
||||
async def test_classmethods():
|
||||
assert isinstance(await trio.Path.home(), trio.Path)
|
||||
|
||||
# pathlib.Path has only two classmethods
|
||||
assert str(await trio.Path.home()) == os.path.expanduser("~")
|
||||
assert str(await trio.Path.cwd()) == os.getcwd()
|
||||
|
||||
# Wrapped method has docstring
|
||||
assert trio.Path.home.__doc__
|
||||
@@ -0,0 +1,40 @@
|
||||
import trio
|
||||
|
||||
|
||||
async def scheduler_trace():
|
||||
"""Returns a scheduler-dependent value we can use to check determinism."""
|
||||
trace = []
|
||||
|
||||
async def tracer(name):
|
||||
for i in range(50):
|
||||
trace.append((name, i))
|
||||
await trio.sleep(0)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(tracer, i)
|
||||
|
||||
return tuple(trace)
|
||||
|
||||
|
||||
def test_the_trio_scheduler_is_not_deterministic():
|
||||
# At least, not yet. See https://github.com/python-trio/trio/issues/32
|
||||
traces = []
|
||||
for _ in range(10):
|
||||
traces.append(trio.run(scheduler_trace))
|
||||
assert len(set(traces)) == len(traces)
|
||||
|
||||
|
||||
def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch):
|
||||
monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
|
||||
traces = []
|
||||
for _ in range(10):
|
||||
state = trio._core._run._r.getstate()
|
||||
try:
|
||||
trio._core._run._r.seed(0)
|
||||
traces.append(trio.run(scheduler_trace))
|
||||
finally:
|
||||
trio._core._run._r.setstate(state)
|
||||
|
||||
assert len(traces) == 10
|
||||
assert len(set(traces)) == 1
|
||||
177
venv/Lib/site-packages/trio/tests/test_signals.py
Normal file
177
venv/Lib/site-packages/trio/tests/test_signals.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import signal
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from .. import _core
|
||||
from .._util import signal_raise
|
||||
from .._signals import open_signal_receiver, _signal_handler
|
||||
|
||||
|
||||
async def test_open_signal_receiver():
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
# Raise it a few times, to exercise signal coalescing, both at the
|
||||
# call_soon level and at the SignalQueue level
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGILL)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
signal_raise(signal.SIGILL)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
async for signum in receiver: # pragma: no branch
|
||||
assert signum == signal.SIGILL
|
||||
break
|
||||
assert receiver._pending_signal_count() == 0
|
||||
signal_raise(signal.SIGILL)
|
||||
async for signum in receiver: # pragma: no branch
|
||||
assert signum == signal.SIGILL
|
||||
break
|
||||
assert receiver._pending_signal_count() == 0
|
||||
with pytest.raises(RuntimeError):
|
||||
await receiver.__anext__()
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_open_signal_receiver_restore_handler_after_one_bad_signal():
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with pytest.raises(ValueError):
|
||||
with open_signal_receiver(signal.SIGILL, 1234567):
|
||||
pass # pragma: no cover
|
||||
# Still restored even if we errored out
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_open_signal_receiver_empty_fail():
|
||||
with pytest.raises(TypeError, match="No signals were provided"):
|
||||
with open_signal_receiver():
|
||||
pass
|
||||
|
||||
|
||||
async def test_open_signal_receiver_restore_handler_after_duplicate_signal():
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGILL):
|
||||
pass
|
||||
# Still restored correctly
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_catch_signals_wrong_thread():
|
||||
async def naughty():
|
||||
with open_signal_receiver(signal.SIGINT):
|
||||
pass # pragma: no cover
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await trio.to_thread.run_sync(trio.run, naughty)
|
||||
|
||||
|
||||
async def test_open_signal_receiver_conflict():
|
||||
with pytest.raises(trio.BusyResourceError):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver.__anext__)
|
||||
nursery.start_soon(receiver.__anext__)
|
||||
|
||||
|
||||
# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
|
||||
# processed.
|
||||
async def wait_run_sync_soon_idempotent_queue_barrier():
|
||||
ev = trio.Event()
|
||||
token = _core.current_trio_token()
|
||||
token.run_sync_soon(ev.set, idempotent=True)
|
||||
await ev.wait()
|
||||
|
||||
|
||||
async def test_open_signal_receiver_no_starvation():
|
||||
# Set up a situation where there are always 2 pending signals available to
|
||||
# report, and make sure that instead of getting the same signal reported
|
||||
# over and over, it alternates between reporting both of them.
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
try:
|
||||
print(signal.getsignal(signal.SIGILL))
|
||||
previous = None
|
||||
for _ in range(10):
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
if previous is None:
|
||||
previous = await receiver.__anext__()
|
||||
else:
|
||||
got = await receiver.__anext__()
|
||||
assert got in [signal.SIGILL, signal.SIGFPE]
|
||||
assert got != previous
|
||||
previous = got
|
||||
# Clear out the last signal so it doesn't get redelivered
|
||||
while receiver._pending_signal_count() != 0:
|
||||
await receiver.__anext__()
|
||||
except: # pragma: no cover
|
||||
# If there's an unhandled exception above, then exiting the
|
||||
# open_signal_receiver block might cause the signal to be
|
||||
# redelivered and give us a core dump instead of a traceback...
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def test_catch_signals_race_condition_on_exit():
|
||||
delivered_directly = set()
|
||||
|
||||
def direct_handler(signo, frame):
|
||||
delivered_directly.add(signo)
|
||||
|
||||
print(1)
|
||||
# Test the version where the call_soon *doesn't* have a chance to run
|
||||
# before we exit the with block:
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
|
||||
delivered_directly.clear()
|
||||
|
||||
print(2)
|
||||
# Test the version where the call_soon *does* have a chance to run before
|
||||
# we exit the with block:
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert receiver._pending_signal_count() == 2
|
||||
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
|
||||
delivered_directly.clear()
|
||||
|
||||
# Again, but with a SIG_IGN signal:
|
||||
|
||||
print(3)
|
||||
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
# test passes if the process reaches this point without dying
|
||||
|
||||
print(4)
|
||||
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert receiver._pending_signal_count() == 1
|
||||
# test passes if the process reaches this point without dying
|
||||
|
||||
# Check exception chaining if there are multiple exception-raising
|
||||
# handlers
|
||||
def raise_handler(signum, _):
|
||||
raise RuntimeError(signum)
|
||||
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert receiver._pending_signal_count() == 2
|
||||
exc = excinfo.value
|
||||
signums = {exc.args[0]}
|
||||
assert isinstance(exc.__context__, RuntimeError)
|
||||
signums.add(exc.__context__.args[0])
|
||||
assert signums == {signal.SIGILL, signal.SIGFPE}
|
||||
1015
venv/Lib/site-packages/trio/tests/test_socket.py
Normal file
1015
venv/Lib/site-packages/trio/tests/test_socket.py
Normal file
File diff suppressed because it is too large
Load Diff
1298
venv/Lib/site-packages/trio/tests/test_ssl.py
Normal file
1298
venv/Lib/site-packages/trio/tests/test_ssl.py
Normal file
File diff suppressed because it is too large
Load Diff
602
venv/Lib/site-packages/trio/tests/test_subprocess.py
Normal file
602
venv/Lib/site-packages/trio/tests/test_subprocess.py
Normal file
@@ -0,0 +1,602 @@
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path as SyncPath
|
||||
|
||||
import pytest
|
||||
from async_generator import asynccontextmanager
|
||||
|
||||
from .. import (
|
||||
ClosedResourceError,
|
||||
Event,
|
||||
Process,
|
||||
_core,
|
||||
fail_after,
|
||||
move_on_after,
|
||||
run_process,
|
||||
sleep,
|
||||
sleep_forever,
|
||||
)
|
||||
from .._core.tests.tutil import skip_if_fbsd_pipes_broken, slow
|
||||
from ..lowlevel import open_process
|
||||
from ..testing import assert_no_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
posix = os.name == "posix"
|
||||
if posix:
|
||||
from signal import SIGKILL, SIGTERM, SIGUSR1
|
||||
else:
|
||||
SIGKILL, SIGTERM, SIGUSR1 = None, None, None
|
||||
|
||||
|
||||
# Since Windows has very few command-line utilities generally available,
|
||||
# all of our subprocesses are Python processes running short bits of
|
||||
# (mostly) cross-platform code.
|
||||
def python(code):
|
||||
return [sys.executable, "-u", "-c", "import sys; " + code]
|
||||
|
||||
|
||||
EXIT_TRUE = python("sys.exit(0)")
|
||||
EXIT_FALSE = python("sys.exit(1)")
|
||||
CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())")
|
||||
|
||||
if posix:
|
||||
SLEEP = lambda seconds: ["/bin/sleep", str(seconds)]
|
||||
else:
|
||||
SLEEP = lambda seconds: python("import time; time.sleep({})".format(seconds))
|
||||
|
||||
|
||||
def got_signal(proc, sig):
|
||||
if posix:
|
||||
return proc.returncode == -sig
|
||||
else:
|
||||
return proc.returncode != 0
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_process_then_kill(*args, **kwargs):
|
||||
proc = await open_process(*args, **kwargs)
|
||||
try:
|
||||
yield proc
|
||||
finally:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_process_in_nursery(*args, **kwargs):
|
||||
async with _core.open_nursery() as nursery:
|
||||
kwargs.setdefault("check", False)
|
||||
proc = await nursery.start(partial(run_process, *args, **kwargs))
|
||||
yield proc
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
background_process_param = pytest.mark.parametrize(
|
||||
"background_process",
|
||||
[open_process_then_kill, run_process_in_nursery],
|
||||
ids=["open_process", "run_process in nursery"],
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_basic(background_process):
|
||||
async with background_process(EXIT_TRUE) as proc:
|
||||
await proc.wait()
|
||||
assert isinstance(proc, Process)
|
||||
assert proc._pidfd is None
|
||||
assert proc.returncode == 0
|
||||
assert repr(proc) == f"<trio.Process {EXIT_TRUE}: exited with status 0>"
|
||||
|
||||
async with background_process(EXIT_FALSE) as proc:
|
||||
await proc.wait()
|
||||
assert proc.returncode == 1
|
||||
assert repr(proc) == "<trio.Process {!r}: {}>".format(
|
||||
EXIT_FALSE, "exited with status 1"
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_auto_update_returncode(background_process):
|
||||
async with background_process(SLEEP(9999)) as p:
|
||||
assert p.returncode is None
|
||||
assert "running" in repr(p)
|
||||
p.kill()
|
||||
p._proc.wait()
|
||||
assert p.returncode is not None
|
||||
assert "exited" in repr(p)
|
||||
assert p._pidfd is None
|
||||
assert p.returncode is not None
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_multi_wait(background_process):
|
||||
async with background_process(SLEEP(10)) as proc:
|
||||
# Check that wait (including multi-wait) tolerates being cancelled
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# Now try waiting for real
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
|
||||
|
||||
# Test for deprecated 'async with process:' semantics
|
||||
async def test_async_with_basics_deprecated(recwarn):
|
||||
async with await open_process(
|
||||
CAT, stdin=subprocess.PIPE, stdout=subprocess.PIPE
|
||||
) as proc:
|
||||
pass
|
||||
assert proc.returncode is not None
|
||||
with pytest.raises(ClosedResourceError):
|
||||
await proc.stdin.send_all(b"x")
|
||||
with pytest.raises(ClosedResourceError):
|
||||
await proc.stdout.receive_some()
|
||||
|
||||
|
||||
# Test for deprecated 'async with process:' semantics
|
||||
async def test_kill_when_context_cancelled(recwarn):
|
||||
with move_on_after(100) as scope:
|
||||
async with await open_process(SLEEP(10)) as proc:
|
||||
assert proc.poll() is None
|
||||
scope.cancel()
|
||||
await sleep_forever()
|
||||
assert scope.cancelled_caught
|
||||
assert got_signal(proc, SIGKILL)
|
||||
assert repr(proc) == "<trio.Process {!r}: {}>".format(
|
||||
SLEEP(10), "exited with signal 9" if posix else "exited with status 1"
|
||||
)
|
||||
|
||||
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python(
|
||||
"data = sys.stdin.buffer.read(); "
|
||||
"sys.stdout.buffer.write(data); "
|
||||
"sys.stderr.buffer.write(data[::-1])"
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_pipes(background_process):
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
msg = b"the quick brown fox jumps over the lazy dog"
|
||||
|
||||
async def feed_input():
|
||||
await proc.stdin.send_all(msg)
|
||||
await proc.stdin.aclose()
|
||||
|
||||
async def check_output(stream, expected):
|
||||
seen = bytearray()
|
||||
async for chunk in stream:
|
||||
seen += chunk
|
||||
assert seen == expected
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
# fail eventually if something is broken
|
||||
nursery.cancel_scope.deadline = _core.current_time() + 30.0
|
||||
nursery.start_soon(feed_input)
|
||||
nursery.start_soon(check_output, proc.stdout, msg)
|
||||
nursery.start_soon(check_output, proc.stderr, msg[::-1])
|
||||
|
||||
assert not nursery.cancel_scope.cancelled_caught
|
||||
assert 0 == await proc.wait()
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_interactive(background_process):
|
||||
# Test some back-and-forth with a subprocess. This one works like so:
|
||||
# in: 32\n
|
||||
# out: 0000...0000\n (32 zeroes)
|
||||
# err: 1111...1111\n (64 ones)
|
||||
# in: 10\n
|
||||
# out: 2222222222\n (10 twos)
|
||||
# err: 3333....3333\n (20 threes)
|
||||
# in: EOF
|
||||
# out: EOF
|
||||
# err: EOF
|
||||
|
||||
async with background_process(
|
||||
python(
|
||||
"idx = 0\n"
|
||||
"while True:\n"
|
||||
" line = sys.stdin.readline()\n"
|
||||
" if line == '': break\n"
|
||||
" request = int(line.strip())\n"
|
||||
" print(str(idx * 2) * request)\n"
|
||||
" print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n"
|
||||
" idx += 1\n"
|
||||
),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
|
||||
newline = b"\n" if posix else b"\r\n"
|
||||
|
||||
async def expect(idx, request):
|
||||
async with _core.open_nursery() as nursery:
|
||||
|
||||
async def drain_one(stream, count, digit):
|
||||
while count > 0:
|
||||
result = await stream.receive_some(count)
|
||||
assert result == (
|
||||
"{}".format(digit).encode("utf-8") * len(result)
|
||||
)
|
||||
count -= len(result)
|
||||
assert count == 0
|
||||
assert await stream.receive_some(len(newline)) == newline
|
||||
|
||||
nursery.start_soon(drain_one, proc.stdout, request, idx * 2)
|
||||
nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1)
|
||||
|
||||
with fail_after(5):
|
||||
await proc.stdin.send_all(b"12")
|
||||
await sleep(0.1)
|
||||
await proc.stdin.send_all(b"345" + newline)
|
||||
await expect(0, 12345)
|
||||
await proc.stdin.send_all(b"100" + newline + b"200" + newline)
|
||||
await expect(1, 100)
|
||||
await expect(2, 200)
|
||||
await proc.stdin.send_all(b"0" + newline)
|
||||
await expect(3, 0)
|
||||
await proc.stdin.send_all(b"999999")
|
||||
with move_on_after(0.1) as scope:
|
||||
await expect(4, 0)
|
||||
assert scope.cancelled_caught
|
||||
await proc.stdin.send_all(newline)
|
||||
await expect(4, 999999)
|
||||
await proc.stdin.aclose()
|
||||
assert await proc.stdout.receive_some(1) == b""
|
||||
assert await proc.stderr.receive_some(1) == b""
|
||||
await proc.wait()
|
||||
|
||||
assert proc.returncode == 0
|
||||
|
||||
|
||||
async def test_run():
|
||||
data = bytes(random.randint(0, 255) for _ in range(2**18))
|
||||
|
||||
result = await run_process(
|
||||
CAT, stdin=data, capture_stdout=True, capture_stderr=True
|
||||
)
|
||||
assert result.args == CAT
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == data
|
||||
assert result.stderr == b""
|
||||
|
||||
result = await run_process(CAT, capture_stdout=True)
|
||||
assert result.args == CAT
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == b""
|
||||
assert result.stderr is None
|
||||
|
||||
result = await run_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=data,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
)
|
||||
assert result.args == COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == data
|
||||
assert result.stderr == data[::-1]
|
||||
|
||||
# invalid combinations
|
||||
with pytest.raises(UnicodeError):
|
||||
await run_process(CAT, stdin="oh no, it's text")
|
||||
with pytest.raises(ValueError):
|
||||
await run_process(CAT, stdin=subprocess.PIPE)
|
||||
with pytest.raises(ValueError):
|
||||
await run_process(CAT, stdout=subprocess.PIPE)
|
||||
with pytest.raises(ValueError):
|
||||
await run_process(CAT, stderr=subprocess.PIPE)
|
||||
with pytest.raises(ValueError):
|
||||
await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL)
|
||||
with pytest.raises(ValueError):
|
||||
await run_process(CAT, capture_stderr=True, stderr=None)
|
||||
|
||||
|
||||
async def test_run_check():
|
||||
cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)")
|
||||
with pytest.raises(subprocess.CalledProcessError) as excinfo:
|
||||
await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True)
|
||||
assert excinfo.value.cmd == cmd
|
||||
assert excinfo.value.returncode == 1
|
||||
assert excinfo.value.stderr == b"test\n"
|
||||
assert excinfo.value.stdout is None
|
||||
|
||||
result = await run_process(
|
||||
cmd, capture_stdout=True, capture_stderr=True, check=False
|
||||
)
|
||||
assert result.args == cmd
|
||||
assert result.stdout == b""
|
||||
assert result.stderr == b"test\n"
|
||||
assert result.returncode == 1
|
||||
|
||||
|
||||
@skip_if_fbsd_pipes_broken
|
||||
async def test_run_with_broken_pipe():
|
||||
result = await run_process(
|
||||
[sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert result.stdout is result.stderr is None
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_stderr_stdout(background_process):
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is None
|
||||
await proc.stdio.send_all(b"1234")
|
||||
await proc.stdio.send_eof()
|
||||
|
||||
output = []
|
||||
while True:
|
||||
chunk = await proc.stdio.receive_some(16)
|
||||
if chunk == b"":
|
||||
break
|
||||
output.append(chunk)
|
||||
assert b"".join(output) == b"12344321"
|
||||
assert proc.returncode == 0
|
||||
|
||||
# equivalent test with run_process()
|
||||
result = await run_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=b"1234",
|
||||
capture_stdout=True,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == b"12344321"
|
||||
assert result.stderr is None
|
||||
|
||||
# this one hits the branch where stderr=STDOUT but stdout
|
||||
# is not redirected
|
||||
async with background_process(
|
||||
CAT, stdin=subprocess.PIPE, stderr=subprocess.STDOUT
|
||||
) as proc:
|
||||
assert proc.stdout is None
|
||||
assert proc.stderr is None
|
||||
await proc.stdin.aclose()
|
||||
await proc.wait()
|
||||
assert proc.returncode == 0
|
||||
|
||||
if posix:
|
||||
try:
|
||||
r, w = os.pipe()
|
||||
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=w,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
os.close(w)
|
||||
assert proc.stdio is None
|
||||
assert proc.stdout is None
|
||||
assert proc.stderr is None
|
||||
await proc.stdin.send_all(b"1234")
|
||||
await proc.stdin.aclose()
|
||||
assert await proc.wait() == 0
|
||||
assert os.read(r, 4096) == b"12344321"
|
||||
assert os.read(r, 4096) == b""
|
||||
finally:
|
||||
os.close(r)
|
||||
|
||||
|
||||
async def test_errors():
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
await open_process(["ls"], encoding="utf-8")
|
||||
assert "unbuffered byte streams" in str(excinfo.value)
|
||||
assert "the 'encoding' option is not supported" in str(excinfo.value)
|
||||
|
||||
if posix:
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
await open_process(["ls"], shell=True)
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
await open_process("ls", shell=False)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_signals(background_process):
|
||||
async def test_one_signal(send_it, signum):
|
||||
with move_on_after(1.0) as scope:
|
||||
async with background_process(SLEEP(3600)) as proc:
|
||||
send_it(proc)
|
||||
await proc.wait()
|
||||
assert not scope.cancelled_caught
|
||||
if posix:
|
||||
assert proc.returncode == -signum
|
||||
else:
|
||||
assert proc.returncode != 0
|
||||
|
||||
await test_one_signal(Process.kill, SIGKILL)
|
||||
await test_one_signal(Process.terminate, SIGTERM)
|
||||
# Test that we can send arbitrary signals.
|
||||
#
|
||||
# We used to use SIGINT here, but it turns out that the Python interpreter
|
||||
# has race conditions that can cause it to explode in weird ways if it
|
||||
# tries to handle SIGINT during startup. SIGUSR1's default disposition is
|
||||
# to terminate the target process, and Python doesn't try to do anything
|
||||
# clever to handle it.
|
||||
if posix:
|
||||
await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not posix, reason="POSIX specific")
|
||||
@background_process_param
|
||||
async def test_wait_reapable_fails(background_process):
|
||||
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
|
||||
try:
|
||||
# With SIGCHLD disabled, the wait() syscall will wait for the
|
||||
# process to exit but then fail with ECHILD. Make sure we
|
||||
# support this case as the stdlib subprocess module does.
|
||||
async with background_process(SLEEP(3600)) as proc:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
nursery.cancel_scope.deadline = _core.current_time() + 1.0
|
||||
assert not nursery.cancel_scope.cancelled_caught
|
||||
assert proc.returncode == 0 # exit status unknowable, so...
|
||||
finally:
|
||||
signal.signal(signal.SIGCHLD, old_sigchld)
|
||||
|
||||
|
||||
@slow
|
||||
def test_waitid_eintr():
|
||||
# This only matters on PyPy (where we're coding EINTR handling
|
||||
# ourselves) but the test works on all waitid platforms.
|
||||
from .._subprocess_platform import wait_child_exiting
|
||||
|
||||
if not wait_child_exiting.__module__.endswith("waitid"):
|
||||
pytest.skip("waitid only")
|
||||
from .._subprocess_platform.waitid import sync_wait_reapable
|
||||
|
||||
got_alarm = False
|
||||
sleeper = subprocess.Popen(["sleep", "3600"])
|
||||
|
||||
def on_alarm(sig, frame):
|
||||
nonlocal got_alarm
|
||||
got_alarm = True
|
||||
sleeper.kill()
|
||||
|
||||
old_sigalrm = signal.signal(signal.SIGALRM, on_alarm)
|
||||
try:
|
||||
signal.alarm(1)
|
||||
sync_wait_reapable(sleeper.pid)
|
||||
assert sleeper.wait(timeout=1) == -9
|
||||
finally:
|
||||
if sleeper.returncode is None: # pragma: no cover
|
||||
# We only get here if something fails in the above;
|
||||
# if the test passes, wait() will reap the process
|
||||
sleeper.kill()
|
||||
sleeper.wait()
|
||||
signal.signal(signal.SIGALRM, old_sigalrm)
|
||||
|
||||
|
||||
async def test_custom_deliver_cancel():
|
||||
custom_deliver_cancel_called = False
|
||||
|
||||
async def custom_deliver_cancel(proc):
|
||||
nonlocal custom_deliver_cancel_called
|
||||
custom_deliver_cancel_called = True
|
||||
proc.terminate()
|
||||
# Make sure this does get cancelled when the process exits, and that
|
||||
# the process really exited.
|
||||
try:
|
||||
await sleep_forever()
|
||||
finally:
|
||||
assert proc.returncode is not None
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel)
|
||||
)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
assert custom_deliver_cancel_called
|
||||
|
||||
|
||||
async def test_warn_on_failed_cancel_terminate(monkeypatch):
|
||||
original_terminate = Process.terminate
|
||||
|
||||
def broken_terminate(self):
|
||||
original_terminate(self)
|
||||
raise OSError("whoops")
|
||||
|
||||
monkeypatch.setattr(Process, "terminate", broken_terminate)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match=".*whoops.*"):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(run_process, SLEEP(9999))
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not posix, reason="posix only")
|
||||
async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch):
|
||||
monkeypatch.setattr(Process, "terminate", lambda *args: None)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(run_process, SLEEP(9999))
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
# the background_process_param exercises a lot of run_process cases, but it uses
|
||||
# check=False, so lets have a test that uses check=True as well
|
||||
async def test_run_process_background_fail():
|
||||
with pytest.raises(subprocess.CalledProcessError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
proc = await nursery.start(run_process, EXIT_FALSE)
|
||||
assert proc.returncode == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not SyncPath("/dev/fd").exists(),
|
||||
reason="requires a way to iterate through open files",
|
||||
)
|
||||
async def test_for_leaking_fds():
|
||||
starting_fds = set(SyncPath("/dev/fd").iterdir())
|
||||
await run_process(EXIT_TRUE)
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
|
||||
|
||||
with pytest.raises(subprocess.CalledProcessError):
|
||||
await run_process(EXIT_FALSE)
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
|
||||
|
||||
with pytest.raises(PermissionError):
|
||||
await run_process(["/dev/fd/0"])
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
|
||||
|
||||
|
||||
# regression test for #2209
|
||||
async def test_subprocess_pidfd_unnotified():
|
||||
noticed_exit = None
|
||||
|
||||
async def wait_and_tell(proc) -> None:
|
||||
nonlocal noticed_exit
|
||||
noticed_exit = Event()
|
||||
await proc.wait()
|
||||
noticed_exit.set()
|
||||
|
||||
proc = await open_process(SLEEP(9999))
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_and_tell, proc)
|
||||
await wait_all_tasks_blocked()
|
||||
assert isinstance(noticed_exit, Event)
|
||||
proc.terminate()
|
||||
# without giving trio a chance to do so,
|
||||
with assert_no_checkpoints():
|
||||
# wait until the process has actually exited;
|
||||
proc._proc.wait()
|
||||
# force a call to poll (that closes the pidfd on linux)
|
||||
proc.poll()
|
||||
with move_on_after(5):
|
||||
# Some platforms use threads to wait for exit, so it might take a bit
|
||||
# for everything to notice
|
||||
await noticed_exit.wait()
|
||||
assert noticed_exit.is_set(), "child task wasn't woken after poll, DEADLOCK"
|
||||
570
venv/Lib/site-packages/trio/tests/test_sync.py
Normal file
570
venv/Lib/site-packages/trio/tests/test_sync.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import pytest
|
||||
|
||||
import weakref
|
||||
|
||||
from ..testing import wait_all_tasks_blocked, assert_checkpoints
|
||||
|
||||
from .. import _core
|
||||
from .. import _timeouts
|
||||
from .._timeouts import sleep_forever, move_on_after
|
||||
from .._sync import *
|
||||
|
||||
|
||||
async def test_Event():
|
||||
e = Event()
|
||||
assert not e.is_set()
|
||||
assert e.statistics().tasks_waiting == 0
|
||||
|
||||
e.set()
|
||||
assert e.is_set()
|
||||
with assert_checkpoints():
|
||||
await e.wait()
|
||||
|
||||
e = Event()
|
||||
|
||||
record = []
|
||||
|
||||
async def child():
|
||||
record.append("sleeping")
|
||||
await e.wait()
|
||||
record.append("woken")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
nursery.start_soon(child)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["sleeping", "sleeping"]
|
||||
assert e.statistics().tasks_waiting == 2
|
||||
e.set()
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["sleeping", "sleeping", "woken", "woken"]
|
||||
|
||||
|
||||
async def test_CapacityLimiter():
|
||||
with pytest.raises(TypeError):
|
||||
CapacityLimiter(1.0)
|
||||
with pytest.raises(ValueError):
|
||||
CapacityLimiter(-1)
|
||||
c = CapacityLimiter(2)
|
||||
repr(c) # smoke test
|
||||
assert c.total_tokens == 2
|
||||
assert c.borrowed_tokens == 0
|
||||
assert c.available_tokens == 2
|
||||
with pytest.raises(RuntimeError):
|
||||
c.release()
|
||||
assert c.borrowed_tokens == 0
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
assert c.available_tokens == 1
|
||||
|
||||
stats = c.statistics()
|
||||
assert stats.borrowed_tokens == 1
|
||||
assert stats.total_tokens == 2
|
||||
assert stats.borrowers == [_core.current_task()]
|
||||
assert stats.tasks_waiting == 0
|
||||
|
||||
# Can't re-acquire when we already have it
|
||||
with pytest.raises(RuntimeError):
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
with pytest.raises(RuntimeError):
|
||||
await c.acquire()
|
||||
assert c.borrowed_tokens == 1
|
||||
|
||||
# We can acquire on behalf of someone else though
|
||||
with assert_checkpoints():
|
||||
await c.acquire_on_behalf_of("someone")
|
||||
|
||||
# But then we've run out of capacity
|
||||
assert c.borrowed_tokens == 2
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
c.acquire_on_behalf_of_nowait("third party")
|
||||
|
||||
assert set(c.statistics().borrowers) == {_core.current_task(), "someone"}
|
||||
|
||||
# Until we release one
|
||||
c.release_on_behalf_of(_core.current_task())
|
||||
assert c.statistics().borrowers == ["someone"]
|
||||
|
||||
c.release_on_behalf_of("someone")
|
||||
assert c.borrowed_tokens == 0
|
||||
with assert_checkpoints():
|
||||
async with c:
|
||||
assert c.borrowed_tokens == 1
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
await c.acquire_on_behalf_of("value 1")
|
||||
await c.acquire_on_behalf_of("value 2")
|
||||
nursery.start_soon(c.acquire_on_behalf_of, "value 3")
|
||||
await wait_all_tasks_blocked()
|
||||
assert c.borrowed_tokens == 2
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.release_on_behalf_of("value 2")
|
||||
# Fairness:
|
||||
assert c.borrowed_tokens == 2
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
c.acquire_nowait()
|
||||
|
||||
c.release_on_behalf_of("value 3")
|
||||
c.release_on_behalf_of("value 1")
|
||||
|
||||
|
||||
async def test_CapacityLimiter_inf():
|
||||
from math import inf
|
||||
|
||||
c = CapacityLimiter(inf)
|
||||
repr(c) # smoke test
|
||||
assert c.total_tokens == inf
|
||||
assert c.borrowed_tokens == 0
|
||||
assert c.available_tokens == inf
|
||||
with pytest.raises(RuntimeError):
|
||||
c.release()
|
||||
assert c.borrowed_tokens == 0
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
assert c.available_tokens == inf
|
||||
|
||||
|
||||
async def test_CapacityLimiter_change_total_tokens():
|
||||
c = CapacityLimiter(2)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
c.total_tokens = 1.0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
c.total_tokens = 0
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
c.total_tokens = -10
|
||||
|
||||
assert c.total_tokens == 2
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(c.acquire_on_behalf_of, i)
|
||||
await wait_all_tasks_blocked()
|
||||
assert set(c.statistics().borrowers) == {0, 1}
|
||||
assert c.statistics().tasks_waiting == 3
|
||||
c.total_tokens += 2
|
||||
assert set(c.statistics().borrowers) == {0, 1, 2, 3}
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.total_tokens -= 3
|
||||
assert c.borrowed_tokens == 4
|
||||
assert c.total_tokens == 1
|
||||
c.release_on_behalf_of(0)
|
||||
c.release_on_behalf_of(1)
|
||||
c.release_on_behalf_of(2)
|
||||
assert set(c.statistics().borrowers) == {3}
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.release_on_behalf_of(3)
|
||||
assert set(c.statistics().borrowers) == {4}
|
||||
assert c.statistics().tasks_waiting == 0
|
||||
|
||||
|
||||
# regression test for issue #548
|
||||
async def test_CapacityLimiter_memleak_548():
|
||||
limiter = CapacityLimiter(total_tokens=1)
|
||||
await limiter.acquire()
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(limiter.acquire)
|
||||
await wait_all_tasks_blocked() # give it a chance to run the task
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
# if this is 1, the acquire call (despite being killed) is still there in the task, and will
|
||||
# leak memory all the while the limiter is active
|
||||
assert len(limiter._pending_borrowers) == 0
|
||||
|
||||
|
||||
async def test_Semaphore():
|
||||
with pytest.raises(TypeError):
|
||||
Semaphore(1.0)
|
||||
with pytest.raises(ValueError):
|
||||
Semaphore(-1)
|
||||
s = Semaphore(1)
|
||||
repr(s) # smoke test
|
||||
assert s.value == 1
|
||||
assert s.max_value is None
|
||||
s.release()
|
||||
assert s.value == 2
|
||||
assert s.statistics().tasks_waiting == 0
|
||||
s.acquire_nowait()
|
||||
assert s.value == 1
|
||||
with assert_checkpoints():
|
||||
await s.acquire()
|
||||
assert s.value == 0
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
s.acquire_nowait()
|
||||
|
||||
s.release()
|
||||
assert s.value == 1
|
||||
with assert_checkpoints():
|
||||
async with s:
|
||||
assert s.value == 0
|
||||
assert s.value == 1
|
||||
s.acquire_nowait()
|
||||
|
||||
record = []
|
||||
|
||||
async def do_acquire(s):
|
||||
record.append("started")
|
||||
await s.acquire()
|
||||
record.append("finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_acquire, s)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["started"]
|
||||
assert s.value == 0
|
||||
s.release()
|
||||
# Fairness:
|
||||
assert s.value == 0
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
s.acquire_nowait()
|
||||
assert record == ["started", "finished"]
|
||||
|
||||
|
||||
async def test_Semaphore_bounded():
|
||||
with pytest.raises(TypeError):
|
||||
Semaphore(1, max_value=1.0)
|
||||
with pytest.raises(ValueError):
|
||||
Semaphore(2, max_value=1)
|
||||
bs = Semaphore(1, max_value=1)
|
||||
assert bs.max_value == 1
|
||||
repr(bs) # smoke test
|
||||
with pytest.raises(ValueError):
|
||||
bs.release()
|
||||
assert bs.value == 1
|
||||
bs.acquire_nowait()
|
||||
assert bs.value == 0
|
||||
bs.release()
|
||||
assert bs.value == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
|
||||
async def test_Lock_and_StrictFIFOLock(lockcls):
|
||||
l = lockcls() # noqa
|
||||
assert not l.locked()
|
||||
|
||||
# make sure locks can be weakref'ed (gh-331)
|
||||
r = weakref.ref(l)
|
||||
assert r() is l
|
||||
|
||||
repr(l) # smoke test
|
||||
# make sure repr uses the right name for subclasses
|
||||
assert lockcls.__name__ in repr(l)
|
||||
with assert_checkpoints():
|
||||
async with l:
|
||||
assert l.locked()
|
||||
repr(l) # smoke test (repr branches on locked/unlocked)
|
||||
assert not l.locked()
|
||||
l.acquire_nowait()
|
||||
assert l.locked()
|
||||
l.release()
|
||||
assert not l.locked()
|
||||
with assert_checkpoints():
|
||||
await l.acquire()
|
||||
assert l.locked()
|
||||
l.release()
|
||||
assert not l.locked()
|
||||
|
||||
l.acquire_nowait()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Error out if we already own the lock
|
||||
l.acquire_nowait()
|
||||
l.release()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Error out if we don't own the lock
|
||||
l.release()
|
||||
|
||||
holder_task = None
|
||||
|
||||
async def holder():
|
||||
nonlocal holder_task
|
||||
holder_task = _core.current_task()
|
||||
async with l:
|
||||
await sleep_forever()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
assert not l.locked()
|
||||
nursery.start_soon(holder)
|
||||
await wait_all_tasks_blocked()
|
||||
assert l.locked()
|
||||
# WouldBlock if someone else holds the lock
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
l.acquire_nowait()
|
||||
# Can't release a lock someone else holds
|
||||
with pytest.raises(RuntimeError):
|
||||
l.release()
|
||||
|
||||
statistics = l.statistics()
|
||||
print(statistics)
|
||||
assert statistics.locked
|
||||
assert statistics.owner is holder_task
|
||||
assert statistics.tasks_waiting == 0
|
||||
|
||||
nursery.start_soon(holder)
|
||||
await wait_all_tasks_blocked()
|
||||
statistics = l.statistics()
|
||||
print(statistics)
|
||||
assert statistics.tasks_waiting == 1
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
statistics = l.statistics()
|
||||
assert not statistics.locked
|
||||
assert statistics.owner is None
|
||||
assert statistics.tasks_waiting == 0
|
||||
|
||||
|
||||
async def test_Condition():
|
||||
with pytest.raises(TypeError):
|
||||
Condition(Semaphore(1))
|
||||
with pytest.raises(TypeError):
|
||||
Condition(StrictFIFOLock)
|
||||
l = Lock() # noqa
|
||||
c = Condition(l)
|
||||
assert not l.locked()
|
||||
assert not c.locked()
|
||||
with assert_checkpoints():
|
||||
await c.acquire()
|
||||
assert l.locked()
|
||||
assert c.locked()
|
||||
|
||||
c = Condition()
|
||||
assert not c.locked()
|
||||
c.acquire_nowait()
|
||||
assert c.locked()
|
||||
with pytest.raises(RuntimeError):
|
||||
c.acquire_nowait()
|
||||
c.release()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't wait without holding the lock
|
||||
await c.wait()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't notify without holding the lock
|
||||
c.notify()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't notify without holding the lock
|
||||
c.notify_all()
|
||||
|
||||
finished_waiters = set()
|
||||
|
||||
async def waiter(i):
|
||||
async with c:
|
||||
await c.wait()
|
||||
finished_waiters.add(i)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i)
|
||||
await wait_all_tasks_blocked()
|
||||
async with c:
|
||||
c.notify()
|
||||
assert c.locked()
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0}
|
||||
async with c:
|
||||
c.notify_all()
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0, 1, 2}
|
||||
|
||||
finished_waiters = set()
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i)
|
||||
await wait_all_tasks_blocked()
|
||||
async with c:
|
||||
c.notify(2)
|
||||
statistics = c.statistics()
|
||||
print(statistics)
|
||||
assert statistics.tasks_waiting == 1
|
||||
assert statistics.lock_statistics.tasks_waiting == 2
|
||||
# exiting the context manager hands off the lock to the first task
|
||||
assert c.statistics().lock_statistics.tasks_waiting == 1
|
||||
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0, 1}
|
||||
|
||||
async with c:
|
||||
c.notify_all()
|
||||
|
||||
# After being cancelled still hold the lock (!)
|
||||
# (Note that c.__aexit__ checks that we hold the lock as well)
|
||||
with _core.CancelScope() as scope:
|
||||
async with c:
|
||||
scope.cancel()
|
||||
try:
|
||||
await c.wait()
|
||||
finally:
|
||||
assert c.locked()
|
||||
|
||||
|
||||
from .._sync import async_cm
|
||||
from .._channel import open_memory_channel
|
||||
|
||||
# Three ways of implementing a Lock in terms of a channel. Used to let us put
|
||||
# the channel through the generic lock tests.
|
||||
|
||||
|
||||
@async_cm
|
||||
class ChannelLock1:
|
||||
def __init__(self, capacity):
|
||||
self.s, self.r = open_memory_channel(capacity)
|
||||
for _ in range(capacity - 1):
|
||||
self.s.send_nowait(None)
|
||||
|
||||
def acquire_nowait(self):
|
||||
self.s.send_nowait(None)
|
||||
|
||||
async def acquire(self):
|
||||
await self.s.send(None)
|
||||
|
||||
def release(self):
|
||||
self.r.receive_nowait()
|
||||
|
||||
|
||||
@async_cm
|
||||
class ChannelLock2:
|
||||
def __init__(self):
|
||||
self.s, self.r = open_memory_channel(10)
|
||||
self.s.send_nowait(None)
|
||||
|
||||
def acquire_nowait(self):
|
||||
self.r.receive_nowait()
|
||||
|
||||
async def acquire(self):
|
||||
await self.r.receive()
|
||||
|
||||
def release(self):
|
||||
self.s.send_nowait(None)
|
||||
|
||||
|
||||
@async_cm
|
||||
class ChannelLock3:
|
||||
def __init__(self):
|
||||
self.s, self.r = open_memory_channel(0)
|
||||
# self.acquired is true when one task acquires the lock and
|
||||
# only becomes false when it's released and no tasks are
|
||||
# waiting to acquire.
|
||||
self.acquired = False
|
||||
|
||||
def acquire_nowait(self):
|
||||
assert not self.acquired
|
||||
self.acquired = True
|
||||
|
||||
async def acquire(self):
|
||||
if self.acquired:
|
||||
await self.s.send(None)
|
||||
else:
|
||||
self.acquired = True
|
||||
await _core.checkpoint()
|
||||
|
||||
def release(self):
|
||||
try:
|
||||
self.r.receive_nowait()
|
||||
except _core.WouldBlock:
|
||||
assert self.acquired
|
||||
self.acquired = False
|
||||
|
||||
|
||||
lock_factories = [
|
||||
lambda: CapacityLimiter(1),
|
||||
lambda: Semaphore(1),
|
||||
Lock,
|
||||
StrictFIFOLock,
|
||||
lambda: ChannelLock1(10),
|
||||
lambda: ChannelLock1(1),
|
||||
ChannelLock2,
|
||||
ChannelLock3,
|
||||
]
|
||||
lock_factory_names = [
|
||||
"CapacityLimiter(1)",
|
||||
"Semaphore(1)",
|
||||
"Lock",
|
||||
"StrictFIFOLock",
|
||||
"ChannelLock1(10)",
|
||||
"ChannelLock1(1)",
|
||||
"ChannelLock2",
|
||||
"ChannelLock3",
|
||||
]
|
||||
|
||||
generic_lock_test = pytest.mark.parametrize(
|
||||
"lock_factory", lock_factories, ids=lock_factory_names
|
||||
)
|
||||
|
||||
|
||||
# Spawn a bunch of workers that take a lock and then yield; make sure that
|
||||
# only one worker is ever in the critical section at a time.
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_exclusion(lock_factory):
|
||||
LOOPS = 10
|
||||
WORKERS = 5
|
||||
in_critical_section = False
|
||||
acquires = 0
|
||||
|
||||
async def worker(lock_like):
|
||||
nonlocal in_critical_section, acquires
|
||||
for _ in range(LOOPS):
|
||||
async with lock_like:
|
||||
acquires += 1
|
||||
assert not in_critical_section
|
||||
in_critical_section = True
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
assert in_critical_section
|
||||
in_critical_section = False
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lock_like = lock_factory()
|
||||
for _ in range(WORKERS):
|
||||
nursery.start_soon(worker, lock_like)
|
||||
assert not in_critical_section
|
||||
assert acquires == LOOPS * WORKERS
|
||||
|
||||
|
||||
# Several workers queue on the same lock; make sure they each get it, in
|
||||
# order.
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_fifo_fairness(lock_factory):
|
||||
initial_order = []
|
||||
record = []
|
||||
LOOPS = 5
|
||||
|
||||
async def loopy(name, lock_like):
|
||||
# Record the order each task was initially scheduled in
|
||||
initial_order.append(name)
|
||||
for _ in range(LOOPS):
|
||||
async with lock_like:
|
||||
record.append(name)
|
||||
|
||||
lock_like = lock_factory()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(loopy, 1, lock_like)
|
||||
nursery.start_soon(loopy, 2, lock_like)
|
||||
nursery.start_soon(loopy, 3, lock_like)
|
||||
# The first three could be in any order due to scheduling randomness,
|
||||
# but after that they should repeat in the same order
|
||||
for i in range(LOOPS):
|
||||
assert record[3 * i : 3 * (i + 1)] == initial_order
|
||||
|
||||
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory):
|
||||
lock_like = lock_factory()
|
||||
|
||||
record = []
|
||||
|
||||
async def lock_taker():
|
||||
record.append("started")
|
||||
async with lock_like:
|
||||
pass
|
||||
record.append("finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lock_like.acquire_nowait()
|
||||
nursery.start_soon(lock_taker)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["started"]
|
||||
lock_like.release()
|
||||
657
venv/Lib/site-packages/trio/tests/test_testing.py
Normal file
657
venv/Lib/site-packages/trio/tests/test_testing.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# XX this should get broken up, like testing.py did
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from .._core.tests.tutil import can_bind_ipv6
|
||||
from .. import sleep
|
||||
from .. import _core
|
||||
from .._highlevel_generic import aclose_forcefully
|
||||
from ..testing import *
|
||||
from ..testing._check_streams import _assert_raises
|
||||
from ..testing._memory_streams import _UnboundedByteQueue
|
||||
from .. import socket as tsocket
|
||||
from .._highlevel_socket import SocketListener
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked():
|
||||
record = []
|
||||
|
||||
async def busy_bee():
|
||||
for _ in range(10):
|
||||
await _core.checkpoint()
|
||||
record.append("busy bee exhausted")
|
||||
|
||||
async def waiting_for_bee_to_leave():
|
||||
await wait_all_tasks_blocked()
|
||||
record.append("quiet at last!")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(busy_bee)
|
||||
nursery.start_soon(waiting_for_bee_to_leave)
|
||||
nursery.start_soon(waiting_for_bee_to_leave)
|
||||
|
||||
# check cancellation
|
||||
record = []
|
||||
|
||||
async def cancelled_while_waiting():
|
||||
try:
|
||||
await wait_all_tasks_blocked()
|
||||
except _core.Cancelled:
|
||||
record.append("ok")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancelled_while_waiting)
|
||||
nursery.cancel_scope.cancel()
|
||||
assert record == ["ok"]
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked_with_timeouts(mock_clock):
|
||||
record = []
|
||||
|
||||
async def timeout_task():
|
||||
record.append("tt start")
|
||||
await sleep(5)
|
||||
record.append("tt finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(timeout_task)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["tt start"]
|
||||
mock_clock.jump(10)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["tt start", "tt finished"]
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked_with_cushion():
|
||||
record = []
|
||||
|
||||
async def blink():
|
||||
record.append("blink start")
|
||||
await sleep(0.01)
|
||||
await sleep(0.01)
|
||||
await sleep(0.01)
|
||||
record.append("blink end")
|
||||
|
||||
async def wait_no_cushion():
|
||||
await wait_all_tasks_blocked()
|
||||
record.append("wait_no_cushion end")
|
||||
|
||||
async def wait_small_cushion():
|
||||
await wait_all_tasks_blocked(0.02)
|
||||
record.append("wait_small_cushion end")
|
||||
|
||||
async def wait_big_cushion():
|
||||
await wait_all_tasks_blocked(0.03)
|
||||
record.append("wait_big_cushion end")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(blink)
|
||||
nursery.start_soon(wait_no_cushion)
|
||||
nursery.start_soon(wait_small_cushion)
|
||||
nursery.start_soon(wait_small_cushion)
|
||||
nursery.start_soon(wait_big_cushion)
|
||||
|
||||
assert record == [
|
||||
"blink start",
|
||||
"wait_no_cushion end",
|
||||
"blink end",
|
||||
"wait_small_cushion end",
|
||||
"wait_small_cushion end",
|
||||
"wait_big_cushion end",
|
||||
]
|
||||
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
async def test_assert_checkpoints(recwarn):
|
||||
with assert_checkpoints():
|
||||
await _core.checkpoint()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_checkpoints():
|
||||
1 + 1
|
||||
|
||||
# partial yield cases
|
||||
# if you have a schedule point but not a cancel point, or vice-versa, then
|
||||
# that's not a checkpoint.
|
||||
for partial_yield in [
|
||||
_core.checkpoint_if_cancelled,
|
||||
_core.cancel_shielded_checkpoint,
|
||||
]:
|
||||
print(partial_yield)
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_checkpoints():
|
||||
await partial_yield()
|
||||
|
||||
# But both together count as a checkpoint
|
||||
with assert_checkpoints():
|
||||
await _core.checkpoint_if_cancelled()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
|
||||
async def test_assert_no_checkpoints(recwarn):
|
||||
with assert_no_checkpoints():
|
||||
1 + 1
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await _core.checkpoint()
|
||||
|
||||
# partial yield cases
|
||||
# if you have a schedule point but not a cancel point, or vice-versa, then
|
||||
# that doesn't make *either* version of assert_{no_,}yields happy.
|
||||
for partial_yield in [
|
||||
_core.checkpoint_if_cancelled,
|
||||
_core.cancel_shielded_checkpoint,
|
||||
]:
|
||||
print(partial_yield)
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await partial_yield()
|
||||
|
||||
# And both together also count as a checkpoint
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await _core.checkpoint_if_cancelled()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
async def test_Sequencer():
|
||||
record = []
|
||||
|
||||
def t(val):
|
||||
print(val)
|
||||
record.append(val)
|
||||
|
||||
async def f1(seq):
|
||||
async with seq(1):
|
||||
t(("f1", 1))
|
||||
async with seq(3):
|
||||
t(("f1", 3))
|
||||
async with seq(4):
|
||||
t(("f1", 4))
|
||||
|
||||
async def f2(seq):
|
||||
async with seq(0):
|
||||
t(("f2", 0))
|
||||
async with seq(2):
|
||||
t(("f2", 2))
|
||||
|
||||
seq = Sequencer()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(f1, seq)
|
||||
nursery.start_soon(f2, seq)
|
||||
async with seq(5):
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)]
|
||||
|
||||
seq = Sequencer()
|
||||
# Catches us if we try to re-use a sequence point:
|
||||
async with seq(0):
|
||||
pass
|
||||
with pytest.raises(RuntimeError):
|
||||
async with seq(0):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
async def test_Sequencer_cancel():
|
||||
# Killing a blocked task makes everything blow up
|
||||
record = []
|
||||
seq = Sequencer()
|
||||
|
||||
async def child(i):
|
||||
with _core.CancelScope() as scope:
|
||||
if i == 1:
|
||||
scope.cancel()
|
||||
try:
|
||||
async with seq(i):
|
||||
pass # pragma: no cover
|
||||
except RuntimeError:
|
||||
record.append("seq({}) RuntimeError".format(i))
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child, 1)
|
||||
nursery.start_soon(child, 2)
|
||||
async with seq(0):
|
||||
pass # pragma: no cover
|
||||
|
||||
assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"]
|
||||
|
||||
# Late arrivals also get errors
|
||||
with pytest.raises(RuntimeError):
|
||||
async with seq(3):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
async def test__assert_raises():
|
||||
with pytest.raises(AssertionError):
|
||||
with _assert_raises(RuntimeError):
|
||||
1 + 1
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
with _assert_raises(RuntimeError):
|
||||
"foo" + 1
|
||||
|
||||
with _assert_raises(RuntimeError):
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
# This is a private implementation detail, but it's complex enough to be worth
|
||||
# testing directly
|
||||
async def test__UnboundeByteQueue():
|
||||
ubq = _UnboundedByteQueue()
|
||||
|
||||
ubq.put(b"123")
|
||||
ubq.put(b"456")
|
||||
assert ubq.get_nowait(1) == b"1"
|
||||
assert ubq.get_nowait(10) == b"23456"
|
||||
ubq.put(b"789")
|
||||
assert ubq.get_nowait() == b"789"
|
||||
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
ubq.get_nowait(10)
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
ubq.get_nowait()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
ubq.put("string")
|
||||
|
||||
ubq.put(b"abc")
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get(10) == b"abc"
|
||||
ubq.put(b"def")
|
||||
ubq.put(b"ghi")
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get(1) == b"d"
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get() == b"efghi"
|
||||
|
||||
async def putter(data):
|
||||
await wait_all_tasks_blocked()
|
||||
ubq.put(data)
|
||||
|
||||
async def getter(expect):
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get() == expect
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"xyz")
|
||||
nursery.start_soon(putter, b"xyz")
|
||||
|
||||
# Two gets at the same time -> BusyResourceError
|
||||
with pytest.raises(_core.BusyResourceError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"asdf")
|
||||
nursery.start_soon(getter, b"asdf")
|
||||
|
||||
# Closing
|
||||
|
||||
ubq.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
ubq.put(b"---")
|
||||
|
||||
assert ubq.get_nowait(10) == b""
|
||||
assert ubq.get_nowait() == b""
|
||||
assert await ubq.get(10) == b""
|
||||
assert await ubq.get() == b""
|
||||
|
||||
# close is idempotent
|
||||
ubq.close()
|
||||
|
||||
# close wakes up blocked getters
|
||||
ubq2 = _UnboundedByteQueue()
|
||||
|
||||
async def closer():
|
||||
await wait_all_tasks_blocked()
|
||||
ubq2.close()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"")
|
||||
nursery.start_soon(closer)
|
||||
|
||||
|
||||
async def test_MemorySendStream():
|
||||
mss = MemorySendStream()
|
||||
|
||||
async def do_send_all(data):
|
||||
with assert_checkpoints():
|
||||
await mss.send_all(data)
|
||||
|
||||
await do_send_all(b"123")
|
||||
assert mss.get_data_nowait(1) == b"1"
|
||||
assert mss.get_data_nowait() == b"23"
|
||||
|
||||
with assert_checkpoints():
|
||||
await mss.wait_send_all_might_not_block()
|
||||
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
mss.get_data_nowait()
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
mss.get_data_nowait(10)
|
||||
|
||||
await do_send_all(b"456")
|
||||
with assert_checkpoints():
|
||||
assert await mss.get_data() == b"456"
|
||||
|
||||
# Call send_all twice at once; one should get BusyResourceError and one
|
||||
# should succeed. But we can't let the error propagate, because it might
|
||||
# cause the other to be cancelled before it can finish doing its thing,
|
||||
# and we don't know which one will get the error.
|
||||
resource_busy_count = 0
|
||||
|
||||
async def do_send_all_count_resourcebusy():
|
||||
nonlocal resource_busy_count
|
||||
try:
|
||||
await do_send_all(b"xxx")
|
||||
except _core.BusyResourceError:
|
||||
resource_busy_count += 1
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all_count_resourcebusy)
|
||||
nursery.start_soon(do_send_all_count_resourcebusy)
|
||||
|
||||
assert resource_busy_count == 1
|
||||
|
||||
with assert_checkpoints():
|
||||
await mss.aclose()
|
||||
|
||||
assert await mss.get_data() == b"xxx"
|
||||
assert await mss.get_data() == b""
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"---")
|
||||
|
||||
# hooks
|
||||
|
||||
assert mss.send_all_hook is None
|
||||
assert mss.wait_send_all_might_not_block_hook is None
|
||||
assert mss.close_hook is None
|
||||
|
||||
record = []
|
||||
|
||||
async def send_all_hook():
|
||||
# hook runs after send_all does its work (can pull data out)
|
||||
assert mss2.get_data_nowait() == b"abc"
|
||||
record.append("send_all_hook")
|
||||
|
||||
async def wait_send_all_might_not_block_hook():
|
||||
record.append("wait_send_all_might_not_block_hook")
|
||||
|
||||
def close_hook():
|
||||
record.append("close_hook")
|
||||
|
||||
mss2 = MemorySendStream(
|
||||
send_all_hook, wait_send_all_might_not_block_hook, close_hook
|
||||
)
|
||||
|
||||
assert mss2.send_all_hook is send_all_hook
|
||||
assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
|
||||
assert mss2.close_hook is close_hook
|
||||
|
||||
await mss2.send_all(b"abc")
|
||||
await mss2.wait_send_all_might_not_block()
|
||||
await aclose_forcefully(mss2)
|
||||
mss2.close()
|
||||
|
||||
assert record == [
|
||||
"send_all_hook",
|
||||
"wait_send_all_might_not_block_hook",
|
||||
"close_hook",
|
||||
"close_hook",
|
||||
]
|
||||
|
||||
|
||||
async def test_MemoryReceiveStream():
|
||||
mrs = MemoryReceiveStream()
|
||||
|
||||
async def do_receive_some(max_bytes):
|
||||
with assert_checkpoints():
|
||||
return await mrs.receive_some(max_bytes)
|
||||
|
||||
mrs.put_data(b"abc")
|
||||
assert await do_receive_some(1) == b"a"
|
||||
assert await do_receive_some(10) == b"bc"
|
||||
mrs.put_data(b"abc")
|
||||
assert await do_receive_some(None) == b"abc"
|
||||
|
||||
with pytest.raises(_core.BusyResourceError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive_some, 10)
|
||||
nursery.start_soon(do_receive_some, 10)
|
||||
|
||||
assert mrs.receive_some_hook is None
|
||||
|
||||
mrs.put_data(b"def")
|
||||
mrs.put_eof()
|
||||
mrs.put_eof()
|
||||
|
||||
assert await do_receive_some(10) == b"def"
|
||||
assert await do_receive_some(10) == b""
|
||||
assert await do_receive_some(10) == b""
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
mrs.put_data(b"---")
|
||||
|
||||
async def receive_some_hook():
|
||||
mrs2.put_data(b"xxx")
|
||||
|
||||
record = []
|
||||
|
||||
def close_hook():
|
||||
record.append("closed")
|
||||
|
||||
mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
|
||||
assert mrs2.receive_some_hook is receive_some_hook
|
||||
assert mrs2.close_hook is close_hook
|
||||
|
||||
mrs2.put_data(b"yyy")
|
||||
assert await mrs2.receive_some(10) == b"yyyxxx"
|
||||
assert await mrs2.receive_some(10) == b"xxx"
|
||||
assert await mrs2.receive_some(10) == b"xxx"
|
||||
|
||||
mrs2.put_data(b"zzz")
|
||||
mrs2.receive_some_hook = None
|
||||
assert await mrs2.receive_some(10) == b"zzz"
|
||||
|
||||
mrs2.put_data(b"lost on close")
|
||||
with assert_checkpoints():
|
||||
await mrs2.aclose()
|
||||
assert record == ["closed"]
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await mrs2.receive_some(10)
|
||||
|
||||
|
||||
async def test_MemoryRecvStream_closing():
|
||||
mrs = MemoryReceiveStream()
|
||||
# close with no pending data
|
||||
mrs.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
assert await mrs.receive_some(10) == b""
|
||||
# repeated closes ok
|
||||
mrs.close()
|
||||
# put_data now fails
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
mrs.put_data(b"123")
|
||||
|
||||
mrs2 = MemoryReceiveStream()
|
||||
# close with pending data
|
||||
mrs2.put_data(b"xyz")
|
||||
mrs2.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await mrs2.receive_some(10)
|
||||
|
||||
|
||||
async def test_memory_stream_pump():
|
||||
mss = MemorySendStream()
|
||||
mrs = MemoryReceiveStream()
|
||||
|
||||
# no-op if no data present
|
||||
memory_stream_pump(mss, mrs)
|
||||
|
||||
await mss.send_all(b"123")
|
||||
memory_stream_pump(mss, mrs)
|
||||
assert await mrs.receive_some(10) == b"123"
|
||||
|
||||
await mss.send_all(b"456")
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert await mrs.receive_some(10) == b"4"
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert not memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert await mrs.receive_some(10) == b"56"
|
||||
|
||||
mss.close()
|
||||
memory_stream_pump(mss, mrs)
|
||||
assert await mrs.receive_some(10) == b""
|
||||
|
||||
|
||||
async def test_memory_stream_one_way_pair():
|
||||
s, r = memory_stream_one_way_pair()
|
||||
assert s.send_all_hook is not None
|
||||
assert s.wait_send_all_might_not_block_hook is None
|
||||
assert s.close_hook is not None
|
||||
assert r.receive_some_hook is None
|
||||
await s.send_all(b"123")
|
||||
assert await r.receive_some(10) == b"123"
|
||||
|
||||
async def receiver(expected):
|
||||
assert await r.receive_some(10) == expected
|
||||
|
||||
# This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"abc")
|
||||
await wait_all_tasks_blocked()
|
||||
await s.send_all(b"abc")
|
||||
|
||||
# And this fails if we don't pump from close_hook
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"")
|
||||
await wait_all_tasks_blocked()
|
||||
await s.aclose()
|
||||
|
||||
s, r = memory_stream_one_way_pair()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"")
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
|
||||
s, r = memory_stream_one_way_pair()
|
||||
|
||||
old = s.send_all_hook
|
||||
s.send_all_hook = None
|
||||
await s.send_all(b"456")
|
||||
|
||||
async def cancel_after_idle(nursery):
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async def check_for_cancel():
|
||||
with pytest.raises(_core.Cancelled):
|
||||
# This should block forever... or until cancelled. Even though we
|
||||
# sent some data on the send stream.
|
||||
await r.receive_some(10)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancel_after_idle, nursery)
|
||||
nursery.start_soon(check_for_cancel)
|
||||
|
||||
s.send_all_hook = old
|
||||
await s.send_all(b"789")
|
||||
assert await r.receive_some(10) == b"456789"
|
||||
|
||||
|
||||
async def test_memory_stream_pair():
|
||||
a, b = memory_stream_pair()
|
||||
await a.send_all(b"123")
|
||||
await b.send_all(b"abc")
|
||||
assert await b.receive_some(10) == b"123"
|
||||
assert await a.receive_some(10) == b"abc"
|
||||
|
||||
await a.send_eof()
|
||||
assert await b.receive_some(10) == b""
|
||||
|
||||
async def sender():
|
||||
await wait_all_tasks_blocked()
|
||||
await b.send_all(b"xyz")
|
||||
|
||||
async def receiver():
|
||||
assert await a.receive_some(10) == b"xyz"
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver)
|
||||
nursery.start_soon(sender)
|
||||
|
||||
|
||||
async def test_memory_streams_with_generic_tests():
|
||||
async def one_way_stream_maker():
|
||||
return memory_stream_one_way_pair()
|
||||
|
||||
await check_one_way_stream(one_way_stream_maker, None)
|
||||
|
||||
async def half_closeable_stream_maker():
|
||||
return memory_stream_pair()
|
||||
|
||||
await check_half_closeable_stream(half_closeable_stream_maker, None)
|
||||
|
||||
|
||||
async def test_lockstep_streams_with_generic_tests():
|
||||
async def one_way_stream_maker():
|
||||
return lockstep_stream_one_way_pair()
|
||||
|
||||
await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
|
||||
|
||||
async def two_way_stream_maker():
|
||||
return lockstep_stream_pair()
|
||||
|
||||
await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
|
||||
|
||||
|
||||
async def test_open_stream_to_socket_listener():
|
||||
async def check(listener):
|
||||
async with listener:
|
||||
client_stream = await open_stream_to_socket_listener(listener)
|
||||
async with client_stream:
|
||||
server_stream = await listener.accept()
|
||||
async with server_stream:
|
||||
await client_stream.send_all(b"x")
|
||||
await server_stream.receive_some(1) == b"x"
|
||||
|
||||
# Listener bound to localhost
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("127.0.0.1", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
# Listener bound to IPv4 wildcard (needs special handling)
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("0.0.0.0", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
if can_bind_ipv6:
|
||||
# Listener bound to IPv6 wildcard (needs special handling)
|
||||
sock = tsocket.socket(family=tsocket.AF_INET6)
|
||||
await sock.bind(("::", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
if hasattr(tsocket, "AF_UNIX"):
|
||||
# Listener bound to Unix-domain socket
|
||||
sock = tsocket.socket(family=tsocket.AF_UNIX)
|
||||
# can't use pytest's tmpdir; if we try then macOS says "OSError:
|
||||
# AF_UNIX path too long"
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = "{}/sock".format(tmpdir)
|
||||
await sock.bind(path)
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
740
venv/Lib/site-packages/trio/tests/test_threads.py
Normal file
740
venv/Lib/site-packages/trio/tests/test_threads.py
Normal file
@@ -0,0 +1,740 @@
|
||||
import contextvars
|
||||
import threading
|
||||
import queue as stdlib_queue
|
||||
import time
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
from sniffio import current_async_library_cvar
|
||||
from trio._core import TrioToken, current_trio_token
|
||||
|
||||
from .. import _core
|
||||
from .. import Event, CapacityLimiter, sleep
|
||||
from ..testing import wait_all_tasks_blocked
|
||||
from .._core.tests.tutil import buggy_pypy_asyncgens
|
||||
from .._threads import (
|
||||
to_thread_run_sync,
|
||||
current_default_thread_limiter,
|
||||
from_thread_run,
|
||||
from_thread_run_sync,
|
||||
)
|
||||
|
||||
from .._core.tests.test_ki import ki_self
|
||||
|
||||
|
||||
async def test_do_in_trio_thread():
|
||||
trio_thread = threading.current_thread()
|
||||
|
||||
async def check_case(do_in_trio_thread, fn, expected, trio_token=None):
|
||||
record = []
|
||||
|
||||
def threadfn():
|
||||
try:
|
||||
record.append(("start", threading.current_thread()))
|
||||
x = do_in_trio_thread(fn, record, trio_token=trio_token)
|
||||
record.append(("got", x))
|
||||
except BaseException as exc:
|
||||
print(exc)
|
||||
record.append(("error", type(exc)))
|
||||
|
||||
child_thread = threading.Thread(target=threadfn, daemon=True)
|
||||
child_thread.start()
|
||||
while child_thread.is_alive():
|
||||
print("yawn")
|
||||
await sleep(0.01)
|
||||
assert record == [("start", child_thread), ("f", trio_thread), expected]
|
||||
|
||||
token = _core.current_trio_token()
|
||||
|
||||
def f(record):
|
||||
assert not _core.currently_ki_protected()
|
||||
record.append(("f", threading.current_thread()))
|
||||
return 2
|
||||
|
||||
await check_case(from_thread_run_sync, f, ("got", 2), trio_token=token)
|
||||
|
||||
def f(record):
|
||||
assert not _core.currently_ki_protected()
|
||||
record.append(("f", threading.current_thread()))
|
||||
raise ValueError
|
||||
|
||||
await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token)
|
||||
|
||||
async def f(record):
|
||||
assert not _core.currently_ki_protected()
|
||||
await _core.checkpoint()
|
||||
record.append(("f", threading.current_thread()))
|
||||
return 3
|
||||
|
||||
await check_case(from_thread_run, f, ("got", 3), trio_token=token)
|
||||
|
||||
async def f(record):
|
||||
assert not _core.currently_ki_protected()
|
||||
await _core.checkpoint()
|
||||
record.append(("f", threading.current_thread()))
|
||||
raise KeyError
|
||||
|
||||
await check_case(from_thread_run, f, ("error", KeyError), trio_token=token)
|
||||
|
||||
|
||||
async def test_do_in_trio_thread_from_trio_thread():
|
||||
with pytest.raises(RuntimeError):
|
||||
from_thread_run_sync(lambda: None) # pragma: no branch
|
||||
|
||||
async def foo(): # pragma: no cover
|
||||
pass
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
from_thread_run(foo)
|
||||
|
||||
|
||||
def test_run_in_trio_thread_ki():
|
||||
# if we get a control-C during a run_in_trio_thread, then it propagates
|
||||
# back to the caller (slick!)
|
||||
record = set()
|
||||
|
||||
async def check_run_in_trio_thread():
|
||||
token = _core.current_trio_token()
|
||||
|
||||
def trio_thread_fn():
|
||||
print("in Trio thread")
|
||||
assert not _core.currently_ki_protected()
|
||||
print("ki_self")
|
||||
try:
|
||||
ki_self()
|
||||
finally:
|
||||
import sys
|
||||
|
||||
print("finally", sys.exc_info())
|
||||
|
||||
async def trio_thread_afn():
|
||||
trio_thread_fn()
|
||||
|
||||
def external_thread_fn():
|
||||
try:
|
||||
print("running")
|
||||
from_thread_run_sync(trio_thread_fn, trio_token=token)
|
||||
except KeyboardInterrupt:
|
||||
print("ok1")
|
||||
record.add("ok1")
|
||||
try:
|
||||
from_thread_run(trio_thread_afn, trio_token=token)
|
||||
except KeyboardInterrupt:
|
||||
print("ok2")
|
||||
record.add("ok2")
|
||||
|
||||
thread = threading.Thread(target=external_thread_fn)
|
||||
thread.start()
|
||||
print("waiting")
|
||||
while thread.is_alive():
|
||||
await sleep(0.01)
|
||||
print("waited, joining")
|
||||
thread.join()
|
||||
print("done")
|
||||
|
||||
_core.run(check_run_in_trio_thread)
|
||||
assert record == {"ok1", "ok2"}
|
||||
|
||||
|
||||
def test_await_in_trio_thread_while_main_exits():
|
||||
record = []
|
||||
ev = Event()
|
||||
|
||||
async def trio_fn():
|
||||
record.append("sleeping")
|
||||
ev.set()
|
||||
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
|
||||
|
||||
def thread_fn(token):
|
||||
try:
|
||||
from_thread_run(trio_fn, trio_token=token)
|
||||
except _core.Cancelled:
|
||||
record.append("cancelled")
|
||||
|
||||
async def main():
|
||||
token = _core.current_trio_token()
|
||||
thread = threading.Thread(target=thread_fn, args=(token,))
|
||||
thread.start()
|
||||
await ev.wait()
|
||||
assert record == ["sleeping"]
|
||||
return thread
|
||||
|
||||
thread = _core.run(main)
|
||||
thread.join()
|
||||
assert record == ["sleeping", "cancelled"]
|
||||
|
||||
|
||||
async def test_run_in_worker_thread():
|
||||
trio_thread = threading.current_thread()
|
||||
|
||||
def f(x):
|
||||
return (x, threading.current_thread())
|
||||
|
||||
x, child_thread = await to_thread_run_sync(f, 1)
|
||||
assert x == 1
|
||||
assert child_thread != trio_thread
|
||||
|
||||
def g():
|
||||
raise ValueError(threading.current_thread())
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await to_thread_run_sync(g)
|
||||
print(excinfo.value.args)
|
||||
assert excinfo.value.args[0] != trio_thread
|
||||
|
||||
|
||||
async def test_run_in_worker_thread_cancellation():
|
||||
register = [None]
|
||||
|
||||
def f(q):
|
||||
# Make the thread block for a controlled amount of time
|
||||
register[0] = "blocking"
|
||||
q.get()
|
||||
register[0] = "finished"
|
||||
|
||||
async def child(q, cancellable):
|
||||
record.append("start")
|
||||
try:
|
||||
return await to_thread_run_sync(f, q, cancellable=cancellable)
|
||||
finally:
|
||||
record.append("exit")
|
||||
|
||||
record = []
|
||||
q = stdlib_queue.Queue()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child, q, True)
|
||||
# Give it a chance to get started. (This is important because
|
||||
# to_thread_run_sync does a checkpoint_if_cancelled before
|
||||
# blocking on the thread, and we don't want to trigger this.)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["start"]
|
||||
# Then cancel it.
|
||||
nursery.cancel_scope.cancel()
|
||||
# The task exited, but the thread didn't:
|
||||
assert register[0] != "finished"
|
||||
# Put the thread out of its misery:
|
||||
q.put(None)
|
||||
while register[0] != "finished":
|
||||
time.sleep(0.01)
|
||||
|
||||
# This one can't be cancelled
|
||||
record = []
|
||||
register[0] = None
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child, q, False)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
with _core.CancelScope(shield=True):
|
||||
for _ in range(10):
|
||||
await _core.checkpoint()
|
||||
# It's still running
|
||||
assert record == ["start"]
|
||||
q.put(None)
|
||||
# Now it exits
|
||||
|
||||
# But if we cancel *before* it enters, the entry is itself a cancellation
|
||||
# point
|
||||
with _core.CancelScope() as scope:
|
||||
scope.cancel()
|
||||
await child(q, False)
|
||||
assert scope.cancelled_caught
|
||||
|
||||
|
||||
# Make sure that if trio.run exits, and then the thread finishes, then that's
|
||||
# handled gracefully. (Requires that the thread result machinery be prepared
|
||||
# for call_soon to raise RunFinishedError.)
|
||||
def test_run_in_worker_thread_abandoned(capfd, monkeypatch):
|
||||
monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
|
||||
|
||||
q1 = stdlib_queue.Queue()
|
||||
q2 = stdlib_queue.Queue()
|
||||
|
||||
def thread_fn():
|
||||
q1.get()
|
||||
q2.put(threading.current_thread())
|
||||
|
||||
async def main():
|
||||
async def child():
|
||||
await to_thread_run_sync(thread_fn, cancellable=True)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
_core.run(main)
|
||||
|
||||
q1.put(None)
|
||||
# This makes sure:
|
||||
# - the thread actually ran
|
||||
# - that thread has finished before we check for its output
|
||||
thread = q2.get()
|
||||
while thread.is_alive():
|
||||
time.sleep(0.01) # pragma: no cover
|
||||
|
||||
# Make sure we don't have a "Exception in thread ..." dump to the console:
|
||||
out, err = capfd.readouterr()
|
||||
assert "Exception in thread" not in out
|
||||
assert "Exception in thread" not in err
|
||||
|
||||
|
||||
@pytest.mark.parametrize("MAX", [3, 5, 10])
|
||||
@pytest.mark.parametrize("cancel", [False, True])
|
||||
@pytest.mark.parametrize("use_default_limiter", [False, True])
|
||||
async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter):
|
||||
# This test is a bit tricky. The goal is to make sure that if we set
|
||||
# limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever
|
||||
# running at a time, even if there are more concurrent calls to
|
||||
# to_thread_run_sync, and even if some of those are cancelled. And
|
||||
# also to make sure that the default limiter actually limits.
|
||||
COUNT = 2 * MAX
|
||||
gate = threading.Event()
|
||||
lock = threading.Lock()
|
||||
if use_default_limiter:
|
||||
c = current_default_thread_limiter()
|
||||
orig_total_tokens = c.total_tokens
|
||||
c.total_tokens = MAX
|
||||
limiter_arg = None
|
||||
else:
|
||||
c = CapacityLimiter(MAX)
|
||||
orig_total_tokens = MAX
|
||||
limiter_arg = c
|
||||
try:
|
||||
# We used to use regular variables and 'nonlocal' here, but it turns
|
||||
# out that it's not safe to assign to closed-over variables that are
|
||||
# visible in multiple threads, at least as of CPython 3.10 and PyPy
|
||||
# 7.3:
|
||||
#
|
||||
# https://bugs.python.org/issue30744
|
||||
# https://bitbucket.org/pypy/pypy/issues/2591/
|
||||
#
|
||||
# Mutating them in-place is OK though (as long as you use proper
|
||||
# locking etc.).
|
||||
class state:
|
||||
pass
|
||||
|
||||
state.ran = 0
|
||||
state.high_water = 0
|
||||
state.running = 0
|
||||
state.parked = 0
|
||||
|
||||
token = _core.current_trio_token()
|
||||
|
||||
def thread_fn(cancel_scope):
|
||||
print("thread_fn start")
|
||||
from_thread_run_sync(cancel_scope.cancel, trio_token=token)
|
||||
with lock:
|
||||
state.ran += 1
|
||||
state.running += 1
|
||||
state.high_water = max(state.high_water, state.running)
|
||||
# The Trio thread below watches this value and uses it as a
|
||||
# signal that all the stats calculations have finished.
|
||||
state.parked += 1
|
||||
gate.wait()
|
||||
with lock:
|
||||
state.parked -= 1
|
||||
state.running -= 1
|
||||
print("thread_fn exiting")
|
||||
|
||||
async def run_thread(event):
|
||||
with _core.CancelScope() as cancel_scope:
|
||||
await to_thread_run_sync(
|
||||
thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel
|
||||
)
|
||||
print("run_thread finished, cancelled:", cancel_scope.cancelled_caught)
|
||||
event.set()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
print("spawning")
|
||||
events = []
|
||||
for i in range(COUNT):
|
||||
events.append(Event())
|
||||
nursery.start_soon(run_thread, events[-1])
|
||||
await wait_all_tasks_blocked()
|
||||
# In the cancel case, we in particular want to make sure that the
|
||||
# cancelled tasks don't release the semaphore. So let's wait until
|
||||
# at least one of them has exited, and that everything has had a
|
||||
# chance to settle down from this, before we check that everyone
|
||||
# who's supposed to be waiting is waiting:
|
||||
if cancel:
|
||||
print("waiting for first cancellation to clear")
|
||||
await events[0].wait()
|
||||
await wait_all_tasks_blocked()
|
||||
# Then wait until the first MAX threads are parked in gate.wait(),
|
||||
# and the next MAX threads are parked on the semaphore, to make
|
||||
# sure no-one is sneaking past, and to make sure the high_water
|
||||
# check below won't fail due to scheduling issues. (It could still
|
||||
# fail if too many threads are let through here.)
|
||||
while state.parked != MAX or c.statistics().tasks_waiting != MAX:
|
||||
await sleep(0.01) # pragma: no cover
|
||||
# Then release the threads
|
||||
gate.set()
|
||||
|
||||
assert state.high_water == MAX
|
||||
|
||||
if cancel:
|
||||
# Some threads might still be running; need to wait to them to
|
||||
# finish before checking that all threads ran. We can do this
|
||||
# using the CapacityLimiter.
|
||||
while c.borrowed_tokens > 0:
|
||||
await sleep(0.01) # pragma: no cover
|
||||
|
||||
assert state.ran == COUNT
|
||||
assert state.running == 0
|
||||
finally:
|
||||
c.total_tokens = orig_total_tokens
|
||||
|
||||
|
||||
async def test_run_in_worker_thread_custom_limiter():
|
||||
# Basically just checking that we only call acquire_on_behalf_of and
|
||||
# release_on_behalf_of, since that's part of our documented API.
|
||||
record = []
|
||||
|
||||
class CustomLimiter:
|
||||
async def acquire_on_behalf_of(self, borrower):
|
||||
record.append("acquire")
|
||||
self._borrower = borrower
|
||||
|
||||
def release_on_behalf_of(self, borrower):
|
||||
record.append("release")
|
||||
assert borrower == self._borrower
|
||||
|
||||
await to_thread_run_sync(lambda: None, limiter=CustomLimiter())
|
||||
assert record == ["acquire", "release"]
|
||||
|
||||
|
||||
async def test_run_in_worker_thread_limiter_error():
|
||||
record = []
|
||||
|
||||
class BadCapacityLimiter:
|
||||
async def acquire_on_behalf_of(self, borrower):
|
||||
record.append("acquire")
|
||||
|
||||
def release_on_behalf_of(self, borrower):
|
||||
record.append("release")
|
||||
raise ValueError
|
||||
|
||||
bs = BadCapacityLimiter()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await to_thread_run_sync(lambda: None, limiter=bs)
|
||||
assert excinfo.value.__context__ is None
|
||||
assert record == ["acquire", "release"]
|
||||
record = []
|
||||
|
||||
# If the original function raised an error, then the semaphore error
|
||||
# chains with it
|
||||
d = {}
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await to_thread_run_sync(lambda: d["x"], limiter=bs)
|
||||
assert isinstance(excinfo.value.__context__, KeyError)
|
||||
assert record == ["acquire", "release"]
|
||||
|
||||
|
||||
async def test_run_in_worker_thread_fail_to_spawn(monkeypatch):
|
||||
# Test the unlikely but possible case where trying to spawn a thread fails
|
||||
def bad_start(self, *args):
|
||||
raise RuntimeError("the engines canna take it captain")
|
||||
|
||||
monkeypatch.setattr(_core._thread_cache.ThreadCache, "start_thread_soon", bad_start)
|
||||
|
||||
limiter = current_default_thread_limiter()
|
||||
assert limiter.borrowed_tokens == 0
|
||||
|
||||
# We get an appropriate error, and the limiter is cleanly released
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
await to_thread_run_sync(lambda: None) # pragma: no cover
|
||||
assert "engines" in str(excinfo.value)
|
||||
|
||||
assert limiter.borrowed_tokens == 0
|
||||
|
||||
|
||||
async def test_trio_to_thread_run_sync_token():
|
||||
# Test that to_thread_run_sync automatically injects the current trio token
|
||||
# into a spawned thread
|
||||
def thread_fn():
|
||||
callee_token = from_thread_run_sync(_core.current_trio_token)
|
||||
return callee_token
|
||||
|
||||
caller_token = _core.current_trio_token()
|
||||
callee_token = await to_thread_run_sync(thread_fn)
|
||||
assert callee_token == caller_token
|
||||
|
||||
|
||||
async def test_trio_to_thread_run_sync_expected_error():
|
||||
# Test correct error when passed async function
|
||||
async def async_fn(): # pragma: no cover
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError, match="expected a sync function"):
|
||||
await to_thread_run_sync(async_fn)
|
||||
|
||||
|
||||
trio_test_contextvar = contextvars.ContextVar("trio_test_contextvar")
|
||||
|
||||
|
||||
async def test_trio_to_thread_run_sync_contextvars():
|
||||
trio_thread = threading.current_thread()
|
||||
trio_test_contextvar.set("main")
|
||||
|
||||
def f():
|
||||
value = trio_test_contextvar.get()
|
||||
sniffio_cvar_value = current_async_library_cvar.get()
|
||||
return (value, sniffio_cvar_value, threading.current_thread())
|
||||
|
||||
value, sniffio_cvar_value, child_thread = await to_thread_run_sync(f)
|
||||
assert value == "main"
|
||||
assert sniffio_cvar_value == None
|
||||
assert child_thread != trio_thread
|
||||
|
||||
def g():
|
||||
parent_value = trio_test_contextvar.get()
|
||||
trio_test_contextvar.set("worker")
|
||||
inner_value = trio_test_contextvar.get()
|
||||
sniffio_cvar_value = current_async_library_cvar.get()
|
||||
return (
|
||||
parent_value,
|
||||
inner_value,
|
||||
sniffio_cvar_value,
|
||||
threading.current_thread(),
|
||||
)
|
||||
|
||||
(
|
||||
parent_value,
|
||||
inner_value,
|
||||
sniffio_cvar_value,
|
||||
child_thread,
|
||||
) = await to_thread_run_sync(g)
|
||||
current_value = trio_test_contextvar.get()
|
||||
sniffio_outer_value = current_async_library_cvar.get()
|
||||
assert parent_value == "main"
|
||||
assert inner_value == "worker"
|
||||
assert (
|
||||
current_value == "main"
|
||||
), "The contextvar value set on the worker would not propagate back to the main thread"
|
||||
assert sniffio_cvar_value is None
|
||||
assert sniffio_outer_value == "trio"
|
||||
|
||||
|
||||
async def test_trio_from_thread_run_sync():
|
||||
# Test that to_thread_run_sync correctly "hands off" the trio token to
|
||||
# trio.from_thread.run_sync()
|
||||
def thread_fn():
|
||||
trio_time = from_thread_run_sync(_core.current_time)
|
||||
return trio_time
|
||||
|
||||
trio_time = await to_thread_run_sync(thread_fn)
|
||||
assert isinstance(trio_time, float)
|
||||
|
||||
# Test correct error when passed async function
|
||||
async def async_fn(): # pragma: no cover
|
||||
pass
|
||||
|
||||
def thread_fn():
|
||||
from_thread_run_sync(async_fn)
|
||||
|
||||
with pytest.raises(TypeError, match="expected a sync function"):
|
||||
await to_thread_run_sync(thread_fn)
|
||||
|
||||
|
||||
async def test_trio_from_thread_run():
|
||||
# Test that to_thread_run_sync correctly "hands off" the trio token to
|
||||
# trio.from_thread.run()
|
||||
record = []
|
||||
|
||||
async def back_in_trio_fn():
|
||||
_core.current_time() # implicitly checks that we're in trio
|
||||
record.append("back in trio")
|
||||
|
||||
def thread_fn():
|
||||
record.append("in thread")
|
||||
from_thread_run(back_in_trio_fn)
|
||||
|
||||
await to_thread_run_sync(thread_fn)
|
||||
assert record == ["in thread", "back in trio"]
|
||||
|
||||
# Test correct error when passed sync function
|
||||
def sync_fn(): # pragma: no cover
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError, match="appears to be synchronous"):
|
||||
await to_thread_run_sync(from_thread_run, sync_fn)
|
||||
|
||||
|
||||
async def test_trio_from_thread_token():
|
||||
# Test that to_thread_run_sync and spawned trio.from_thread.run_sync()
|
||||
# share the same Trio token
|
||||
def thread_fn():
|
||||
callee_token = from_thread_run_sync(_core.current_trio_token)
|
||||
return callee_token
|
||||
|
||||
caller_token = _core.current_trio_token()
|
||||
callee_token = await to_thread_run_sync(thread_fn)
|
||||
assert callee_token == caller_token
|
||||
|
||||
|
||||
async def test_trio_from_thread_token_kwarg():
|
||||
# Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can
|
||||
# use an explicitly defined token
|
||||
def thread_fn(token):
|
||||
callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token)
|
||||
return callee_token
|
||||
|
||||
caller_token = _core.current_trio_token()
|
||||
callee_token = await to_thread_run_sync(thread_fn, caller_token)
|
||||
assert callee_token == caller_token
|
||||
|
||||
|
||||
async def test_from_thread_no_token():
|
||||
# Test that a "raw call" to trio.from_thread.run() fails because no token
|
||||
# has been provided
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
from_thread_run_sync(_core.current_time)
|
||||
|
||||
|
||||
async def test_trio_from_thread_run_sync_contextvars():
|
||||
trio_test_contextvar.set("main")
|
||||
|
||||
def thread_fn():
|
||||
thread_parent_value = trio_test_contextvar.get()
|
||||
trio_test_contextvar.set("worker")
|
||||
thread_current_value = trio_test_contextvar.get()
|
||||
sniffio_cvar_thread_pre_value = current_async_library_cvar.get()
|
||||
|
||||
def back_in_main():
|
||||
back_parent_value = trio_test_contextvar.get()
|
||||
trio_test_contextvar.set("back_in_main")
|
||||
back_current_value = trio_test_contextvar.get()
|
||||
sniffio_cvar_back_value = current_async_library_cvar.get()
|
||||
return back_parent_value, back_current_value, sniffio_cvar_back_value
|
||||
|
||||
(
|
||||
back_parent_value,
|
||||
back_current_value,
|
||||
sniffio_cvar_back_value,
|
||||
) = from_thread_run_sync(back_in_main)
|
||||
thread_after_value = trio_test_contextvar.get()
|
||||
sniffio_cvar_thread_after_value = current_async_library_cvar.get()
|
||||
return (
|
||||
thread_parent_value,
|
||||
thread_current_value,
|
||||
thread_after_value,
|
||||
sniffio_cvar_thread_pre_value,
|
||||
sniffio_cvar_thread_after_value,
|
||||
back_parent_value,
|
||||
back_current_value,
|
||||
sniffio_cvar_back_value,
|
||||
)
|
||||
|
||||
(
|
||||
thread_parent_value,
|
||||
thread_current_value,
|
||||
thread_after_value,
|
||||
sniffio_cvar_thread_pre_value,
|
||||
sniffio_cvar_thread_after_value,
|
||||
back_parent_value,
|
||||
back_current_value,
|
||||
sniffio_cvar_back_value,
|
||||
) = await to_thread_run_sync(thread_fn)
|
||||
current_value = trio_test_contextvar.get()
|
||||
sniffio_cvar_out_value = current_async_library_cvar.get()
|
||||
assert current_value == thread_parent_value == "main"
|
||||
assert thread_current_value == back_parent_value == thread_after_value == "worker"
|
||||
assert back_current_value == "back_in_main"
|
||||
assert sniffio_cvar_out_value == sniffio_cvar_back_value == "trio"
|
||||
assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None
|
||||
|
||||
|
||||
async def test_trio_from_thread_run_contextvars():
|
||||
trio_test_contextvar.set("main")
|
||||
|
||||
def thread_fn():
|
||||
thread_parent_value = trio_test_contextvar.get()
|
||||
trio_test_contextvar.set("worker")
|
||||
thread_current_value = trio_test_contextvar.get()
|
||||
sniffio_cvar_thread_pre_value = current_async_library_cvar.get()
|
||||
|
||||
async def async_back_in_main():
|
||||
back_parent_value = trio_test_contextvar.get()
|
||||
trio_test_contextvar.set("back_in_main")
|
||||
back_current_value = trio_test_contextvar.get()
|
||||
sniffio_cvar_back_value = current_async_library_cvar.get()
|
||||
return back_parent_value, back_current_value, sniffio_cvar_back_value
|
||||
|
||||
(
|
||||
back_parent_value,
|
||||
back_current_value,
|
||||
sniffio_cvar_back_value,
|
||||
) = from_thread_run(async_back_in_main)
|
||||
thread_after_value = trio_test_contextvar.get()
|
||||
sniffio_cvar_thread_after_value = current_async_library_cvar.get()
|
||||
return (
|
||||
thread_parent_value,
|
||||
thread_current_value,
|
||||
thread_after_value,
|
||||
sniffio_cvar_thread_pre_value,
|
||||
sniffio_cvar_thread_after_value,
|
||||
back_parent_value,
|
||||
back_current_value,
|
||||
sniffio_cvar_back_value,
|
||||
)
|
||||
|
||||
(
|
||||
thread_parent_value,
|
||||
thread_current_value,
|
||||
thread_after_value,
|
||||
sniffio_cvar_thread_pre_value,
|
||||
sniffio_cvar_thread_after_value,
|
||||
back_parent_value,
|
||||
back_current_value,
|
||||
sniffio_cvar_back_value,
|
||||
) = await to_thread_run_sync(thread_fn)
|
||||
current_value = trio_test_contextvar.get()
|
||||
assert current_value == thread_parent_value == "main"
|
||||
assert thread_current_value == back_parent_value == thread_after_value == "worker"
|
||||
assert back_current_value == "back_in_main"
|
||||
assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None
|
||||
assert sniffio_cvar_back_value == "trio"
|
||||
|
||||
|
||||
def test_run_fn_as_system_task_catched_badly_typed_token():
|
||||
with pytest.raises(RuntimeError):
|
||||
from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype")
|
||||
|
||||
|
||||
async def test_from_thread_inside_trio_thread():
|
||||
def not_called(): # pragma: no cover
|
||||
assert False
|
||||
|
||||
trio_token = _core.current_trio_token()
|
||||
with pytest.raises(RuntimeError):
|
||||
from_thread_run_sync(not_called, trio_token=trio_token)
|
||||
|
||||
|
||||
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
|
||||
def test_from_thread_run_during_shutdown():
|
||||
save = []
|
||||
record = []
|
||||
|
||||
async def agen():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
with pytest.raises(_core.RunFinishedError), _core.CancelScope(shield=True):
|
||||
await to_thread_run_sync(from_thread_run, sleep, 0)
|
||||
record.append("ok")
|
||||
|
||||
async def main():
|
||||
save.append(agen())
|
||||
await save[-1].asend(None)
|
||||
|
||||
_core.run(main)
|
||||
assert record == ["ok"]
|
||||
|
||||
|
||||
async def test_trio_token_weak_referenceable():
|
||||
token = current_trio_token()
|
||||
assert isinstance(token, TrioToken)
|
||||
weak_reference = weakref.ref(token)
|
||||
assert token is weak_reference()
|
||||
104
venv/Lib/site-packages/trio/tests/test_timeouts.py
Normal file
104
venv/Lib/site-packages/trio/tests/test_timeouts.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import outcome
|
||||
import pytest
|
||||
import time
|
||||
|
||||
from .._core.tests.tutil import slow
|
||||
from .. import _core
|
||||
from ..testing import assert_checkpoints
|
||||
from .._timeouts import *
|
||||
|
||||
|
||||
async def check_takes_about(f, expected_dur):
|
||||
start = time.perf_counter()
|
||||
result = await outcome.acapture(f)
|
||||
dur = time.perf_counter() - start
|
||||
print(dur / expected_dur)
|
||||
# 1.5 is an arbitrary fudge factor because there's always some delay
|
||||
# between when we become eligible to wake up and when we actually do. We
|
||||
# used to sleep for 0.05, and regularly observed overruns of 1.6x on
|
||||
# Appveyor, and then started seeing overruns of 2.3x on Travis's macOS, so
|
||||
# now we bumped up the sleep to 1 second, marked the tests as slow, and
|
||||
# hopefully now the proportional error will be less huge.
|
||||
#
|
||||
# We also also for durations that are a hair shorter than expected. For
|
||||
# example, here's a run on Windows where a 1.0 second sleep was measured
|
||||
# to take 0.9999999999999858 seconds:
|
||||
# https://ci.appveyor.com/project/njsmith/trio/build/1.0.768/job/3lbdyxl63q3h9s21
|
||||
# I believe that what happened here is that Windows's low clock resolution
|
||||
# meant that our calls to time.monotonic() returned exactly the same
|
||||
# values as the calls inside the actual run loop, but the two subtractions
|
||||
# returned slightly different values because the run loop's clock adds a
|
||||
# random floating point offset to both times, which should cancel out, but
|
||||
# lol floating point we got slightly different rounding errors. (That
|
||||
# value above is exactly 128 ULPs below 1.0, which would make sense if it
|
||||
# started as a 1 ULP error at a different dynamic range.)
|
||||
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
|
||||
return result.unwrap()
|
||||
|
||||
|
||||
# How long to (attempt to) sleep for when testing. Smaller numbers make the
|
||||
# test suite go faster.
|
||||
TARGET = 1.0
|
||||
|
||||
|
||||
@slow
|
||||
async def test_sleep():
|
||||
async def sleep_1():
|
||||
await sleep_until(_core.current_time() + TARGET)
|
||||
|
||||
await check_takes_about(sleep_1, TARGET)
|
||||
|
||||
async def sleep_2():
|
||||
await sleep(TARGET)
|
||||
|
||||
await check_takes_about(sleep_2, TARGET)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await sleep(-1)
|
||||
|
||||
with assert_checkpoints():
|
||||
await sleep(0)
|
||||
# This also serves as a test of the trivial move_on_at
|
||||
with move_on_at(_core.current_time()):
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await sleep(0)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_move_on_after():
|
||||
with pytest.raises(ValueError):
|
||||
with move_on_after(-1):
|
||||
pass # pragma: no cover
|
||||
|
||||
async def sleep_3():
|
||||
with move_on_after(TARGET):
|
||||
await sleep(100)
|
||||
|
||||
await check_takes_about(sleep_3, TARGET)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_fail():
|
||||
async def sleep_4():
|
||||
with fail_at(_core.current_time() + TARGET):
|
||||
await sleep(100)
|
||||
|
||||
with pytest.raises(TooSlowError):
|
||||
await check_takes_about(sleep_4, TARGET)
|
||||
|
||||
with fail_at(_core.current_time() + 100):
|
||||
await sleep(0)
|
||||
|
||||
async def sleep_5():
|
||||
with fail_after(TARGET):
|
||||
await sleep(100)
|
||||
|
||||
with pytest.raises(TooSlowError):
|
||||
await check_takes_about(sleep_5, TARGET)
|
||||
|
||||
with fail_after(100):
|
||||
await sleep(0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with fail_after(-1):
|
||||
pass # pragma: no cover
|
||||
276
venv/Lib/site-packages/trio/tests/test_unix_pipes.py
Normal file
276
venv/Lib/site-packages/trio/tests/test_unix_pipes.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import errno
|
||||
import select
|
||||
import os
|
||||
import tempfile
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from .._core.tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken
|
||||
from .. import _core, move_on_after
|
||||
from ..testing import wait_all_tasks_blocked, check_one_way_stream
|
||||
|
||||
posix = os.name == "posix"
|
||||
pytestmark = pytest.mark.skipif(not posix, reason="posix only")
|
||||
if posix:
|
||||
from .._unix_pipes import FdStream
|
||||
else:
|
||||
with pytest.raises(ImportError):
|
||||
from .._unix_pipes import FdStream
|
||||
|
||||
|
||||
# Have to use quoted types so import doesn't crash on windows
|
||||
async def make_pipe() -> "Tuple[FdStream, FdStream]":
|
||||
"""Makes a new pair of pipes."""
|
||||
(r, w) = os.pipe()
|
||||
return FdStream(w), FdStream(r)
|
||||
|
||||
|
||||
async def make_clogged_pipe():
|
||||
s, r = await make_pipe()
|
||||
try:
|
||||
while True:
|
||||
# We want to totally fill up the pipe buffer.
|
||||
# This requires working around a weird feature that POSIX pipes
|
||||
# have.
|
||||
# If you do a write of <= PIPE_BUF bytes, then it's guaranteed
|
||||
# to either complete entirely, or not at all. So if we tried to
|
||||
# write PIPE_BUF bytes, and the buffer's free space is only
|
||||
# PIPE_BUF/2, then the write will raise BlockingIOError... even
|
||||
# though a smaller write could still succeed! To avoid this,
|
||||
# make sure to write >PIPE_BUF bytes each time, which disables
|
||||
# the special behavior.
|
||||
# For details, search for PIPE_BUF here:
|
||||
# http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html
|
||||
|
||||
# for the getattr:
|
||||
# https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3
|
||||
buf_size = getattr(select, "PIPE_BUF", 8192)
|
||||
os.write(s.fileno(), b"x" * buf_size * 2)
|
||||
except BlockingIOError:
|
||||
pass
|
||||
return s, r
|
||||
|
||||
|
||||
async def test_send_pipe():
|
||||
r, w = os.pipe()
|
||||
async with FdStream(w) as send:
|
||||
assert send.fileno() == w
|
||||
await send.send_all(b"123")
|
||||
assert (os.read(r, 8)) == b"123"
|
||||
|
||||
os.close(r)
|
||||
|
||||
|
||||
async def test_receive_pipe():
|
||||
r, w = os.pipe()
|
||||
async with FdStream(r) as recv:
|
||||
assert (recv.fileno()) == r
|
||||
os.write(w, b"123")
|
||||
assert (await recv.receive_some(8)) == b"123"
|
||||
|
||||
os.close(w)
|
||||
|
||||
|
||||
async def test_pipes_combined():
|
||||
write, read = await make_pipe()
|
||||
count = 2**20
|
||||
|
||||
async def sender():
|
||||
big = bytearray(count)
|
||||
await write.send_all(big)
|
||||
|
||||
async def reader():
|
||||
await wait_all_tasks_blocked()
|
||||
received = 0
|
||||
while received < count:
|
||||
received += len(await read.receive_some(4096))
|
||||
|
||||
assert received == count
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(sender)
|
||||
n.start_soon(reader)
|
||||
|
||||
await read.aclose()
|
||||
await write.aclose()
|
||||
|
||||
|
||||
async def test_pipe_errors():
|
||||
with pytest.raises(TypeError):
|
||||
FdStream(None)
|
||||
|
||||
r, w = os.pipe()
|
||||
os.close(w)
|
||||
async with FdStream(r) as s:
|
||||
with pytest.raises(ValueError):
|
||||
await s.receive_some(0)
|
||||
|
||||
|
||||
async def test_del():
|
||||
w, r = await make_pipe()
|
||||
f1, f2 = w.fileno(), r.fileno()
|
||||
del w, r
|
||||
gc_collect_harder()
|
||||
|
||||
with pytest.raises(OSError) as excinfo:
|
||||
os.close(f1)
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
with pytest.raises(OSError) as excinfo:
|
||||
os.close(f2)
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
|
||||
async def test_async_with():
|
||||
w, r = await make_pipe()
|
||||
async with w, r:
|
||||
pass
|
||||
|
||||
assert w.fileno() == -1
|
||||
assert r.fileno() == -1
|
||||
|
||||
with pytest.raises(OSError) as excinfo:
|
||||
os.close(w.fileno())
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
with pytest.raises(OSError) as excinfo:
|
||||
os.close(r.fileno())
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
|
||||
async def test_misdirected_aclose_regression():
|
||||
# https://github.com/python-trio/trio/issues/661#issuecomment-456582356
|
||||
w, r = await make_pipe()
|
||||
old_r_fd = r.fileno()
|
||||
|
||||
# Close the original objects
|
||||
await w.aclose()
|
||||
await r.aclose()
|
||||
|
||||
# Do a little dance to get a new pipe whose receive handle matches the old
|
||||
# receive handle.
|
||||
r2_fd, w2_fd = os.pipe()
|
||||
if r2_fd != old_r_fd: # pragma: no cover
|
||||
os.dup2(r2_fd, old_r_fd)
|
||||
os.close(r2_fd)
|
||||
async with FdStream(old_r_fd) as r2:
|
||||
assert r2.fileno() == old_r_fd
|
||||
|
||||
# And now set up a background task that's working on the new receive
|
||||
# handle
|
||||
async def expect_eof():
|
||||
assert await r2.receive_some(10) == b""
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_eof)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# Here's the key test: does calling aclose() again on the *old*
|
||||
# handle, cause the task blocked on the *new* handle to raise
|
||||
# ClosedResourceError?
|
||||
await r.aclose()
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# Guess we survived! Close the new write handle so that the task
|
||||
# gets an EOF and can exit cleanly.
|
||||
os.close(w2_fd)
|
||||
|
||||
|
||||
async def test_close_at_bad_time_for_receive_some(monkeypatch):
|
||||
# We used to have race conditions where if one task was using the pipe,
|
||||
# and another closed it at *just* the wrong moment, it would give an
|
||||
# unexpected error instead of ClosedResourceError:
|
||||
# https://github.com/python-trio/trio/issues/661
|
||||
#
|
||||
# This tests what happens if the pipe gets closed in the moment *between*
|
||||
# when receive_some wakes up, and when it tries to call os.read
|
||||
async def expect_closedresourceerror():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await r.receive_some(10)
|
||||
|
||||
orig_wait_readable = _core._run.TheIOManager.wait_readable
|
||||
|
||||
async def patched_wait_readable(*args, **kwargs):
|
||||
await orig_wait_readable(*args, **kwargs)
|
||||
await r.aclose()
|
||||
|
||||
monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable)
|
||||
s, r = await make_pipe()
|
||||
async with s, r:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_closedresourceerror)
|
||||
await wait_all_tasks_blocked()
|
||||
# Trigger everything by waking up the receiver
|
||||
await s.send_all(b"x")
|
||||
|
||||
|
||||
async def test_close_at_bad_time_for_send_all(monkeypatch):
|
||||
# We used to have race conditions where if one task was using the pipe,
|
||||
# and another closed it at *just* the wrong moment, it would give an
|
||||
# unexpected error instead of ClosedResourceError:
|
||||
# https://github.com/python-trio/trio/issues/661
|
||||
#
|
||||
# This tests what happens if the pipe gets closed in the moment *between*
|
||||
# when send_all wakes up, and when it tries to call os.write
|
||||
async def expect_closedresourceerror():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await s.send_all(b"x" * 100)
|
||||
|
||||
orig_wait_writable = _core._run.TheIOManager.wait_writable
|
||||
|
||||
async def patched_wait_writable(*args, **kwargs):
|
||||
await orig_wait_writable(*args, **kwargs)
|
||||
await s.aclose()
|
||||
|
||||
monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable)
|
||||
s, r = await make_clogged_pipe()
|
||||
async with s, r:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_closedresourceerror)
|
||||
await wait_all_tasks_blocked()
|
||||
# Trigger everything by waking up the sender. On ppc64el, PIPE_BUF
|
||||
# is 8192 but make_clogged_pipe() ends up writing a total of
|
||||
# 1048576 bytes before the pipe is full, and then a subsequent
|
||||
# receive_some(10000) isn't sufficient for orig_wait_writable() to
|
||||
# return for our subsequent aclose() call. It's necessary to empty
|
||||
# the pipe further before this happens. So we loop here until the
|
||||
# pipe is empty to make sure that the sender wakes up even in this
|
||||
# case. Otherwise patched_wait_writable() never gets to the
|
||||
# aclose(), so expect_closedresourceerror() never returns, the
|
||||
# nursery never finishes all tasks and this test hangs.
|
||||
received_data = await r.receive_some(10000)
|
||||
while received_data:
|
||||
received_data = await r.receive_some(10000)
|
||||
|
||||
|
||||
# On FreeBSD, directories are readable, and we haven't found any other trick
|
||||
# for making an unreadable fd, so there's no way to run this test. Fortunately
|
||||
# the logic this is testing doesn't depend on the platform, so testing on
|
||||
# other platforms is probably good enough.
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("freebsd"),
|
||||
reason="no way to make read() return a bizarro error on FreeBSD",
|
||||
)
|
||||
async def test_bizarro_OSError_from_receive():
|
||||
# Make sure that if the read syscall returns some bizarro error, then we
|
||||
# get a BrokenResourceError. This is incredibly unlikely; there's almost
|
||||
# no way to trigger a failure here intentionally (except for EBADF, but we
|
||||
# exploit that to detect file closure, so it takes a different path). So
|
||||
# we set up a strange scenario where the pipe fd somehow transmutes into a
|
||||
# directory fd, causing os.read to raise IsADirectoryError (yes, that's a
|
||||
# real built-in exception type).
|
||||
s, r = await make_pipe()
|
||||
async with s, r:
|
||||
dir_fd = os.open("/", os.O_DIRECTORY, 0)
|
||||
try:
|
||||
os.dup2(dir_fd, r.fileno())
|
||||
with pytest.raises(_core.BrokenResourceError):
|
||||
await r.receive_some(10)
|
||||
finally:
|
||||
os.close(dir_fd)
|
||||
|
||||
|
||||
@skip_if_fbsd_pipes_broken
|
||||
async def test_pipe_fully():
|
||||
await check_one_way_stream(make_pipe, make_clogged_pipe)
|
||||
189
venv/Lib/site-packages/trio/tests/test_util.py
Normal file
189
venv/Lib/site-packages/trio/tests/test_util.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import signal
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from .. import _core
|
||||
from .._core.tests.tutil import (
|
||||
ignore_coroutine_never_awaited_warnings,
|
||||
create_asyncio_future_in_new_loop,
|
||||
)
|
||||
from .._util import (
|
||||
signal_raise,
|
||||
ConflictDetector,
|
||||
is_main_thread,
|
||||
coroutine_or_error,
|
||||
generic_function,
|
||||
Final,
|
||||
NoPublicConstructor,
|
||||
)
|
||||
from ..testing import wait_all_tasks_blocked
|
||||
|
||||
|
||||
def test_signal_raise():
|
||||
record = []
|
||||
|
||||
def handler(signum, _):
|
||||
record.append(signum)
|
||||
|
||||
old = signal.signal(signal.SIGFPE, handler)
|
||||
try:
|
||||
signal_raise(signal.SIGFPE)
|
||||
finally:
|
||||
signal.signal(signal.SIGFPE, old)
|
||||
assert record == [signal.SIGFPE]
|
||||
|
||||
|
||||
async def test_ConflictDetector():
|
||||
ul1 = ConflictDetector("ul1")
|
||||
ul2 = ConflictDetector("ul2")
|
||||
|
||||
with ul1:
|
||||
with ul2:
|
||||
print("ok")
|
||||
|
||||
with pytest.raises(_core.BusyResourceError) as excinfo:
|
||||
with ul1:
|
||||
with ul1:
|
||||
pass # pragma: no cover
|
||||
assert "ul1" in str(excinfo.value)
|
||||
|
||||
async def wait_with_ul1():
|
||||
with ul1:
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
with pytest.raises(_core.BusyResourceError) as excinfo:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_with_ul1)
|
||||
nursery.start_soon(wait_with_ul1)
|
||||
assert "ul1" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_module_metadata_is_fixed_up():
|
||||
import trio
|
||||
import trio.testing
|
||||
|
||||
assert trio.Cancelled.__module__ == "trio"
|
||||
assert trio.open_nursery.__module__ == "trio"
|
||||
assert trio.abc.Stream.__module__ == "trio.abc"
|
||||
assert trio.lowlevel.wait_task_rescheduled.__module__ == "trio.lowlevel"
|
||||
assert trio.testing.trio_test.__module__ == "trio.testing"
|
||||
|
||||
# Also check methods
|
||||
assert trio.lowlevel.ParkingLot.__init__.__module__ == "trio.lowlevel"
|
||||
assert trio.abc.Stream.send_all.__module__ == "trio.abc"
|
||||
|
||||
# And names
|
||||
assert trio.Cancelled.__name__ == "Cancelled"
|
||||
assert trio.Cancelled.__qualname__ == "Cancelled"
|
||||
assert trio.abc.SendStream.send_all.__name__ == "send_all"
|
||||
assert trio.abc.SendStream.send_all.__qualname__ == "SendStream.send_all"
|
||||
assert trio.to_thread.__name__ == "trio.to_thread"
|
||||
assert trio.to_thread.run_sync.__name__ == "run_sync"
|
||||
assert trio.to_thread.run_sync.__qualname__ == "run_sync"
|
||||
|
||||
|
||||
async def test_is_main_thread():
|
||||
assert is_main_thread()
|
||||
|
||||
def not_main_thread():
|
||||
assert not is_main_thread()
|
||||
|
||||
await trio.to_thread.run_sync(not_main_thread)
|
||||
|
||||
|
||||
# @coroutine is deprecated since python 3.8, which is fine with us.
|
||||
@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
|
||||
def test_coroutine_or_error():
|
||||
class Deferred:
|
||||
"Just kidding"
|
||||
|
||||
with ignore_coroutine_never_awaited_warnings():
|
||||
|
||||
async def f(): # pragma: no cover
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(f())
|
||||
assert "expecting an async function" in str(excinfo.value)
|
||||
|
||||
import asyncio
|
||||
|
||||
@asyncio.coroutine
|
||||
def generator_based_coro(): # pragma: no cover
|
||||
yield from asyncio.sleep(1)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(generator_based_coro())
|
||||
assert "asyncio" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(create_asyncio_future_in_new_loop())
|
||||
assert "asyncio" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(create_asyncio_future_in_new_loop)
|
||||
assert "asyncio" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(Deferred())
|
||||
assert "twisted" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(lambda: Deferred())
|
||||
assert "twisted" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(len, [[1, 2, 3]])
|
||||
|
||||
assert "appears to be synchronous" in str(excinfo.value)
|
||||
|
||||
async def async_gen(arg): # pragma: no cover
|
||||
yield
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(async_gen, [0])
|
||||
msg = "expected an async function but got an async generator"
|
||||
assert msg in str(excinfo.value)
|
||||
|
||||
# Make sure no references are kept around to keep anything alive
|
||||
del excinfo
|
||||
|
||||
|
||||
def test_generic_function():
|
||||
@generic_function
|
||||
def test_func(arg):
|
||||
"""Look, a docstring!"""
|
||||
return arg
|
||||
|
||||
assert test_func is test_func[int] is test_func[int, str]
|
||||
assert test_func(42) == test_func[int](42) == 42
|
||||
assert test_func.__doc__ == "Look, a docstring!"
|
||||
assert test_func.__qualname__ == "test_generic_function.<locals>.test_func"
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__module__ == __name__
|
||||
|
||||
|
||||
def test_final_metaclass():
|
||||
class FinalClass(metaclass=Final):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class SubClass(FinalClass):
|
||||
pass
|
||||
|
||||
|
||||
def test_no_public_constructor_metaclass():
|
||||
class SpecialClass(metaclass=NoPublicConstructor):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
SpecialClass()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class SubClass(SpecialClass):
|
||||
pass
|
||||
|
||||
# Private constructor should not raise
|
||||
assert isinstance(SpecialClass._create(), SpecialClass)
|
||||
220
venv/Lib/site-packages/trio/tests/test_wait_for_object.py
Normal file
220
venv/Lib/site-packages/trio/tests/test_wait_for_object.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
on_windows = os.name == "nt"
|
||||
# Mark all the tests in this file as being windows-only
|
||||
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
|
||||
|
||||
from .._core.tests.tutil import slow
|
||||
import trio
|
||||
from .. import _core
|
||||
from .. import _timeouts
|
||||
|
||||
if on_windows:
|
||||
from .._core._windows_cffi import ffi, kernel32
|
||||
from .._wait_for_object import (
|
||||
WaitForSingleObject,
|
||||
WaitForMultipleObjects_sync,
|
||||
)
|
||||
|
||||
|
||||
async def test_WaitForMultipleObjects_sync():
|
||||
# This does a series of tests where we set/close the handle before
|
||||
# initiating the waiting for it.
|
||||
#
|
||||
# Note that closing the handle (not signaling) will cause the
|
||||
# *initiation* of a wait to return immediately. But closing a handle
|
||||
# that is already being waited on will not stop whatever is waiting
|
||||
# for it.
|
||||
|
||||
# One handle
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle1)
|
||||
WaitForMultipleObjects_sync(handle1)
|
||||
kernel32.CloseHandle(handle1)
|
||||
print("test_WaitForMultipleObjects_sync one OK")
|
||||
|
||||
# Two handles, signal first
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle1)
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync set first OK")
|
||||
|
||||
# Two handles, signal second
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle2)
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync set second OK")
|
||||
|
||||
# Two handles, close first
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.CloseHandle(handle1)
|
||||
with pytest.raises(OSError):
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync close first OK")
|
||||
|
||||
# Two handles, close second
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.CloseHandle(handle2)
|
||||
with pytest.raises(OSError):
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle1)
|
||||
print("test_WaitForMultipleObjects_sync close second OK")
|
||||
|
||||
|
||||
@slow
|
||||
async def test_WaitForMultipleObjects_sync_slow():
|
||||
# This does a series of test in which the main thread sync-waits for
|
||||
# handles, while we spawn a thread to set the handles after a short while.
|
||||
|
||||
TIMEOUT = 0.3
|
||||
|
||||
# One handle
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1
|
||||
)
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
# If we would comment the line below, the above thread will be stuck,
|
||||
# and Trio won't exit this scope
|
||||
kernel32.SetEvent(handle1)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
kernel32.CloseHandle(handle1)
|
||||
print("test_WaitForMultipleObjects_sync_slow one OK")
|
||||
|
||||
# Two handles, signal first
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2
|
||||
)
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
kernel32.SetEvent(handle1)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync_slow thread-set first OK")
|
||||
|
||||
# Two handles, signal second
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2
|
||||
)
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
kernel32.SetEvent(handle2)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync_slow thread-set second OK")
|
||||
|
||||
|
||||
async def test_WaitForSingleObject():
|
||||
# This does a series of test for setting/closing the handle before
|
||||
# initiating the wait.
|
||||
|
||||
# Test already set
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle)
|
||||
await WaitForSingleObject(handle) # should return at once
|
||||
kernel32.CloseHandle(handle)
|
||||
print("test_WaitForSingleObject already set OK")
|
||||
|
||||
# Test already set, as int
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle_int = int(ffi.cast("intptr_t", handle))
|
||||
kernel32.SetEvent(handle)
|
||||
await WaitForSingleObject(handle_int) # should return at once
|
||||
kernel32.CloseHandle(handle)
|
||||
print("test_WaitForSingleObject already set OK")
|
||||
|
||||
# Test already closed
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.CloseHandle(handle)
|
||||
with pytest.raises(OSError):
|
||||
await WaitForSingleObject(handle) # should return at once
|
||||
print("test_WaitForSingleObject already closed OK")
|
||||
|
||||
# Not a handle
|
||||
with pytest.raises(TypeError):
|
||||
await WaitForSingleObject("not a handle") # Wrong type
|
||||
# with pytest.raises(OSError):
|
||||
# await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :(
|
||||
print("test_WaitForSingleObject not a handle OK")
|
||||
|
||||
|
||||
@slow
|
||||
async def test_WaitForSingleObject_slow():
|
||||
# This does a series of test for setting the handle in another task,
|
||||
# and cancelling the wait task.
|
||||
|
||||
# Set the timeout used in the tests. We test the waiting time against
|
||||
# the timeout with a certain margin.
|
||||
TIMEOUT = 0.3
|
||||
|
||||
async def signal_soon_async(handle):
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
kernel32.SetEvent(handle)
|
||||
|
||||
# Test handle is SET after TIMEOUT in separate coroutine
|
||||
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(WaitForSingleObject, handle)
|
||||
nursery.start_soon(signal_soon_async, handle)
|
||||
|
||||
kernel32.CloseHandle(handle)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
print("test_WaitForSingleObject_slow set from task OK")
|
||||
|
||||
# Test handle is SET after TIMEOUT in separate coroutine, as int
|
||||
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle_int = int(ffi.cast("intptr_t", handle))
|
||||
t0 = _core.current_time()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(WaitForSingleObject, handle_int)
|
||||
nursery.start_soon(signal_soon_async, handle)
|
||||
|
||||
kernel32.CloseHandle(handle)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
print("test_WaitForSingleObject_slow set from task as int OK")
|
||||
|
||||
# Test handle is CLOSED after 1 sec - NOPE see comment above
|
||||
|
||||
# Test cancellation
|
||||
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
|
||||
with _timeouts.move_on_after(TIMEOUT):
|
||||
await WaitForSingleObject(handle)
|
||||
|
||||
kernel32.CloseHandle(handle)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
print("test_WaitForSingleObject_slow cancellation OK")
|
||||
110
venv/Lib/site-packages/trio/tests/test_windows_pipes.py
Normal file
110
venv/Lib/site-packages/trio/tests/test_windows_pipes.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import errno
|
||||
import select
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
from .._core.tests.tutil import gc_collect_harder
|
||||
from .. import _core, move_on_after
|
||||
from ..testing import wait_all_tasks_blocked, check_one_way_stream
|
||||
|
||||
if sys.platform == "win32":
|
||||
from .._windows_pipes import PipeSendStream, PipeReceiveStream
|
||||
from .._core._windows_cffi import _handle, kernel32
|
||||
from asyncio.windows_utils import pipe
|
||||
else:
|
||||
pytestmark = pytest.mark.skip(reason="windows only")
|
||||
pipe = None # type: Any
|
||||
PipeSendStream = None # type: Any
|
||||
PipeReceiveStream = None # type: Any
|
||||
|
||||
|
||||
async def make_pipe() -> "Tuple[PipeSendStream, PipeReceiveStream]":
|
||||
"""Makes a new pair of pipes."""
|
||||
(r, w) = pipe()
|
||||
return PipeSendStream(w), PipeReceiveStream(r)
|
||||
|
||||
|
||||
async def test_pipe_typecheck():
|
||||
with pytest.raises(TypeError):
|
||||
PipeSendStream(1.0)
|
||||
with pytest.raises(TypeError):
|
||||
PipeReceiveStream(None)
|
||||
|
||||
|
||||
async def test_pipe_error_on_close():
|
||||
# Make sure we correctly handle a failure from kernel32.CloseHandle
|
||||
r, w = pipe()
|
||||
|
||||
send_stream = PipeSendStream(w)
|
||||
receive_stream = PipeReceiveStream(r)
|
||||
|
||||
assert kernel32.CloseHandle(_handle(r))
|
||||
assert kernel32.CloseHandle(_handle(w))
|
||||
|
||||
with pytest.raises(OSError):
|
||||
await send_stream.aclose()
|
||||
with pytest.raises(OSError):
|
||||
await receive_stream.aclose()
|
||||
|
||||
|
||||
async def test_pipes_combined():
|
||||
write, read = await make_pipe()
|
||||
count = 2**20
|
||||
replicas = 3
|
||||
|
||||
async def sender():
|
||||
async with write:
|
||||
big = bytearray(count)
|
||||
for _ in range(replicas):
|
||||
await write.send_all(big)
|
||||
|
||||
async def reader():
|
||||
async with read:
|
||||
await wait_all_tasks_blocked()
|
||||
total_received = 0
|
||||
while True:
|
||||
# 5000 is chosen because it doesn't evenly divide 2**20
|
||||
received = len(await read.receive_some(5000))
|
||||
if not received:
|
||||
break
|
||||
total_received += received
|
||||
|
||||
assert total_received == count * replicas
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(sender)
|
||||
n.start_soon(reader)
|
||||
|
||||
|
||||
async def test_async_with():
|
||||
w, r = await make_pipe()
|
||||
async with w, r:
|
||||
pass
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await w.send_all(b"")
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await r.receive_some(10)
|
||||
|
||||
|
||||
async def test_close_during_write():
|
||||
w, r = await make_pipe()
|
||||
async with _core.open_nursery() as nursery:
|
||||
|
||||
async def write_forever():
|
||||
with pytest.raises(_core.ClosedResourceError) as excinfo:
|
||||
while True:
|
||||
await w.send_all(b"x" * 4096)
|
||||
assert "another task" in str(excinfo.value)
|
||||
|
||||
nursery.start_soon(write_forever)
|
||||
await wait_all_tasks_blocked(0.1)
|
||||
await w.aclose()
|
||||
|
||||
|
||||
async def test_pipe_fully():
|
||||
# passing make_clogged_pipe tests wait_send_all_might_not_block, and we
|
||||
# can't implement that on Windows
|
||||
await check_one_way_stream(make_pipe, None)
|
||||
0
venv/Lib/site-packages/trio/tests/tools/__init__.py
Normal file
0
venv/Lib/site-packages/trio/tests/tools/__init__.py
Normal file
72
venv/Lib/site-packages/trio/tests/tools/test_gen_exports.py
Normal file
72
venv/Lib/site-packages/trio/tests/tools/test_gen_exports.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import ast
|
||||
import astor
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
|
||||
from shutil import copyfile
|
||||
from trio._tools.gen_exports import (
|
||||
get_public_methods,
|
||||
create_passthrough_args,
|
||||
process,
|
||||
)
|
||||
|
||||
SOURCE = '''from _run import _public
|
||||
|
||||
class Test:
|
||||
@_public
|
||||
def public_func(self):
|
||||
"""With doc string"""
|
||||
|
||||
@ignore_this
|
||||
@_public
|
||||
@another_decorator
|
||||
async def public_async_func(self):
|
||||
pass # no doc string
|
||||
|
||||
def not_public(self):
|
||||
pass
|
||||
|
||||
async def not_public_async(self):
|
||||
pass
|
||||
'''
|
||||
|
||||
|
||||
def test_get_public_methods():
|
||||
methods = list(get_public_methods(ast.parse(SOURCE)))
|
||||
assert {m.name for m in methods} == {"public_func", "public_async_func"}
|
||||
|
||||
|
||||
def test_create_pass_through_args():
|
||||
testcases = [
|
||||
("def f()", "()"),
|
||||
("def f(one)", "(one)"),
|
||||
("def f(one, two)", "(one, two)"),
|
||||
("def f(one, *args)", "(one, *args)"),
|
||||
(
|
||||
"def f(one, *args, kw1, kw2=None, **kwargs)",
|
||||
"(one, *args, kw1=kw1, kw2=kw2, **kwargs)",
|
||||
),
|
||||
]
|
||||
|
||||
for (funcdef, expected) in testcases:
|
||||
func_node = ast.parse(funcdef + ":\n pass").body[0]
|
||||
assert isinstance(func_node, ast.FunctionDef)
|
||||
assert create_passthrough_args(func_node) == expected
|
||||
|
||||
|
||||
def test_process(tmp_path):
|
||||
modpath = tmp_path / "_module.py"
|
||||
genpath = tmp_path / "_generated_module.py"
|
||||
modpath.write_text(SOURCE, encoding="utf-8")
|
||||
assert not genpath.exists()
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process([(str(modpath), "runner")], do_test=True)
|
||||
assert excinfo.value.code == 1
|
||||
process([(str(modpath), "runner")], do_test=False)
|
||||
assert genpath.exists()
|
||||
process([(str(modpath), "runner")], do_test=True)
|
||||
# But if we change the lookup path it notices
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process([(str(modpath), "runner.io_manager")], do_test=True)
|
||||
assert excinfo.value.code == 1
|
||||
Reference in New Issue
Block a user