""" This example showcases postgres features """ from tortoise import Tortoise, fields, run_async from tortoise.models import Model from tortoise.backends.asyncpg.executor import AsyncpgExecutor from tortoise.backends.asyncpg.client import AsyncpgDBClient import logging import copy logging.basicConfig(level=logging.DEBUG) class MetaAsyncpgExecutor(AsyncpgExecutor): def get_update_sql(self, update_fields, condition_keys=[]) -> str: """ Generates the SQL for updating a model depending on provided update_fields. Result is cached for performance. """ key = ",".join(update_fields) if update_fields else "" if len(condition_keys) > 0: key = "{},{}".format(key, ','.join(condition_keys)) if key in self.update_cache: return self.update_cache[key] table = self.model._meta.basetable query = self.db.query_class.update(table) count = 0 for field in update_fields or self.model._meta.fields_db_projection.keys(): db_field = self.model._meta.fields_db_projection[field] field_object = self.model._meta.fields_map[field] if not field_object.pk: query = query.set(db_field, self.Parameter(count)) count += 1 query = query.where(table[self.model._meta.db_pk_field] == self.Parameter(count)) count += 1 for k in condition_keys: query = query.where(table[k] == self.Parameter(count)) count += 1 sql = self.update_cache[key] = query.get_sql() return sql async def execute_update(self, instance, update_fields, condition_fields=[]) -> int: values = [ self.column_map[field](getattr(instance, field), instance) for field in update_fields or self.model._meta.fields_db_projection.keys() if not self.model._meta.fields_map[field].pk ] values.append(self.model._meta.pk.to_db_value(instance.pk, instance)) valid_condition_keys = [] for k, v in condition_fields: if k not in (update_fields or []): valid_condition_keys.append(k) values.append(v) return (await self.db.execute_query(self.get_update_sql(update_fields, valid_condition_keys), values))[0] AsyncpgDBClient.executor_class = MetaAsyncpgExecutor class MetaverModel(Model): metaver = fields.IntField(default=0) created = fields.DatetimeField(auto_now_add=True) updated = fields.DatetimeField(auto_now=True) class Meta: abstract = True def __init__(self, **kwargs): super(MetaverModel, self).__init__(**kwargs) self._snapshot_data = {} def make_snapshot(self): new_data = dict() for key in self._meta.fields_db_projection.keys(): new_data[key] = copy.deepcopy(getattr(self, key)) self._snapshot_data = new_data @property def changed(self): now_data = dict() for key in self._meta.fields_db_projection.keys(): now_data[key] = getattr(self, key) diff = self.dict_diff(now_data, self._snapshot_data) return diff.keys() def dict_diff(self, first, second): """ Return a dict of keys that differ with another config object. If a value is not found in one fo the configs, it will be represented by None. @param first: Fist dictionary to diff. @param second: Second dicationary to diff. @return diff: Dict of Key => (first.val, second.val) """ diff = {} # Check all keys in first dict for key in first.keys(): if key not in second: diff[key] = (first[key], None) elif (first[key] != second[key]): diff[key] = (first[key], second[key]) return diff @classmethod async def create(cls, **kwargs): instance = await super(MetaverModel, cls).create(**kwargs) instance.make_snapshot() return instance @classmethod def _init_from_db(cls, **kwargs): instance = super(MetaverModel, cls)._init_from_db(**kwargs) instance.make_snapshot() return instance @classmethod async def bulk_create(cls, objects, using_db = None) -> None: await super(MetaverModel, cls).bulk_create(objects, using_db=using_db) for obj in objects: obj.make_snapshot() async def save(self, using_db = None, update_fields = None, force=False): changed = self.changed if len(changed) == 0: return old_metaver = self.metaver self.metaver += 1 fileds = list(set(update_fields or ()) | set(changed) | set(('updated', 'metaver'))) condition_fields = [] if not force: condition_fields.append(('metaver', old_metaver)) db = using_db or self._meta.db executor = db.executor_class(model=self.__class__, db=db) if self._saved_in_db: ret = await executor.execute_update(self, update_fields, condition_fields=condition_fields) if ret == 0: raise Exception('model save failed.') else: ret = await executor.execute_insert(self) if ret == 0: raise Exception('model insert failed.') self._saved_in_db = True self.make_snapshot() class Report(MetaverModel): id = fields.IntField(pk=True) content = fields.JSONField() def __str__(self): return str(self.id) async def run(): await Tortoise.init( { "connections": { "default": { "engine": "tortoise.backends.asyncpg", "credentials": { "host": "127.0.0.1", "port": "54320", "user": "postgres", "password": "", "database": "test-tortoise", }, } }, "apps": {"models": {"models": ["__main__"], "default_connection": "default"}}, }, _create_db=False, ) await Tortoise.generate_schemas() report_data = {"foo": "bar"} report = await Report.create(content=report_data) print(report.id) report.content['hello'] = 'world' print(report.content) await report.save() del report.content['foo'] await report.save() print('\n>>>> test not modified update....') report = await Report.filter(id=report.id).first() await report.save() print('>>>> end.\n') print('\n>>>> test concurrent update....') r1 = await Report.filter(id=report.id).first() r2 = await Report.filter(id=report.id).first() r1.content['r1'] = 'value1' await r1.save() r2.content['r2'] = 'value2' try: await r2.save() except: await r2.save(force=True) r3 = await Report.filter(id=report.id).first() assert r3.content['r2'] == 'value2' print(report.id) #await Tortoise._drop_databases() if __name__ == "__main__": run_async(run())