diff --git a/flixopt/comparison.py b/flixopt/comparison.py index 7e3e983e1..a8c2076c8 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -19,7 +19,7 @@ ) if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import ItemsView, Iterator, KeysView, ValuesView from .flow_system import FlowSystem @@ -158,10 +158,19 @@ class Comparison: """ def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = None) -> None: + from .flow_system import FlowSystem + + if not isinstance(flow_systems, list): + raise TypeError(f'flow_systems must be a list, got {type(flow_systems).__name__}') + + non_fs = [(i, type(fs).__name__) for i, fs in enumerate(flow_systems) if not isinstance(fs, FlowSystem)] + if non_fs: + raise TypeError(f'flow_systems must contain only FlowSystem instances; got {non_fs} (index, type)') + if len(flow_systems) < 2: raise ValueError('Comparison requires at least 2 FlowSystems') - self._systems = flow_systems + self._systems: list[FlowSystem] = flow_systems self._names = names or [fs.name or f'System {i}' for i, fs in enumerate(flow_systems)] if len(self._names) != len(self._systems): @@ -224,14 +233,30 @@ def __getitem__(self, key: int | str) -> FlowSystem: return self._systems[idx] raise KeyError(f"Case '{key}' not found. Available: {self._names}") - def __iter__(self) -> Iterator[tuple[str, FlowSystem]]: - """Iterate over (name, FlowSystem) pairs.""" - yield from zip(self._names, self._systems, strict=True) + def __iter__(self) -> Iterator[str]: + """Iterate over case names, matching the ``dict`` / ``Mapping`` protocol. + + Use :meth:`items` for ``(name, FlowSystem)`` pairs or :meth:`values` + for FlowSystems. + """ + return iter(self._names) def __contains__(self, key: str) -> bool: """Check if a case name exists.""" return key in self._names + def keys(self) -> KeysView[str]: + """Return a view of case names, like :meth:`dict.keys`.""" + return self.flow_systems.keys() + + def values(self) -> ValuesView[FlowSystem]: + """Return a view of FlowSystems, like :meth:`dict.values`.""" + return self.flow_systems.values() + + def items(self) -> ItemsView[str, FlowSystem]: + """Return a view of ``(name, FlowSystem)`` pairs, like :meth:`dict.items`.""" + return self.flow_systems.items() + @property def flow_systems(self) -> dict[str, FlowSystem]: """Access underlying FlowSystems as a dict mapping name → FlowSystem.""" diff --git a/pyproject.toml b/pyproject.toml index 642ad5765..51191d099 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ full = [ # Development tools and testing dev = [ + "xarray<2026.3", # TODO: drop once linopy ships xarray 2026.3+ compat fix "tsam==3.3.0", # Time series aggregation for clustering "pytest==9.0.3", "pytest-xdist==3.8.0", diff --git a/tests/test_comparison.py b/tests/test_comparison.py index 7f7e7093e..94328da97 100644 --- a/tests/test_comparison.py +++ b/tests/test_comparison.py @@ -173,6 +173,16 @@ def test_comparison_rejects_unoptimized_system(self, base_flow_system, optimized with pytest.raises(RuntimeError, match='no solution'): _ = comp.solution + def test_comparison_rejects_non_list(self, optimized_base, optimized_with_chp): + """Comparison rejects non-list flow_systems input.""" + with pytest.raises(TypeError, match='must be a list'): + fx.Comparison((optimized_base, optimized_with_chp)) + + def test_comparison_rejects_non_flowsystem_items(self, optimized_base): + """Comparison rejects list items that are not FlowSystem instances.""" + with pytest.raises(TypeError, match='FlowSystem instances'): + fx.Comparison([optimized_base, 'not a flow system']) + # ============================================================================ # CONTAINER PROTOCOL TESTS @@ -212,11 +222,25 @@ def test_getitem_invalid_index_raises(self, optimized_base, optimized_with_chp): with pytest.raises(IndexError): _ = comp[99] - def test_iter(self, optimized_base, optimized_with_chp): - """Iteration yields (name, FlowSystem) pairs.""" + def test_iter_yields_names(self, optimized_base, optimized_with_chp): + """Iteration yields case names, matching the dict/Mapping protocol.""" + comp = fx.Comparison([optimized_base, optimized_with_chp]) + assert list(comp) == ['Base', 'WithCHP'] + + def test_keys(self, optimized_base, optimized_with_chp): + """keys() returns case names.""" + comp = fx.Comparison([optimized_base, optimized_with_chp]) + assert list(comp.keys()) == ['Base', 'WithCHP'] + + def test_values(self, optimized_base, optimized_with_chp): + """values() returns FlowSystems.""" + comp = fx.Comparison([optimized_base, optimized_with_chp]) + assert list(comp.values()) == [optimized_base, optimized_with_chp] + + def test_items(self, optimized_base, optimized_with_chp): + """items() returns (name, FlowSystem) pairs without warning.""" comp = fx.Comparison([optimized_base, optimized_with_chp]) - items = list(comp) - assert items == [('Base', optimized_base), ('WithCHP', optimized_with_chp)] + assert list(comp.items()) == [('Base', optimized_base), ('WithCHP', optimized_with_chp)] def test_contains(self, optimized_base, optimized_with_chp): """'in' operator checks for case name."""