Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit b2a3ed4

Browse files
committed
adding basic __getitem__
1 parent 7635c3a commit b2a3ed4

8 files changed

Lines changed: 99 additions & 195 deletions

File tree

ddptensor/ddptensor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,8 @@ def __repr__(self):
5656
# f"{att} = property(lambda self: self._t.{att})"
5757
# )
5858

59-
# def __getitem__(self, *args):
60-
# x = self._t.__getitem__(*args)
61-
# return dtensor(x)
59+
def __getitem__(self, *args):
60+
return dtensor(self._t.__getitem__(*args))
6261

6362
# def __setitem__(self, key, value):
6463
# x = self._t.__setitem__(key, value._t if isinstance(value, dtensor) else value)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
extra_compile_args = ["-DUSE_MKL", "-DXTENSOR_USE_XSIMD=1", "-DXTENSOR_USE_OPENMP=1",
1919
"-std=c++17", "-fopenmp",
2020
"-Wno-unused-but-set-variable", "-Wno-sign-compare", "-Wno-unused-local-typedefs", "-Wno-reorder",
21-
"-march=native",], # "-O0", "-g"],
21+
"-march=native", "-O0", "-g"],
2222
libraries = ["mpi", "mkl_intel_lp64", "mkl_intel_thread", "mkl_core", "iomp5", "pthread", "rt", "dl", "m"],
2323
library_dirs = [jp(mpiroot, "lib")],
2424
language = 'c++'

src/ddptensor.cpp

Lines changed: 27 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -14,164 +14,10 @@
1414
namespace py = pybind11;
1515
using namespace pybind11::literals; // to bring _a
1616

17-
#include "ddptensor/ddptensor_impl.hpp"
1817
#include "ddptensor/MPITransceiver.hpp"
1918
#include "ddptensor/MPIMediator.hpp"
2019
#include "ddptensor/x.hpp"
2120

22-
/// Thensor which is closely following the Python API
23-
class dtensor
24-
{
25-
public:
26-
typedef tensor_i::ptr_type ptr_type;
27-
28-
dtensor(dtensor &&) = default;
29-
30-
dtensor(const shape_type & shape, DType dt)
31-
: _tensor(create_dtensor(PVSlice(shape), shape, dt))
32-
{
33-
}
34-
35-
dtensor(const shape_type & shape, const char * create, const char * mod, py::args args, const py::kwargs & kwargs)
36-
: _tensor(create_dtensor(PVSlice(shape), shape, create, mod, args, kwargs))
37-
{
38-
}
39-
40-
dtensor(ptr_type && t)
41-
: _tensor(std::move(t))
42-
{
43-
}
44-
45-
shape_type shape() const
46-
{
47-
return _tensor->shape();
48-
}
49-
50-
DType dtype() const
51-
{
52-
return _tensor->dtype();
53-
}
54-
55-
py::object get_slice(const std::vector<py::slice> & v)
56-
{
57-
return _tensor->get_slice(NDSlice(v));
58-
}
59-
60-
dtensor __getitem__(const NDIndex & v)
61-
{
62-
return dtensor(_tensor->__getitem__(NDSlice(v)));
63-
}
64-
65-
dtensor __getitem__(int64_t i)
66-
{
67-
return __getitem__(NDIndex(1,i));
68-
}
69-
70-
dtensor __getitem__(const std::vector<py::slice> & v)
71-
{
72-
return dtensor(_tensor->__getitem__(NDSlice(v)));
73-
}
74-
75-
dtensor __getitem__(const py::slice & s)
76-
{
77-
return dtensor(_tensor->__getitem__(NDSlice(std::vector<py::slice>(1, s))));
78-
}
79-
80-
void __setitem__(const std::vector<py::slice> & v, const dtensor * ob)
81-
{
82-
//const dtensor * ob = b.cast<const dtensor*>();
83-
_tensor->__setitem__(NDSlice(v), ob->_tensor);
84-
}
85-
86-
std::string __repr__() const
87-
{
88-
return _tensor->__repr__();
89-
}
90-
91-
// "__array_namespace__", # (self, /, *, api_version=None)
92-
// "__dlpack__", # (self, /, *, stream=None)
93-
// "__dlpack_device__", # (self, /)
94-
95-
bool __bool__()
96-
{
97-
return _tensor->__bool__();
98-
}
99-
100-
double __float__()
101-
{
102-
return _tensor->__float__();
103-
}
104-
105-
int64_t __int__()
106-
{
107-
return _tensor->__int__();
108-
}
109-
110-
uint64_t __len__()
111-
{
112-
auto shp = _tensor->shape();
113-
return shp.empty() ? 1 : shp[0];
114-
}
115-
116-
ptr_type _tensor;
117-
};
118-
119-
dtensor create(const shape_type & shape, const char * op, const char * mod, py::args args, const py::kwargs& kwargs)
120-
{
121-
return dtensor(shape, op, mod, args, kwargs);
122-
}
123-
124-
dtensor ew_op(const dtensor & a, const char * op, const char * mod, py::args args, const py::kwargs& kwargs)
125-
{
126-
return dtensor(a._tensor->_ew_op(op, mod, args, kwargs));
127-
}
128-
129-
dtensor ew_unary_op(const dtensor & a, const char * op, bool is_method)
130-
{
131-
return dtensor(a._tensor->_ew_unary_op(op, is_method));
132-
}
133-
134-
dtensor ew_binary_op(const dtensor & a, const char * op, const py::object & b, bool is_method)
135-
{
136-
const dtensor * ob = nullptr;
137-
try {
138-
ob = b.cast<const dtensor*>();
139-
} catch(...) {
140-
return dtensor(a._tensor->_ew_binary_op(op, b, is_method));
141-
}
142-
return dtensor(a._tensor->_ew_binary_op(op, ob->_tensor, is_method));
143-
}
144-
145-
dtensor & ew_binary_op_inplace(dtensor & a, const char * op, const py::object & b)
146-
{
147-
const dtensor * ob = nullptr;
148-
try {
149-
ob = b.cast<const dtensor*>();
150-
} catch(...) {
151-
a._tensor->_ew_binary_op_inplace(op, b);
152-
}
153-
a._tensor->_ew_binary_op_inplace(op, ob->_tensor);
154-
return a;
155-
}
156-
157-
dtensor reduce_op(const dtensor & a, const char * op, const py::kwargs & kwargs)
158-
{
159-
dim_vec_type dims;
160-
if(kwargs.contains("axis")) {
161-
auto ax = kwargs["axis"];
162-
if(!ax.is_none()) {
163-
try {
164-
auto a = ax.cast<dim_vec_type::value_type>();
165-
dims.resize(1);
166-
dims[0] = a;
167-
} catch(...) {
168-
dims = ax.cast<dim_vec_type>();
169-
}
170-
}
171-
}
172-
return dtensor(a._tensor->_reduce_op(op, dims));
173-
}
174-
17521
// ###################################################################
17622
// ###################################################################
17723
// ###################################################################
@@ -182,10 +28,11 @@ auto TypeDispatch(DType dt, Ts&&... args)
18228
switch(dt) {
18329
case DT_FLOAT64:
18430
return OpDispatch<double>::op(std::forward<Ts>(args)...);
185-
case DT_FLOAT32:
186-
return OpDispatch<float>::op(std::forward<Ts>(args)...);
31+
#if 0
18732
case DT_INT64:
18833
return OpDispatch<int64_t>::op(std::forward<Ts>(args)...);
34+
case DT_FLOAT32:
35+
return OpDispatch<float>::op(std::forward<Ts>(args)...);
18936
case DT_INT32:
19037
return OpDispatch<int32_t>::op(std::forward<Ts>(args)...);
19138
case DT_INT16:
@@ -196,6 +43,7 @@ auto TypeDispatch(DType dt, Ts&&... args)
19643
return OpDispatch<uint32_t>::op(std::forward<Ts>(args)...);
19744
case DT_UINT16:
19845
return OpDispatch<uint16_t>::op(std::forward<Ts>(args)...);
46+
#endif
19947
/* FIXME
20048
case DT_BOOL:
20149
return OpDispatch<bool>::op(std::forward<Ts>(args)...);
@@ -252,6 +100,14 @@ struct ReduceOp
252100
}
253101
};
254102

103+
struct GetItem
104+
{
105+
static auto op(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
106+
{
107+
return TypeDispatch<x::GetItem>(a->dtype(), a, NDSlice(v));
108+
}
109+
};
110+
255111
rank_type myrank()
256112
{
257113
return theTransceiver->rank();
@@ -274,6 +130,17 @@ PYBIND11_MODULE(_ddptensor, m) {
274130

275131
m.doc() = "A partitioned and distributed tensor";
276132

133+
def_enums(m);
134+
135+
py::enum_<DType>(m, "dtype")
136+
.value("float64", DT_FLOAT64)
137+
.value("int64", DT_INT64)
138+
.value("bool", DT_BOOL)
139+
.export_values();
140+
141+
m.def("fini", &fini);
142+
m.def("myrank", &myrank);
143+
277144
py::class_<Creator>(m, "Creator")
278145
.def("create_from_shape", &Creator::create_from_shape)
279146
.def("full", &Creator::full);
@@ -291,26 +158,10 @@ PYBIND11_MODULE(_ddptensor, m) {
291158
.def("op", &ReduceOp::op);
292159

293160
py::class_<x::DPTensorBaseX, x::DPTensorBaseX::ptr_type>(m, "DPTensorX")
294-
.def("__repr__", &x::DPTensorBaseX::__repr__);
295-
296-
def_enums(m);
297-
298-
py::enum_<DType>(m, "dtype")
299-
.value("float64", DT_FLOAT64)
300-
.value("int64", DT_INT64)
301-
.value("bool", DT_BOOL)
302-
.export_values();
303-
304-
m.def("fini", &fini);
305-
m.def("myrank", &myrank);
306-
307-
m.def("create", &create);
308-
m.def("ew_op", &ew_op);
309-
m.def("ew_unary_op", &ew_unary_op);
310-
m.def("ew_binary_op", &ew_binary_op);
311-
m.def("ew_binary_op_inplace", &ew_binary_op_inplace);
312-
m.def("reduce_op", &reduce_op);
161+
.def("__repr__", &x::DPTensorBaseX::__repr__)
162+
.def("__getitem__", &GetItem::op);
313163

164+
#if 0
314165
py::class_<dtensor>(m, "dtensor")
315166
.def(py::init<const shape_type &, DType>())
316167
.def_property_readonly("dtype", &dtensor::dtype)
@@ -326,7 +177,7 @@ PYBIND11_MODULE(_ddptensor, m) {
326177
.def("__getitem__", py::overload_cast<int64_t>(&dtensor::__getitem__))
327178
.def("__setitem__", &dtensor::__setitem__)
328179
.def("get_slice", &dtensor::get_slice);
329-
180+
#endif
330181
//py::class_<dpdlpack>(m, "dpdlpack")
331182
// .def("__dlpack__", &dpdlpack.__dlpack__);
332183
}

src/include/ddptensor/NDSlice.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ class NDSlice {
8989
return ret;
9090
}
9191

92+
std::vector<Slice> slices() const
93+
{
94+
return _slice_vec;
95+
}
96+
9297
///
9398
/// @return total number of elements represented by the nd-slice
9499
///
@@ -202,7 +207,7 @@ class NDSlice {
202207
NDSlice trim_shift(const NDSlice & t_slc, const NDSlice & s_slc) const
203208
{
204209
return _convert([&](uint64_t i) {
205-
return _slice_vec[i].trim(t_slc.dim(i)._start, t_slc.dim(i)._end, s_slc.dim(i)._start);
210+
return _slice_vec[i].trim(t_slc.dim(i)._start, t_slc.dim(i)._end).shift(s_slc.dim(i)._start);
206211
} );
207212
}
208213

src/include/ddptensor/PVSlice.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class PVSlice
175175
return _slice.trim(_base->split_dim(), rank * _base->offset(), (rank+1) * _base->offset());
176176
}
177177

178-
NDSlice local_slice_of_rank(rank_type rank) const
178+
NDSlice local_slice_of_rank(rank_type rank = theTransceiver->rank()) const
179179
{
180180
if(_base->split_dim() == NOSPLIT) {
181181
return rank == theTransceiver->rank() ? slice() : NDSlice();

src/include/ddptensor/Slice.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ struct Slice
201201
/// @param s Start index to trim to
202202
/// @param e End index to trim to
203203
///
204-
Slice trim(value_type s, value_type e, value_type shift = 0) const
204+
Slice trim(value_type s, value_type e) const
205205
{
206206
assert(_step > 0);
207207
auto start = _start;
@@ -211,7 +211,7 @@ struct Slice
211211
else start = s;
212212
}
213213
auto end = std::min(e, _end);
214-
return {start-shift, end-shift, _step};
214+
return {std::min(start, end), end, _step};
215215
}
216216

217217
Slice map(const Slice & slc) const

0 commit comments

Comments
 (0)