Last active
January 31, 2020 12:21
-
-
Save zeaphoo/2c1612e4a45f88c84750603fec9007a1 to your computer and use it in GitHub Desktop.
Revisions
-
zeaphoo revised this gist
Jan 31, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -164,7 +164,7 @@ async def run(): "port": "54320", "user": "postgres", "password": "", "database": "test-tortoise", }, } }, -
zeaphoo created this gist
Jan 31, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,208 @@ """ This example showcases postgres features """ from tortoise import Tortoise, fields, run_async from tortoise.models import Model from tortoise.backends.asyncpg.executor import AsyncpgExecutor from tortoise.backends.asyncpg.client import AsyncpgDBClient import logging import copy logging.basicConfig(level=logging.DEBUG) class MetaAsyncpgExecutor(AsyncpgExecutor): def get_update_sql(self, update_fields, condition_keys=[]) -> str: """ Generates the SQL for updating a model depending on provided update_fields. Result is cached for performance. """ key = ",".join(update_fields) if update_fields else "" if len(condition_keys) > 0: key = "{},{}".format(key, ','.join(condition_keys)) if key in self.update_cache: return self.update_cache[key] table = self.model._meta.basetable query = self.db.query_class.update(table) count = 0 for field in update_fields or self.model._meta.fields_db_projection.keys(): db_field = self.model._meta.fields_db_projection[field] field_object = self.model._meta.fields_map[field] if not field_object.pk: query = query.set(db_field, self.Parameter(count)) count += 1 query = query.where(table[self.model._meta.db_pk_field] == self.Parameter(count)) count += 1 for k in condition_keys: query = query.where(table[k] == self.Parameter(count)) count += 1 sql = self.update_cache[key] = query.get_sql() return sql async def execute_update(self, instance, update_fields, condition_fields=[]) -> int: values = [ self.column_map[field](getattr(instance, field), instance) for field in update_fields or self.model._meta.fields_db_projection.keys() if not self.model._meta.fields_map[field].pk ] values.append(self.model._meta.pk.to_db_value(instance.pk, instance)) valid_condition_keys = [] for k, v in condition_fields: if k not in (update_fields or []): valid_condition_keys.append(k) values.append(v) return (await self.db.execute_query(self.get_update_sql(update_fields, valid_condition_keys), values))[0] AsyncpgDBClient.executor_class = MetaAsyncpgExecutor class MetaverModel(Model): metaver = fields.IntField(default=0) created = fields.DatetimeField(auto_now_add=True) updated = fields.DatetimeField(auto_now=True) class Meta: abstract = True def __init__(self, **kwargs): super(MetaverModel, self).__init__(**kwargs) self._snapshot_data = {} def make_snapshot(self): new_data = dict() for key in self._meta.fields_db_projection.keys(): new_data[key] = copy.deepcopy(getattr(self, key)) self._snapshot_data = new_data @property def changed(self): now_data = dict() for key in self._meta.fields_db_projection.keys(): now_data[key] = getattr(self, key) diff = self.dict_diff(now_data, self._snapshot_data) return diff.keys() def dict_diff(self, first, second): """ Return a dict of keys that differ with another config object. If a value is not found in one fo the configs, it will be represented by None. @param first: Fist dictionary to diff. @param second: Second dicationary to diff. @return diff: Dict of Key => (first.val, second.val) """ diff = {} # Check all keys in first dict for key in first.keys(): if key not in second: diff[key] = (first[key], None) elif (first[key] != second[key]): diff[key] = (first[key], second[key]) return diff @classmethod async def create(cls, **kwargs): instance = await super(MetaverModel, cls).create(**kwargs) instance.make_snapshot() return instance @classmethod def _init_from_db(cls, **kwargs): instance = super(MetaverModel, cls)._init_from_db(**kwargs) instance.make_snapshot() return instance @classmethod async def bulk_create(cls, objects, using_db = None) -> None: await super(MetaverModel, cls).bulk_create(objects, using_db=using_db) for obj in objects: obj.make_snapshot() async def save(self, using_db = None, update_fields = None, force=False): changed = self.changed if len(changed) == 0: return old_metaver = self.metaver self.metaver += 1 fileds = list(set(update_fields or ()) | set(changed) | set(('updated', 'metaver'))) condition_fields = [] if not force: condition_fields.append(('metaver', old_metaver)) db = using_db or self._meta.db executor = db.executor_class(model=self.__class__, db=db) if self._saved_in_db: ret = await executor.execute_update(self, update_fields, condition_fields=condition_fields) if ret == 0: raise Exception('model save failed.') else: ret = await executor.execute_insert(self) if ret == 0: raise Exception('model insert failed.') self._saved_in_db = True self.make_snapshot() class Report(MetaverModel): id = fields.IntField(pk=True) content = fields.JSONField() def __str__(self): return str(self.id) async def run(): await Tortoise.init( { "connections": { "default": { "engine": "tortoise.backends.asyncpg", "credentials": { "host": "127.0.0.1", "port": "54320", "user": "postgres", "password": "", "database": "missionflow", }, } }, "apps": {"models": {"models": ["__main__"], "default_connection": "default"}}, }, _create_db=False, ) await Tortoise.generate_schemas() report_data = {"foo": "bar"} report = await Report.create(content=report_data) print(report.id) report.content['hello'] = 'world' print(report.content) await report.save() del report.content['foo'] await report.save() print('\n>>>> test not modified update....') report = await Report.filter(id=report.id).first() await report.save() print('>>>> end.\n') print('\n>>>> test concurrent update....') r1 = await Report.filter(id=report.id).first() r2 = await Report.filter(id=report.id).first() r1.content['r1'] = 'value1' await r1.save() r2.content['r2'] = 'value2' try: await r2.save() except: await r2.save(force=True) r3 = await Report.filter(id=report.id).first() assert r3.content['r2'] == 'value2' print(report.id) #await Tortoise._drop_databases() if __name__ == "__main__": run_async(run())