|
2 | 2 |
|
3 | 3 | from collections.abc import Sequence |
4 | 4 | from types import ModuleType |
5 | | -from typing import Literal |
| 5 | +from typing import Literal, cast |
6 | 6 |
|
7 | 7 | from ._lib import _funcs |
8 | 8 | from ._lib._utils._compat import ( |
|
20 | 20 |
|
21 | 21 | __all__ = [ |
22 | 22 | "atleast_nd", |
| 23 | + "broadcast_shapes", |
23 | 24 | "cov", |
24 | 25 | "create_diagonal", |
25 | 26 | "expand_dims", |
@@ -81,6 +82,64 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array |
81 | 82 | return _funcs.atleast_nd(x, ndim=ndim, xp=xp) |
82 | 83 |
|
83 | 84 |
|
| 85 | +def _numpy_broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...] | None: |
| 86 | + try: |
| 87 | + import numpy as np |
| 88 | + except ImportError: |
| 89 | + return None |
| 90 | + |
| 91 | + return np.broadcast_shapes(*shapes) |
| 92 | + |
| 93 | + |
| 94 | +def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]: |
| 95 | + """ |
| 96 | + Compute the shape of the broadcasted arrays. |
| 97 | +
|
| 98 | + Duplicates :func:`numpy.broadcast_shapes`, with additional support for |
| 99 | + None and NaN sizes. |
| 100 | +
|
| 101 | + This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape`` |
| 102 | + without needing to worry about the backend potentially deep copying |
| 103 | + the arrays. |
| 104 | +
|
| 105 | + Parameters |
| 106 | + ---------- |
| 107 | + *shapes : tuple[int | None, ...] |
| 108 | + Shapes of the arrays to broadcast. |
| 109 | +
|
| 110 | + Returns |
| 111 | + ------- |
| 112 | + tuple[int | None, ...] |
| 113 | + The shape of the broadcasted arrays. |
| 114 | +
|
| 115 | + See Also |
| 116 | + -------- |
| 117 | + numpy.broadcast_shapes : Equivalent NumPy function. |
| 118 | + array_api.broadcast_arrays : Function to broadcast actual arrays. |
| 119 | +
|
| 120 | + Notes |
| 121 | + ----- |
| 122 | + This function accepts the Array API's ``None`` for unknown sizes, |
| 123 | + as well as Dask's non-standard ``math.nan``. |
| 124 | + Regardless of input, the output always contains ``None`` for unknown sizes. |
| 125 | +
|
| 126 | + Examples |
| 127 | + -------- |
| 128 | + >>> import array_api_extra as xpx |
| 129 | + >>> xpx.broadcast_shapes((2, 3), (2, 1)) |
| 130 | + (2, 3) |
| 131 | + >>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3)) |
| 132 | + (4, 2, 3) |
| 133 | + """ |
| 134 | + if all(isinstance(size, int) for shape in shapes for size in shape): |
| 135 | + int_shapes = cast(tuple[tuple[int, ...], ...], shapes) |
| 136 | + out = _numpy_broadcast_shapes(*int_shapes) |
| 137 | + if out is not None: |
| 138 | + return cast(tuple[int | None, ...], out) |
| 139 | + |
| 140 | + return _funcs.broadcast_shapes(*shapes) |
| 141 | + |
| 142 | + |
84 | 143 | def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: |
85 | 144 | """ |
86 | 145 | Estimate a covariance matrix (or a stack of covariance matrices). |
|
0 commit comments