diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2dbbdd..505b3ad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,7 +44,8 @@ jobs: - uses: ./.github/actions/python-poetry-env with: python-version: ${{ matrix.python-version }} - - run: poetry run pytest + - run: poetry run coverage run -m pytest + - run: poetry run coverage report docs: runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index 1391dff..1f1e137 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Changed +- `AlgorithmBase.__call__` now takes a single `durations: dict[Item, float]` + argument instead of separate `items` and `durations` arguments. Custom + algorithm subclasses must update their signature. Use the new public + `pytest_split.algorithms.compute_durations(items, cached_durations)` helper to + build the dict the same way the plugin does. +- Algorithms now own only group membership; the order of `selected` items in + the returned `TestGroup`s is implementation-defined. The plugin rebuilds the + chosen group's `selected` and `deselected` lists in pytest's collection + order before the test session executes, so end-to-end behaviour is + unchanged. + ### Fixed - Fix malformed bullet points rendering in GitHub Pages documentation diff --git a/pyproject.toml b/pyproject.toml index be14d7d..164e854 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,16 +63,12 @@ pytest-split = "pytest_split.plugin" target-version = ["py310", "py311", "py312", "py313", "py314"] include = '\.pyi?$' -[tool.pytest.ini_options] -addopts = """\ - --cov pytest_split \ - --cov tests \ - --cov-report term-missing \ - --no-cov-on-fail \ -""" +[tool.coverage.run] +source = ["pytest_split"] [tool.coverage.report] -fail_under = 90 +fail_under = 95 +show_missing = true exclude_lines = [ 'if TYPE_CHECKING:', 'pragma: no cover' diff --git a/src/pytest_split/algorithms.py b/src/pytest_split/algorithms.py index 83223b3..27c9227 100644 --- a/src/pytest_split/algorithms.py +++ b/src/pytest_split/algorithms.py @@ -1,7 +1,6 @@ import enum import heapq from abc import ABC, abstractmethod -from operator import itemgetter from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: @@ -19,7 +18,7 @@ class AlgorithmBase(ABC): @abstractmethod def __call__( - self, splits: int, items: "list[nodes.Item]", durations: "dict[str, float]" + self, splits: int, durations: "dict[nodes.Item, float]" ) -> "list[TestGroup]": pass @@ -42,47 +41,42 @@ class LeastDurationAlgorithm(AlgorithmBase): maintaining the original order of items. It is therefore important that the order of items be identical on all nodes that use this plugin. Due to issue #25 this might not always be the case. + The order of ``selected`` items in each returned group is implementation-defined; the plugin reorders the chosen + group in pytest's collection order before execution. + :param splits: How many groups we're splitting in. - :param items: Test items passed down by Pytest. - :param durations: Our cached test runtimes. Assumes contains timings only of relevant tests + :param durations: Mapping from each test item to its duration. Build it with :func:`compute_durations`. :return: List of groups """ def __call__( - self, splits: int, items: "list[nodes.Item]", durations: "dict[str, float]" + self, splits: int, durations: "dict[nodes.Item, float]" ) -> "list[TestGroup]": - items_with_durations = _get_items_with_durations(items, durations) - - # add index of item in list - items_with_durations_indexed = [ - (*tup, i) for i, tup in enumerate(items_with_durations) - ] - # Sort by name to ensure it's always the same order - items_with_durations_indexed = sorted( - items_with_durations_indexed, key=lambda tup: str(tup[0]) + items_with_durations = sorted( + durations.items(), key=lambda tup: tup[0].nodeid ) # sort in ascending order sorted_items_with_durations = sorted( - items_with_durations_indexed, key=lambda tup: tup[1], reverse=True + items_with_durations, key=lambda tup: tup[1], reverse=True ) - selected: list[list[tuple[nodes.Item, int]]] = [[] for _ in range(splits)] + selected: list[list[nodes.Item]] = [[] for _ in range(splits)] deselected: list[list[nodes.Item]] = [[] for _ in range(splits)] duration: list[float] = [0 for _ in range(splits)] # create a heap of the form (summed_durations, group_index) heap: list[tuple[float, int]] = [(0, i) for i in range(splits)] heapq.heapify(heap) - for item, item_duration, original_index in sorted_items_with_durations: + for item, item_duration in sorted_items_with_durations: # get group with smallest sum summed_durations, group_idx = heapq.heappop(heap) new_group_durations = summed_durations + item_duration # store assignment - selected[group_idx].append((item, original_index)) + selected[group_idx].append(item) duration[group_idx] = new_group_durations for i in range(splits): if i != group_idx: @@ -91,19 +85,12 @@ def __call__( # store new duration - in case of ties it sorts by the group_idx heapq.heappush(heap, (new_group_durations, group_idx)) - groups = [] - for i in range(splits): - # sort the items by their original index to maintain relative ordering - # we don't care about the order of deselected items - s = [ - item - for item, original_index in sorted(selected[i], key=lambda tup: tup[1]) - ] - group = TestGroup( - selected=s, deselected=deselected[i], duration=duration[i] + return [ + TestGroup( + selected=selected[i], deselected=deselected[i], duration=duration[i] ) - groups.append(group) - return groups + for i in range(splits) + ] class DurationBasedChunksAlgorithm(AlgorithmBase): @@ -114,23 +101,21 @@ class DurationBasedChunksAlgorithm(AlgorithmBase): and creating group_1 = items[0:i_0], group_2 = items[i_0, i_1], group_3 = items[i_1, i_2], ... :param splits: How many groups we're splitting in. - :param items: Test items passed down by Pytest. - :param durations: Our cached test runtimes. Assumes contains timings only of relevant tests + :param durations: Mapping from each test item to its duration. Build it with :func:`compute_durations`. :return: List of TestGroup """ def __call__( - self, splits: int, items: "list[nodes.Item]", durations: "dict[str, float]" + self, splits: int, durations: "dict[nodes.Item, float]" ) -> "list[TestGroup]": - items_with_durations = _get_items_with_durations(items, durations) - time_per_group = sum(map(itemgetter(1), items_with_durations)) / splits + time_per_group = sum(durations.values()) / splits selected: list[list[nodes.Item]] = [[] for i in range(splits)] deselected: list[list[nodes.Item]] = [[] for i in range(splits)] duration: list[float] = [0 for i in range(splits)] group_idx = 0 - for item, item_duration in items_with_durations: + for item, item_duration in durations.items(): if duration[group_idx] >= time_per_group: group_idx += 1 @@ -148,33 +133,43 @@ def __call__( ] -def _get_items_with_durations( - items: "list[nodes.Item]", durations: "dict[str, float]" -) -> "list[tuple[nodes.Item, float]]": - durations = _remove_irrelevant_durations(items, durations) - avg_duration_per_test = _get_avg_duration_per_test(durations) - items_with_durations = [ - (item, durations.get(item.nodeid, avg_duration_per_test)) for item in items - ] - return items_with_durations - +def compute_durations( + items: "list[nodes.Item]", cached_durations: "dict[str, float]" +) -> "dict[nodes.Item, float]": + """ + Build the splitting input from collected items and their cached durations. -def _get_avg_duration_per_test(durations: "dict[str, float]") -> float: - if durations: - avg_duration_per_test = sum(durations.values()) / len(durations) + Items missing from ``cached_durations`` get the average duration of the + cached entries that are relevant to this suite; with no cached data at + all, every item gets ``1`` as a placeholder. + """ + # Filtering down durations to relevant ones ensures the avg isn't skewed by irrelevant data + relevant = { + item.nodeid: cached_durations[item.nodeid] + for item in items + if item.nodeid in cached_durations + } + if relevant: + avg = sum(relevant.values()) / len(relevant) else: # If there are no durations, give every test the same arbitrary value - avg_duration_per_test = 1 - return avg_duration_per_test + avg = 1 + return {item: relevant.get(item.nodeid, avg) for item in items} -def _remove_irrelevant_durations( - items: "list[nodes.Item]", durations: "dict[str, float]" -) -> "dict[str, float]": - # Filtering down durations to relevant ones ensures the avg isn't skewed by irrelevant data - test_ids = [item.nodeid for item in items] - durations = {name: durations[name] for name in test_ids if name in durations} - return durations +def select_in_collection_order( + group: TestGroup, items: "list[nodes.Item]" +) -> TestGroup: + """ + Rebuild ``group`` so that ``selected`` and ``deselected`` filter + ``items`` in their original collection order, keyed on nodeid. + """ + selected_ids = {it.nodeid for it in group.selected} + return TestGroup( + selected=[it for it in items if it.nodeid in selected_ids], + deselected=[it for it in items if it.nodeid not in selected_ids], + duration=group.duration, + ) class Algorithms(enum.Enum): diff --git a/src/pytest_split/plugin.py b/src/pytest_split/plugin.py index 9140936..ace9abc 100644 --- a/src/pytest_split/plugin.py +++ b/src/pytest_split/plugin.py @@ -160,8 +160,9 @@ def pytest_collection_modifyitems( group_idx: int = config.option.group algo = algorithms.Algorithms[config.option.splitting_algorithm].value - groups = algo(splits, items, self.cached_durations) - group = groups[group_idx - 1] + durations = algorithms.compute_durations(items, self.cached_durations) + groups = algo(splits, durations) + group = algorithms.select_in_collection_order(groups[group_idx - 1], items) ensure_ipynb_compatibility(group, items) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index a02b6ed..20b30c3 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -10,6 +10,9 @@ from pytest_split.algorithms import ( AlgorithmBase, Algorithms, + TestGroup, + compute_durations, + select_in_collection_order, ) item = namedtuple("item", "nodeid") # noqa: PYI024 @@ -21,20 +24,27 @@ def test__split_test(self, algo_name): durations = {"a": 1, "b": 1, "c": 1} items = [item(x) for x in durations] algo = Algorithms[algo_name].value - first, second, third = algo(splits=3, items=items, durations=durations) - # each split should have one test - assert first.selected == [item("a")] - assert first.deselected == [item("b"), item("c")] - assert first.duration == 1 - - assert second.selected == [item("b")] - assert second.deselected == [item("a"), item("c")] - assert second.duration == 1 + first, second, third = algo( + splits=3, durations=compute_durations(items, durations) + ) - assert third.selected == [item("c")] - assert third.deselected == [item("a"), item("b")] - assert third.duration == 1 + # each split should have one test + assert first == TestGroup( + selected=[item("a")], + deselected=[item("b"), item("c")], + duration=1, + ) + assert second == TestGroup( + selected=[item("b")], + deselected=[item("a"), item("c")], + duration=1, + ) + assert third == TestGroup( + selected=[item("c")], + deselected=[item("a"), item("b")], + duration=1, + ) @pytest.mark.parametrize("algo_name", Algorithms.names()) def test__split_tests_handles_tests_in_durations_but_missing_from_items( @@ -43,39 +53,83 @@ def test__split_tests_handles_tests_in_durations_but_missing_from_items( durations = {"a": 1, "b": 1} items = [item(x) for x in ["a"]] algo = Algorithms[algo_name].value - splits = algo(splits=2, items=items, durations=durations) - first, second = splits - assert first.selected == [item("a")] - assert second.selected == [] + first, second = algo(splits=2, durations=compute_durations(items, durations)) + + assert first == TestGroup( + selected=[item("a")], deselected=[], duration=1 + ) + assert second == TestGroup( + selected=[], deselected=[item("a")], duration=0 + ) @pytest.mark.parametrize("algo_name", Algorithms.names()) def test__split_tests_handles_tests_with_missing_durations(self, algo_name): durations = {"a": 1} items = [item(x) for x in ["a", "b"]] algo = Algorithms[algo_name].value - splits = algo(splits=2, items=items, durations=durations) - first, second = splits - assert first.selected == [item("a")] - assert second.selected == [item("b")] + first, second = algo(splits=2, durations=compute_durations(items, durations)) + + assert first == TestGroup( + selected=[item("a")], deselected=[item("b")], duration=1 + ) + assert second == TestGroup( + selected=[item("b")], deselected=[item("a")], duration=1 + ) def test__split_test_handles_large_duration_at_end(self): """NOTE: only least_duration does this correctly""" durations = {"a": 1, "b": 1, "c": 1, "d": 3} items = [item(x) for x in ["a", "b", "c", "d"]] algo = Algorithms["least_duration"].value - splits = algo(splits=2, items=items, durations=durations) - first, second = splits - assert first.selected == [item("d")] - assert second.selected == [item(x) for x in ["a", "b", "c"]] + first, second = algo(splits=2, durations=compute_durations(items, durations)) + + assert first == TestGroup( + selected=[item("d")], + deselected=[item("a"), item("b"), item("c")], + duration=3, + ) + assert second == TestGroup( + selected=[item("a"), item("b"), item("c")], + deselected=[item("d")], + duration=3, + ) @pytest.mark.parametrize( ("algo_name", "expected"), [ - ("duration_based_chunks", [[item("a"), item("b")], [item("c"), item("d")]]), - ("least_duration", [[item("a"), item("c")], [item("b"), item("d")]]), + ( + "duration_based_chunks", + [ + TestGroup( + selected=[item("a"), item("b")], + deselected=[item("c"), item("d")], + duration=2, + ), + TestGroup( + selected=[item("c"), item("d")], + deselected=[item("a"), item("b")], + duration=2, + ), + ], + ), + ( + "least_duration", + [ + TestGroup( + selected=[item("a"), item("c")], + deselected=[item("b"), item("d")], + duration=2, + ), + TestGroup( + selected=[item("b"), item("d")], + deselected=[item("a"), item("c")], + duration=2, + ), + ], + ), ], ) def test__split_tests_calculates_avg_test_duration_only_on_present_tests( @@ -88,23 +142,45 @@ def test__split_tests_calculates_avg_test_duration_only_on_present_tests( durations = {"b": 1, "c": 1, "d": 1, "e": 10000} items = [item(x) for x in ["a", "b", "c", "d"]] algo = Algorithms[algo_name].value - splits = algo(splits=2, items=items, durations=durations) - first, second = splits - expected_first, expected_second = expected - assert first.selected == expected_first - assert second.selected == expected_second + groups = algo(splits=2, durations=compute_durations(items, durations)) + + assert groups == expected @pytest.mark.parametrize( ("algo_name", "expected"), [ ( "duration_based_chunks", - [[item("a"), item("b"), item("c"), item("d"), item("e")], []], + [ + TestGroup( + selected=[item(x) for x in "abcde"], + deselected=[], + duration=10014, + ), + TestGroup( + selected=[], + deselected=[item(x) for x in "abcde"], + duration=0, + ), + ], ), ( "least_duration", - [[item("e")], [item("a"), item("b"), item("c"), item("d")]], + # selected/deselected are in heap-pop order (duration desc) + # since the algorithm no longer restores input order. + [ + TestGroup( + selected=[item("e")], + deselected=[item(x) for x in "dcba"], + duration=10000, + ), + TestGroup( + selected=[item(x) for x in "dcba"], + deselected=[item("e")], + duration=14, + ), + ], ), ], ) @@ -112,12 +188,10 @@ def test__split_tests_maintains_relative_order_of_tests(self, algo_name, expecte durations = {"a": 2, "b": 3, "c": 4, "d": 5, "e": 10000} items = [item(x) for x in ["a", "b", "c", "d", "e"]] algo = Algorithms[algo_name].value - splits = algo(splits=2, items=items, durations=durations) - first, second = splits - expected_first, expected_second = expected - assert first.selected == expected_first - assert second.selected == expected_second + groups = algo(splits=2, durations=compute_durations(items, durations)) + + assert groups == expected def test__split_tests_same_set_regardless_of_order(self): """NOTE: only least_duration does this correctly""" @@ -128,7 +202,10 @@ def test__split_tests_same_set_regardless_of_order(self): for n in (2, 3, 4): selected_each: list[set[Item]] = [set() for _ in range(n)] for order in itertools.permutations(items): - splits = algo(splits=n, items=order, durations=durations) + splits = algo( + splits=n, + durations=compute_durations(list(order), durations), + ) for i, group in enumerate(splits): if not selected_each[i]: selected_each[i] = set(group.selected) @@ -139,13 +216,53 @@ def test__algorithms_members_derived_correctly(self): assert issubclass(Algorithms[a].value.__class__, AlgorithmBase) +class TestComputeDurations: + def test_uses_real_durations_avg_fills_missing_ignores_irrelevant(self): + # "ghost" isn't in the suite so it's excluded from the avg, and "c" + # gets the avg of "a" and "b": (2.0 + 4.0) / 2 = 3.0. + items = [item("a"), item("b"), item("c")] + cached = {"a": 2.0, "b": 4.0, "ghost": 10000.0} + assert compute_durations(items, cached) == { + item("a"): 2.0, + item("b"): 4.0, + item("c"): 3.0, + } + + def test_falls_back_to_one_when_no_relevant_durations(self): + assert compute_durations([item("a"), item("b")], {}) == { + item("a"): 1, + item("b"): 1, + } + + def test_returned_dict_iterates_in_input_order(self): + items = [item("c"), item("a"), item("b")] + assert list(compute_durations(items, {"a": 1, "b": 2, "c": 3})) == items + + +class TestSelectInCollectionOrder: + def test_rebuilds_selected_and_deselected_in_input_order(self): + items = [item("a"), item("b"), item("c"), item("d")] + # Algorithm returned membership in some other order. + group = TestGroup( + selected=[item("c"), item("a")], + deselected=[item("d"), item("b")], + duration=5.0, + ) + + result = select_in_collection_order(group, items) + + assert result.selected == [item("a"), item("c")] + assert result.deselected == [item("b"), item("d")] + assert result.duration == 5.0 + + class MyAlgorithm(AlgorithmBase): - def __call__(self, a, b, c): + def __call__(self, a, b): """no-op""" class MyOtherAlgorithm(AlgorithmBase): - def __call__(self, a, b, c): + def __call__(self, a, b): """no-op""" diff --git a/tests/test_ipynb.py b/tests/test_ipynb.py index 9bc2000..c60aa11 100644 --- a/tests/test_ipynb.py +++ b/tests/test_ipynb.py @@ -1,7 +1,7 @@ from collections import namedtuple import pytest -from pytest_split.algorithms import Algorithms +from pytest_split.algorithms import Algorithms, compute_durations from pytest_split.ipynb_compatibility import ensure_ipynb_compatibility item = namedtuple("item", "nodeid") # noqa: PYI024 @@ -29,7 +29,7 @@ def test_ensure_ipynb_compatibility(self, algo_name): } items = [item(x) for x in durations] algo = Algorithms[algo_name].value - groups = algo(splits=3, items=items, durations=durations) + groups = algo(splits=3, durations=compute_durations(items, durations)) assert groups[0].selected == [ item(nodeid="temp/nbs/test_1.ipynb::Cell 0"),