diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 3af7b959..6954e5f9 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -156,6 +156,42 @@ def test_expand_dims(x, axis): raise +@given( + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(max_dims=4)), + axes=shared_shapes().flatmap( + lambda s: st.lists( + st.integers(2*(-len(s)-1), 2*len(s)), + min_size=0 if len(s)==0 else 1, + max_size=len(s) + ).map(tuple) + ) +) +def test_expand_dims_tuples(x, axes): + # normalize the axes + y_ndim = x.ndim + len(axes) + n_axes = tuple(ax + y_ndim if ax < 0 else ax for ax in axes) + unique_axes = set(n_axes) + + if any(ax < 0 or ax >= y_ndim for ax in n_axes) or len(n_axes) != len(unique_axes): + with pytest.raises((IndexError, ValueError)): + xp.expand_dims(x, axis=axes) + return + + repro_snippet = ph.format_snippet(f"xp.expand_dims({x!r}, axis={axes!r})") + try: + y = xp.expand_dims(x, axis=axes) + + ye = x + for ax in sorted(n_axes): + ye = xp.expand_dims(ye, axis=ax) + assert y.shape == ye.shape + # TODO value tests; check that y.shape is 1s and items from x.shape, in order + + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise + + @pytest.mark.min_version("2023.12") @given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), data=st.data()) def test_moveaxis(x, data):