import asyncio from contextlib import asynccontextmanager from django.test.client import AsyncClient as DjangoAsyncClient from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory @asynccontextmanager async def gather_tasks_context(): """ An async context manager to gather and wait for tasks created within the context. """ original_create_task = asyncio.create_task tasks = [] def create_task_wrapper(coro): task = original_create_task(coro) tasks.append(task) return task asyncio.create_task = create_task_wrapper try: yield finally: asyncio.create_task = original_create_task if tasks: await asyncio.gather(*tasks) class AsyncAPIRequestFactory(DjangoAsyncRequestFactory): def __init__(self, **defaults): super().__init__(**defaults) self._credentials = {} def post(self, path, data=None, content_type="application/json", **extra): """Construct a POST request.""" data = self._encode_json({} if data is None else data, content_type) return self.generic("POST", path, data, content_type, **extra) def generic(self, *args, **extra): # Include the CONTENT_TYPE, regardless of whether or not data is empty. extra |= self._credentials return super().generic(*args, **extra) def override_content_type(methods, default_content_type="application/json"): def decorator(cls): for method_name in methods: original_method = getattr(cls, method_name) def wrapper(original_method): def inner(*args, content_type=None, **kwargs): if content_type is None: content_type = default_content_type return original_method(*args, content_type=content_type, **kwargs) return inner setattr(cls, method_name, wrapper(original_method)) return cls return decorator @override_content_type(methods=["post", "put", "patch", "delete"]) class AsyncAPIClient(DjangoAsyncClient, AsyncAPIRequestFactory): """ """ def credentials(self, **kwargs): """ Sets headers that will be used on every outgoing request. """ self._credentials = kwargs