Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save sandys/671b8b86ba913e6436d4cb22d04b135f to your computer and use it in GitHub Desktop.

Select an option

Save sandys/671b8b86ba913e6436d4cb22d04b135f to your computer and use it in GitHub Desktop.
fastapi with python 3.10 dataclasses - used to create both sqlalchemy and pydantic models simultaneously. And setting up sqlalchemy the right way (without deadlocks or other problems). Additionally, this also takes care of unified logging when running under gunicorn..as well as being able to run in restartable mode.

cmdline

poetry run gunicorn testpg:app -p 8080 --preload --reload --reload-engine inotify -w 10 -k uvicorn.workers.UvicornWorker --log-level debug --access-logfile - --error-logfile - --access-logformat "SSSS - %(h)s %(l)s %(u)s %(t)s \"%(r)s\" %(s)s %(b)s \"%(f)s\" \"%(a)s"

How to quickly run postgres (using docker)

docker run --network="host" -it --rm --name some-postgres -e POSTGRES_PASSWORD=mysecretpassword -e PGDATA=/var/lib/postgresql/data/pgdata -v /tmp/pgdata2:/var/lib/postgresql/data -e POSTGRES_USER=test postgres

This command will quickly start postgres on port 5432 and create a database test with user test and password mysecretpassword

If you want to connect using psql, docker run --network="host" -it --rm postgres psql postgresql://test@localhost:5432/test

