Last active
March 10, 2021 08:51
-
-
Save mahmoud/10f6b6b0a9c5860030693357124131df to your computer and use it in GitHub Desktop.
Revisions
-
mahmoud revised this gist
Mar 10, 2021 . No changes.There are no files selected for viewing
-
mahmoud revised this gist
Mar 10, 2021 . No changes.There are no files selected for viewing
-
mahmoud created this gist
Mar 10, 2021 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,268 @@ # -*- coding: utf-8 -*- """ Common test fixtures for pytest """ from __future__ import print_function, unicode_literals import os from itertools import groupby import pytest from django.test import TransactionTestCase, TestCase from django.test.testcases import connections_support_transactions from boltons.dictutils import OMD from boltons.fileutils import atomic_save from six import PY3 from pytest_django.plugin import _blocking_manager, validate_django_db from django.db.backends.base.base import BaseDatabaseWrapper # Side effects on import. Can't be helped, we need to unblock all DB # accesses _blocking_manager.unblock() _blocking_manager._blocking_wrapper = BaseDatabaseWrapper.ensure_connection IS_TR_ENV = bool(PY3) TEST_PARTITION_OPTION = "--test_partition" TOTAL_TESTS_OPTION = "--total_test_partitions" def _patchedTearDownClass(cls): """ Overrides tearDownClass from TestCase """ if connections_support_transactions(): cls._rollback_atomics(cls.cls_atomics) super(TransactionTestCase, cls).tearDownClass() TestCase.tearDownClass = classmethod(_patchedTearDownClass) @pytest.fixture(autouse=True) def enable_db_access_for_all_tests(db): """ Force the test to provide DB access to all tests. We probably shouldn't do this, but we have no idea what tests use the DB and what doesn't at this time This also forces the default TestCase environment to be djangos TestCase with atomic transaction support: https://github.com/pytest-dev/pytest-django/blob/ce5d5bc0b29748ed411b9d683a33e1b13d98e17f/pytest_django/fixtures.py#L149 """ pass @pytest.fixture(scope='session') def _transaction_wrap(django_db_setup): # wraps the following client data setup in an exterior transaction separate # from the transaction around the individual tests themselves TestCase._enter_atomics() def pytest_addoption(parser): parser.addoption(TEST_PARTITION_OPTION, type=int, default=os.getenv('TEST_PARTITION') or 0) parser.addoption(TOTAL_TESTS_OPTION, type=int, default=os.getenv('TOTAL_TEST_PARTITIONS', 1)) parser.addoption('--tr-regen-state', action="store_true", default=False, help="regenerate list of skipped/failing tests into a new skipfile, for tech refresh") parser.addoption('--tr-no-autoskip', action='store_true', default=False, help='do not autoskip known failing tests based on tr_test_state.txt') parser.addoption('--tr-recheck', action='store_true', default=False, help='update TR state file for any passing tests. meant to be used with' ' test pattern filtering and implies --tr-no-autoskip') return def pytest_collection_modifyitems(config, items): def get_marker_transaction(test): marker = test.get_closest_marker('django_db') if marker: transaction, _ = validate_django_db(marker) return transaction return None def has_fixture(test, fixture): funcargnames = getattr(test, 'funcargnames', None) return funcargnames and fixture in funcargnames def run_transaction_test_cases_after_all_other_tests(test): """ Detect if a test case is marked as a transaction test case, and if so, make sure to run it last since transaction test cases truncate the database (and thus leave no data for "non-transaction" test cases to act on.) Part of the teardown for djangos TransactionTestCase does this: https://github.com/django/django/blob/b61ea56789a5825bd2961a335cb82f65e09f1614/django/test/testcases.py#L1000 """ is_test_case_subclass = getattr( test, 'cls', None) and issubclass(test.cls, TestCase) is_transaction_test_case_subclass = getattr( test, 'cls', None) and issubclass(test.cls, TransactionTestCase) if is_test_case_subclass or get_marker_transaction(test) is False: return 0 elif is_transaction_test_case_subclass or get_marker_transaction(test) is True: return 1 elif has_fixture(test, 'transactional_db') or has_fixture(test, 'live_server'): # live_server uses transactional_db. So same truncation. return 1 elif has_fixture(test, 'db'): return 0 return 0 def sort_by_app_and_test_folder_name(test): if test.cls: return test.cls.__module__ elif test.function: return test.function.__module__ def group_by_test_case(test): if test.cls: return test.cls.__name__ else: return "" key_funcs = (run_transaction_test_cases_after_all_other_tests, sort_by_app_and_test_folder_name, group_by_test_case,) for key_func in key_funcs: items.sort(key=key_func) TEST_NUMBER = config.getoption(TEST_PARTITION_OPTION) TOTAL_TESTS = config.getoption(TOTAL_TESTS_OPTION) # assign each group of test cases a number temp_new_list = [] for index, (group_by_name, tests) in enumerate(groupby(items, group_by_test_case)): if group_by_name == "": for inner_index, test in enumerate(tests): if inner_index % TOTAL_TESTS == TEST_NUMBER: temp_new_list.append(test) else: if index % TOTAL_TESTS == TEST_NUMBER: for test in tests: temp_new_list.append(test) items[:] = temp_new_list if not config.getoption("--needs-isolation"): skip_needs_isolation = pytest.mark.skip( reason="need --needs-isolation option to run") for item in items: if "needs_isolation" in item.keywords: item.add_marker(skip_needs_isolation) skip_tr_tests = not (config.option.tr_recheck or config.option.tr_no_autoskip or config.option.tr_regen_state) if IS_TR_ENV and skip_tr_tests: trf_path = config.rootdir + '/tr_test_state.txt' try: trf = TestResultFile.from_path(trf_path) except OSError: pass else: skip_tr_failure = pytest.mark.xfail( reason='known failure related to tech refresh') fns = set(trf.get_failing_nodeids()) for item in items: if item.nodeid in fns: item.add_marker(skip_tr_failure) return _all_reports = OMD() @pytest.hookimpl(tryfirst=True) def pytest_runtest_logreport(report): # this function runs 3x for each test: setup, call, teardown. # this approach ensures that if any phase fails, this test stays marked as failed if _all_reports.get(report.nodeid) != 'failed': _all_reports.add(report.nodeid, report.outcome) return @pytest.fixture(scope="session", autouse=True) def tr_state_save(request): session = request.node config = session.config if config.option.tr_regen_state: if config.option.tr_recheck: raise SystemExit('--tr-regen-state is mutually exclusive with --tr-recheck') args = [arg for arg in config.args if arg] if config.option.keyword: raise SystemExit('refusing to regenerate TR state while running a subset of tests') yield if not IS_TR_ENV or not _all_reports: return # not tech refreshing / process aborted results = _all_reports.items() path = config.rootdir + '/tr_test_state.txt' if config.option.tr_regen_state: if len(_all_reports) != session.testscollected: raise SystemExit('refusing to regenerate TR state with incomplete test run') new_trf = TestResultFile(path, results) new_trf.save() if config.option.tr_recheck: try: trf = TestResultFile.from_path(path) except OSError: print('no existing test result file at %r, nothing to recheck against' % path) trf.update(results) trf.save() return class TestResultFile(object): def __init__(self, path, results, intro_lines=()): self.results = OMD(sorted(results)) self.path = path self.intro_lines = intro_lines def get_failing_nodeids(self): return [nodeid for nodeid, res in self.results.items() if res == 'failed'] def update(self, new_results): for nodeid, outcome in new_results: self.results.add(nodeid, outcome) self.results = self.results.sorted() @classmethod def from_path(cls, path): with open(path) as f: contents = f.read() contents_lines = contents.splitlines() intro_lines = [] results = [] intro_done = False for line in contents_lines: line = line.strip() if not line: continue if not intro_done and line.startswith('#'): intro_lines.append(line[2:] if line.startswith('# ') else line[1:]) else: intro_done = True result, _, nodeid = line.partition(' - ') results.append((nodeid, result)) return cls(path, results, intro_lines=intro_lines) def save(self): lines = [] for line in self.intro_lines: lines.append('# %s\n' % line) for nodeid, result in self.results.items(): lines.append('%s - %s\n' % (result, nodeid)) with atomic_save(self.path) as f: f.writelines([line.encode('utf8') for line in lines]) return