diff --git a/src/pyrecest/_backend/__init__.py b/src/pyrecest/_backend/__init__.py index a52e577fd..acdfbf22a 100644 --- a/src/pyrecest/_backend/__init__.py +++ b/src/pyrecest/_backend/__init__.py @@ -349,6 +349,30 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): return mean +def _is_empty_assignment_index(indices): + """Return whether ``indices`` selects no elements for assignment helpers.""" + if isinstance(indices, list): + return len(indices) == 0 + if isinstance(indices, tuple): + return False + + ndim = getattr(indices, "ndim", None) + shape = getattr(indices, "shape", None) + return ndim is not None and ndim > 0 and shape is not None and shape[0] == 0 + + +def _assignment_with_empty_indices_noop(assignment_func, copy_func): + """Return an assignment wrapper that treats empty indices as a no-op.""" + + @wraps(assignment_func) + def assignment(x, values, indices, axis=0): + if _is_empty_assignment_index(indices): + return copy_func(x) + return assignment_func(x, values, indices, axis=axis) + + return assignment + + class BackendImporter(importlib.abc.MetaPathFinder, importlib.abc.Loader): """ Meta path finder and loader for dynamically creating backend modules. @@ -442,6 +466,14 @@ def _create_backend_module(self, backend_name: str): attribute, getattr(backend, "asarray"), ) + if ( + module_name == "" + and attribute_name in {"assignment", "assignment_by_sum"} + ): + attribute = _assignment_with_empty_indices_noop( + attribute, + getattr(backend, "copy"), + ) setattr(new_submodule, attribute_name, attribute) for attribute_name in OPTIONAL_BACKEND_ATTRIBUTES.get(module_name, []): diff --git a/tests/test_backend_contract.py b/tests/test_backend_contract.py index c1e9db036..32dc8604d 100644 --- a/tests/test_backend_contract.py +++ b/tests/test_backend_contract.py @@ -27,6 +27,15 @@ def test_convert_to_wider_dtype_preserves_matching_boolean_dtype(self): self.assertEqual(to_numpy(first).dtype, np.dtype("bool")) self.assertEqual(to_numpy(second).dtype, np.dtype("bool")) + def test_assignment_with_empty_indices_is_a_noop(self): + original = array([1.0, 2.0, 3.0]) + + assigned = backend.assignment(original, 99.0, []) + added = backend.assignment_by_sum(original, 99.0, []) + + npt.assert_allclose(to_numpy(assigned), [1.0, 2.0, 3.0]) + npt.assert_allclose(to_numpy(added), [1.0, 2.0, 3.0]) + def test_choice_supports_numpy_like_size_replace_and_probabilities(self): values = array([0, 1, 2, 3]) weights = array([0.1, 0.2, 0.3, 0.4])