Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 2c3e9be

Browse files
committed
first cut supporting creation and inplace elementwise ops through xtensor
1 parent baec5fa commit 2c3e9be

14 files changed

Lines changed: 435 additions & 54 deletions

File tree

.gitmodules

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
11
[submodule "third_party/bitsery"]
22
path = third_party/bitsery
33
url = https://github.com/fraillt/bitsery
4+
[submodule "third_party/xtensor"]
5+
path = third_party/xtensor
6+
url = https://github.com/xtensor-stack/xtensor
7+
[submodule "third_party/xtensor-blas"]
8+
path = third_party/xtensor-blas
9+
url = https://github.com/xtensor-stack/xtensor-blas
10+
[submodule "third_party/xsimd"]
11+
path = third_party/xsimd
12+
url = https://github.com/xtensor-stack/xsimd
13+
[submodule "third_party/xtl"]
14+
path = third_party/xtl
15+
url = https://github.com/xtensor-stack/xtl

ddptensor/__init__.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from . import _ddptensor as _cdt
22
from .ddptensor import float64, int64, fini, dtensor
33
from os import getenv
4+
from . import array_api as api
45

56
__impl_str = getenv("DDPNP_ARRAY", 'numpy')
67
exec(f"import {__impl_str} as __impl")
@@ -77,17 +78,11 @@
7778
f"{op} = lambda this: dtensor(_cdt.ew_unary_op(this._t, '{op}', False))"
7879
)
7980

80-
creators_with_shape = [
81-
"empty", # (shape, *, dtype=None, device=None)
82-
"full", # (shape, fill_value, *, dtype=None, device=None)
83-
"ones", # (shape, *, dtype=None, device=None)
84-
"zeros", # (shape, *, dtype=None, device=None)
85-
]
86-
87-
for func in creators_with_shape:
88-
exec(
89-
f"{func} = lambda shape, *args, **kwargs: dtensor(_cdt.create(shape, '{func}', '{__impl_str}', *args, **kwargs))"
90-
)
81+
for func in api.creators:
82+
if func in ["empty", "full", "ones", "zeros",]:
83+
exec(
84+
f"{func} = lambda shape, *args, **kwargs: dtensor(_cdt.create(shape, '{func}', '{__impl_str}', *args, **kwargs))"
85+
)
9186

