Skip to content

Instantly share code, notes, and snippets.

@zeaphoo
Last active January 31, 2020 12:21
Show Gist options
  • Select an option

  • Save zeaphoo/2c1612e4a45f88c84750603fec9007a1 to your computer and use it in GitHub Desktop.

Select an option

Save zeaphoo/2c1612e4a45f88c84750603fec9007a1 to your computer and use it in GitHub Desktop.

Revisions

  1. zeaphoo revised this gist Jan 31, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion metaver_model.py
    Original file line number Diff line number Diff line change
    @@ -164,7 +164,7 @@ async def run():
    "port": "54320",
    "user": "postgres",
    "password": "",
    "database": "missionflow",
    "database": "test-tortoise",
    },
    }
    },
  2. zeaphoo created this gist Jan 31, 2020.
    208 changes: 208 additions & 0 deletions metaver_model.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,208 @@
    """
    This example showcases postgres features
    """
    from tortoise import Tortoise, fields, run_async
    from tortoise.models import Model
    from tortoise.backends.asyncpg.executor import AsyncpgExecutor
    from tortoise.backends.asyncpg.client import AsyncpgDBClient
    import logging
    import copy

    logging.basicConfig(level=logging.DEBUG)

    class MetaAsyncpgExecutor(AsyncpgExecutor):
    def get_update_sql(self, update_fields, condition_keys=[]) -> str:
    """
    Generates the SQL for updating a model depending on provided update_fields.
    Result is cached for performance.
    """
    key = ",".join(update_fields) if update_fields else ""
    if len(condition_keys) > 0:
    key = "{},{}".format(key, ','.join(condition_keys))
    if key in self.update_cache:
    return self.update_cache[key]

    table = self.model._meta.basetable
    query = self.db.query_class.update(table)
    count = 0
    for field in update_fields or self.model._meta.fields_db_projection.keys():
    db_field = self.model._meta.fields_db_projection[field]
    field_object = self.model._meta.fields_map[field]
    if not field_object.pk:
    query = query.set(db_field, self.Parameter(count))
    count += 1

    query = query.where(table[self.model._meta.db_pk_field] == self.Parameter(count))
    count += 1
    for k in condition_keys:
    query = query.where(table[k] == self.Parameter(count))
    count += 1

    sql = self.update_cache[key] = query.get_sql()
    return sql

    async def execute_update(self, instance, update_fields, condition_fields=[]) -> int:
    values = [
    self.column_map[field](getattr(instance, field), instance)
    for field in update_fields or self.model._meta.fields_db_projection.keys()
    if not self.model._meta.fields_map[field].pk
    ]
    values.append(self.model._meta.pk.to_db_value(instance.pk, instance))
    valid_condition_keys = []
    for k, v in condition_fields:
    if k not in (update_fields or []):
    valid_condition_keys.append(k)
    values.append(v)
    return (await self.db.execute_query(self.get_update_sql(update_fields, valid_condition_keys), values))[0]

    AsyncpgDBClient.executor_class = MetaAsyncpgExecutor

    class MetaverModel(Model):
    metaver = fields.IntField(default=0)
    created = fields.DatetimeField(auto_now_add=True)
    updated = fields.DatetimeField(auto_now=True)

    class Meta:
    abstract = True

    def __init__(self, **kwargs):
    super(MetaverModel, self).__init__(**kwargs)
    self._snapshot_data = {}


    def make_snapshot(self):
    new_data = dict()
    for key in self._meta.fields_db_projection.keys():
    new_data[key] = copy.deepcopy(getattr(self, key))
    self._snapshot_data = new_data

    @property
    def changed(self):
    now_data = dict()
    for key in self._meta.fields_db_projection.keys():
    now_data[key] = getattr(self, key)
    diff = self.dict_diff(now_data, self._snapshot_data)
    return diff.keys()

    def dict_diff(self, first, second):
    """ Return a dict of keys that differ with another config object. If a value is
    not found in one fo the configs, it will be represented by None.
    @param first: Fist dictionary to diff.
    @param second: Second dicationary to diff.
    @return diff: Dict of Key => (first.val, second.val)
    """
    diff = {}
    # Check all keys in first dict
    for key in first.keys():
    if key not in second:
    diff[key] = (first[key], None)
    elif (first[key] != second[key]):
    diff[key] = (first[key], second[key])
    return diff

    @classmethod
    async def create(cls, **kwargs):
    instance = await super(MetaverModel, cls).create(**kwargs)
    instance.make_snapshot()
    return instance

    @classmethod
    def _init_from_db(cls, **kwargs):
    instance = super(MetaverModel, cls)._init_from_db(**kwargs)
    instance.make_snapshot()
    return instance

    @classmethod
    async def bulk_create(cls, objects, using_db = None) -> None:
    await super(MetaverModel, cls).bulk_create(objects, using_db=using_db)
    for obj in objects:
    obj.make_snapshot()

    async def save(self, using_db = None, update_fields = None, force=False):
    changed = self.changed
    if len(changed) == 0:
    return

    old_metaver = self.metaver
    self.metaver += 1
    fileds = list(set(update_fields or ()) | set(changed) | set(('updated', 'metaver')))
    condition_fields = []
    if not force:
    condition_fields.append(('metaver', old_metaver))

    db = using_db or self._meta.db
    executor = db.executor_class(model=self.__class__, db=db)
    if self._saved_in_db:
    ret = await executor.execute_update(self, update_fields, condition_fields=condition_fields)
    if ret == 0:
    raise Exception('model save failed.')
    else:
    ret = await executor.execute_insert(self)
    if ret == 0:
    raise Exception('model insert failed.')
    self._saved_in_db = True
    self.make_snapshot()



    class Report(MetaverModel):
    id = fields.IntField(pk=True)
    content = fields.JSONField()

    def __str__(self):
    return str(self.id)


    async def run():
    await Tortoise.init(
    {
    "connections": {
    "default": {
    "engine": "tortoise.backends.asyncpg",
    "credentials": {
    "host": "127.0.0.1",
    "port": "54320",
    "user": "postgres",
    "password": "",
    "database": "missionflow",
    },
    }
    },
    "apps": {"models": {"models": ["__main__"], "default_connection": "default"}},
    },
    _create_db=False,
    )
    await Tortoise.generate_schemas()

    report_data = {"foo": "bar"}
    report = await Report.create(content=report_data)
    print(report.id)
    report.content['hello'] = 'world'
    print(report.content)
    await report.save()
    del report.content['foo']
    await report.save()
    print('\n>>>> test not modified update....')
    report = await Report.filter(id=report.id).first()
    await report.save()
    print('>>>> end.\n')

    print('\n>>>> test concurrent update....')
    r1 = await Report.filter(id=report.id).first()
    r2 = await Report.filter(id=report.id).first()
    r1.content['r1'] = 'value1'
    await r1.save()
    r2.content['r2'] = 'value2'
    try:
    await r2.save()
    except:
    await r2.save(force=True)
    r3 = await Report.filter(id=report.id).first()
    assert r3.content['r2'] == 'value2'
    print(report.id)

    #await Tortoise._drop_databases()


    if __name__ == "__main__":
    run_async(run())