Skip to content

Instantly share code, notes, and snippets.

@DrChai
Last active April 8, 2025 07:26
Show Gist options
  • Save DrChai/cedf18d0655e7784a8282ecd85f28e90 to your computer and use it in GitHub Desktop.
Save DrChai/cedf18d0655e7784a8282ecd85f28e90 to your computer and use it in GitHub Desktop.
A minimal Django REST framework async patcher

Overview

This gist provides a minimal implementation of an asynchronous APIView for Django REST Framework (DRF). It extends the default DRF APIView to support asynchronous operations, including permission checks, authentication, and request handling. It is designed to help migrate existing synchronous WSGI-based views to ASGI for improved concurrency and performance.

Usage

from path.to.async_apiviews import APIView, Request, BasePermission
from rest_framework.response import Response

class AsyncExampleView(APIView):
    permission_classes = [IsAuthenticated]

    async def get(self, request: Request, *args, **kwargs):
        """
        Handle GET requests asynchronously.
        """
        # Perform an asynchronous operation
        data = await query.get_operation()
        # Fire-and-forget task (non-awaitable)
        asyncio.create_task(task.send_notification()) 
        return Response({"message": "Hello, async world!", "data": data})
import asyncio
from contextlib import asynccontextmanager
from django.test.client import AsyncClient as DjangoAsyncClient
from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory
@asynccontextmanager
async def gather_tasks_context():
"""
An async context manager to gather and wait for tasks created within the context.
"""
original_create_task = asyncio.create_task
tasks = []
def create_task_wrapper(coro):
task = original_create_task(coro)
tasks.append(task)
return task
asyncio.create_task = create_task_wrapper
try:
yield
finally:
asyncio.create_task = original_create_task
if tasks:
await asyncio.gather(*tasks)
class AsyncAPIRequestFactory(DjangoAsyncRequestFactory):
def __init__(self, **defaults):
super().__init__(**defaults)
self._credentials = {}
def post(self, path, data=None, content_type="application/json", **extra):
"""Construct a POST request."""
data = self._encode_json({} if data is None else data, content_type)
return self.generic("POST", path, data, content_type, **extra)
def generic(self, *args, **extra):
# Include the CONTENT_TYPE, regardless of whether or not data is empty.
extra |= self._credentials
return super().generic(*args, **extra)
def override_content_type(methods, default_content_type="application/json"):
def decorator(cls):
for method_name in methods:
original_method = getattr(cls, method_name)
def wrapper(original_method):
def inner(*args, content_type=None, **kwargs):
if content_type is None:
content_type = default_content_type
return original_method(*args, content_type=content_type, **kwargs)
return inner
setattr(cls, method_name, wrapper(original_method))
return cls
return decorator
@override_content_type(methods=["post", "put", "patch", "delete"])
class AsyncAPIClient(DjangoAsyncClient, AsyncAPIRequestFactory):
""" """
def credentials(self, **kwargs):
"""
Sets headers that will be used on every outgoing request.
"""
self._credentials = kwargs
import asyncio
from rest_framework import exceptions
from rest_framework.permissions import BasePermission as DRFBasePermission
from rest_framework.request import Request as DRFRequest
from rest_framework.views import APIView as DRFAPIView
class BasePermission(DRFBasePermission):
async def has_permission(self, request, view):
""" """
return True
async def has_object_permission(self, request, view, obj):
""" """
return True
class Request(DRFRequest):
@property
def user(self):
""" """
if not hasattr(self, "_user"):
raise AttributeError
return self._user
@user.setter
def user(self, value):
""" """
self._user = value
self._request.user = value
async def _authenticate(self):
for authenticator in self.authenticators:
try:
user_auth_tuple = await authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
return
self._not_authenticated()
class APIView(DRFAPIView):
async def initial(self, request, *args, **kwargs):
""" """
self.format_kwarg = self.get_format_suffix(**kwargs)
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
# Determine the API version, if versioning is in use.
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
await self.perform_authentication(request)
await self.check_permissions(request)
# self.check_throttles(request)
async def dispatch(self, request, *args, **kwargs):
""" """
self.args = args
self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
await self.initial(request, *args, **kwargs)
if request.method.lower() in self.http_method_names and (
handler := getattr(self, request.method.lower(), None)
):
response = await handler(request, *args, **kwargs)
else:
raise exceptions.MethodNotAllowed(request.method)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
def initialize_request(self, request, *args, **kwargs):
""" """
parser_context = self.get_parser_context(request)
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context,
)
async def permission_handler(
self, permission_class, method="has_permission", *args
):
permission = permission_class()
if not await getattr(permission, method)(self.request, self, *args):
self.permission_denied(
self.request,
message=getattr(permission, "detail", None),
code=getattr(permission, "code", None),
)
async def check_permissions(self, request: Request) -> None:
try:
async with asyncio.TaskGroup() as tg:
for permission in self.permission_classes:
tg.create_task(self.permission_handler(permission))
except ExceptionGroup as e:
raise e.exceptions[0]
async def check_object_permissions(self, request: Request, obj) -> None:
""" """
try:
async with asyncio.TaskGroup() as tg:
for permission in self.permission_classes:
tg.create_task(
self.permission_handler(
permission, "has_object_permission", obj
)
)
except ExceptionGroup as e:
raise e.exceptions[0]
async def perform_authentication(self, request: Request) -> None:
await request._authenticate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment