# Standard Library from typing import ( Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Type, cast, ) # Third Party Libraries from fastapi import APIRouter, HTTPException, status from fastapi.routing import APIRoute from pydantic import BaseModel from sqlalchemy import ForeignKeyConstraint, Table, inspect from sqlalchemy.ext.declarative import DeclarativeMeta from sqlalchemy.orm import Session from sqlalchemy.sql.elements import UnaryExpression from starlette import routing from starlette.responses import Response from starlette.types import ASGIApp # App and Model Imports from app.utils.oop import all_subclasses class CRUDApiRouter(APIRouter): """Automatically generate a router with essential methods on.""" db_dep: Session """Dependency to get a DB session""" model: DeclarativeMeta """Model this router is for""" _HTTP_2XX_NO_RETURN_CODES: Set[int] = { status.HTTP_204_NO_CONTENT, status.HTTP_205_RESET_CONTENT, } model_base: Optional[DeclarativeMeta] = None """Base for the model if using create or update""" # for listing rows list_schema: Optional[BaseModel] = None list_deps: Optional[List[Any]] = None list_sort: Optional[Sequence[UnaryExpression]] = None list_default_offset: Optional[int] = 0 list_default_limit: Optional[int] = 100 list_max_limit: int = 1_000 list_view_path: str = "/" # for detail page detail_schema: Optional[BaseModel] = None detail_deps: Optional[List[Any]] = None detail_view_path: str = "/{id}" detail_status_code_not_found: int = status.HTTP_404_NOT_FOUND # for create view create_schema_in: Optional[BaseModel] = None create_schema_out: Optional[BaseModel] = None create_status_code: int = status.HTTP_201_CREATED create_status_code_fk_error: int = status.HTTP_422_UNPROCESSABLE_ENTITY create_status_code_conflict: int = status.HTTP_409_CONFLICT create_view_path: str = "/" create_deps: Optional[List[Any]] = None # for update view update_schema_in: Optional[BaseModel] = None update_schema_out: Optional[BaseModel] = None update_status_code: int = status.HTTP_204_NO_CONTENT update_status_code_fk_error: int = status.HTTP_422_UNPROCESSABLE_ENTITY update_status_code_conflict: int = status.HTTP_409_CONFLICT update_status_code_not_found: int = status.HTTP_404_NOT_FOUND update_view_path: str = "/{id}" update_deps: Optional[List[Any]] = None delete_view: bool = False delete_deps: Optional[List[Any]] = None delete_view_path: str = "/{id}" delete_status_code: int = status.HTTP_204_NO_CONTENT delete_schema_out: Optional[BaseModel] = None delete_status_code_not_found: int = status.HTTP_404_NOT_FOUND def __init__( self, routes: Optional[List[routing.BaseRoute]] = None, redirect_slashes: bool = True, default: Optional[ASGIApp] = None, dependency_overrides_provider: Optional[Any] = None, route_class: Type[APIRoute] = APIRoute, default_response_class: Optional[Type[Response]] = None, on_startup: Optional[Sequence[Callable]] = None, on_shutdown: Optional[Sequence[Callable]] = None, ) -> None: """Instantiate like a normal API view then add CRUD methods.""" assert self.model is not None super().__init__( routes=routes, redirect_slashes=redirect_slashes, default=default, dependency_overrides_provider=dependency_overrides_provider, route_class=route_class, default_response_class=default_response_class, on_startup=on_startup, on_shutdown=on_shutdown, ) if self.list_schema is not None: endpoint = self._generate_list_view() self.get( self.list_view_path, response_model=List[self.list_schema], # type: ignore dependencies=(self.list_deps or []), description=endpoint._description, # type: ignore )(endpoint) if self.detail_schema is not None: endpoint = self._generate_detail_view() self.get( self.detail_view_path, response_model=self.detail_schema, # type: ignore dependencies=(self.detail_deps or []), description=endpoint._description, # type: ignore )(endpoint) if self.create_schema_in is not None: assert self.create_schema_out is not None assert self.model_base is not None endpoint = self._generate_create_view() self.post( self.create_view_path, response_model=self.create_schema_out, # type: ignore status_code=self.create_status_code, dependencies=(self.create_deps or []), description=endpoint._description, # type: ignore )(endpoint) if self.update_schema_in is not None: assert self.model_base is not None endpoint = self._generate_update_view() self.put( self.update_view_path, response_model=self.update_schema_out, # type: ignore status_code=self.update_status_code, dependencies=(self.update_deps or []), description=endpoint._description, # type: ignore )(endpoint) if self.delete_view is True: endpoint = self._generate_delete_view() self.delete( self.delete_view_path, response_model=self.delete_schema_out, # type: ignore status_code=self.delete_status_code, dependencies=(self.delete_deps or []), description=endpoint._description, # type: ignore )(endpoint) def _generate_list_view(self) -> Callable: """Create a generic list view for Model. To sort, set list_sort = [SomeModel.some_property.asc()] """ def list_view( db: Session = self.db_dep, offset: Optional[int] = self.list_default_offset, limit: Optional[int] = self.list_default_limit, ) -> Any: query = db.query(self.model) if self.list_sort is not None: query = query.order_by(*self.list_sort) if ( self.list_default_offset is not None and self.list_default_limit is not None ): query = query.offset(offset).limit(limit) return query.all() list_view._description = ( # type: ignore f"🗃️ List {self.model.__name__} sorted by " f"{', '.join(str(x) for x in self.list_sort or [])}" ) return list_view def _generate_detail_view(self) -> Callable: """Create a generic detail view for Model. At the moment this supports _only_ a model with a single, non-composite PK which is an integer called id. """ cols = inspect(self.model).primary_key assert len(cols) == 1 [pk_col] = cols assert pk_col.type.python_type == int def detail_view(id: int, db: Session = self.db_dep,) -> Any: obj = ( db.query(self.model) .filter(getattr(self.model, pk_col.key) == id) .scalar() ) if obj is None: raise HTTPException( status_code=self.detail_status_code_not_found ) return obj detail_view._description = ( # type: ignore f"""📁 Show {self.model.__name__} identified by {pk_col.key}""" ) return detail_view @property def _MODEL_MAP(self) -> Dict[Table, DeclarativeMeta]: """Create a mapping of table names to models""" return { cast(Table, m.__table__): cast(DeclarativeMeta, m) for m in all_subclasses(self.model_base) # type: ignore } def _check_fk_constraints( self, *, db: Session, model_map: Dict[Table, DeclarativeMeta], fk_constraints: Set[ForeignKeyConstraint], obj_in: BaseModel, status_code_fk_error: int, ) -> None: """Raises errors if FK constraints are violated.""" for constraint in fk_constraints: # check that the corresponding rows exist referred_model = model_map[constraint.referred_table] # neeed a mapping of column keys on self.model constraint_columns_map = { # which map to col.key: next( iter( # the key on the foreign referred_model col_fk.target_fullname.split(".")[-1] # being referenced by this foreign key for col_fk in col.foreign_keys if col_fk.constraint.referred_table # type: ignore == constraint.referred_table ) ) # for every column in this constraint for col in constraint.columns } # now, we check a row with that value exists num_rows = ( db.query(referred_model) .filter( *( getattr(referred_model, target_key) == getattr(obj_in, local_key) for ( local_key, target_key, ) in constraint_columns_map.items() ) ) .count() ) if num_rows == 0: error_row = ( f"{target_key}={getattr(obj_in, local_key)}" for local_key, target_key in constraint_columns_map.items() ) raise HTTPException( status_code=status_code_fk_error, detail=( "I could not find a value of " f"{referred_model.__name__} with values " f"{' '.join(error_row)}" ), ) def _check_unique_indexes( self, db: Session, unique_indexes: Set[FrozenSet[str]], obj_in: BaseModel, status_code_conflict: int, ) -> None: """Check unique indexes won't be violated.""" for ix in unique_indexes: num_rows = ( db.query(self.model) .filter( *(getattr(self.model, k) == getattr(obj_in, k) for k in ix) ) .count() ) if num_rows != 0: error_row = (f"{k}={getattr(obj_in, k)}" for k in ix) raise HTTPException( status_code=self.create_status_code_conflict, detail=( f"A row already exists in {self.model.__name__} " f"with values {' '.join(error_row)}" ), ) def _generate_create_view(self) -> Callable: """Generate a generic create view for the Model. Required features: - check for conflicts on unique constraints - check referenced foreign keys exist - create and return resource with 201 by default """ # we know it's valid because it passed the schema # are there any foreign keys? schema_cols = ( cast(BaseModel, self.create_schema_in) .schema()["properties"] .keys() ) # filter out the constraints we need to check fkcs = self.model.__table__.foreign_key_constraints # type: ignore fk_constraints = { constraint for constraint in fkcs if {c.key for c in constraint.columns} < schema_cols } # likewise unique indexes unique_indexes = { frozenset(c.key for c in ix.columns) for ix in self.model.__table__.indexes # type: ignore if ix.unique and frozenset(c.key for c in ix.columns) < schema_cols } # we'll need this to lookup models based on tables model_map = self._MODEL_MAP def create_view( *, obj_in: self.create_schema_in, # type: ignore db: Session = self.db_dep, ) -> Any: f"""Create new {self.model.__name__}""" # for every constraint on this model self._check_fk_constraints( db=db, model_map=model_map, fk_constraints=fk_constraints, obj_in=obj_in, status_code_fk_error=self.create_status_code_fk_error, ) # check for unique indexes self._check_unique_indexes( db=db, unique_indexes=unique_indexes, obj_in=obj_in, status_code_conflict=self.create_status_code_conflict, ) # make the insert instance = self.model() for k, v in obj_in.dict().items(): setattr(instance, k, v) db.add(instance) db.commit() db.refresh(instance) return instance create_view._description = ( # type: ignore f"""💾 Create new {self.model.__name__}""" ) return create_view def _generate_update_view(self) -> Callable: """Generic update view creator.""" assert not ( self.update_schema_out is not None and self.update_status_code in self._HTTP_2XX_NO_RETURN_CODES ) cols = inspect(self.model).primary_key assert len(cols) == 1 [pk_col] = cols assert pk_col.type.python_type == int # we know it's valid because it passed the schema # are there any foreign keys? schema_cols = ( cast(BaseModel, self.update_schema_in) .schema()["properties"] .keys() ) # filter out the constraints we need to check fkcs = self.model.__table__.foreign_key_constraints # type: ignore fk_constraints = { constraint for constraint in fkcs # type: ignore if {c.key for c in constraint.columns} < schema_cols } # likewise unique indexes unique_indexes = { frozenset(c.key for c in ix.columns) for ix in self.model.__table__.indexes # type: ignore if ix.unique and frozenset(c.key for c in ix.columns) < schema_cols } # we'll need this to lookup models based on tables model_map = self._MODEL_MAP def update_view( *, id: int, db: Session = self.db_dep, obj_in: self.update_schema_in, # type: ignore ) -> Any: obj = ( db.query(self.model) .filter(getattr(self.model, pk_col.key) == id) .scalar() ) if obj is None: raise HTTPException( status_code=self.update_status_code_not_found ) # for every constraint on this model self._check_fk_constraints( db=db, model_map=model_map, fk_constraints=fk_constraints, obj_in=obj_in, status_code_fk_error=self.update_status_code_fk_error, ) # check for unique indexes self._check_unique_indexes( db=db, unique_indexes=unique_indexes, obj_in=obj_in, status_code_conflict=self.update_status_code_conflict, ) for k, v in obj_in.dict().items(): setattr(obj, k, v) db.add(obj) db.commit() db.refresh(obj) return ( obj if self.update_status_code not in self._HTTP_2XX_NO_RETURN_CODES else None ) update_view._description = ( # type: ignore f"""📝 Update {self.model.__name__} identified by {pk_col.key}""" ) return update_view def _generate_delete_view(self) -> Callable: """Generic delete view.""" assert not ( self.delete_schema_out is not None and self.delete_status_code in self._HTTP_2XX_NO_RETURN_CODES ) assert not ( self.delete_schema_out is None and self.delete_status_code not in self._HTTP_2XX_NO_RETURN_CODES ) cols = inspect(self.model).primary_key assert len(cols) == 1 [pk_col] = cols assert pk_col.type.python_type == int def delete_view(*, id: int, db: Session = self.db_dep,) -> Any: obj = ( db.query(self.model) .filter(getattr(self.model, pk_col.key) == id) .scalar() ) if obj is None: raise HTTPException( status_code=self.delete_status_code_not_found ) db.delete(obj) db.commit() return obj if self.delete_schema_out is not None else None delete_view._description = ( # type: ignore f"""❌ Delete {self.model.__name__} identified by {pk_col.key}""" ) return delete_view