Skip to content

Instantly share code, notes, and snippets.

@graingert
Created December 15, 2024 10:50
Show Gist options
  • Save graingert/3c75a9bec75ea1535e22a2f7d938a344 to your computer and use it in GitHub Desktop.
Save graingert/3c75a9bec75ea1535e22a2f7d938a344 to your computer and use it in GitHub Desktop.

Revisions

  1. graingert created this gist Dec 15, 2024.
    109 changes: 109 additions & 0 deletions edge_cancelation.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,109 @@
    from __future__ import annotations

    import dataclasses
    import math
    from collections.abc import Callable, Coroutine, Generator
    from typing import TYPE_CHECKING

    import trio.lowlevel
    from typing_extensions import ParamSpec, Self, TypeVar, overload

    if TYPE_CHECKING:
    from types import TracebackType

    _P = ParamSpec("_P")

    _YieldT_co = TypeVar("_YieldT_co", covariant=True)
    _SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
    _ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None)
    _SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True)
    _ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True)


    @dataclasses.dataclass
    class WrapCoro(
    Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
    Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
    ):
    _current_task: trio.lowlevel.Task
    _cancel_scope: trio.CancelScope
    _coro: Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]
    _was_cancelled: bool = False

    def __await__(self) -> Self:
    return self

    def send(self, value: _SendT_contra_nd) -> _YieldT_co:
    if self._was_cancelled:
    self._cancel_scope.shield = True
    return self._coro.send(value)

    cancelled = trio.current_effective_deadline() == -math.inf
    r = self._coro.send(value)
    if cancelled:
    self._was_cancelled = True
    return r

    @overload
    def throw(
    self,
    typ: type[BaseException],
    val: BaseException | object = None,
    tb: TracebackType | None = None,
    ) -> _YieldT_co: ...
    @overload
    def throw(
    self,
    typ: BaseException,
    val: None = None,
    tb: TracebackType | None = None,
    ) -> _YieldT_co: ...

    def throw(
    self,
    typ: type[BaseException] | BaseException,
    val: object = None,
    tb: TracebackType | None = None,
    ) -> _YieldT_co:
    if val is None and tb is None:
    return self._coro.throw(typ)
    return self._coro.throw(typ, val, tb) # type: ignore[arg-type]

    def close(self) -> None:
    pass


    def edge_cancel(
    fn: Callable[_P, Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]],
    ) -> Callable[_P, Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]]:
    async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _ReturnT_co_nd:
    with trio.CancelScope() as scope:
    return await WrapCoro(
    trio.lowlevel.current_task(),
    scope,
    fn(*args, **kwargs),
    )

    return wrapper


    async def demo() -> None:

    @edge_cancel
    async def edge_cancelled() -> None:
    print("started")
    try:
    await trio.sleep(math.inf)
    except BaseException:
    print("cancelled!")
    await trio.sleep(0.1)
    print("slept!")
    await trio.sleep(0.1)
    print("slept!")
    raise

    async with trio.open_nursery() as nursery:
    nursery.start_soon(edge_cancelled)
    nursery.cancel_scope.cancel()

    trio.run(demo)