import contextlib from contextlib import contextmanager from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import ( NullPool, QueuePool, StaticPool, ) from sqlalchemy_utils import database_exists, create_database from ..constants import ( DB_ENGINE, DB_HOST, DB_NAME, DB_PASSWORD, DB_PORT, DB_USERNAME, SQLITE, POSTGRESQL, ) __session_factory = None __session = None __engine = None __is_test = False def _get_metadata(): from .mixins import Base return Base.metadata def setup_db(*, is_test=False, **db_config): """Main setup function for our database. This will perform the initial DB connection and also create tables as needed. """ global __engine global __is_test if __engine: return __is_test = is_test connection_string = get_connection_string() connection_kwargs = db_config.get('connection_kwargs', {}) # In a serverless environment use a staic pool which is a single-connection pool per Lambda connection_kwargs.update({ 'poolclass': StaticPool, }) session_kwargs = db_config.get('session_kwargs', {}) __engine = create_engine(connection_string, **connection_kwargs) print('Connected to: %s' % (__engine.url, )) if not database_exists(__engine.url): # pragma: no cover print('Creating database: %s' % (__engine.url, )) create_database(__engine.url) create_tables() get_session(**session_kwargs) def get_connection_string(**kwargs): """Return a connection string for sqlalchemy:: dialect+driver://username:password@host:port/database """ global DB_NAME if DB_ENGINE not in (SQLITE, POSTGRESQL): raise ValueError( 'Invalid database engine specified: %s. Only sqlite' \ ' and postgresql are supported' % (DB_ENGINE, )) if DB_ENGINE == SQLITE: # missing filename creates an in-memory db return 'sqlite://%s' % kwargs.get('filename', '') if __is_test and not DB_NAME.startswith('test_'): DB_NAME = 'test_%s' % (DB_NAME, ) return 'postgresql://%s:%s@%s:%s/%s' % ( DB_USERNAME, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, ) def close_db(): # pragma: no cover if not __session: return try: __session.commit() except: __session.rollback() finally: __session.close() def commit_session(_raise=False): # pragma: no cover if not __session: return try: __session.commit() except Exception as e: __session.rollback() if _raise: raise def create_tables(): assert __engine meta = _get_metadata() meta.create_all(__engine) def get_session(**kwargs): """Main API for connection to the DB via the SQLAlchemy session. Clients should use this for any DB interactions as it will connect and setup the database as needed. After initialization the global session will be returned so this is safe to call multiple times in a single thread. """ setup_db() assert __engine global __session global __session_factory if __session is not None: return __session if __session_factory is None: # pragma: no cover __session_factory = sessionmaker(bind=__engine, **kwargs) __session = __session_factory() return __session def session_committer(func): """Decorator to comming the DB session. Use this from high-level functions such as handler so that the session is always committed or closed. """ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) finally: commit_session() return wrapper def session_getter(func): """Decorator to get a session and inject it as the first argument in a function""" def wrapper(*args, **kwargs): session = get_session() return func(session, *args, **kwargs) return wrapper @contextmanager def dbtransaction(): # pragma: no cover """Use as a context manager to commit a transaction""" session = get_session() try: yield session session.commit() except: session.rollback() raise def _drop_tables(*, force=False): if not __is_test and not force: return assert __engine meta = _get_metadata() meta.drop_all(__engine) def _clear_tables(*, force=False): if not __is_test and not force: return assert __engine meta = _get_metadata() with contextlib.closing(__engine.connect()) as con: trans = con.begin() for table in reversed(meta.sorted_tables): try: con.execute(table.delete()) except: pass trans.commit()