9287
statisticals = [
9388
"max", # (x, /, *, axis=None, keepdims=False)
@@ -103,17 +98,3 @@
10398
exec(
10499
f"{func} = lambda this, **kwargs: dtensor(_cdt.reduce_op(this._t, '{func}', **kwargs))"
105100
)
106-
107-
108-
creators = [
109-
"arange", # (start, /, stop=None, step=1, *, dtype=None, device=None)
110-
"asarray", # (obj, /, *, dtype=None, device=None, copy=None)
111-
"empty_like", # (x, /, *, dtype=None, device=None)
112-
"eye", # (n_rows, n_cols=None, /, *, k=0, dtype=None, device=None)
113-
"from_dlpack", # (x, /)
114-
"full_like", # (x, /, fill_value, *, dtype=None, device=None)
115-
"linspace", # (start, stop, /, num, *, dtype=None, device=None, endpoint=True)
116-
"meshgrid", # (*arrays, indexing=’xy’)
117-
"ones_like", # (x, /, *, dtype=None, device=None)
118-
"zeros_like", # (x, /, *, dtype=None, device=None)
119-
]

ddptensor/array_api.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
creators = [
2+
"arange", # (start, /, stop=None, step=1, *, dtype=None, device=None)
3+
"asarray", # (obj, /, *, dtype=None, device=None, copy=None)
4+
"empty",
5+
"empty_like", # (x, /, *, dtype=None, device=None)
6+
"eye", # (n_rows, n_cols=None, /, *, k=0, dtype=None, device=None)
7+
"from_dlpack", # (x, /)
8+
"full",
9+
"full_like", # (x, /, fill_value, *, dtype=None, device=None)
10+
"linspace", # (start, stop, /, num, *, dtype=None, device=None, endpoint=True)
11+
"meshgrid", # (*arrays, indexing=’xy’)
12+
"ones",
13+
"ones_like", # (x, /, *, dtype=None, device=None)
14+
"zeros",
15+
"zeros_like", # (x, /, *, dtype=None, device=None)
16+
]
17+
18+
ew_binary_methods_inplace = [
19+
# inplace operators
20+
"__iadd__",
21+
"__iand__",
22+
"__ifloordiv__",
23+
"__ilshift__",
24+
"__imod__",
25+
"__imul__",
26+
"__ior__",
27+
"__ipow__",
28+
"__irshift__",
29+
"__isub__",
30+
"__itruediv__",
31+
"__ixor__",
32+
]

ddptensor/ddptensor.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from . import _ddptensor as _cdt
22
from ._ddptensor import float64, int64, fini
3+
from . import array_api as api
34

45
ew_binary_methods = [
56
"__add__", # (self, other, /)
@@ -36,22 +37,6 @@
3637
"__rxor__",
3738
]
3839

39-
ew_binary_methods_inplace = [
40-
# inplace operators
41-
"__iadd__",
42-
"__iand__",
43-
"__iflowdiv__",
44-
"__ilshift__",
45-
"__imod__",
46-
"__imul__",
47-
"__ior__",
48-
"__ipow__",
49-
"__irshift__",
50-
"__isub__",
51-
"__itruediv__",
52-
"__ixor__",
53-
]
54-
5540
ew_unary_methods = [
5641
"__abs__", # (self, /)
5742
"__invert__", # (self, /)
@@ -90,7 +75,7 @@ def __repr__(self):
9075
f"{method} = lambda self, other: dtensor(_cdt.ew_binary_op(self._t, '{method}', other._t if isinstance(other, dtensor) else other, True))"
9176
)
9277

93-
for method in ew_binary_methods_inplace:
78+
for method in api.ew_binary_methods_inplace:
9479
exec(
9580
f"{method} = lambda self, other: (self, _cdt.ew_binary_op_inplace(self._t, '{method}', other._t if isinstance(other, dtensor) else other))[0]"
9681
)

scripts/code_gen.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import array_api as api
2+
3+
print("""// SPDX-License-Identifier: BSD-3-Clause
4+
#pragma once
5+
#include <pybind11/pybind11.h>
6+
#include <pybind11/stl.h>
7+
namespace py = pybind11;
8+
""")
9+
10+
print("enum CreatorId : int {")
11+
for x in api.creators:
12+
print(f" {x.upper()},")
13+
print(" CREATOR_LAST")
14+
print("};\n")
15+
16+
print("enum IEWBinOpId : int {")
17+
for x in api.ew_binary_methods_inplace:
18+
x = x[2:-2] + " = CREATOR_LAST" if x == api.ew_binary_methods_inplace[0] else x[2:-2]
19+
print(f" {x.upper()},")
20+
print(" IEWBINOP_LAST")
21+
print("};\n")
22+
23+
print("void def_enums(py::module_ & m)\n{")
24+
25+
print(' py::enum_<CreatorId>(m, "CreatorId")')
26+
for x in api.creators:
27+
print(f' .value("{x.upper()}", {x.upper()})')
28+
print(" .export_values();\n")
29+
30+
print(' py::enum_<IEWBinOpId>(m, "IEWBinOpId")')
31+
for x in api.ew_binary_methods_inplace:
32+
print(f' .value("{x[2:-2].upper()}", {x[2:-2].upper()})')
33+
print(" .export_values();")
34+
35+
print("}")

setup.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,23 @@
55
from pybind11.setup_helpers import Pybind11Extension
66

77
mpiroot = os.environ.get('MPIROOT')
8+
mklroot = os.environ.get('MKLROOT')
9+
xtroot = os.getenv('XTROOT', 'third_party')
10+
11+
xt_includes = [jp(xtroot, x, "include") for x in ("xtl", "xsimd", "xtensor-blas", "xtensor")]
812

913
ext_modules = [
1014
Pybind11Extension(
1115
"ddptensor._ddptensor",
1216
glob("src/*.cpp"),
13-
include_dirs=[jp(mpiroot, "include"), jp("third_party", "bitsery", "include"), jp("src", "include"), ],
14-
extra_compile_args=["-DUSE_MKL", "-std=c++17", "-Wno-unused-but-set-variable", "-Wno-sign-compare", "-Wno-unused-local-typedefs", "-Wno-reorder", "-O0", "-g"],
15-
libraries=["mpi", "rt", "pthread", "dl", "mkl_intel_lp64", "mkl_intel_thread", "mkl_core", "iomp5", "m"],
16-
library_dirs=[jp(mpiroot, "lib")],
17-
language='c++'
17+
include_dirs = xt_includes + [jp(mpiroot, "include"), jp("third_party", "bitsery", "include"), jp("src", "include"), ],
18+
extra_compile_args = ["-DUSE_MKL", "-DXTENSOR_USE_XSIMD=1", "-DXTENSOR_USE_OPENMP=1",
19+
"-std=c++17",
20+
"-Wno-unused-but-set-variable", "-Wno-sign-compare", "-Wno-unused-local-typedefs", "-Wno-reorder",
21+
"-march=native", "-O0", "-g"],
22+
libraries = ["mpi", "mkl_intel_lp64", "mkl_intel_thread", "mkl_core", "iomp5", "pthread", "rt", "dl", "m"],
23+
library_dirs = [jp(mpiroot, "lib")],
24+
language = 'c++'
1825
),
1926
]
2027

src/ddptensor.cpp

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
#include <pybind11/pybind11.h>
1313
#include <pybind11/stl.h>
1414
namespace py = pybind11;
15+
using namespace pybind11::literals; // to bring _a
1516

1617
#include "ddptensor/ddptensor_impl.hpp"
1718
#include "ddptensor/MPITransceiver.hpp"
1819
#include "ddptensor/MPIMediator.hpp"
20+
#include "ddptensor/x.hpp"
1921

2022
/// Thensor which is closely following the Python API
2123
class dtensor
@@ -170,6 +172,62 @@ dtensor reduce_op(const dtensor & a, const char * op, const py::kwargs & kwargs)
170172
return dtensor(a._tensor->_reduce_op(op, dims));
171173
}
172174

