Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit aceb19a

Browse files
committed
adding elementwise methods
1 parent 2c3e9be commit aceb19a

8 files changed

Lines changed: 256 additions & 39 deletions

File tree

ddptensor/array_api.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,38 @@
3030
"__itruediv__",
3131
"__ixor__",
3232
]
33+
34+
ew_binary_methods = [
35+
"__add__", # (self, other, /)
36+
"__and__", # (self, other, /)
37+
"__eq__", # (self, other, /)
38+
"__floordiv__", # (self, other, /)
39+
"__ge__", # (self, other, /)
40+
"__gt__", # (self, other, /)
41+
"__le__", # (self, other, /)
42+
"__lshift__", # (self, other, /)
43+
"__lt__", # (self, other, /)
44+
"__matmul__", # (self, other, /)
45+
"__mod__", # (self, other, /)
46+
"__mul__", # (self, other, /)
47+
"__ne__", # (self, other, /)
48+
"__or__", # (self, other, /)
49+
"__pow__", # (self, other, /)
50+
"__rshift__", # (self, other, /)
51+
"__sub__", # (self, other, /)
52+
"__truediv__", # (self, other, /)
53+
"__xor__", # (self, other, /)
54+
# reflected operators
55+
"__radd__",
56+
"__rand__",
57+
"__rfloordiv__",
58+
"__rlshift__",
59+
"__rmod__",
60+
"__rmul__",
61+
"__ror__",
62+
"__rpow__",
63+
"__rrshift__",
64+
"__rsub__",
65+
"__rtruediv__",
66+
"__rxor__",
67+
]

ddptensor/ddptensor.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,6 @@
22
from ._ddptensor import float64, int64, fini
33
from . import array_api as api
44

5-
ew_binary_methods = [
6-
"__add__", # (self, other, /)
7-
"__and__", # (self, other, /)
8-
"__eq__", # (self, other, /)
9-
"__floordiv__", # (self, other, /)
10-
"__ge__", # (self, other, /)
11-
"__gt__", # (self, other, /)
12-
"__le__", # (self, other, /)
13-
"__lshift__", # (self, other, /)
14-
"__lt__", # (self, other, /)
15-
"__matmul__", # (self, other, /)
16-
"__mod__", # (self, other, /)
17-
"__mul__", # (self, other, /)
18-
"__ne__", # (self, other, /)
19-
"__or__", # (self, other, /)
20-
"__pow__", # (self, other, /)
21-
"__rshift__", # (self, other, /)
22-
"__sub__", # (self, other, /)
23-
"__truediv__", # (self, other, /)
24-
"__xor__", # (self, other, /)
25-
# reflected operators
26-
"__radd__",
27-
"__rand__",
28-
"__rflowdiv__",
29-
"__rlshift__",
30-
"__rmod__",
31-
"__rmul__",
32-
"__ror__",
33-
"__rpow__",
34-
"__rrshift__",
35-
"__rsub__",
36-
"__rtruediv__",
37-
"__rxor__",
38-
]
39-
405
ew_unary_methods = [
416
"__abs__", # (self, /)
427
"__invert__", # (self, /)
@@ -70,7 +35,7 @@ def __repr__(self):
7035
return self._t.__repr__()
7136

7237

73-
for method in ew_binary_methods:
38+
for method in api.ew_binary_methods:
7439
exec(
7540
f"{method} = lambda self, other: dtensor(_cdt.ew_binary_op(self._t, '{method}', other._t if isinstance(other, dtensor) else other, True))"
7641
)

scripts/code_gen.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
print(" IEWBINOP_LAST")
2121
print("};\n")
2222

23+
print("enum EWBinOpId : int {")
24+
for x in api.ew_binary_methods:
25+
x = x[2:-2] + " = IEWBINOP_LAST" if x == api.ew_binary_methods[0] else x[2:-2]
26+
print(f" {x.upper()},")
27+
print(" EWBINOP_LAST")
28+
print("};\n")
29+
2330
print("void def_enums(py::module_ & m)\n{")
2431

2532
print(' py::enum_<CreatorId>(m, "CreatorId")')
@@ -30,6 +37,11 @@
3037
print(' py::enum_<IEWBinOpId>(m, "IEWBinOpId")')
3138
for x in api.ew_binary_methods_inplace:
3239
print(f' .value("{x[2:-2].upper()}", {x[2:-2].upper()})')
33-
print(" .export_values();")
40+
print(" .export_values();\n")
41+
42+
print(' py::enum_<EWBinOpId>(m, "EWBinOpId")')
43+
for x in api.ew_binary_methods:
44+
print(f' .value("{x[2:-2].upper()}", {x[2:-2].upper()})')
45+
print(" .export_values();\n")
3446

3547
print("}")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
glob("src/*.cpp"),
1717
include_dirs = xt_includes + [jp(mpiroot, "include"), jp("third_party", "bitsery", "include"), jp("src", "include"), ],
1818
extra_compile_args = ["-DUSE_MKL", "-DXTENSOR_USE_XSIMD=1", "-DXTENSOR_USE_OPENMP=1",
19-
"-std=c++17",
19+
"-std=c++17", "-fopenmp",
2020
"-Wno-unused-but-set-variable", "-Wno-sign-compare", "-Wno-unused-local-typedefs", "-Wno-reorder",
2121
"-march=native", "-O0", "-g"],
2222
libraries = ["mpi", "mkl_intel_lp64", "mkl_intel_thread", "mkl_core", "iomp5", "pthread", "rt", "dl", "m"],

src/ddptensor.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,14 @@ struct IEWBinOp
228228
}
229229
};
230230

231+
struct EWBinOp
232+
{
233+
static auto op(EWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
234+
{
235+
return TypeDispatch<x::EWBinOp>(a->dtype(), op, a, b);
236+
}
237+
};
238+
231239
rank_type myrank()
232240
{
233241
return theTransceiver->rank();
@@ -257,6 +265,9 @@ PYBIND11_MODULE(_ddptensor, m) {
257265
py::class_<IEWBinOp>(m, "IEWBinOp")
258266
.def("op", &IEWBinOp::op);
259267

268+
py::class_<EWBinOp>(m, "EWBinOp")
269+
.def("op", &EWBinOp::op);
270+
260271
py::class_<x::DPTensorBaseX, x::DPTensorBaseX::ptr_type>(m, "DPTensorX")
261272
.def("__repr__", &x::DPTensorBaseX::__repr__);
262273

src/include/ddptensor/p2c_ids.hpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,41 @@ enum IEWBinOpId : int {
3838
IEWBINOP_LAST
3939
};
4040

41+
enum EWBinOpId : int {
42+
ADD = IEWBINOP_LAST,
43+
AND,
44+
EQ,
45+
FLOORDIV,
46+
GE,
47+
GT,
48+
LE,
49+
LSHIFT,
50+
LT,
51+
MATMUL,
52+
MOD,
53+
MUL,
54+
NE,
55+
OR,
56+
POW,
57+
RSHIFT,
58+
SUB,
59+
TRUEDIV,
60+
XOR,
61+
RADD,
62+
RAND,
63+
RFLOORDIV,
64+
RLSHIFT,
65+
RMOD,
66+
RMUL,
67+
ROR,
68+
RPOW,
69+
RRSHIFT,
70+
RSUB,
71+
RTRUEDIV,
72+
RXOR,
73+
EWBINOP_LAST
74+
};
75+
4176
void def_enums(py::module_ & m)
4277
{
4378
py::enum_<CreatorId>(m, "CreatorId")
@@ -71,4 +106,39 @@ void def_enums(py::module_ & m)
71106
.value("ITRUEDIV", ITRUEDIV)
72107
.value("IXOR", IXOR)
73108
.export_values();
109+
110+
py::enum_<EWBinOpId>(m, "EWBinOpId")
111+
.value("ADD", ADD)
112+
.value("AND", AND)
113+
.value("EQ", EQ)
114+
.value("FLOORDIV", FLOORDIV)
115+
.value("GE", GE)
116+
.value("GT", GT)
117+
.value("LE", LE)
118+
.value("LSHIFT", LSHIFT)
119+
.value("LT", LT)
120+
.value("MATMUL", MATMUL)
121+
.value("MOD", MOD)
122+
.value("MUL", MUL)
123+
.value("NE", NE)
124+
.value("OR", OR)
125+
.value("POW", POW)
126+
.value("RSHIFT", RSHIFT)
127+
.value("SUB", SUB)
128+
.value("TRUEDIV", TRUEDIV)
129+
.value("XOR", XOR)
130+
.value("RADD", RADD)
131+
.value("RAND", RAND)
132+
.value("RFLOORDIV", RFLOORDIV)
133+
.value("RLSHIFT", RLSHIFT)
134+
.value("RMOD", RMOD)
135+
.value("RMUL", RMUL)
136+
.value("ROR", ROR)
137+
.value("RPOW", RPOW)
138+
.value("RRSHIFT", RRSHIFT)
139+
.value("RSUB", RSUB)
140+
.value("RTRUEDIV", RTRUEDIV)
141+
.value("RXOR", RXOR)
142+
.export_values();
143+
74144
}

src/include/ddptensor/x.hpp

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace x
3333
virtual ~DPTensorBaseX() {};
3434
virtual std::string __repr__() const = 0;
3535
virtual DType dtype() const = 0;
36+
virtual shape_type shape() const = 0;
3637
};
3738

3839
template<typename T>
@@ -65,6 +66,16 @@ namespace x
6566
{
6667
return _xarray;
6768
}
69+
70+
const PVSlice & slice() const
71+
{
72+
return _slice;
73+
}
74+
75+
virtual shape_type shape() const
76+
{
77+
return _slice.shape();
78+
}
6879
};
6980

7081
template<typename T>
@@ -147,7 +158,7 @@ namespace x
147158
static void op(IEWBinOpId iop, ptr_type a_ptr, const ptr_type & b_ptr)
148159
{
149160
auto _a = dynamic_cast<DPTensorX<T>*>(a_ptr.get());
150-
auto _b = dynamic_cast<DPTensorX<T>*>(b_ptr.get());
161+
auto const _b = dynamic_cast<DPTensorX<T>*>(b_ptr.get());
151162
if(!_a || !_b)
152163
throw std::runtime_error("Invalid array object: could not dynamically cast");
153164
auto & a = _a->xarray();
@@ -178,6 +189,117 @@ namespace x
178189
integral_iop(iop, a, b);
179190
}
180191

192+
#pragma GCC diagnostic pop
193+
194+
};
195+
196+
197+
template<typename T>
198+
class EWBinOp
199+
{
200+
public:
201+
using ptr_type = DPTensorBaseX::ptr_type;
202+
203+
template<typename X>
204+
static ptr_type mk_tx(const DPTensorBaseX & tx, X && x)
205+
{
206+
return std::make_shared<DPTensorX<typename X::value_type>>(tx.shape(), x);
207+
}
208+
209+
#pragma GCC diagnostic ignored "-Wswitch"
210+
211+
template<typename A, typename B, typename U = T, std::enable_if_t<std::is_floating_point<U>::value, bool> = true>
212+
static ptr_type integral_op(EWBinOpId iop, const DPTensorX<T> & tx, A && a, B && b)
213+
{
214+
throw std::runtime_error("Illegal or unknown inplace elementwise binary operation");
215+
}
216+
217+
template<typename A, typename B, typename U = T, std::enable_if_t<std::is_integral<U>::value, bool> = true>
218+
static ptr_type integral_op(EWBinOpId iop, const DPTensorBaseX & tx, A && a, B && b)
219+
{
220+
switch(iop) {
221+
case AND:
222+
case RAND:
223+
return mk_tx(tx, a & b);
224+
case LSHIFT:
225+
return mk_tx(tx, a << b);
226+
case MOD:
227+
return mk_tx(tx, a % b);
228+
case OR:
229+
case ROR:
230+
return mk_tx(tx, a | b);
231+
case RSHIFT:
232+
return mk_tx(tx, a >> b);
233+
case XOR:
234+
case RXOR:
235+
return mk_tx(tx, a ^ b);
236+
case RLSHIFT:
237+
return mk_tx(tx, b << a);
238+
case RMOD:
239+
return mk_tx(tx, b % a);
240+
case RRSHIFT:
241+
return mk_tx(tx, b >> a);
242+
default:
243+
throw std::runtime_error("Unknown elementwise binary operation");
244+
}
245+
}
246+
247+
static ptr_type op(EWBinOpId bop, const ptr_type & a_ptr, const ptr_type & b_ptr)
248+
{
249+
auto _a = dynamic_cast<DPTensorX<T>*>(a_ptr.get());
250+
auto const _b = dynamic_cast<DPTensorX<T>*>(b_ptr.get());
251+
if(!_a || !_b)
252+
throw std::runtime_error("Invalid array object: could not dynamically cast");
253+
auto & a = _a->xarray();
254+
auto const & b = _b->xarray();
255+
256+
switch(bop) {
257+
case ADD:
258+
case RADD:
259+
return mk_tx(*_a, a + b);
260+
case EQ:
261+
return mk_tx(*_a, xt::equal(a, b));
262+
case FLOORDIV:
263+
return mk_tx(*_a, xt::floor(a / b));
264+
case GE:
265+
return mk_tx(*_a, a >= b);
266+
case GT:
267+
return mk_tx(*_a, a > b);
268+
case LE:
269+
return mk_tx(*_a, a <= b);
270+
case LT:
271+
return mk_tx(*_a, a < b);
272+
/* FIXME
273+
case MATMUL:
274+
return mk_tx(*_a, );
275+
*/
276+
case MUL:
277+
case RMUL:
278+
return mk_tx(*_a, a * b);
279+
case NE:
280+
return mk_tx(*_a, xt::not_equal(a, b));
281+
/* FIXME
282+
case POW:
283+
return mk_tx(*_a, );
284+
*/
285+
case SUB:
286+
return mk_tx(*_a, a - b);
287+
case TRUEDIV:
288+
return mk_tx(*_a, a / b);
289+
case RFLOORDIV:
290+
return mk_tx(*_a, xt::floor(b / a));
291+
/* FIXME
292+
case RPOW:
293+
return mk_tx(*_a, );
294+
*/
295+
case RSUB:
296+
return mk_tx(*_a, b - a);
297+
case RTRUEDIV:
298+
return mk_tx(*_a, b / a);
299+
}
300+
return integral_op(bop, *_a, a, b);
301+
}
302+
181303
#pragma GCC diagnostic pop
182304

183305
};

test/test_x.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
b = dt.Creator.create_from_shape(dt.ONES, [4,4], dt.float64)
44
dt.IEWBinOp.op(dt.IADD, a, b)
55
print(a)
6+
print(dt.EWBinOp.op(dt.EQ, a, b))
7+
dt.fini()

0 commit comments

Comments
 (0)