Created
December 15, 2024 10:50
-
-
Save graingert/3c75a9bec75ea1535e22a2f7d938a344 to your computer and use it in GitHub Desktop.
Revisions
-
graingert created this gist
Dec 15, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,109 @@ from __future__ import annotations import dataclasses import math from collections.abc import Callable, Coroutine, Generator from typing import TYPE_CHECKING import trio.lowlevel from typing_extensions import ParamSpec, Self, TypeVar, overload if TYPE_CHECKING: from types import TracebackType _P = ParamSpec("_P") _YieldT_co = TypeVar("_YieldT_co", covariant=True) _SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None) _ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None) _SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True) _ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True) @dataclasses.dataclass class WrapCoro( Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], ): _current_task: trio.lowlevel.Task _cancel_scope: trio.CancelScope _coro: Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd] _was_cancelled: bool = False def __await__(self) -> Self: return self def send(self, value: _SendT_contra_nd) -> _YieldT_co: if self._was_cancelled: self._cancel_scope.shield = True return self._coro.send(value) cancelled = trio.current_effective_deadline() == -math.inf r = self._coro.send(value) if cancelled: self._was_cancelled = True return r @overload def throw( self, typ: type[BaseException], val: BaseException | object = None, tb: TracebackType | None = None, ) -> _YieldT_co: ... @overload def throw( self, typ: BaseException, val: None = None, tb: TracebackType | None = None, ) -> _YieldT_co: ... def throw( self, typ: type[BaseException] | BaseException, val: object = None, tb: TracebackType | None = None, ) -> _YieldT_co: if val is None and tb is None: return self._coro.throw(typ) return self._coro.throw(typ, val, tb) # type: ignore[arg-type] def close(self) -> None: pass def edge_cancel( fn: Callable[_P, Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]], ) -> Callable[_P, Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]]: async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _ReturnT_co_nd: with trio.CancelScope() as scope: return await WrapCoro( trio.lowlevel.current_task(), scope, fn(*args, **kwargs), ) return wrapper async def demo() -> None: @edge_cancel async def edge_cancelled() -> None: print("started") try: await trio.sleep(math.inf) except BaseException: print("cancelled!") await trio.sleep(0.1) print("slept!") await trio.sleep(0.1) print("slept!") raise async with trio.open_nursery() as nursery: nursery.start_soon(edge_cancelled) nursery.cancel_scope.cancel() trio.run(demo)