175+
// ###################################################################
176+
// ###################################################################
177+
// ###################################################################
178+
179+
template<template<typename OD> class OpDispatch, typename... Ts>
180+
auto TypeDispatch(DType dt, Ts&&... args)
181+
{
182+
switch(dt) {
183+
case DT_FLOAT64:
184+
return OpDispatch<double>::op(std::forward<Ts>(args)...);
185+
case DT_FLOAT32:
186+
return OpDispatch<float>::op(std::forward<Ts>(args)...);
187+
case DT_INT64:
188+
return OpDispatch<int64_t>::op(std::forward<Ts>(args)...);
189+
case DT_INT32:
190+
return OpDispatch<int32_t>::op(std::forward<Ts>(args)...);
191+
case DT_INT16:
192+
return OpDispatch<int16_t>::op(std::forward<Ts>(args)...);
193+
case DT_UINT64:
194+
return OpDispatch<uint64_t>::op(std::forward<Ts>(args)...);
195+
case DT_UINT32:
196+
return OpDispatch<uint32_t>::op(std::forward<Ts>(args)...);
197+
case DT_UINT16:
198+
return OpDispatch<uint16_t>::op(std::forward<Ts>(args)...);
199+
/* FIXME
200+
case DT_BOOL:
201+
return OpDispatch<bool>::op(std::forward<Ts>(args)...);
202+
*/
203+
default:
204+
throw std::runtime_error("unknown dtype");
205+
}
206+
}
207+
208+
struct Creator
209+
{
210+
211+
static auto create_from_shape(CreatorId op, shape_type && shape, DType dtype=DT_FLOAT64)
212+
{
213+
return TypeDispatch<x::Creator>(dtype, op, std::forward<shape_type>(shape));
214+
}
215+
216+
static auto full(shape_type && shape, py::object && val, DType dtype=DT_FLOAT64)
217+
{
218+
auto op = FULL;
219+
return TypeDispatch<x::Creator>(dtype, op, std::forward<shape_type>(shape), std::forward<py::object>(val));
220+
}
221+
};
222+
223+
struct IEWBinOp
224+
{
225+
static auto op(IEWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
226+
{
227+
return TypeDispatch<x::IEWBinOp>(a->dtype(), op, a, b);
228+
}
229+
};
230+
173231
rank_type myrank()
174232
{
175233
return theTransceiver->rank();
@@ -192,14 +250,18 @@ PYBIND11_MODULE(_ddptensor, m) {
192250

193251
m.doc() = "A partitioned and distributed tensor";
194252

195-
/* static const DType _DT_FLOAT64 = DT_FLOAT64;
196-
static const DType _DT_INT64 = DT_INT64;
197-
static const DType _DT_BOOL = DT_BOOL;
253+
py::class_<Creator>(m, "Creator")
254+
.def("create_from_shape", &Creator::create_from_shape)
255+
.def("full", &Creator::full);
256+
257+
py::class_<IEWBinOp>(m, "IEWBinOp")
258+
.def("op", &IEWBinOp::op);
259+
260+
py::class_<x::DPTensorBaseX, x::DPTensorBaseX::ptr_type>(m, "DPTensorX")
261+
.def("__repr__", &x::DPTensorBaseX::__repr__);
262+
263+
def_enums(m);
198264

199-
m.def_readonly("float64", &_DT_FLOAT64);
200-
m.def_readonly("int64", &_DT_INT64);
201-
m.def_readonly("bool", &_DT_BOOL);
202-
*/
203265
py::enum_<DType>(m, "dtype")
204266
.value("float64", DT_FLOAT64)
205267
.value("int64", DT_INT64)

src/include/ddptensor/p2c_ids.hpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// SPDX-License-Identifier: BSD-3-Clause
2+
#pragma once
3+
#include <pybind11/pybind11.h>
4+
#include <pybind11/stl.h>
5+
namespace py = pybind11;
6+
7+
enum CreatorId : int {
8+
ARANGE,
9+
ASARRAY,
10+
EMPTY,
11+
EMPTY_LIKE,
12+
EYE,
13+
FROM_DLPACK,
14+
FULL,
15+
FULL_LIKE,
16+
LINSPACE,
17+
MESHGRID,
18+
ONES,
19+
ONES_LIKE,
20+
ZEROS,
21+
ZEROS_LIKE,
22+
CREATOR_LAST
23+
};
24+
25+
enum IEWBinOpId : int {
26+
IADD = CREATOR_LAST,
27+
IAND,
28+
IFLOORDIV,
29+
ILSHIFT,
30+
IMOD,
31+
IMUL,
32+
IOR,
33+
IPOW,
34+
IRSHIFT,
35+
ISUB,
36+
ITRUEDIV,
37+
IXOR,
38+
IEWBINOP_LAST
39+
};
40+
41+
void def_enums(py::module_ & m)
42+
{
43+
py::enum_<CreatorId>(m, "CreatorId")
44+
.value("ARANGE", ARANGE)
45+
.value("ASARRAY", ASARRAY)
46+
.value("EMPTY", EMPTY)
47+
.value("EMPTY_LIKE", EMPTY_LIKE)
48+
.value("EYE", EYE)
49+
.value("FROM_DLPACK", FROM_DLPACK)
50+
.value("FULL", FULL)
51+
.value("FULL_LIKE", FULL_LIKE)
52+
.value("LINSPACE", LINSPACE)
53+
.value("MESHGRID", MESHGRID)
54+
.value("ONES", ONES)
55+
.value("ONES_LIKE", ONES_LIKE)
56+
.value("ZEROS", ZEROS)
57+
.value("ZEROS_LIKE", ZEROS_LIKE)
58+
.export_values();
59+
60+
py::enum_<IEWBinOpId>(m, "IEWBinOpId")
61+
.value("IADD", IADD)
62+
.value("IAND", IAND)
63+
.value("IFLOORDIV", IFLOORDIV)
64+
.value("ILSHIFT", ILSHIFT)
65+
.value("IMOD", IMOD)
66+
.value("IMUL", IMUL)
67+
.value("IOR", IOR)
68+
.value("IPOW", IPOW)
69+
.value("IRSHIFT", IRSHIFT)
70+
.value("ISUB", ISUB)
71+
.value("ITRUEDIV", ITRUEDIV)
72+
.value("IXOR", IXOR)
73+
.export_values();
74+
}

0 commit comments

Comments
 (0)