import time import logging from airflow.utils.db import create_session from airflow.utils import timezone from airflow.models import TaskInstance, DagRun from airflow.models.serialized_dag import SerializedDagModel from airflow.utils.types import DagRunType logger = logging.getLogger(__name__) out_hdlr = logging.FileHandler('./log.txt') out_hdlr.setFormatter(logging.Formatter('%(asctime)s %(message)s')) out_hdlr.setLevel(logging.INFO) logger.addHandler(out_hdlr) logger.setLevel(logging.INFO) MODE='bulk_insert_mappings' def create_tis_in_new_dag_run(dag, run_id, number_of_tis): tasks = list(dag.task_dict.values())[0:number_of_tis] t1 = time.monotonic() success = True tis = [] try: with create_session() as session: if MODE == 'unit-of-work': for i, task in enumerate(tasks): ti = TaskInstance(task, run_id=run_id) session.add(ti) elif MODE == 'bulk_save_objects': session.bulk_save_objects( [ TaskInstance(task, run_id=run_id) for task in tasks ] ) elif MODE == 'bulk_insert_mappings': session.bulk_insert_mappings( TaskInstance, [ { 'dag_id': task.dag_id, 'task_id': task.task_id, 'run_id': run_id, 'pool': task.pool, 'queue': task.queue, 'pool_slots': task.pool_slots, 'priority_weight': task.priority_weight_total, 'run_as_user': task.run_as_user, 'max_tries': task.retries, 'executor_config': task.executor_config, 'operator': task.task_type } for task in tasks ] ) session.flush() except: raise success = False t2 = time.monotonic() logger.info('Created %s tis. success?: %s, perf: %s', len(tasks), success, t2 - t1) return t2 - t1, success def perf_tis_creation(dag): perf = {} for number_of_tis in [1000, 3000, 5000, 10000, 15000, 20000, 25000]: with create_session() as session: dag_run = DagRun(dag_id=dag_id, run_type=DagRunType.MANUAL, run_id=DagRun.generate_run_id(DagRunType.MANUAL, timezone.utcnow())) session.add(dag_run) duration, success = create_tis_in_new_dag_run(dag, dag_run.run_id, number_of_tis) perf[number_of_tis] = (duration, success) time.sleep(5) if __name__ == '__main__': dag_id = 'fake_dag' dm = SerializedDagModel.get(dag_id) dag = dm.dag logger.info('%s', MODE) perf_tis_creation(dag)