Skip to content

Commit 5939f4b

Browse files
committed
ENH: delegate broadcast_shapes
1 parent bc126fa commit 5939f4b

4 files changed

Lines changed: 102 additions & 42 deletions

File tree

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._delegation import (
44
argpartition,
55
atleast_nd,
6+
broadcast_shapes,
67
cov,
78
create_diagonal,
89
expand_dims,
@@ -20,7 +21,6 @@
2021
from ._lib._at import at
2122
from ._lib._funcs import (
2223
apply_where,
23-
broadcast_shapes,
2424
default_dtype,
2525
kron,
2626
nunique,

src/array_api_extra/_delegation.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Sequence
44
from types import ModuleType
5-
from typing import Literal
5+
from typing import Literal, cast
66

77
from ._lib import _funcs
88
from ._lib._utils._compat import (
@@ -20,6 +20,7 @@
2020

2121
__all__ = [
2222
"atleast_nd",
23+
"broadcast_shapes",
2324
"cov",
2425
"create_diagonal",
2526
"expand_dims",
@@ -81,6 +82,64 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
8182
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
8283

8384

85+
def _numpy_broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...] | None:
86+
try:
87+
import numpy as np
88+
except ImportError:
89+
return None
90+
91+
return np.broadcast_shapes(*shapes)
92+
93+
94+
def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]:
95+
"""
96+
Compute the shape of the broadcasted arrays.
97+
98+
Duplicates :func:`numpy.broadcast_shapes`, with additional support for
99+
None and NaN sizes.
100+
101+
This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
102+
without needing to worry about the backend potentially deep copying
103+
the arrays.
104+
105+
Parameters
106+
----------
107+
*shapes : tuple[int | None, ...]
108+
Shapes of the arrays to broadcast.
109+
110+
Returns
111+
-------
112+
tuple[int | None, ...]
113+
The shape of the broadcasted arrays.
114+
115+
See Also
116+
--------
117+
numpy.broadcast_shapes : Equivalent NumPy function.
118+
array_api.broadcast_arrays : Function to broadcast actual arrays.
119+
120+
Notes
121+
-----
122+
This function accepts the Array API's ``None`` for unknown sizes,
123+
as well as Dask's non-standard ``math.nan``.
124+
Regardless of input, the output always contains ``None`` for unknown sizes.
125+
126+
Examples
127+
--------
128+
>>> import array_api_extra as xpx
129+
>>> xpx.broadcast_shapes((2, 3), (2, 1))
130+
(2, 3)
131+
>>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3))
132+
(4, 2, 3)
133+
"""
134+
if all(isinstance(size, int) for shape in shapes for size in shape):
135+
int_shapes = cast(tuple[tuple[int, ...], ...], shapes)
136+
out = _numpy_broadcast_shapes(*int_shapes)
137+
if out is not None:
138+
return cast(tuple[int | None, ...], out)
139+
140+
return _funcs.broadcast_shapes(*shapes)
141+
142+
84143
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
85144
"""
86145
Estimate a covariance matrix (or a stack of covariance matrices).

src/array_api_extra/_lib/_funcs.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -220,46 +220,10 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
220220

221221
# `float` in signature to accept `math.nan` for Dask.
222222
# `int`s are still accepted as `float` is a superclass of `int` in typing
223-
def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]:
224-
"""
225-
Compute the shape of the broadcasted arrays.
226-
227-
Duplicates :func:`numpy.broadcast_shapes`, with additional support for
228-
None and NaN sizes.
229-
230-
This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
231-
without needing to worry about the backend potentially deep copying
232-
the arrays.
233-
234-
Parameters
235-
----------
236-
*shapes : tuple[int | None, ...]
237-
Shapes of the arrays to broadcast.
238-
239-
Returns
240-
-------
241-
tuple[int | None, ...]
242-
The shape of the broadcasted arrays.
243-
244-
See Also
245-
--------
246-
numpy.broadcast_shapes : Equivalent NumPy function.
247-
array_api.broadcast_arrays : Function to broadcast actual arrays.
248-
249-
Notes
250-
-----
251-
This function accepts the Array API's ``None`` for unknown sizes,
252-
as well as Dask's non-standard ``math.nan``.
253-
Regardless of input, the output always contains ``None`` for unknown sizes.
254-
255-
Examples
256-
--------
257-
>>> import array_api_extra as xpx
258-
>>> xpx.broadcast_shapes((2, 3), (2, 1))
259-
(2, 3)
260-
>>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3))
261-
(4, 2, 3)
262-
"""
223+
def broadcast_shapes( # numpydoc ignore=PR01,RT01
224+
*shapes: tuple[float | None, ...],
225+
) -> tuple[int | None, ...]:
226+
"""See docstring in array_api_extra._delegation."""
263227
if not shapes:
264228
return () # Match NumPy output
265229

tests/test_funcs.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from hypothesis import strategies as st
1313
from typing_extensions import override
1414

15+
import array_api_extra._delegation as delegation
1516
from array_api_extra import (
1617
apply_where,
1718
argpartition,
@@ -489,6 +490,42 @@ def test_5D_values(self, xp: ModuleType):
489490

490491

491492
class TestBroadcastShapes:
493+
def test_delegates_known_integer_shapes(self, monkeypatch: pytest.MonkeyPatch):
494+
calls = []
495+
496+
def mock_broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]:
497+
calls.append(shapes)
498+
return (99,)
499+
500+
monkeypatch.setattr(
501+
delegation, "_numpy_broadcast_shapes", mock_broadcast_shapes
502+
)
503+
504+
assert broadcast_shapes((2,), (1,)) == (99,)
505+
assert calls == [((2,), (1,))]
506+
507+
def test_fallback_for_unknown_sizes(self, monkeypatch: pytest.MonkeyPatch):
508+
def mock_broadcast_shapes(*_shapes: tuple[int, ...]) -> tuple[int, ...]:
509+
msg = "NumPy delegation should not handle unknown sizes"
510+
raise AssertionError(msg)
511+
512+
monkeypatch.setattr(
513+
delegation, "_numpy_broadcast_shapes", mock_broadcast_shapes
514+
)
515+
516+
assert broadcast_shapes((None,), (1,)) == (None,)
517+
assert broadcast_shapes((math.nan,), (1,)) == (None,)
518+
519+
def test_fallback_without_numpy(self, monkeypatch: pytest.MonkeyPatch):
520+
def mock_broadcast_shapes(*_shapes: tuple[int, ...]) -> tuple[int, ...] | None:
521+
return None
522+
523+
monkeypatch.setattr(
524+
delegation, "_numpy_broadcast_shapes", mock_broadcast_shapes
525+
)
526+
527+
assert broadcast_shapes((2,), (1,)) == (2,)
528+
492529
@pytest.mark.parametrize(
493530
"args",
494531
[

0 commit comments

Comments
 (0)