from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, Query from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from typing import Any, Generic, TypeVar from app.db.base import Base ModelType = TypeVar('ModelType', bound=Base) CreateSchemaType = TypeVar('CreateSchemaType', bound=BaseModel) UpdateSchemaType = TypeVar('UpdateSchemaType', bound=BaseModel) class CRUDMixin(Generic[CreateSchemaType, UpdateSchemaType]): @classmethod def query(cls, db: Session) -> Query: return db.query(cls) @classmethod def get(cls, db: Session, **kwargs) -> ModelType | None: return cls.query(db=db).filter_by(**kwargs).first() @classmethod def get_all(cls, db: Session, **kwargs) -> list[ModelType] | None: return cls.query(db=db).filter_by(**kwargs).all() @classmethod def create(cls, db: Session, *, data: CreateSchemaType) -> ModelType | None: data_dict: dict[str, Any] = data.dict() db_obj: ModelType = cls(**data_dict) try: db.add(db_obj) db.commit() except IntegrityError as exc: db.rollback() raise exc @classmethod def update( cls, db: Session, *, db_obj: ModelType, data: UpdateSchemaType | dict[str, Any] ) -> ModelType | None: data_dict: dict[str, Any] = data # data_dict holds dict of set update values if not isinstance(data, dict): data_dict = data.dict(exclude_unset=True) for field in jsonable_encoder(db_obj): # iterate over class columns if field in data_dict: setattr(db_obj, field, data_dict[field]) try: db.add(db_obj) db.commit() except IntegrityError as exc: db.rollback() raise exc @classmethod def delete(cls, db: Session, *, id: Any) -> ModelType: db_obj: ModelType = cls.get(db=db, id=id) try: db.delete(db_obj) db.commit() except IntegrityError as exc: db.rollback() raise exc