Skip to content

Instantly share code, notes, and snippets.

@DrChai
Last active April 8, 2025 07:26
Show Gist options
  • Save DrChai/cedf18d0655e7784a8282ecd85f28e90 to your computer and use it in GitHub Desktop.
Save DrChai/cedf18d0655e7784a8282ecd85f28e90 to your computer and use it in GitHub Desktop.

Revisions

  1. DrChai renamed this gist Apr 8, 2025. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  2. DrChai created this gist Apr 8, 2025.
    22 changes: 22 additions & 0 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,22 @@
    ## Overview
    This gist provides a minimal implementation of an asynchronous APIView for Django REST Framework (DRF). It extends the default DRF `APIView` to support asynchronous operations, including permission checks, authentication, and request handling.
    It is designed to help migrate existing synchronous WSGI-based views to ASGI for improved concurrency and performance.
    ## Usage

    ```python
    from path.to.async_apiviews import APIView, Request, BasePermission
    from rest_framework.response import Response

    class AsyncExampleView(APIView):
    permission_classes = [IsAuthenticated]

    async def get(self, request: Request, *args, **kwargs):
    """
    Handle GET requests asynchronously.
    """
    # Perform an asynchronous operation
    data = await query.get_operation()
    # Fire-and-forget task (non-awaitable)
    asyncio.create_task(task.send_notification())
    return Response({"message": "Hello, async world!", "data": data})
    ```
    76 changes: 76 additions & 0 deletions async_client.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,76 @@
    import asyncio
    from contextlib import asynccontextmanager

    from django.test.client import AsyncClient as DjangoAsyncClient
    from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory


    @asynccontextmanager
    async def gather_tasks_context():
    """
    An async context manager to gather and wait for tasks created within the context.
    """
    original_create_task = asyncio.create_task
    tasks = []

    def create_task_wrapper(coro):
    task = original_create_task(coro)
    tasks.append(task)
    return task

    asyncio.create_task = create_task_wrapper

    try:
    yield
    finally:
    asyncio.create_task = original_create_task
    if tasks:
    await asyncio.gather(*tasks)


    class AsyncAPIRequestFactory(DjangoAsyncRequestFactory):

    def __init__(self, **defaults):
    super().__init__(**defaults)
    self._credentials = {}

    def post(self, path, data=None, content_type="application/json", **extra):
    """Construct a POST request."""
    data = self._encode_json({} if data is None else data, content_type)
    return self.generic("POST", path, data, content_type, **extra)

    def generic(self, *args, **extra):
    # Include the CONTENT_TYPE, regardless of whether or not data is empty.
    extra |= self._credentials
    return super().generic(*args, **extra)


    def override_content_type(methods, default_content_type="application/json"):
    def decorator(cls):
    for method_name in methods:
    original_method = getattr(cls, method_name)

    def wrapper(original_method):
    def inner(*args, content_type=None, **kwargs):
    if content_type is None:
    content_type = default_content_type
    return original_method(*args, content_type=content_type, **kwargs)

    return inner

    setattr(cls, method_name, wrapper(original_method))

    return cls

    return decorator


    @override_content_type(methods=["post", "put", "patch", "delete"])
    class AsyncAPIClient(DjangoAsyncClient, AsyncAPIRequestFactory):
    """ """

    def credentials(self, **kwargs):
    """
    Sets headers that will be used on every outgoing request.
    """
    self._credentials = kwargs
    137 changes: 137 additions & 0 deletions async_view.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,137 @@
    import asyncio

    from rest_framework import exceptions
    from rest_framework.permissions import BasePermission as DRFBasePermission
    from rest_framework.request import Request as DRFRequest
    from rest_framework.views import APIView as DRFAPIView


    class BasePermission(DRFBasePermission):
    async def has_permission(self, request, view):
    """ """
    return True

    async def has_object_permission(self, request, view, obj):
    """ """
    return True


    class Request(DRFRequest):

    @property
    def user(self):
    """ """
    if not hasattr(self, "_user"):
    raise AttributeError
    return self._user

    @user.setter
    def user(self, value):
    """ """
    self._user = value
    self._request.user = value

    async def _authenticate(self):

    for authenticator in self.authenticators:
    try:
    user_auth_tuple = await authenticator.authenticate(self)
    except exceptions.APIException:
    self._not_authenticated()
    raise

    if user_auth_tuple is not None:
    self._authenticator = authenticator
    self.user, self.auth = user_auth_tuple
    return

    self._not_authenticated()


    class APIView(DRFAPIView):

    async def initial(self, request, *args, **kwargs):
    """ """
    self.format_kwarg = self.get_format_suffix(**kwargs)

    # Perform content negotiation and store the accepted info on the request
    neg = self.perform_content_negotiation(request)
    request.accepted_renderer, request.accepted_media_type = neg

    # Determine the API version, if versioning is in use.
    version, scheme = self.determine_version(request, *args, **kwargs)
    request.version, request.versioning_scheme = version, scheme

    # Ensure that the incoming request is permitted
    await self.perform_authentication(request)
    await self.check_permissions(request)
    # self.check_throttles(request)

    async def dispatch(self, request, *args, **kwargs):
    """ """
    self.args = args
    self.kwargs = kwargs
    request = self.initialize_request(request, *args, **kwargs)
    self.request = request
    self.headers = self.default_response_headers # deprecate?

    try:
    await self.initial(request, *args, **kwargs)
    if request.method.lower() in self.http_method_names and (
    handler := getattr(self, request.method.lower(), None)
    ):
    response = await handler(request, *args, **kwargs)
    else:
    raise exceptions.MethodNotAllowed(request.method)
    except Exception as exc:
    response = self.handle_exception(exc)

    self.response = self.finalize_response(request, response, *args, **kwargs)
    return self.response

    def initialize_request(self, request, *args, **kwargs):
    """ """
    parser_context = self.get_parser_context(request)

    return Request(
    request,
    parsers=self.get_parsers(),
    authenticators=self.get_authenticators(),
    negotiator=self.get_content_negotiator(),
    parser_context=parser_context,
    )

    async def permission_handler(
    self, permission_class, method="has_permission", *args
    ):
    permission = permission_class()
    if not await getattr(permission, method)(self.request, self, *args):
    self.permission_denied(
    self.request,
    message=getattr(permission, "detail", None),
    code=getattr(permission, "code", None),
    )

    async def check_permissions(self, request: Request) -> None:
    try:
    async with asyncio.TaskGroup() as tg:
    for permission in self.permission_classes:
    tg.create_task(self.permission_handler(permission))
    except ExceptionGroup as e:
    raise e.exceptions[0]

    async def check_object_permissions(self, request: Request, obj) -> None:
    """ """
    try:
    async with asyncio.TaskGroup() as tg:
    for permission in self.permission_classes:
    tg.create_task(
    self.permission_handler(
    permission, "has_object_permission", obj
    )
    )
    except ExceptionGroup as e:
    raise e.exceptions[0]

    async def perform_authentication(self, request: Request) -> None:
    await request._authenticate()