|
1 | 1 | from . import _ddptensor as _cdt |
2 | 2 | from .ddptensor import float64, int64, fini, dtensor |
3 | 3 | from os import getenv |
| 4 | +from . import array_api as api |
4 | 5 |
|
5 | 6 | __impl_str = getenv("DDPNP_ARRAY", 'numpy') |
6 | 7 | exec(f"import {__impl_str} as __impl") |
|
77 | 78 | f"{op} = lambda this: dtensor(_cdt.ew_unary_op(this._t, '{op}', False))" |
78 | 79 | ) |
79 | 80 |
|
80 | | -creators_with_shape = [ |
81 | | - "empty", # (shape, *, dtype=None, device=None) |
82 | | - "full", # (shape, fill_value, *, dtype=None, device=None) |
83 | | - "ones", # (shape, *, dtype=None, device=None) |
84 | | - "zeros", # (shape, *, dtype=None, device=None) |
85 | | -] |
86 | | - |
87 | | -for func in creators_with_shape: |
88 | | - exec( |
89 | | - f"{func} = lambda shape, *args, **kwargs: dtensor(_cdt.create(shape, '{func}', '{__impl_str}', *args, **kwargs))" |
90 | | - ) |
| 81 | +for func in api.creators: |
| 82 | + if func in ["empty", "full", "ones", "zeros",]: |
| 83 | + exec( |
| 84 | + f"{func} = lambda shape, *args, **kwargs: dtensor(_cdt.create(shape, '{func}', '{__impl_str}', *args, **kwargs))" |
| 85 | + ) |
91 | 86 |
|
92 | 87 | statisticals = [ |
93 | 88 | "max", # (x, /, *, axis=None, keepdims=False) |
|
103 | 98 | exec( |
104 | 99 | f"{func} = lambda this, **kwargs: dtensor(_cdt.reduce_op(this._t, '{func}', **kwargs))" |
105 | 100 | ) |
106 | | - |
107 | | - |
108 | | -creators = [ |
109 | | - "arange", # (start, /, stop=None, step=1, *, dtype=None, device=None) |
110 | | - "asarray", # (obj, /, *, dtype=None, device=None, copy=None) |
111 | | - "empty_like", # (x, /, *, dtype=None, device=None) |
112 | | - "eye", # (n_rows, n_cols=None, /, *, k=0, dtype=None, device=None) |
113 | | - "from_dlpack", # (x, /) |
114 | | - "full_like", # (x, /, fill_value, *, dtype=None, device=None) |
115 | | - "linspace", # (start, stop, /, num, *, dtype=None, device=None, endpoint=True) |
116 | | - "meshgrid", # (*arrays, indexing=’xy’) |
117 | | - "ones_like", # (x, /, *, dtype=None, device=None) |
118 | | - "zeros_like", # (x, /, *, dtype=None, device=None) |
119 | | -] |
0 commit comments