|
import asyncio |
|
|
|
from rest_framework import exceptions |
|
from rest_framework.permissions import BasePermission as DRFBasePermission |
|
from rest_framework.request import Request as DRFRequest |
|
from rest_framework.views import APIView as DRFAPIView |
|
|
|
|
|
class BasePermission(DRFBasePermission): |
|
async def has_permission(self, request, view): |
|
""" """ |
|
return True |
|
|
|
async def has_object_permission(self, request, view, obj): |
|
""" """ |
|
return True |
|
|
|
|
|
class Request(DRFRequest): |
|
|
|
@property |
|
def user(self): |
|
""" """ |
|
if not hasattr(self, "_user"): |
|
raise AttributeError |
|
return self._user |
|
|
|
@user.setter |
|
def user(self, value): |
|
""" """ |
|
self._user = value |
|
self._request.user = value |
|
|
|
async def _authenticate(self): |
|
|
|
for authenticator in self.authenticators: |
|
try: |
|
user_auth_tuple = await authenticator.authenticate(self) |
|
except exceptions.APIException: |
|
self._not_authenticated() |
|
raise |
|
|
|
if user_auth_tuple is not None: |
|
self._authenticator = authenticator |
|
self.user, self.auth = user_auth_tuple |
|
return |
|
|
|
self._not_authenticated() |
|
|
|
|
|
class APIView(DRFAPIView): |
|
|
|
async def initial(self, request, *args, **kwargs): |
|
""" """ |
|
self.format_kwarg = self.get_format_suffix(**kwargs) |
|
|
|
# Perform content negotiation and store the accepted info on the request |
|
neg = self.perform_content_negotiation(request) |
|
request.accepted_renderer, request.accepted_media_type = neg |
|
|
|
# Determine the API version, if versioning is in use. |
|
version, scheme = self.determine_version(request, *args, **kwargs) |
|
request.version, request.versioning_scheme = version, scheme |
|
|
|
# Ensure that the incoming request is permitted |
|
await self.perform_authentication(request) |
|
await self.check_permissions(request) |
|
# self.check_throttles(request) |
|
|
|
async def dispatch(self, request, *args, **kwargs): |
|
""" """ |
|
self.args = args |
|
self.kwargs = kwargs |
|
request = self.initialize_request(request, *args, **kwargs) |
|
self.request = request |
|
self.headers = self.default_response_headers # deprecate? |
|
|
|
try: |
|
await self.initial(request, *args, **kwargs) |
|
if request.method.lower() in self.http_method_names and ( |
|
handler := getattr(self, request.method.lower(), None) |
|
): |
|
response = await handler(request, *args, **kwargs) |
|
else: |
|
raise exceptions.MethodNotAllowed(request.method) |
|
except Exception as exc: |
|
response = self.handle_exception(exc) |
|
|
|
self.response = self.finalize_response(request, response, *args, **kwargs) |
|
return self.response |
|
|
|
def initialize_request(self, request, *args, **kwargs): |
|
""" """ |
|
parser_context = self.get_parser_context(request) |
|
|
|
return Request( |
|
request, |
|
parsers=self.get_parsers(), |
|
authenticators=self.get_authenticators(), |
|
negotiator=self.get_content_negotiator(), |
|
parser_context=parser_context, |
|
) |
|
|
|
async def permission_handler( |
|
self, permission_class, method="has_permission", *args |
|
): |
|
permission = permission_class() |
|
if not await getattr(permission, method)(self.request, self, *args): |
|
self.permission_denied( |
|
self.request, |
|
message=getattr(permission, "detail", None), |
|
code=getattr(permission, "code", None), |
|
) |
|
|
|
async def check_permissions(self, request: Request) -> None: |
|
try: |
|
async with asyncio.TaskGroup() as tg: |
|
for permission in self.permission_classes: |
|
tg.create_task(self.permission_handler(permission)) |
|
except ExceptionGroup as e: |
|
raise e.exceptions[0] |
|
|
|
async def check_object_permissions(self, request: Request, obj) -> None: |
|
""" """ |
|
try: |
|
async with asyncio.TaskGroup() as tg: |
|
for permission in self.permission_classes: |
|
tg.create_task( |
|
self.permission_handler( |
|
permission, "has_object_permission", obj |
|
) |
|
) |
|
except ExceptionGroup as e: |
|
raise e.exceptions[0] |
|
|
|
async def perform_authentication(self, request: Request) -> None: |
|
await request._authenticate() |