Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 6467b2b

Browse files
committed
adding reductions
1 parent 5dead87 commit 6467b2b

10 files changed

Lines changed: 130 additions & 47 deletions

File tree

ddptensor/__init__.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,7 @@
2222
f"{func} = lambda shape, *args, **kwargs: dtensor(_cdt.create(shape, '{func}', '{__impl_str}', *args, **kwargs))"
2323
)
2424

25-
statisticals = [
26-
"max", # (x, /, *, axis=None, keepdims=False)
27-
"mean", # (x, /, *, axis=None, keepdims=False)
28-
"min", # (x, /, *, axis=None, keepdims=False)
29-
"prod", # (x, /, *, axis=None, keepdims=False)
30-
"sum", # (x, /, *, axis=None, keepdims=False)
31-
"std", # (x, /, *, axis=None, correction=0.0, keepdims=False)
32-
"var", # (x, /, *, axis=None, correction=0.0, keepdims=False)
33-
]
34-
35-
for func in statisticals:
25+
for func in api.statisticals:
3626
exec(
3727
f"{func} = lambda this, **kwargs: dtensor(_cdt.reduce_op(this._t, '{func}', **kwargs))"
3828
)

ddptensor/array_api.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,13 @@
134134
"remainder", # (x1, x2, /)
135135
"subtract", # (x1, x2, /)
136136
]
137+
138+
statisticals = [
139+
"max", # (x, /, *, axis=None, keepdims=False)
140+
"mean", # (x, /, *, axis=None, keepdims=False)
141+
"min", # (x, /, *, axis=None, keepdims=False)
142+
"prod", # (x, /, *, axis=None, keepdims=False)
143+
"sum", # (x, /, *, axis=None, keepdims=False)
144+
"std", # (x, /, *, axis=None, correction=0.0, keepdims=False)
145+
"var", # (x, /, *, axis=None, correction=0.0, keepdims=False)
146+
]

scripts/code_gen.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@
3636
print(" EWBINOP_LAST")
3737
print("};\n")
3838

39-
print("void def_enums(py::module_ & m)\n{")
39+
print("enum ReduceOpId : int {")
40+
for x in api.statisticals:
41+
x = x + " = EWBINOP_LAST" if x == api.statisticals[0] else x
42+
print(f" {x.upper()},")
43+
print(" REDUCEOP_LAST")
44+
print("};\n")
45+
46+
print("static void def_enums(py::module_ & m)\n{")
4047

4148
print(' py::enum_<CreatorId>(m, "CreatorId")')
4249
for x in api.creators:
@@ -58,4 +65,9 @@
5865
print(f' .value("{x.upper()}", {x.upper()})')
5966
print(" .export_values();\n")
6067

68+
print(' py::enum_<ReduceOpId>(m, "ReduceOpId")')
69+
for x in api.statisticals:
70+
print(f' .value("{x.upper()}", {x.upper()})')
71+
print(" .export_values();\n")
72+
6173
print("}")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
extra_compile_args = ["-DUSE_MKL", "-DXTENSOR_USE_XSIMD=1", "-DXTENSOR_USE_OPENMP=1",
1919
"-std=c++17", "-fopenmp",
2020
"-Wno-unused-but-set-variable", "-Wno-sign-compare", "-Wno-unused-local-typedefs", "-Wno-reorder",
21-
"-march=native", "-O0", "-g"],
21+
"-march=native",], # "-O0", "-g"],
2222
libraries = ["mpi", "mkl_intel_lp64", "mkl_intel_thread", "mkl_core", "iomp5", "pthread", "rt", "dl", "m"],
2323
library_dirs = [jp(mpiroot, "lib")],
2424
language = 'c++'

src/MPITransceiver.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ static MPI_Datatype to_mpi(DType T)
4242
static MPI_Op to_mpi(RedOpType o)
4343
{
4444
switch(o) {
45-
case OP_MAX: return MPI_MAX;
46-
case OP_MIN: return MPI_MIN;
47-
case OP_SUM: return MPI_SUM;
48-
case OP_PROD: return MPI_PROD;
49-
case OP_LAND: return MPI_LAND;
50-
case OP_BAND: return MPI_BAND;
51-
case OP_LOR: return MPI_LOR;
52-
case OP_BOR: return MPI_BOR;
53-
case OP_LXOR: return MPI_LXOR;
54-
case OP_BXOR: return MPI_BXOR;
45+
case MAX: return MPI_MAX;
46+
case MIN: return MPI_MIN;
47+
case SUM: return MPI_SUM;
48+
case PROD: return MPI_PROD;
49+
// case OP_LAND: return MPI_LAND;
50+
// case OP_BAND: return MPI_BAND;
51+
// case OP_LOR: return MPI_LOR;
52+
// case OP_BOR: return MPI_BOR;
53+
// case OP_LXOR: return MPI_LXOR;
54+
// case OP_BXOR: return MPI_BXOR;
5555
default: throw std::logic_error("unsupported operation type");
5656
}
5757
}

src/ddptensor.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ struct EWUnyOp
244244
}
245245
};
246246

247+
struct ReduceOp
248+
{
249+
static auto op(ReduceOpId op, x::DPTensorBaseX::ptr_type a, const dim_vec_type & dim)
250+
{
251+
return TypeDispatch<x::ReduceOp>(a->dtype(), op, a, dim);
252+
}
253+
};
254+
247255
rank_type myrank()
248256
{
249257
return theTransceiver->rank();
@@ -279,6 +287,9 @@ PYBIND11_MODULE(_ddptensor, m) {
279287
py::class_<EWBinOp>(m, "EWBinOp")
280288
.def("op", &EWBinOp::op);
281289

290+
py::class_<ReduceOp>(m, "ReduceOp")
291+
.def("op", &ReduceOp::op);
292+
282293
py::class_<x::DPTensorBaseX, x::DPTensorBaseX::ptr_type>(m, "DPTensorX")
283294
.def("__repr__", &x::DPTensorBaseX::__repr__);
284295

src/include/ddptensor/UtilsAndTypes.hpp

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <pybind11/pybind11.h>
1010
namespace py = pybind11;
11+
#include "p2c_ids.hpp"
1112

1213
using shape_type = std::vector<uint64_t>;
1314
using dim_vec_type = std::vector<int>;
@@ -81,32 +82,17 @@ inline const py::object & get_impl_dtype(const DType dt)
8182
return _dtypes[dt];
8283
}
8384

84-
// identifies reduction operation
85-
enum RedOpType {
86-
OP_MAX = 100,
87-
OP_MIN,
88-
OP_SUM,
89-
OP_PROD,
90-
OP_MEAN,
91-
OP_STD,
92-
OP_VAR,
93-
OP_LAND,
94-
OP_BAND,
95-
OP_LOR,
96-
OP_BOR,
97-
OP_LXOR,
98-
OP_BXOR
99-
};
85+
using RedOpType = ReduceOpId;
10086

10187
inline RedOpType red_op(const char * op)
10288
{
103-
if(!strcmp(op, "max")) return OP_MAX;
104-
if(!strcmp(op, "min")) return OP_MIN;
105-
if(!strcmp(op, "sum")) return OP_SUM;
106-
if(!strcmp(op, "prod")) return OP_PROD;
107-
if(!strcmp(op, "mean")) return OP_MEAN;
108-
if(!strcmp(op, "std")) return OP_STD;
109-
if(!strcmp(op, "var")) return OP_VAR;
89+
if(!strcmp(op, "max")) return MAX;
90+
if(!strcmp(op, "min")) return MIN;
91+
if(!strcmp(op, "sum")) return SUM;
92+
if(!strcmp(op, "prod")) return PROD;
93+
if(!strcmp(op, "mean")) return MEAN;
94+
if(!strcmp(op, "std")) return STD;
95+
if(!strcmp(op, "var")) return VAR;
11096
throw std::logic_error("unsupported reduction operation");
11197
}
11298

src/include/ddptensor/p2c_ids.hpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,18 @@ enum EWBinOpId : int {
137137
EWBINOP_LAST
138138
};
139139

140-
void def_enums(py::module_ & m)
140+
enum ReduceOpId : int {
141+
MAX = EWBINOP_LAST,
142+
MEAN,
143+
MIN,
144+
PROD,
145+
SUM,
146+
STD,
147+
VAR,
148+
REDUCEOP_LAST
149+
};
150+
151+
static void def_enums(py::module_ & m)
141152
{
142153
py::enum_<CreatorId>(m, "CreatorId")
143154
.value("ARANGE", ARANGE)
@@ -268,4 +279,14 @@ void def_enums(py::module_ & m)
268279
.value("SUBTRACT", SUBTRACT)
269280
.export_values();
270281

282+
py::enum_<ReduceOpId>(m, "ReduceOpId")
283+
.value("MAX", MAX)
284+
.value("MEAN", MEAN)
285+
.value("MIN", MIN)
286+
.value("PROD", PROD)
287+
.value("SUM", SUM)
288+
.value("STD", STD)
289+
.value("VAR", VAR)
290+
.export_values();
291+
271292
}

src/include/ddptensor/x.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,5 +429,57 @@ namespace x
429429
#pragma GCC diagnostic pop
430430

431431
};
432+
433+
434+
template<typename T>
435+
class ReduceOp
436+
{
437+
public:
438+
using ptr_type = DPTensorBaseX::ptr_type;
439+
440+
#pragma GCC diagnostic ignored "-Wswitch"
441+
442+
template<typename X>
443+
static ptr_type dist_reduce(ReduceOpId rop, const PVSlice & slice, const dim_vec_type & dims, X && x)
444+
{
445+
xt::xarray<typename X::value_type> a = x;
446+
auto new_shape = reduce_shape(slice.shape(), dims);
447+
if(slice.need_reduce(dims)) {
448+
auto len = VPROD(new_shape);
449+
theTransceiver->reduce_all(a.data(), DTYPE<typename X::value_type>::value, len, rop);
450+
}
451+
return std::make_shared<DPTensorX<typename X::value_type>>(new_shape, a);
452+
}
453+
454+
static ptr_type op(ReduceOpId rop, const ptr_type & a_ptr, const dim_vec_type & dims)
455+
{
456+
auto const _a = dynamic_cast<DPTensorX<T>*>(a_ptr.get());
457+
if(!_a )
458+
throw std::runtime_error("Invalid array object: could not dynamically cast");
459+
auto const & a = _a->xarray();
460+
461+
switch(rop) {
462+
case MEAN:
463+
return dist_reduce(rop, _a->slice(), dims, xt::mean(a, dims));
464+
case PROD:
465+
return dist_reduce(rop, _a->slice(), dims, xt::prod(a, dims));
466+
case SUM:
467+
return dist_reduce(rop, _a->slice(), dims, xt::sum(a, dims));
468+
case STD:
469+
return dist_reduce(rop, _a->slice(), dims, xt::stddev(a, dims));
470+
case VAR:
471+
return dist_reduce(rop, _a->slice(), dims, xt::variance(a, dims));
472+
case MAX:
473+
case MIN:
474+
throw std::runtime_error("Reduction operation not implemented");
475+
default:
476+
throw std::runtime_error("Unknown reduction operation");
477+
}
478+
}
479+
480+
#pragma GCC diagnostic pop
481+
482+
};
483+
432484

433485
} // namespace x

test/test_x.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
print(a)
66
print(dt.EWBinOp.op(dt.EQUAL, a, b))
77
print(dt.EWUnyOp.op(dt.SQRT, a))
8+
print(dt.ReduceOp.op(dt.SUM, a, [1]))
89
dt.fini()

0 commit comments

Comments
 (0)