Last active
January 31, 2020 12:21
-
-
Save zeaphoo/2c1612e4a45f88c84750603fec9007a1 to your computer and use it in GitHub Desktop.
add opt lock and dirty state to tortoise-orm model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| This example showcases postgres features | |
| """ | |
| from tortoise import Tortoise, fields, run_async | |
| from tortoise.models import Model | |
| from tortoise.backends.asyncpg.executor import AsyncpgExecutor | |
| from tortoise.backends.asyncpg.client import AsyncpgDBClient | |
| import logging | |
| import copy | |
| logging.basicConfig(level=logging.DEBUG) | |
| class MetaAsyncpgExecutor(AsyncpgExecutor): | |
| def get_update_sql(self, update_fields, condition_keys=[]) -> str: | |
| """ | |
| Generates the SQL for updating a model depending on provided update_fields. | |
| Result is cached for performance. | |
| """ | |
| key = ",".join(update_fields) if update_fields else "" | |
| if len(condition_keys) > 0: | |
| key = "{},{}".format(key, ','.join(condition_keys)) | |
| if key in self.update_cache: | |
| return self.update_cache[key] | |
| table = self.model._meta.basetable | |
| query = self.db.query_class.update(table) | |
| count = 0 | |
| for field in update_fields or self.model._meta.fields_db_projection.keys(): | |
| db_field = self.model._meta.fields_db_projection[field] | |
| field_object = self.model._meta.fields_map[field] | |
| if not field_object.pk: | |
| query = query.set(db_field, self.Parameter(count)) | |
| count += 1 | |
| query = query.where(table[self.model._meta.db_pk_field] == self.Parameter(count)) | |
| count += 1 | |
| for k in condition_keys: | |
| query = query.where(table[k] == self.Parameter(count)) | |
| count += 1 | |
| sql = self.update_cache[key] = query.get_sql() | |
| return sql | |
| async def execute_update(self, instance, update_fields, condition_fields=[]) -> int: | |
| values = [ | |
| self.column_map[field](getattr(instance, field), instance) | |
| for field in update_fields or self.model._meta.fields_db_projection.keys() | |
| if not self.model._meta.fields_map[field].pk | |
| ] | |
| values.append(self.model._meta.pk.to_db_value(instance.pk, instance)) | |
| valid_condition_keys = [] | |
| for k, v in condition_fields: | |
| if k not in (update_fields or []): | |
| valid_condition_keys.append(k) | |
| values.append(v) | |
| return (await self.db.execute_query(self.get_update_sql(update_fields, valid_condition_keys), values))[0] | |
| AsyncpgDBClient.executor_class = MetaAsyncpgExecutor | |
| class MetaverModel(Model): | |
| metaver = fields.IntField(default=0) | |
| created = fields.DatetimeField(auto_now_add=True) | |
| updated = fields.DatetimeField(auto_now=True) | |
| class Meta: | |
| abstract = True | |
| def __init__(self, **kwargs): | |
| super(MetaverModel, self).__init__(**kwargs) | |
| self._snapshot_data = {} | |
| def make_snapshot(self): | |
| new_data = dict() | |
| for key in self._meta.fields_db_projection.keys(): | |
| new_data[key] = copy.deepcopy(getattr(self, key)) | |
| self._snapshot_data = new_data | |
| @property | |
| def changed(self): | |
| now_data = dict() | |
| for key in self._meta.fields_db_projection.keys(): | |
| now_data[key] = getattr(self, key) | |
| diff = self.dict_diff(now_data, self._snapshot_data) | |
| return diff.keys() | |
| def dict_diff(self, first, second): | |
| """ Return a dict of keys that differ with another config object. If a value is | |
| not found in one fo the configs, it will be represented by None. | |
| @param first: Fist dictionary to diff. | |
| @param second: Second dicationary to diff. | |
| @return diff: Dict of Key => (first.val, second.val) | |
| """ | |
| diff = {} | |
| # Check all keys in first dict | |
| for key in first.keys(): | |
| if key not in second: | |
| diff[key] = (first[key], None) | |
| elif (first[key] != second[key]): | |
| diff[key] = (first[key], second[key]) | |
| return diff | |
| @classmethod | |
| async def create(cls, **kwargs): | |
| instance = await super(MetaverModel, cls).create(**kwargs) | |
| instance.make_snapshot() | |
| return instance | |
| @classmethod | |
| def _init_from_db(cls, **kwargs): | |
| instance = super(MetaverModel, cls)._init_from_db(**kwargs) | |
| instance.make_snapshot() | |
| return instance | |
| @classmethod | |
| async def bulk_create(cls, objects, using_db = None) -> None: | |
| await super(MetaverModel, cls).bulk_create(objects, using_db=using_db) | |
| for obj in objects: | |
| obj.make_snapshot() | |
| async def save(self, using_db = None, update_fields = None, force=False): | |
| changed = self.changed | |
| if len(changed) == 0: | |
| return | |
| old_metaver = self.metaver | |
| self.metaver += 1 | |
| fileds = list(set(update_fields or ()) | set(changed) | set(('updated', 'metaver'))) | |
| condition_fields = [] | |
| if not force: | |
| condition_fields.append(('metaver', old_metaver)) | |
| db = using_db or self._meta.db | |
| executor = db.executor_class(model=self.__class__, db=db) | |
| if self._saved_in_db: | |
| ret = await executor.execute_update(self, update_fields, condition_fields=condition_fields) | |
| if ret == 0: | |
| raise Exception('model save failed.') | |
| else: | |
| ret = await executor.execute_insert(self) | |
| if ret == 0: | |
| raise Exception('model insert failed.') | |
| self._saved_in_db = True | |
| self.make_snapshot() | |
| class Report(MetaverModel): | |
| id = fields.IntField(pk=True) | |
| content = fields.JSONField() | |
| def __str__(self): | |
| return str(self.id) | |
| async def run(): | |
| await Tortoise.init( | |
| { | |
| "connections": { | |
| "default": { | |
| "engine": "tortoise.backends.asyncpg", | |
| "credentials": { | |
| "host": "127.0.0.1", | |
| "port": "54320", | |
| "user": "postgres", | |
| "password": "", | |
| "database": "test-tortoise", | |
| }, | |
| } | |
| }, | |
| "apps": {"models": {"models": ["__main__"], "default_connection": "default"}}, | |
| }, | |
| _create_db=False, | |
| ) | |
| await Tortoise.generate_schemas() | |
| report_data = {"foo": "bar"} | |
| report = await Report.create(content=report_data) | |
| print(report.id) | |
| report.content['hello'] = 'world' | |
| print(report.content) | |
| await report.save() | |
| del report.content['foo'] | |
| await report.save() | |
| print('\n>>>> test not modified update....') | |
| report = await Report.filter(id=report.id).first() | |
| await report.save() | |
| print('>>>> end.\n') | |
| print('\n>>>> test concurrent update....') | |
| r1 = await Report.filter(id=report.id).first() | |
| r2 = await Report.filter(id=report.id).first() | |
| r1.content['r1'] = 'value1' | |
| await r1.save() | |
| r2.content['r2'] = 'value2' | |
| try: | |
| await r2.save() | |
| except: | |
| await r2.save(force=True) | |
| r3 = await Report.filter(id=report.id).first() | |
| assert r3.content['r2'] == 'value2' | |
| print(report.id) | |
| #await Tortoise._drop_databases() | |
| if __name__ == "__main__": | |
| run_async(run()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment