Created
July 1, 2022 18:10
-
-
Save s-bose/259c6d51cc46f9c14879d18e93d12d25 to your computer and use it in GitHub Desktop.
a mixin class to enable crud functionalities in sqlalchemy orm classes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sqlalchemy.exc import IntegrityError | |
| from sqlalchemy.orm import Session, Query | |
| from fastapi.encoders import jsonable_encoder | |
| from pydantic import BaseModel | |
| from typing import Any, Generic, TypeVar | |
| from app.db.base import Base | |
| ModelType = TypeVar('ModelType', bound=Base) | |
| CreateSchemaType = TypeVar('CreateSchemaType', bound=BaseModel) | |
| UpdateSchemaType = TypeVar('UpdateSchemaType', bound=BaseModel) | |
| class CRUDMixin(Generic[CreateSchemaType, UpdateSchemaType]): | |
| @classmethod | |
| def query(cls, db: Session) -> Query: | |
| return db.query(cls) | |
| @classmethod | |
| def get(cls, db: Session, **kwargs) -> ModelType | None: | |
| return cls.query(db=db).filter_by(**kwargs).first() | |
| @classmethod | |
| def get_all(cls, db: Session, **kwargs) -> list[ModelType] | None: | |
| return cls.query(db=db).filter_by(**kwargs).all() | |
| @classmethod | |
| def create(cls, db: Session, *, data: CreateSchemaType) -> ModelType | None: | |
| data_dict: dict[str, Any] = data.dict() | |
| db_obj: ModelType = cls(**data_dict) | |
| try: | |
| db.add(db_obj) | |
| db.commit() | |
| except IntegrityError as exc: | |
| db.rollback() | |
| raise exc | |
| @classmethod | |
| def update( | |
| cls, | |
| db: Session, | |
| *, | |
| db_obj: ModelType, | |
| data: UpdateSchemaType | dict[str, Any] | |
| ) -> ModelType | None: | |
| data_dict: dict[str, Any] = data # data_dict holds dict of set update values | |
| if not isinstance(data, dict): | |
| data_dict = data.dict(exclude_unset=True) | |
| for field in jsonable_encoder(db_obj): # iterate over class columns | |
| if field in data_dict: | |
| setattr(db_obj, field, data_dict[field]) | |
| try: | |
| db.add(db_obj) | |
| db.commit() | |
| except IntegrityError as exc: | |
| db.rollback() | |
| raise exc | |
| @classmethod | |
| def delete(cls, db: Session, *, id: Any) -> ModelType: | |
| db_obj: ModelType = cls.get(db=db, id=id) | |
| try: | |
| db.delete(db_obj) | |
| db.commit() | |
| except IntegrityError as exc: | |
| db.rollback() | |
| raise exc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment