diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 681bfb0e..3b23ad91 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -122,38 +122,74 @@ def test_concat(dtypes, base_shape, data): raise -@pytest.mark.unvectorized -@given( - x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()), - axis=shared_shapes().flatmap( - # Generate both valid and invalid axis - lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) - ), -) -def test_expand_dims(x, axis): - if axis < -x.ndim - 1 or axis > x.ndim: - with pytest.raises(IndexError): - xp.expand_dims(x, axis=axis) - return +class TestExpandDims: + @pytest.mark.unvectorized + @given( + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()), + axis=shared_shapes().flatmap( + # Generate both valid and invalid axis + lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) + ), + ) + def test_expand_dims(self, x, axis): + if axis < -x.ndim - 1 or axis > x.ndim: + with pytest.raises(IndexError): + xp.expand_dims(x, axis=axis) + return - repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axis!r})") - try: - out = xp.expand_dims(x, axis=axis) + repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axis!r})") + try: + out = xp.expand_dims(x, axis=axis) - ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype) - shape = [side for side in x.shape] - index = axis if axis >= 0 else x.ndim + axis + 1 - shape.insert(index, 1) - shape = tuple(shape) - ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape) - - assert_array_ndindex( - "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape) + shape = [side for side in x.shape] + index = axis if axis >= 0 else x.ndim + axis + 1 + shape.insert(index, 1) + shape = tuple(shape) + ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape) + + assert_array_ndindex( + "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape) + ) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise + + @given( + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(max_dims=4)), + axes=shared_shapes().flatmap( + lambda s: st.lists( + st.integers(2*(-len(s)-1), 2*len(s)), + min_size=0 if len(s)==0 else 1, + max_size=len(s) + ).map(tuple) ) - except Exception as exc: - ph.add_note(exc, repro_snippet) - raise + ) + def test_expand_dims_tuples(self, x, axes): + # normalize the axes + y_ndim = x.ndim + len(axes) + n_axes = tuple(ax + y_ndim if ax < 0 else ax for ax in axes) + unique_axes = set(n_axes) + + if any(ax < 0 or ax >= y_ndim for ax in n_axes) or len(n_axes) != len(unique_axes): + with pytest.raises((IndexError, ValueError)): + xp.expand_dims(x, axis=axes) + return + + repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axes!r})") + try: + y = xp.expand_dims(x, axis=axes) + + ye = x + for ax in sorted(n_axes): + ye = xp.expand_dims(ye, axis=ax) + assert y.shape == ye.shape + # TODO value tests; check that y.shape is 1s and items from x.shape, in order + + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise @pytest.mark.min_version("2023.12")