from datetime import datetime
from fastapi import BackgroundTasks, Depends, FastAPI
from pydantic import BaseModel
from sqlalchemy import (
Column,
create_engine,
DateTime,
TIMESTAMP,
Boolean,
Numeric,
Integer,
String,
engine,
Table,
ForeignKey,
ARRAY,
)
from sqlalchemy import (
DECIMAL,
TEXT,
TIMESTAMP,
BigInteger,
Boolean,
CheckConstraint,
Column,
Date,
Enum,
Float,
ForeignKey,
Index,
Integer,
Numeric,
PrimaryKeyConstraint,
String,
Text,
UniqueConstraint,
and_,
create_engine,
event,
func,
or_,
)
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy import select
from sqlalchemy.ext.declarative import declared_attr
from starlette.middleware.cors import CORSMiddleware
import decimal
from sqlalchemy.schema import Index
from typing import Optional, Dict, List, Any, Tuple
from contextlib import asynccontextmanager
from functools import lru_cache
# from async_lru import alru_cache as async_lru_cache
from typing import List
from typing import Optional
from dataclasses import dataclass
from dataclasses import field, dataclass
from sqlalchemy.orm import registry
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
from sqlalchemy.ext.asyncio import create_async_engine
import pydantic
import asyncio
import typer
# Standard for SQLite
# SQLALCHEMY_DATABASE_URL = "sqlite:///test10.db"
SQLALCHEMY_DATABASE_URL = "postgresql+asyncpg://postgres@localhost:5432/sss"
mapper_registry = registry()
@lru_cache()
def get_engine() -> AsyncEngine:
return create_async_engine(
SQLALCHEMY_DATABASE_URL,
# connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
@asynccontextmanager
async def get_db() -> AsyncSession:
# Explicit type because sessionmaker.__call__ stub is Any
# e = await get_engine()
session: AsyncSession = sessionmaker(
autocommit=False,
autoflush=False,
bind=get_engine(),
class_=AsyncSession,
expire_on_commit=False,
)()
try:
yield session
await session.commit()
except:
await session.rollback()
raise
finally:
await session.close()
@dataclass
class SurrogatePK:
__sa_dataclass_metadata_key__ = "sa"
id: int = field(
init=False,
metadata={"sa": Column(Integer, primary_key=True)},
)
@dataclass
class TimeStampMixin:
__sa_dataclass_metadata_key__ = "sa"
created_at: datetime = field(
init=False, metadata={"sa": Column(DateTime, default=datetime.utcnow)}
)
updated_at: datetime = field(
init=False,
metadata={
"sa": Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
},
)
# @mapper_registry.mapped
@dataclass
class User(SurrogatePK, TimeStampMixin):
__tablename__ = "user"
identity: Optional[str] = field(
default=None, metadata={"sa": Column(String(length=255), nullable=False)}
)
row_status: Optional[str] = field(
default=None, metadata={"sa": Column(String(length=20), nullable=False)}
)
@declared_attr
def __table_args__(cls):
return (
Index(
"index_on_identity_v3_user_identity",
"identity",
"row_status",
unique=True,
postgresql_where=cls.row_status == "active",
),
)
@mapper_registry.mapped
@dataclass
class UserSQL(User):
pass
UserPyd = pydantic.dataclasses.dataclass(User)
# Create the app, database, and stocks table
app = FastAPI()
cli = typer.Typer()
Base = mapper_registry.generate_base()
async def init_models():
# e = await get_engine()
async with get_engine().begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
@cli.command()
def db_init_models(name: str):
asyncio.run(init_models())
print("Done")
# init_models()
@app.get("/", response_model=List[UserPyd])
async def foo(context_session: AsyncSession = Depends(get_db)):
async with context_session as db:
# Query stocks table and print results
query = await db.execute(select(UserSQL))
for d in query:
print(
f"""{d.identity}\t
{d.row_status}\t
{d.created_at}\t
{d.updated_at}"""
)
return query.scalars().all()
if __name__ == "__main__":
cli()
[tool.poetry]
name = "api"
version = "0.1.0"
description = ""
authors = ["sandeep srinivasa <[email protected]>"]
[tool.poetry.dependencies]
python = "^3.8"
pydantic = {extras = ["email"], version = "^1.8.1"}
fastapi = "^0.63.0"
uvicorn = {extras = ["standard"], version = "^0.13.4"}
gunicorn = "^20.0.4"
msgpack-asgi = "^1.0.0"
inotify = "^0.2.10"
hashids = "^1.3.1"
GeoAlchemy2 = "^0.8.4"
redis = "^3.5.3"
boto3 = "^1.17.29"
pendulum = "^2.1.2"
fuzzywuzzy = "^0.18.0"
pandas = "^1.2.3"
python-Levenshtein = "^0.12.2"
SQLAlchemy = "^1.4.2"
psycopg2-binary = "^2.8.6"
asyncpg = "^0.22.0"
typer = "^0.3.2"
[tool.poetry.dev-dependencies]
black = {version = "^20.8b1", allow-prereleases = true}
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
from datetime import datetime
from fastapi import BackgroundTasks, Depends, FastAPI
from pydantic import BaseModel
from sqlalchemy import (
Column,
create_engine,
DateTime,
TIMESTAMP,
Boolean,
Numeric,
Integer,
String,
engine,
Table,
ForeignKey,
ARRAY,
)
from sqlalchemy import (
DECIMAL,
TEXT,
TIMESTAMP,
BigInteger,
Boolean,
CheckConstraint,
Column,
Date,
Enum,
Float,
ForeignKey,
Index,
Integer,
Numeric,
PrimaryKeyConstraint,
String,
Text,
UniqueConstraint,
and_,
create_engine,
event,
func,
or_,
)
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.ext.declarative import declared_attr
from starlette.middleware.cors import CORSMiddleware
from sqlalchemy.orm import relationship
import decimal
from sqlalchemy.schema import Index
from typing import Optional, Dict, List, Any, Tuple
from contextlib import contextmanager
from functools import lru_cache
from typing import List
from typing import Optional
from dataclasses import dataclass
from dataclasses import field, dataclass
from sqlalchemy.orm import registry
import pydantic
# Standard for SQLite
# SQLALCHEMY_DATABASE_URL = "sqlite:///test10.db"
SQLALCHEMY_DATABASE_URL = "postgresql://postgres@localhost:5432/sss"
mapper_registry = registry()
@lru_cache()
def get_engine() -> engine.Engine:
return create_engine(
SQLALCHEMY_DATABASE_URL,
# connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
@contextmanager
def get_db():
# Explicit type because sessionmaker.__call__ stub is Any
session: Session = sessionmaker(
autocommit=False, autoflush=False, bind=get_engine()
)()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
@dataclass
class SurrogatePK:
__sa_dataclass_metadata_key__ = "sa"
id: int = field(
init=False,
metadata={"sa": Column(Integer, primary_key=True)},
)
@dataclass
class TimeStampMixin:
__sa_dataclass_metadata_key__ = "sa"
created_at: datetime = field(
init=False, metadata={"sa": Column(DateTime, default=datetime.utcnow)}
)
updated_at: datetime = field(
init=False,
metadata={
"sa": Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
},
)
# @mapper_registry.mapped
@dataclass
class User(SurrogatePK, TimeStampMixin):
__tablename__ = "user"
identity: Optional[str] = field(
default=None, metadata={"sa": Column(String(length=255), nullable=False)}
)
row_status: Optional[str] = field(
default=None, metadata={"sa": Column(String(length=20), nullable=False)}
)
@declared_attr
def __table_args__(cls):
return (
Index(
"index_on_identity_v3_user_identity",
"identity",
"row_status",
unique=True,
postgresql_where=cls.row_status == "active",
),
)
@mapper_registry.mapped
@dataclass
class UserSQL(User):
pass
UserPyd = pydantic.dataclasses.dataclass(User)
# Create the app, database, and stocks table
app = FastAPI()
Base = mapper_registry.generate_base()
Base.metadata.drop_all(bind=get_engine())
Base.metadata.create_all(bind=get_engine())
@app.get("/", response_model=List[UserPyd])
def foo(context_session: Session = Depends(get_db)):
with context_session as db:
# Query stocks table and print results
query = db.query(UserSQL).all()
for d in query:
print(
f"""{d.identity}\t
{d.row_status}\t
{d.created_at}\t
{d.updated_at}"""
)
return query
@lovetoburnswhen
Copy link

Should the type annotations for the session context managers be

@contextmanager
def get_db(db_conn=Depends(get_engine)) -> Generator[Session, None, None]:

@sandys
Copy link
Author

sandys commented Oct 20, 2021

get_db

fair point - at that time, I couldnt figure it out. does it work for you ? wondering if u tested it.

@lovetoburnswhen
Copy link

get_db

fair point - at that time, I couldnt figure it out. does it work for you ? wondering if u tested it.

Yep, pyright and mypy seem happy

@pozsa
Copy link

pozsa commented Jan 11, 2022

I use

from typing import AsyncIterator

from sqlalchemy.ext.asyncio import AsyncSession

async def get_db() -> AsyncIterator[AsyncSession]:

mypy has been happy so far

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment