|
14 | 14 | # are simply forwarded as-is. |
15 | 15 |
|
16 | 16 | from . import _ddptensor as _cdt |
17 | | -from ._ddptensor import float64, float32, int64, int32, int16, uint64, uint32, uint16, fini |
| 17 | +from ._ddptensor import ( |
| 18 | + FLOAT64 as float64, |
| 19 | + FLOAT32 as float32, |
| 20 | + INT64 as int64, |
| 21 | + INT32 as int32, |
| 22 | + INT16 as int16, |
| 23 | + INT8 as int8, |
| 24 | + UINT64 as uint64, |
| 25 | + UINT32 as uint32, |
| 26 | + UINT16 as uint16, |
| 27 | + UINT8 as uint8, |
| 28 | + fini |
| 29 | +) |
18 | 30 | from .ddptensor import dtensor |
19 | 31 | from os import getenv |
20 | 32 | from . import array_api as api |
21 | 33 | from . import spmd |
22 | 34 |
|
23 | | -for op in api.ew_binary_ops: |
24 | | - OP = op.upper() |
25 | | - exec( |
26 | | - f"{op} = lambda this, other: dtensor(_cdt.EWBinOp.op(_cdt.{OP}, this._t, other._t if isinstance(other, ddptensor) else other))" |
27 | | - ) |
| 35 | +for op in api.api_categories["EWBinOp"]: |
| 36 | + if not op.startswith("__"): |
| 37 | + OP = op.upper() |
| 38 | + exec( |
| 39 | + f"{op} = lambda this, other: dtensor(_cdt.EWBinOp.op(_cdt.{OP}, this._t, other._t if isinstance(other, ddptensor) else other))" |
| 40 | + ) |
28 | 41 |
|
29 | | -for op in api.ew_unary_ops: |
30 | | - OP = op.upper() |
31 | | - exec( |
32 | | - f"{op} = lambda this: dtensor(_cdt.EWUnyOp.op(_cdt.{OP}, this._t))" |
33 | | - ) |
| 42 | +for op in api.api_categories["EWUnyOp"]: |
| 43 | + if not op.startswith("__"): |
| 44 | + OP = op.upper() |
| 45 | + exec( |
| 46 | + f"{op} = lambda this: dtensor(_cdt.EWUnyOp.op(_cdt.{OP}, this._t))" |
| 47 | + ) |
34 | 48 |
|
35 | | -for func in api.creators: |
| 49 | +for func in api.api_categories["Creator"]: |
36 | 50 | FUNC = func.upper() |
37 | 51 | if func in ["empty", "ones", "zeros",]: |
38 | 52 | exec( |
|
43 | 57 | f"{func} = lambda shape, val, dtype: dtensor(_cdt.Creator.full(_cdt.{FUNC}, shape, val, dtype))" |
44 | 58 | ) |
45 | 59 |
|
46 | | -for func in api.statisticals: |
| 60 | +for func in api.api_categories["ReduceOp"]: |
47 | 61 | FUNC = func.upper() |
48 | 62 | exec( |
49 | 63 | f"{func} = lambda this, dim: dtensor(_cdt.ReduceOp.op(_cdt.{FUNC}, this._t, dim))" |
|
0 commit comments