2020 FLOAT32 as float32 ,
2121 INT64 as int64 ,
2222 INT32 as int32 ,
23- INT16 as int16 ,
23+ INT16 as int16 ,
2424 INT8 as int8 ,
2525 UINT64 as uint64 ,
2626 UINT32 as uint32 ,
2929 BOOL as bool ,
3030 init as _init ,
3131 fini ,
32- sync
32+ sync ,
3333)
3434
3535from .ddptensor import dtensor
3636from os import getenv
3737from . import array_api as api
3838from . import spmd
3939
40- _ddpt_cw = _bool (int (getenv ('DDPT_CW' , True )))
40+ _ddpt_cw = _bool (int (getenv ("DDPT_CW" , True )))
41+
4142
4243def init (cw = None ):
4344 cw = _ddpt_cw if cw is None else cw
4445 _init (cw )
4546
47+
4648def to_numpy (a ):
4749 return _cdt .to_numpy (a ._t )
4850
51+
4952for op in api .api_categories ["EWBinOp" ]:
5053 if not op .startswith ("__" ):
5154 OP = op .upper ()
@@ -56,9 +59,7 @@ def to_numpy(a):
5659for op in api .api_categories ["EWUnyOp" ]:
5760 if not op .startswith ("__" ):
5861 OP = op .upper ()
59- exec (
60- f"{ op } = lambda this: dtensor(_cdt.EWUnyOp.op(_cdt.{ OP } , this._t))"
61- )
62+ exec (f"{ op } = lambda this: dtensor(_cdt.EWUnyOp.op(_cdt.{ OP } , this._t))" )
6263
6364for func in api .api_categories ["Creator" ]:
6465 FUNC = func .upper ()
@@ -98,7 +99,10 @@ def to_numpy(a):
9899
99100for func in api .api_categories ["LinAlgOp" ]:
100101 FUNC = func .upper ()
101- if func in ["tensordot" , "vecdot" ,]:
102+ if func in [
103+ "tensordot" ,
104+ "vecdot" ,
105+ ]:
102106 exec (
103107 f"{ func } = lambda this, other, axis: dtensor(_cdt.LinAlgOp.{ func } (this._t, other._t, axis))"
104108 )
@@ -107,9 +111,7 @@ def to_numpy(a):
107111 f"{ func } = lambda this, other: dtensor(_cdt.LinAlgOp.vecdot(this._t, other._t, 0))"
108112 )
109113 elif func == "matrix_transpose" :
110- exec (
111- f"{ func } = lambda this: dtensor(_cdt.LinAlgOp.{ func } (this._t))"
112- )
114+ exec (f"{ func } = lambda this: dtensor(_cdt.LinAlgOp.{ func } (this._t))" )
113115
114116for func in api .api_categories ["SortOp" ]:
115117 exec (
0 commit comments