Skip to content

Instantly share code, notes, and snippets.

@mahmoud
Last active March 10, 2021 08:51
Show Gist options
  • Save mahmoud/10f6b6b0a9c5860030693357124131df to your computer and use it in GitHub Desktop.
Save mahmoud/10f6b6b0a9c5860030693357124131df to your computer and use it in GitHub Desktop.

Revisions

  1. mahmoud revised this gist Mar 10, 2021. No changes.
  2. mahmoud revised this gist Mar 10, 2021. No changes.
  3. mahmoud created this gist Mar 10, 2021.
    268 changes: 268 additions & 0 deletions conftest.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,268 @@
    # -*- coding: utf-8 -*-
    """
    Common test fixtures for pytest
    """
    from __future__ import print_function, unicode_literals

    import os
    from itertools import groupby

    import pytest
    from django.test import TransactionTestCase, TestCase
    from django.test.testcases import connections_support_transactions
    from boltons.dictutils import OMD
    from boltons.fileutils import atomic_save
    from six import PY3

    from pytest_django.plugin import _blocking_manager, validate_django_db
    from django.db.backends.base.base import BaseDatabaseWrapper

    # Side effects on import. Can't be helped, we need to unblock all DB
    # accesses
    _blocking_manager.unblock()
    _blocking_manager._blocking_wrapper = BaseDatabaseWrapper.ensure_connection

    IS_TR_ENV = bool(PY3)

    TEST_PARTITION_OPTION = "--test_partition"
    TOTAL_TESTS_OPTION = "--total_test_partitions"


    def _patchedTearDownClass(cls):
    """
    Overrides tearDownClass from TestCase
    """
    if connections_support_transactions():
    cls._rollback_atomics(cls.cls_atomics)
    super(TransactionTestCase, cls).tearDownClass()


    TestCase.tearDownClass = classmethod(_patchedTearDownClass)


    @pytest.fixture(autouse=True)
    def enable_db_access_for_all_tests(db):
    """
    Force the test to provide DB access to all tests.
    We probably shouldn't do this, but we have no idea what tests use the DB
    and what doesn't at this time
    This also forces the default TestCase environment to be djangos TestCase
    with atomic transaction support:
    https://github.com/pytest-dev/pytest-django/blob/ce5d5bc0b29748ed411b9d683a33e1b13d98e17f/pytest_django/fixtures.py#L149
    """
    pass


    @pytest.fixture(scope='session')
    def _transaction_wrap(django_db_setup):
    # wraps the following client data setup in an exterior transaction separate
    # from the transaction around the individual tests themselves
    TestCase._enter_atomics()


    def pytest_addoption(parser):
    parser.addoption(TEST_PARTITION_OPTION, type=int, default=os.getenv('TEST_PARTITION') or 0)
    parser.addoption(TOTAL_TESTS_OPTION, type=int, default=os.getenv('TOTAL_TEST_PARTITIONS', 1))
    parser.addoption('--tr-regen-state', action="store_true", default=False,
    help="regenerate list of skipped/failing tests into a new skipfile, for tech refresh")
    parser.addoption('--tr-no-autoskip', action='store_true', default=False,
    help='do not autoskip known failing tests based on tr_test_state.txt')
    parser.addoption('--tr-recheck', action='store_true', default=False,
    help='update TR state file for any passing tests. meant to be used with'
    ' test pattern filtering and implies --tr-no-autoskip')
    return


    def pytest_collection_modifyitems(config, items):
    def get_marker_transaction(test):
    marker = test.get_closest_marker('django_db')
    if marker:
    transaction, _ = validate_django_db(marker)
    return transaction

    return None

    def has_fixture(test, fixture):
    funcargnames = getattr(test, 'funcargnames', None)
    return funcargnames and fixture in funcargnames

    def run_transaction_test_cases_after_all_other_tests(test):
    """
    Detect if a test case is marked as a transaction test case, and
    if so, make sure to run it last since transaction test cases
    truncate the database (and thus leave no data for "non-transaction"
    test cases to act on.)
    Part of the teardown for djangos TransactionTestCase does this:
    https://github.com/django/django/blob/b61ea56789a5825bd2961a335cb82f65e09f1614/django/test/testcases.py#L1000
    """
    is_test_case_subclass = getattr(
    test, 'cls', None) and issubclass(test.cls, TestCase)
    is_transaction_test_case_subclass = getattr(
    test, 'cls', None) and issubclass(test.cls, TransactionTestCase)

    if is_test_case_subclass or get_marker_transaction(test) is False:
    return 0
    elif is_transaction_test_case_subclass or get_marker_transaction(test) is True:
    return 1
    elif has_fixture(test, 'transactional_db') or has_fixture(test, 'live_server'):
    # live_server uses transactional_db. So same truncation.
    return 1
    elif has_fixture(test, 'db'):
    return 0
    return 0

    def sort_by_app_and_test_folder_name(test):
    if test.cls:
    return test.cls.__module__
    elif test.function:
    return test.function.__module__

    def group_by_test_case(test):
    if test.cls:
    return test.cls.__name__
    else:
    return ""

    key_funcs = (run_transaction_test_cases_after_all_other_tests,
    sort_by_app_and_test_folder_name,
    group_by_test_case,)

    for key_func in key_funcs:
    items.sort(key=key_func)

    TEST_NUMBER = config.getoption(TEST_PARTITION_OPTION)
    TOTAL_TESTS = config.getoption(TOTAL_TESTS_OPTION)

    # assign each group of test cases a number
    temp_new_list = []
    for index, (group_by_name, tests) in enumerate(groupby(items, group_by_test_case)):
    if group_by_name == "":
    for inner_index, test in enumerate(tests):
    if inner_index % TOTAL_TESTS == TEST_NUMBER:
    temp_new_list.append(test)

    else:
    if index % TOTAL_TESTS == TEST_NUMBER:
    for test in tests:
    temp_new_list.append(test)

    items[:] = temp_new_list

    if not config.getoption("--needs-isolation"):
    skip_needs_isolation = pytest.mark.skip(
    reason="need --needs-isolation option to run")
    for item in items:
    if "needs_isolation" in item.keywords:
    item.add_marker(skip_needs_isolation)

    skip_tr_tests = not (config.option.tr_recheck or config.option.tr_no_autoskip or config.option.tr_regen_state)
    if IS_TR_ENV and skip_tr_tests:
    trf_path = config.rootdir + '/tr_test_state.txt'
    try:
    trf = TestResultFile.from_path(trf_path)
    except OSError:
    pass
    else:
    skip_tr_failure = pytest.mark.xfail(
    reason='known failure related to tech refresh')
    fns = set(trf.get_failing_nodeids())
    for item in items:
    if item.nodeid in fns:
    item.add_marker(skip_tr_failure)
    return

    _all_reports = OMD()

    @pytest.hookimpl(tryfirst=True)
    def pytest_runtest_logreport(report):
    # this function runs 3x for each test: setup, call, teardown.
    # this approach ensures that if any phase fails, this test stays marked as failed
    if _all_reports.get(report.nodeid) != 'failed':
    _all_reports.add(report.nodeid, report.outcome)
    return


    @pytest.fixture(scope="session", autouse=True)
    def tr_state_save(request):
    session = request.node
    config = session.config
    if config.option.tr_regen_state:
    if config.option.tr_recheck:
    raise SystemExit('--tr-regen-state is mutually exclusive with --tr-recheck')
    args = [arg for arg in config.args if arg]
    if config.option.keyword:
    raise SystemExit('refusing to regenerate TR state while running a subset of tests')

    yield

    if not IS_TR_ENV or not _all_reports:
    return # not tech refreshing / process aborted

    results = _all_reports.items()
    path = config.rootdir + '/tr_test_state.txt'

    if config.option.tr_regen_state:
    if len(_all_reports) != session.testscollected:
    raise SystemExit('refusing to regenerate TR state with incomplete test run')
    new_trf = TestResultFile(path, results)
    new_trf.save()
    if config.option.tr_recheck:
    try:
    trf = TestResultFile.from_path(path)
    except OSError:
    print('no existing test result file at %r, nothing to recheck against' % path)

    trf.update(results)
    trf.save()
    return


    class TestResultFile(object):
    def __init__(self, path, results, intro_lines=()):
    self.results = OMD(sorted(results))
    self.path = path
    self.intro_lines = intro_lines

    def get_failing_nodeids(self):
    return [nodeid for nodeid, res in self.results.items() if res == 'failed']

    def update(self, new_results):
    for nodeid, outcome in new_results:
    self.results.add(nodeid, outcome)
    self.results = self.results.sorted()

    @classmethod
    def from_path(cls, path):
    with open(path) as f:
    contents = f.read()
    contents_lines = contents.splitlines()
    intro_lines = []
    results = []
    intro_done = False
    for line in contents_lines:
    line = line.strip()
    if not line:
    continue
    if not intro_done and line.startswith('#'):
    intro_lines.append(line[2:] if line.startswith('# ') else line[1:])
    else:
    intro_done = True
    result, _, nodeid = line.partition(' - ')
    results.append((nodeid, result))
    return cls(path, results, intro_lines=intro_lines)

    def save(self):
    lines = []
    for line in self.intro_lines:
    lines.append('# %s\n' % line)
    for nodeid, result in self.results.items():
    lines.append('%s - %s\n' % (result, nodeid))

    with atomic_save(self.path) as f:
    f.writelines([line.encode('utf8') for line in lines])

    return