Skip to content

Instantly share code, notes, and snippets.

@musq
Created October 1, 2022 18:57
Show Gist options
  • Select an option

  • Save musq/9e60f6dc9e042f002df7bb866f70d90c to your computer and use it in GitHub Desktop.

Select an option

Save musq/9e60f6dc9e042f002df7bb866f70d90c to your computer and use it in GitHub Desktop.
Manual connection pooling using Peewee - with an eject callback
import logging
import sys
import time
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from functools import wraps
from typing import Callable
from peewee import Database
from playhouse.pool import PooledPostgresqlExtDatabase, PostgresqlExtDatabase
formatter = logging.Formatter("[%(levelname)s] %(asctime)s - %(name)s: %(message)s")
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root = logging.getLogger()
root.addHandler(handler)
root.setLevel(logging.DEBUG)
pool_db = PooledPostgresqlExtDatabase(
None,
autoconnect=False, # https://docs.peewee-orm.com/en/latest/peewee/database.html#using-autoconnect
)
pool_db.init(
"discover_candidate_pool",
host="localhost",
port=11310,
user="postgres",
password="postgres",
stale_timeout=2,
# TODO: Verify that lock_timeout works
options="-c lock_timeout=30000", # timeout in ms
)
# pool_db.connect()
# pool_db.close()
# time.sleep(3)
# pool_db.connect()
# pool_db.close()
# time.sleep(1)
# pool_db.connect()
# pool_db.close()
# pool_db.connect()
def open_close_connection(i: int):
logging.info(i)
pool_db.connect()
time.sleep(0.5)
logging.info(i)
pool_db.close()
futures = {}
with ThreadPoolExecutor(max_workers=8) as executor:
for i in range(3):
future = executor.submit(open_close_connection, i)
futures[future] = i
for future in as_completed(futures):
result = future.result()
# breakpoint()
# db = PostgresqlExtDatabase(
# None,
# autoconnect=False, # https://docs.peewee-orm.com/en/latest/peewee/database.html#using-autoconnect
# )
# db.init(
# "discover_candidate_pool",
# host="localhost",
# port=11310,
# user="postgres",
# password="postgres",
# # TODO: Verify that lock_timeout works
# options="-c lock_timeout=30000", # timeout in ms
# )
exit()
class LRUCacheWithEjectCallback:
"""
LRU cache which executes the eject_callback() on each ejected item. This could be
useful when we need to run tear down sequence on each ejected item. e.g. if we need
to close DB connection, close a file, etc.
Built using Raymond Hettinger's basic LRU cache implementation.
https://pastebin.com/LDwMwtp8
"""
def __init__(self, func, maxsize: int, eject_callback):
self.cache: OrderedDict = OrderedDict()
self.func = func
self.maxsize = maxsize
self.eject_callback = eject_callback
def __call__(self, *args):
cache = self.cache
if args in cache:
cache.move_to_end(args)
return cache[args]
result = self.func(*args)
cache[args] = result
if len(cache) > self.maxsize:
ejected_key, ejected_item = cache.popitem(last=False)
self.eject_callback(ejected_item)
return result
def lru_cache_with_eject_callback(
maxsize: int, eject_callback: Callable
) -> Callable[[Callable], Callable]:
def decorator(fn: Callable) -> Callable:
"""
Based on Raymond Hettinger's basic LRU cache implementation.
https://pastebin.com/LDwMwtp8
"""
cache: OrderedDict = OrderedDict()
@wraps(fn)
def new_func(*args):
"""
CAUTION: The decorated function must only take input as positional
arguments, and never as keyword arguments.
"""
if args in cache:
cache.move_to_end(args)
return cache[args]
result = fn(*args)
cache[args] = result
if len(cache) > maxsize:
ejected_key, ejected_item = cache.popitem(last=False)
eject_callback(ejected_item)
logging.debug(
"Item ejected from LRU cache",
extra={"body": {"function": fn.__name__, "args": ejected_key}},
)
logging.debug(
f"Item ejected from LRU cache function={fn.__name__}, args={ejected_key}"
)
return result
return new_func
return decorator
def close_db_connection(db: Database):
try:
db.close()
except Exception as e:
raise RuntimeError(
"You probably need to reduce the amount of agencies that are processed concurrently"
) from e
@lru_cache_with_eject_callback(maxsize=3, eject_callback=close_db_connection)
def get_db_connection_from_pool(db_name: str) -> Database:
db = PostgresqlExtDatabase(
db_name,
host="localhost",
port=11310,
user="postgres",
password="postgres",
autoconnect=False, # https://docs.peewee-orm.com/en/latest/peewee/database.html#using-autoconnect
# TODO: Verify that lock_timeout works
options="-c lock_timeout=30000", # timeout in ms
)
db.connect()
return db
# db_connection = LRUCacheWithEjectCallback(
# db_connection, maxsize=2, eject_callback=close_db_connection
# )
# postgres_db = db_connection("postgres")
# candidate_pool_db = db_connection("discover_candidate_pool")
# template1_db = db_connection("template1")
from peewee import CharField, DatabaseProxy, Model
database_proxy = DatabaseProxy() # Create a proxy for our db.
class _BaseModel(Model):
class Meta:
database = database_proxy
class SubscribedAudienceFilterOrm(_BaseModel):
class Meta:
schema = "candidate_pool"
table_name = "subscribed_audience_filters"
filter_id = CharField(primary_key=True)
agency_id = CharField()
@contextmanager
def agency_db(db_name: str):
db = get_db_connection_from_pool(db_name)
database_proxy.initialize(db)
yield
database_proxy.initialize(None)
@contextmanager
def db_transactions(db_names: list[str]):
dbs = []
for db_name in db_names:
db = get_db_connection_from_pool(db_name)
db.session_start()
dbs.append(db)
try:
yield
except Exception:
for db in dbs:
db.session_rollback()
raise
else:
for db in dbs:
db.session_commit()
db_names = [
# "discover_candidate_pool",
"discover_candidate_pool2",
"discover_candidate_pool3",
# "postgres",
]
# print(database_proxy.obj)
with db_transactions(db_names):
# print(database_proxy.obj)
for db_name in db_names:
with agency_db(db_name):
# print(database_proxy.obj)
audience_filter_orms = SubscribedAudienceFilterOrm.select()
print(list(audience_filter_orms))
# SubscribedAudienceFilterOrm.create(agency_id="xxx", filter_id=db_name)
# audience_filter_orms = SubscribedAudienceFilterOrm.select()
# print(list(audience_filter_orms))
print()
# print(database_proxy.obj)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment