Created
October 1, 2022 18:57
-
-
Save musq/9e60f6dc9e042f002df7bb866f70d90c to your computer and use it in GitHub Desktop.
Manual connection pooling using Peewee - with an eject callback
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| import sys | |
| import time | |
| from collections import OrderedDict | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from contextlib import contextmanager | |
| from functools import wraps | |
| from typing import Callable | |
| from peewee import Database | |
| from playhouse.pool import PooledPostgresqlExtDatabase, PostgresqlExtDatabase | |
| formatter = logging.Formatter("[%(levelname)s] %(asctime)s - %(name)s: %(message)s") | |
| handler = logging.StreamHandler(sys.stdout) | |
| handler.setFormatter(formatter) | |
| root = logging.getLogger() | |
| root.addHandler(handler) | |
| root.setLevel(logging.DEBUG) | |
| pool_db = PooledPostgresqlExtDatabase( | |
| None, | |
| autoconnect=False, # https://docs.peewee-orm.com/en/latest/peewee/database.html#using-autoconnect | |
| ) | |
| pool_db.init( | |
| "discover_candidate_pool", | |
| host="localhost", | |
| port=11310, | |
| user="postgres", | |
| password="postgres", | |
| stale_timeout=2, | |
| # TODO: Verify that lock_timeout works | |
| options="-c lock_timeout=30000", # timeout in ms | |
| ) | |
| # pool_db.connect() | |
| # pool_db.close() | |
| # time.sleep(3) | |
| # pool_db.connect() | |
| # pool_db.close() | |
| # time.sleep(1) | |
| # pool_db.connect() | |
| # pool_db.close() | |
| # pool_db.connect() | |
| def open_close_connection(i: int): | |
| logging.info(i) | |
| pool_db.connect() | |
| time.sleep(0.5) | |
| logging.info(i) | |
| pool_db.close() | |
| futures = {} | |
| with ThreadPoolExecutor(max_workers=8) as executor: | |
| for i in range(3): | |
| future = executor.submit(open_close_connection, i) | |
| futures[future] = i | |
| for future in as_completed(futures): | |
| result = future.result() | |
| # breakpoint() | |
| # db = PostgresqlExtDatabase( | |
| # None, | |
| # autoconnect=False, # https://docs.peewee-orm.com/en/latest/peewee/database.html#using-autoconnect | |
| # ) | |
| # db.init( | |
| # "discover_candidate_pool", | |
| # host="localhost", | |
| # port=11310, | |
| # user="postgres", | |
| # password="postgres", | |
| # # TODO: Verify that lock_timeout works | |
| # options="-c lock_timeout=30000", # timeout in ms | |
| # ) | |
| exit() | |
| class LRUCacheWithEjectCallback: | |
| """ | |
| LRU cache which executes the eject_callback() on each ejected item. This could be | |
| useful when we need to run tear down sequence on each ejected item. e.g. if we need | |
| to close DB connection, close a file, etc. | |
| Built using Raymond Hettinger's basic LRU cache implementation. | |
| https://pastebin.com/LDwMwtp8 | |
| """ | |
| def __init__(self, func, maxsize: int, eject_callback): | |
| self.cache: OrderedDict = OrderedDict() | |
| self.func = func | |
| self.maxsize = maxsize | |
| self.eject_callback = eject_callback | |
| def __call__(self, *args): | |
| cache = self.cache | |
| if args in cache: | |
| cache.move_to_end(args) | |
| return cache[args] | |
| result = self.func(*args) | |
| cache[args] = result | |
| if len(cache) > self.maxsize: | |
| ejected_key, ejected_item = cache.popitem(last=False) | |
| self.eject_callback(ejected_item) | |
| return result | |
| def lru_cache_with_eject_callback( | |
| maxsize: int, eject_callback: Callable | |
| ) -> Callable[[Callable], Callable]: | |
| def decorator(fn: Callable) -> Callable: | |
| """ | |
| Based on Raymond Hettinger's basic LRU cache implementation. | |
| https://pastebin.com/LDwMwtp8 | |
| """ | |
| cache: OrderedDict = OrderedDict() | |
| @wraps(fn) | |
| def new_func(*args): | |
| """ | |
| CAUTION: The decorated function must only take input as positional | |
| arguments, and never as keyword arguments. | |
| """ | |
| if args in cache: | |
| cache.move_to_end(args) | |
| return cache[args] | |
| result = fn(*args) | |
| cache[args] = result | |
| if len(cache) > maxsize: | |
| ejected_key, ejected_item = cache.popitem(last=False) | |
| eject_callback(ejected_item) | |
| logging.debug( | |
| "Item ejected from LRU cache", | |
| extra={"body": {"function": fn.__name__, "args": ejected_key}}, | |
| ) | |
| logging.debug( | |
| f"Item ejected from LRU cache function={fn.__name__}, args={ejected_key}" | |
| ) | |
| return result | |
| return new_func | |
| return decorator | |
| def close_db_connection(db: Database): | |
| try: | |
| db.close() | |
| except Exception as e: | |
| raise RuntimeError( | |
| "You probably need to reduce the amount of agencies that are processed concurrently" | |
| ) from e | |
| @lru_cache_with_eject_callback(maxsize=3, eject_callback=close_db_connection) | |
| def get_db_connection_from_pool(db_name: str) -> Database: | |
| db = PostgresqlExtDatabase( | |
| db_name, | |
| host="localhost", | |
| port=11310, | |
| user="postgres", | |
| password="postgres", | |
| autoconnect=False, # https://docs.peewee-orm.com/en/latest/peewee/database.html#using-autoconnect | |
| # TODO: Verify that lock_timeout works | |
| options="-c lock_timeout=30000", # timeout in ms | |
| ) | |
| db.connect() | |
| return db | |
| # db_connection = LRUCacheWithEjectCallback( | |
| # db_connection, maxsize=2, eject_callback=close_db_connection | |
| # ) | |
| # postgres_db = db_connection("postgres") | |
| # candidate_pool_db = db_connection("discover_candidate_pool") | |
| # template1_db = db_connection("template1") | |
| from peewee import CharField, DatabaseProxy, Model | |
| database_proxy = DatabaseProxy() # Create a proxy for our db. | |
| class _BaseModel(Model): | |
| class Meta: | |
| database = database_proxy | |
| class SubscribedAudienceFilterOrm(_BaseModel): | |
| class Meta: | |
| schema = "candidate_pool" | |
| table_name = "subscribed_audience_filters" | |
| filter_id = CharField(primary_key=True) | |
| agency_id = CharField() | |
| @contextmanager | |
| def agency_db(db_name: str): | |
| db = get_db_connection_from_pool(db_name) | |
| database_proxy.initialize(db) | |
| yield | |
| database_proxy.initialize(None) | |
| @contextmanager | |
| def db_transactions(db_names: list[str]): | |
| dbs = [] | |
| for db_name in db_names: | |
| db = get_db_connection_from_pool(db_name) | |
| db.session_start() | |
| dbs.append(db) | |
| try: | |
| yield | |
| except Exception: | |
| for db in dbs: | |
| db.session_rollback() | |
| raise | |
| else: | |
| for db in dbs: | |
| db.session_commit() | |
| db_names = [ | |
| # "discover_candidate_pool", | |
| "discover_candidate_pool2", | |
| "discover_candidate_pool3", | |
| # "postgres", | |
| ] | |
| # print(database_proxy.obj) | |
| with db_transactions(db_names): | |
| # print(database_proxy.obj) | |
| for db_name in db_names: | |
| with agency_db(db_name): | |
| # print(database_proxy.obj) | |
| audience_filter_orms = SubscribedAudienceFilterOrm.select() | |
| print(list(audience_filter_orms)) | |
| # SubscribedAudienceFilterOrm.create(agency_id="xxx", filter_id=db_name) | |
| # audience_filter_orms = SubscribedAudienceFilterOrm.select() | |
| # print(list(audience_filter_orms)) | |
| print() | |
| # print(database_proxy.obj) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment