Skip to content

Instantly share code, notes, and snippets.

@zeaphoo
Last active January 31, 2020 12:21
Show Gist options
  • Save zeaphoo/2c1612e4a45f88c84750603fec9007a1 to your computer and use it in GitHub Desktop.
Save zeaphoo/2c1612e4a45f88c84750603fec9007a1 to your computer and use it in GitHub Desktop.
add opt lock and dirty state to tortoise-orm model
"""
This example showcases postgres features
"""
from tortoise import Tortoise, fields, run_async
from tortoise.models import Model
from tortoise.backends.asyncpg.executor import AsyncpgExecutor
from tortoise.backends.asyncpg.client import AsyncpgDBClient
import logging
import copy
logging.basicConfig(level=logging.DEBUG)
class MetaAsyncpgExecutor(AsyncpgExecutor):
def get_update_sql(self, update_fields, condition_keys=[]) -> str:
"""
Generates the SQL for updating a model depending on provided update_fields.
Result is cached for performance.
"""
key = ",".join(update_fields) if update_fields else ""
if len(condition_keys) > 0:
key = "{},{}".format(key, ','.join(condition_keys))
if key in self.update_cache:
return self.update_cache[key]
table = self.model._meta.basetable
query = self.db.query_class.update(table)
count = 0
for field in update_fields or self.model._meta.fields_db_projection.keys():
db_field = self.model._meta.fields_db_projection[field]
field_object = self.model._meta.fields_map[field]
if not field_object.pk:
query = query.set(db_field, self.Parameter(count))
count += 1
query = query.where(table[self.model._meta.db_pk_field] == self.Parameter(count))
count += 1
for k in condition_keys:
query = query.where(table[k] == self.Parameter(count))
count += 1
sql = self.update_cache[key] = query.get_sql()
return sql
async def execute_update(self, instance, update_fields, condition_fields=[]) -> int:
values = [
self.column_map[field](getattr(instance, field), instance)
for field in update_fields or self.model._meta.fields_db_projection.keys()
if not self.model._meta.fields_map[field].pk
]
values.append(self.model._meta.pk.to_db_value(instance.pk, instance))
valid_condition_keys = []
for k, v in condition_fields:
if k not in (update_fields or []):
valid_condition_keys.append(k)
values.append(v)
return (await self.db.execute_query(self.get_update_sql(update_fields, valid_condition_keys), values))[0]
AsyncpgDBClient.executor_class = MetaAsyncpgExecutor
class MetaverModel(Model):
metaver = fields.IntField(default=0)
created = fields.DatetimeField(auto_now_add=True)
updated = fields.DatetimeField(auto_now=True)
class Meta:
abstract = True
def __init__(self, **kwargs):
super(MetaverModel, self).__init__(**kwargs)
self._snapshot_data = {}
def make_snapshot(self):
new_data = dict()
for key in self._meta.fields_db_projection.keys():
new_data[key] = copy.deepcopy(getattr(self, key))
self._snapshot_data = new_data
@property
def changed(self):
now_data = dict()
for key in self._meta.fields_db_projection.keys():
now_data[key] = getattr(self, key)
diff = self.dict_diff(now_data, self._snapshot_data)
return diff.keys()
def dict_diff(self, first, second):
""" Return a dict of keys that differ with another config object. If a value is
not found in one fo the configs, it will be represented by None.
@param first: Fist dictionary to diff.
@param second: Second dicationary to diff.
@return diff: Dict of Key => (first.val, second.val)
"""
diff = {}
# Check all keys in first dict
for key in first.keys():
if key not in second:
diff[key] = (first[key], None)
elif (first[key] != second[key]):
diff[key] = (first[key], second[key])
return diff
@classmethod
async def create(cls, **kwargs):
instance = await super(MetaverModel, cls).create(**kwargs)
instance.make_snapshot()
return instance
@classmethod
def _init_from_db(cls, **kwargs):
instance = super(MetaverModel, cls)._init_from_db(**kwargs)
instance.make_snapshot()
return instance
@classmethod
async def bulk_create(cls, objects, using_db = None) -> None:
await super(MetaverModel, cls).bulk_create(objects, using_db=using_db)
for obj in objects:
obj.make_snapshot()
async def save(self, using_db = None, update_fields = None, force=False):
changed = self.changed
if len(changed) == 0:
return
old_metaver = self.metaver
self.metaver += 1
fileds = list(set(update_fields or ()) | set(changed) | set(('updated', 'metaver')))
condition_fields = []
if not force:
condition_fields.append(('metaver', old_metaver))
db = using_db or self._meta.db
executor = db.executor_class(model=self.__class__, db=db)
if self._saved_in_db:
ret = await executor.execute_update(self, update_fields, condition_fields=condition_fields)
if ret == 0:
raise Exception('model save failed.')
else:
ret = await executor.execute_insert(self)
if ret == 0:
raise Exception('model insert failed.')
self._saved_in_db = True
self.make_snapshot()
class Report(MetaverModel):
id = fields.IntField(pk=True)
content = fields.JSONField()
def __str__(self):
return str(self.id)
async def run():
await Tortoise.init(
{
"connections": {
"default": {
"engine": "tortoise.backends.asyncpg",
"credentials": {
"host": "127.0.0.1",
"port": "54320",
"user": "postgres",
"password": "",
"database": "test-tortoise",
},
}
},
"apps": {"models": {"models": ["__main__"], "default_connection": "default"}},
},
_create_db=False,
)
await Tortoise.generate_schemas()
report_data = {"foo": "bar"}
report = await Report.create(content=report_data)
print(report.id)
report.content['hello'] = 'world'
print(report.content)
await report.save()
del report.content['foo']
await report.save()
print('\n>>>> test not modified update....')
report = await Report.filter(id=report.id).first()
await report.save()
print('>>>> end.\n')
print('\n>>>> test concurrent update....')
r1 = await Report.filter(id=report.id).first()
r2 = await Report.filter(id=report.id).first()
r1.content['r1'] = 'value1'
await r1.save()
r2.content['r2'] = 'value2'
try:
await r2.save()
except:
await r2.save(force=True)
r3 = await Report.filter(id=report.id).first()
assert r3.content['r2'] == 'value2'
print(report.id)
#await Tortoise._drop_databases()
if __name__ == "__main__":
run_async(run())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment