Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 5dead87

Browse files
committed
adding unary ops
1 parent aceb19a commit 5dead87

8 files changed

Lines changed: 508 additions & 228 deletions

File tree

ddptensor/__init__.py

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,74 +6,12 @@
66
__impl_str = getenv("DDPNP_ARRAY", 'numpy')
77
exec(f"import {__impl_str} as __impl")
88

9-
ew_binary_ops = [
10-
"add", # (x1, x2, /)
11-
"atan2", # (x1, x2, /)
12-
"bitwise_and", # (x1, x2, /)
13-
"bitwise_left_shift", # (x1, x2, /)
14-
"bitwise_or", # (x1, x2, /)
15-
"bitwise_right_shift", # (x1, x2, /)
16-
"bitwise_xor", # (x1, x2, /)
17-
"divide", # (x1, x2, /)
18-
"equal", # (x1, x2, /)
19-
"floor_divide", # (x1, x2, /)
20-
"greater", # (x1, x2, /)
21-
"greater_equal", # (x1, x2, /)
22-
"less_equal", # (x1, x2, /)
23-
"logaddexp", # (x1, x2)
24-
"logical_and", # (x1, x2, /)
25-
"logical_or", # (x1, x2, /)
26-
"logical_xor", # (x1, x2, /)
27-
"multiply", # (x1, x2, /)
28-
"less", # (x1, x2, /)
29-
"not_equal", # (x1, x2, /)
30-
"pow", # (x1, x2, /)
31-
"remainder", # (x1, x2, /)
32-
"subtract", # (x1, x2, /)
33-
]
34-
35-
for op in ew_binary_ops:
9+
for op in api.ew_binary_ops:
3610
exec(
3711
f"{op} = lambda this, other: dtensor(_cdt.ew_binary_op(this._t, '{op}', other._t if isinstance(other, ddptensor) else other, False))"
3812
)
3913

40-
ew_unary_ops = [
41-
"abs", # (x, /)
42-
"acos", # (x, /)
43-
"acosh", # (x, /)
44-
"asin", # (x, /)
45-
"asinh", # (x, /)
46-
"atan", # (x, /)
47-
"atanh", # (x, /)
48-
"bitwise_invert", # (x, /)
49-
"ceil", # (x, /)
50-
"cos", # (x, /)
51-
"cosh", # (x, /)
52-
"exp", # (x, /)
53-
"expm1", # (x, /)
54-
"floor", # (x, /)
55-
"isfinite", # (x, /)
56-
"isinf", # (x, /)
57-
"isnan", # (x, /)
58-
"logical_not", # (x, /)
59-
"log", # (x, /)
60-
"log1p", # (x, /)
61-
"log2", # (x, /)
62-
"log10", # (x, /)
63-
"negative", # (x, /)
64-
"positive", # (x, /)
65-
"round", # (x, /)
66-
"sign", # (x, /)
67-
"sin", # (x, /)
68-
"sinh", # (x, /)
69-
"square", # (x, /)
70-
"sqrt", # (x, /)
71-
"tan", # (x, /)
72-
"tanh", # (x, /)
73-
"trunc", # (x, /)
74-
]
75-
76-
for op in ew_unary_ops:
14+
for op in api.ew_unary_ops:
7715
exec(
7816
f"{op} = lambda this: dtensor(_cdt.ew_unary_op(this._t, '{op}', False))"
7917
)

ddptensor/array_api.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,49 @@
1515
"zeros_like", # (x, /, *, dtype=None, device=None)
1616
]
1717

18+
ew_unary_methods = [
19+
"__abs__", # (self, /)
20+
"__invert__", # (self, /)
21+
"__neg__", # (self, /)
22+
"__pos__", # (self, /)
23+
]
24+
25+
ew_unary_ops = [
26+
"abs", # (x, /)
27+
"acos", # (x, /)
28+
"acosh", # (x, /)
29+
"asin", # (x, /)
30+
"asinh", # (x, /)
31+
"atan", # (x, /)
32+
"atanh", # (x, /)
33+
"bitwise_invert", # (x, /)
34+
"ceil", # (x, /)
35+
"cos", # (x, /)
36+
"cosh", # (x, /)
37+
"exp", # (x, /)
38+
"expm1", # (x, /)
39+
"floor", # (x, /)
40+
"isfinite", # (x, /)
41+
"isinf", # (x, /)
42+
"isnan", # (x, /)
43+
"logical_not", # (x, /)
44+
"log", # (x, /)
45+
"log1p", # (x, /)
46+
"log2", # (x, /)
47+
"log10", # (x, /)
48+
"negative", # (x, /)
49+
"positive", # (x, /)
50+
"round", # (x, /)
51+
"sign", # (x, /)
52+
"sin", # (x, /)
53+
"sinh", # (x, /)
54+
"square", # (x, /)
55+
"sqrt", # (x, /)
56+
"tan", # (x, /)
57+
"tanh", # (x, /)
58+
"trunc", # (x, /)
59+
]
60+
1861
ew_binary_methods_inplace = [
1962
# inplace operators
2063
"__iadd__",
@@ -65,3 +108,29 @@
65108
"__rtruediv__",
66109
"__rxor__",
67110
]
111+
112+
ew_binary_ops = [
113+
"add", # (x1, x2, /)
114+
"atan2", # (x1, x2, /)
115+
"bitwise_and", # (x1, x2, /)
116+
"bitwise_left_shift", # (x1, x2, /)
117+
"bitwise_or", # (x1, x2, /)
118+
"bitwise_right_shift", # (x1, x2, /)
119+
"bitwise_xor", # (x1, x2, /)
120+
"divide", # (x1, x2, /)
121+
"equal", # (x1, x2, /)
122+
"floor_divide", # (x1, x2, /)
123+
"greater", # (x1, x2, /)
124+
"greater_equal", # (x1, x2, /)
125+
"less_equal", # (x1, x2, /)
126+
"logaddexp", # (x1, x2)
127+
"logical_and", # (x1, x2, /)
128+
"logical_or", # (x1, x2, /)
129+
"logical_xor", # (x1, x2, /)
130+
"multiply", # (x1, x2, /)
131+
"less", # (x1, x2, /)
132+
"not_equal", # (x1, x2, /)
133+
"pow", # (x1, x2, /)
134+
"remainder", # (x1, x2, /)
135+
"subtract", # (x1, x2, /)
136+
]

ddptensor/ddptensor.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,6 @@
22
from ._ddptensor import float64, int64, fini
33
from . import array_api as api
44

5-
ew_unary_methods = [
6-
"__abs__", # (self, /)
7-
"__invert__", # (self, /)
8-
"__neg__", # (self, /)
9-
"__pos__", # (self, /)
10-
]
11-
125
unary_methods = [
136
# "__array_namespace__", # (self, /, *, api_version=None)
147
"__bool__", # (self, /)
@@ -45,7 +38,7 @@ def __repr__(self):
4538
f"{method} = lambda self, other: (self, _cdt.ew_binary_op_inplace(self._t, '{method}', other._t if isinstance(other, dtensor) else other))[0]"
4639
)
4740

48-
for method in ew_unary_methods:
41+
for method in api.ew_unary_methods:
4942
exec(
5043
f"{method} = lambda self: dtensor(_cdt.ew_unary_op(self._t, '{method}', True))"
5144
)

scripts/code_gen.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,25 @@
1313
print(" CREATOR_LAST")
1414
print("};\n")
1515

16+
uops = api.ew_unary_methods + api.ew_unary_ops
17+
print("enum EWUnyOpId : int {")
18+
for x in uops:
19+
x = x + " = CREATOR_LAST" if x == uops[0] else x
20+
print(f" {x.upper()},")
21+
print(" EWUNYOP_LAST")
22+
print("};\n")
23+
1624
print("enum IEWBinOpId : int {")
1725
for x in api.ew_binary_methods_inplace:
18-
x = x[2:-2] + " = CREATOR_LAST" if x == api.ew_binary_methods_inplace[0] else x[2:-2]
26+
x = x + " = EWUNYOP_LAST" if x == api.ew_binary_methods_inplace[0] else x
1927
print(f" {x.upper()},")
2028
print(" IEWBINOP_LAST")
2129
print("};\n")
2230

31+
bops = api.ew_binary_methods + api.ew_binary_ops
2332
print("enum EWBinOpId : int {")
24-
for x in api.ew_binary_methods:
25-
x = x[2:-2] + " = IEWBINOP_LAST" if x == api.ew_binary_methods[0] else x[2:-2]
33+
for x in bops:
34+
x = x + " = IEWBINOP_LAST" if x == bops[0] else x
2635
print(f" {x.upper()},")
2736
print(" EWBINOP_LAST")
2837
print("};\n")
@@ -34,14 +43,19 @@
3443
print(f' .value("{x.upper()}", {x.upper()})')
3544
print(" .export_values();\n")
3645

46+
print(' py::enum_<EWUnyOpId>(m, "EWUnyOpId")')
47+
for x in uops:
48+
print(f' .value("{x.upper()}", {x.upper()})')
49+
print(" .export_values();\n")
50+
3751
print(' py::enum_<IEWBinOpId>(m, "IEWBinOpId")')
3852
for x in api.ew_binary_methods_inplace:
39-
print(f' .value("{x[2:-2].upper()}", {x[2:-2].upper()})')
53+
print(f' .value("{x.upper()}", {x.upper()})')
4054
print(" .export_values();\n")
4155

4256
print(' py::enum_<EWBinOpId>(m, "EWBinOpId")')
43-
for x in api.ew_binary_methods:
44-
print(f' .value("{x[2:-2].upper()}", {x[2:-2].upper()})')
57+
for x in bops:
58+
print(f' .value("{x.upper()}", {x.upper()})')
4559
print(" .export_values();\n")
4660

4761
print("}")

src/ddptensor.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,14 @@ struct EWBinOp
236236
}
237237
};
238238

239+
struct EWUnyOp
240+
{
241+
static auto op(EWUnyOpId op, x::DPTensorBaseX::ptr_type a)
242+
{
243+
return TypeDispatch<x::EWUnyOp>(a->dtype(), op, a);
244+
}
245+
};
246+
239247
rank_type myrank()
240248
{
241249
return theTransceiver->rank();
@@ -262,6 +270,9 @@ PYBIND11_MODULE(_ddptensor, m) {
262270
.def("create_from_shape", &Creator::create_from_shape)
263271
.def("full", &Creator::full);
264272

273+
py::class_<EWUnyOp>(m, "EWUnyOp")
274+
.def("op", &EWUnyOp::op);
275+
265276
py::class_<IEWBinOp>(m, "IEWBinOp")
266277
.def("op", &IEWBinOp::op);
267278

0 commit comments

Comments
 (0)