from __future__ import annotations import decimal import logging import os import signal import sys import threading import time from contextlib import contextmanager from dataclasses import ( dataclass, field, ) from datetime import datetime from functools import lru_cache from typing import ( Any, Dict, Iterable, Generator, List, Optional, Tuple, ) import pydantic import typer import uvicorn as uvicorn import yaml from fastapi import ( BackgroundTasks, Depends, FastAPI, HTTPException, Response, status, ) from fastapi.responses import PlainTextResponse from gunicorn.glogging import Logger from loguru import logger from pydantic import BaseModel from pydantic_sqlalchemy import sqlalchemy_to_pydantic from sqlalchemy import ( ARRAY, DECIMAL, TEXT, TIMESTAMP, BigInteger, Boolean, CheckConstraint, Column, Date, DateTime, Enum, Float, ForeignKey, Index, Integer, Numeric, PrimaryKeyConstraint, String, Table, Text, UniqueConstraint, and_, create_engine, engine, event, func, or_, ) from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import ( Session, registry, relationship, sessionmaker, ) from sqlalchemy.schema import Index from starlette.middleware.cors import CORSMiddleware from uvicorn.workers import UvicornWorker import settings # from cmath import log class ReloaderThread(threading.Thread): def __init__(self, worker: UvicornWorker, sleep_interval: float = 1.0): super().__init__() self.setDaemon(True) self._worker = worker self._interval = sleep_interval def run(self) -> None: while True: if not self._worker.alive: os.kill(os.getpid(), signal.SIGINT) time.sleep(self._interval) class RestartableUvicornWorker(UvicornWorker): CONFIG_KWARGS = { "loop": "uvloop", "http": "httptools", # "log_config": yaml.safe_load(open(os.path.join(os.path.dirname(__file__), "logging.yaml"), "r") } def __init__(self, *args: List[Any], **kwargs: Dict[str, Any]): super().__init__(*args, **kwargs) self._reloader_thread = ReloaderThread(self) def run(self) -> None: if self.cfg.reload: self._reloader_thread.start() super().run() class InterceptHandler(logging.Handler): """ Default handler from examples in loguru documentaion. See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging """ def emit(self, record: logging.LogRecord): # Get corresponding Loguru level if it exists try: level = logger.level(record.levelname).name except ValueError: level = record.levelno # Find caller from where originated the logged message frame, depth = logging.currentframe(), 1 # while frame.f_code.co_filename == logging.__file__: # frame = frame.f_back # depth += 1 logger.opt(depth=depth, exception=record.exc_info).log( level, record.getMessage() ) class GunicornLogger(Logger): def setup(self, cfg) -> None: handler = InterceptHandler() # handler = logging.StreamHandler(sys.stdout) handler.setFormatter( logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s") ) # Add log handler to logger and set log level self.error_log.addHandler(handler) self.error_log.setLevel(settings.LOG_LEVEL) self.access_log.addHandler(handler) self.access_log.setLevel(settings.LOG_LEVEL) # Configure logger before gunicorn starts logging logger.configure(handlers=[{"sink": sys.stdout, "level": settings.LOG_LEVEL}]) @lru_cache() def get_engine() -> engine.Engine: return create_engine( settings.SQLALCHEMY_DATABASE_URL, # connect_args={"check_same_thread": False}, echo=True, pool_pre_ping=True, ) def get_db() -> Generator[Session, None, None]: # Explicit type because sessionmaker.__call__ stub is Any session: Session = sessionmaker( autocommit=False, autoflush=False,expire_on_commit=False, bind=get_engine() )() # session = SessionLocal() try: yield session session.commit() except: session.rollback() raise finally: session.close() mapper_registry = registry() @dataclass class SurrogatePK: __sa_dataclass_metadata_key__ = "sa" id: int = field( init=False, default=None, metadata={"sa": Column(Integer, primary_key=True, autoincrement=True)}, ) @dataclass class TimeStampMixin: __sa_dataclass_metadata_key__ = "sa" created_at: datetime = field( default_factory=datetime.now, metadata={"sa": Column(DateTime, default=datetime.now)}, ) updated_at: datetime = field( default_factory=datetime.now, metadata={ "sa": Column(DateTime, default=datetime.now, onupdate=datetime.utcnow) }, ) @mapper_registry.mapped @dataclass class User(SurrogatePK, TimeStampMixin): __tablename__ = "user" __sa_dataclass_metadata_key__ = "sa" title: str = field(default=None, metadata={"sa": Column(String(50))}) description: str = field(default=None, metadata={"sa": Column(String(50))}) UserPyd = sqlalchemy_to_pydantic(User) mapper_registry.metadata.create_all(bind=get_engine()) # Create the app, database, and stocks table app = FastAPI(debug=True) @app.exception_handler(Exception) async def validation_exception_handler(request, exc): logger.debug(str(exc)) return PlainTextResponse("Something went wrong", status_code=500) cli = typer.Typer() @cli.command() def db_init_models(): Base = mapper_registry.generate_base() Base.metadata.drop_all(bind=get_engine()) Base.metadata.create_all(bind=get_engine()) print("Done") @cli.command() def nothing(name: str): print("Done") @app.get("/items", response_model=List[UserPyd]) def read_items(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): items = db.query(User).offset(skip).limit(limit).all() return items @app.get("/users/", response_model=UserPyd, status_code=status.HTTP_201_CREATED) def create_user(email: str = None, db: Session = Depends(get_db)): u = User(title="sss") db.add(u) db.commit() # return {"data":new_post} return u if __name__ == "__main__": cli()