from datetime import datetime from fastapi import BackgroundTasks, Depends, FastAPI from pydantic import BaseModel from sqlalchemy import ( Column, create_engine, DateTime, TIMESTAMP, Boolean, Numeric, Integer, String, engine, Table, ForeignKey, ARRAY, ) from sqlalchemy import ( DECIMAL, TEXT, TIMESTAMP, BigInteger, Boolean, CheckConstraint, Column, Date, Enum, Float, ForeignKey, Index, Integer, Numeric, PrimaryKeyConstraint, String, Text, UniqueConstraint, and_, create_engine, event, func, or_, ) from sqlalchemy.orm import Session, sessionmaker from sqlalchemy import select from sqlalchemy.ext.declarative import declared_attr from starlette.middleware.cors import CORSMiddleware import decimal from sqlalchemy.schema import Index from typing import Optional, Dict, List, Any, Tuple from contextlib import asynccontextmanager from functools import lru_cache # from async_lru import alru_cache as async_lru_cache from typing import List from typing import Optional from dataclasses import dataclass from dataclasses import field, dataclass from sqlalchemy.orm import registry from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine from sqlalchemy.ext.asyncio import create_async_engine import pydantic import asyncio import typer # Standard for SQLite # SQLALCHEMY_DATABASE_URL = "sqlite:///test10.db" SQLALCHEMY_DATABASE_URL = "postgresql+asyncpg://postgres@localhost:5432/sss" mapper_registry = registry() @lru_cache() def get_engine() -> AsyncEngine: return create_async_engine( SQLALCHEMY_DATABASE_URL, # connect_args={"check_same_thread": False}, pool_pre_ping=True, ) @asynccontextmanager async def get_db() -> AsyncSession: # Explicit type because sessionmaker.__call__ stub is Any # e = await get_engine() session: AsyncSession = sessionmaker( autocommit=False, autoflush=False, bind=get_engine(), class_=AsyncSession, expire_on_commit=False, )() try: yield session await session.commit() except: await session.rollback() raise finally: await session.close() @dataclass class SurrogatePK: __sa_dataclass_metadata_key__ = "sa" id: int = field( init=False, metadata={"sa": Column(Integer, primary_key=True)}, ) @dataclass class TimeStampMixin: __sa_dataclass_metadata_key__ = "sa" created_at: datetime = field( init=False, metadata={"sa": Column(DateTime, default=datetime.utcnow)} ) updated_at: datetime = field( init=False, metadata={ "sa": Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) }, ) # @mapper_registry.mapped @dataclass class User(SurrogatePK, TimeStampMixin): __tablename__ = "user" identity: Optional[str] = field( default=None, metadata={"sa": Column(String(length=255), nullable=False)} ) row_status: Optional[str] = field( default=None, metadata={"sa": Column(String(length=20), nullable=False)} ) @declared_attr def __table_args__(cls): return ( Index( "index_on_identity_v3_user_identity", "identity", "row_status", unique=True, postgresql_where=cls.row_status == "active", ), ) @mapper_registry.mapped @dataclass class UserSQL(User): pass UserPyd = pydantic.dataclasses.dataclass(User) # Create the app, database, and stocks table app = FastAPI() cli = typer.Typer() Base = mapper_registry.generate_base() async def init_models(): # e = await get_engine() async with get_engine().begin() as conn: await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) @cli.command() def db_init_models(name: str): asyncio.run(init_models()) print("Done") # init_models() @app.get("/", response_model=List[UserPyd]) async def foo(context_session: AsyncSession = Depends(get_db)): async with context_session as db: # Query stocks table and print results query = await db.execute(select(UserSQL)) for d in query: print( f"""{d.identity}\t {d.row_status}\t {d.created_at}\t {d.updated_at}""" ) return query.scalars().all() if __name__ == "__main__": cli()