Skip to content

Instantly share code, notes, and snippets.

@ziplus4
Created December 16, 2015 06:42
Show Gist options
  • Save ziplus4/3bf8cc14541a16c65206 to your computer and use it in GitHub Desktop.
Save ziplus4/3bf8cc14541a16c65206 to your computer and use it in GitHub Desktop.

Revisions

  1. ziplus4 created this gist Dec 16, 2015.
    200 changes: 200 additions & 0 deletions db_bind_sharding.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,200 @@
    # -*- coding:utf8 -*-
    import re

    from flask import Flask

    from flask_sqlalchemy import SQLAlchemy as BaseSQLAlchemy
    from flask_sqlalchemy import _SignallingSession as BaseSignallingSession
    from flask_sqlalchemy import orm, partial, get_state

    from datetime import datetime

    class _BindingKeyPattern(object):
    def __init__(self, db, pattern):
    self.db = db
    self.raw_pattern = pattern
    self.compiled_pattern = re.compile(pattern)
    self._shard_keys = None

    def __repr__(self):
    return "%s<%s>" % (self.__class__.__name__, self.raw_pattern)

    def match(self, key):
    return self.compiled_pattern.match(key)

    def get_shard_key(self, hash_num):
    if self._shard_keys is None:
    self._shard_keys = [key for key, value in self.db.app.config['SQLALCHEMY_BINDS'].iteritems() if self.compiled_pattern.match(key)]
    self._shard_keys.sort()

    return self._shard_keys[hash_num % len(self._shard_keys)]


    class _BoundSection(object):
    def __init__(self, db_session_cls, name):
    self.db_session = db_session_cls()
    self.name = name

    def __enter__(self):
    self.db_session.push_binding(self.name)

    def __exit__(self, exc_type, exc_val, exc_tb):
    self.db_session.pop_binding()
    self.db_session.close()


    class _SignallingSession(BaseSignallingSession):
    def __init__(self, *args, **kwargs):
    BaseSignallingSession.__init__(self, *args, **kwargs)
    self._binding_keys = []
    self._binding_key = None

    def push_binding(self, key):
    self._binding_keys.append(self._binding_key)
    self._binding_key = key

    def pop_binding(self):
    self._binding_key = self._binding_keys.pop()

    def get_bind(self, mapper, clause=None):
    binding_key = self.__find_binding_key(mapper)
    if binding_key is None:
    return BaseSignallingSession.get_bind(self, mapper, clause)
    else:
    state = get_state(self.app)
    return state.db.get_engine(self.app, bind=binding_key)

    def __find_binding_key(self, mapper):
    if mapper is None: # 맵퍼 없음
    return self._binding_key
    else:
    mapper_info = getattr(mapper.mapped_table, 'info', {})
    mapped_binding_key = mapper_info.get('bind_key')
    if mapped_binding_key: # 맵핑된 바인딩 키 존재
    if type(mapped_binding_key) is str: # 정적 바인딩
    return mapped_binding_key
    else: # 동적 바인딩
    if mapped_binding_key.match(self._binding_key): # 현재 바인딩
    return self._binding_key
    else: # 푸쉬된 바인딩
    for pushed_binding_key in reversed(self._binding_keys):
    if pushed_binding_key and mapped_binding_key.match(pushed_binding_key):
    return pushed_binding_key
    else:
    raise Exception('NOT_FOUND_MAPPED_BINDING:%s CURRENT_BINDING:%s PUSHED_BINDINGS:%s' % (repr(mapped_binding_key), repr(self._binding_key), repr(self._binding_keys[1:])))
    else: # 맵핑된 바인딩 키가 없으면 디폴트 바인딩
    return self._binding_key


    class SQLAlchemy(BaseSQLAlchemy):
    def BindingKeyPattern(self, pattern):
    return _BindingKeyPattern(self, pattern)

    def binding(self, key):
    return _BoundSection(self.session, key)

    def create_scoped_session(self, options=None):
    if options is None:
    options = {}
    scopefunc=options.pop('scopefunc', None)
    return orm.scoped_session(
    partial(_SignallingSession, self, **options), scopefunc=scopefunc
    )

    def get_binds(self, app=None):
    retval = BaseSQLAlchemy.get_binds(self, app)

    bind = None
    engine = self.get_engine(app, bind)
    tables = self.get_tables_for_bind(bind)
    retval.update(dict((table, engine) for table in tables))
    return retval

    def get_tables_for_bind(self, bind=None):
    result = []
    for table in self.Model.metadata.tables.itervalues():
    table_bind_key = table.info.get('bind_key')
    if table_bind_key == bind:
    result.append(table)
    else:
    if bind:
    if type(table_bind_key) is _BindingKeyPattern and table_bind_key.match(bind):
    result.append(table)
    elif type(table_bind_key) is str and table_bind_key == bind:
    result.append(table)

    return result


    app = Flask(__name__)
    db = SQLAlchemy(app)

    class Notice(db.Model):
    __bind_key__ = 'global'

    id = db.Column(db.Integer, primary_key=True)
    msg = db.Column(db.String, nullable=False)
    ctime = db.Column(db.DateTime, default=datetime.now(), nullable=False)

    def __repr__(self):
    return "%s<id=%d,msg='%s'>" % (self.__class__.__name__, self.id, self.msg)

    class User(db.Model):
    __bind_key__ = db.BindingKeyPattern('[^_]+_user_\d\d')

    id = db.Column(db.Integer, primary_key=True)
    nickname = db.Column(db.String(80), unique=True)

    login_logs = db.relationship(lambda: LoginLog, backref='owner')

    def __repr__(self):
    return "%s<id=%d, nickname='%s'>" % (self.__class__.__name__, self.id, self.nickname)

    @classmethod
    def get_shard_key(cls, nickname):
    return cls.__bind_key__.get_shard_key(hash(nickname))

    class LoginLog(db.Model):
    __bind_key__ = db.BindingKeyPattern('[^_]+_log')

    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(db.Integer, db.ForeignKey(User.id))
    ctime = db.Column(db.DateTime, default=datetime.now(), nullable=False)


    if __name__ == '__main__':
    app.config['SQLALCHEMY_ECHO'] = True
    app.config['SQLALCHEMY_BINDS'] = {
    'global': 'sqlite:///./global.db',
    'master_user_01': 'sqlite:///./master_user_01.db',
    'master_user_02': 'sqlite:///./master_user_02.db',
    'slave_user': 'sqlite:///./slave_user.db',
    'master_log': 'sqlite:///./master_log.db',
    'slave_log': 'sqlite:///./slave_log.db',
    }

    db.drop_all()
    db.create_all()

    notice = Notice(msg='NOTICE1')
    db.session.add(notice)
    db.session.commit()

    nickname = 'jaru'
    with db.binding(User.get_shard_key(nickname)):
    notice = Notice(msg='NOTICE2')
    db.session.add(notice)
    db.session.commit()

    user = User(nickname=nickname)
    db.session.add(user)
    db.session.commit()

    with db.binding('master_log'):
    notice = Notice(msg='NOTICE3')
    db.session.add(notice)
    db.session.commit()

    login_log = LoginLog(owner=user)
    db.session.add(login_log)
    db.session.commit()