1414namespace py = pybind11;
1515using namespace pybind11 ::literals; // to bring _a
1616
17- #include " ddptensor/ddptensor_impl.hpp"
1817#include " ddptensor/MPITransceiver.hpp"
1918#include " ddptensor/MPIMediator.hpp"
2019#include " ddptensor/x.hpp"
2120
22- // / Thensor which is closely following the Python API
23- class dtensor
24- {
25- public:
26- typedef tensor_i::ptr_type ptr_type;
27-
28- dtensor (dtensor &&) = default ;
29-
30- dtensor (const shape_type & shape, DType dt)
31- : _tensor(create_dtensor(PVSlice(shape), shape, dt))
32- {
33- }
34-
35- dtensor (const shape_type & shape, const char * create, const char * mod, py::args args, const py::kwargs & kwargs)
36- : _tensor(create_dtensor(PVSlice(shape), shape, create, mod, args, kwargs))
37- {
38- }
39-
40- dtensor (ptr_type && t)
41- : _tensor(std::move(t))
42- {
43- }
44-
45- shape_type shape () const
46- {
47- return _tensor->shape ();
48- }
49-
50- DType dtype () const
51- {
52- return _tensor->dtype ();
53- }
54-
55- py::object get_slice (const std::vector<py::slice> & v)
56- {
57- return _tensor->get_slice (NDSlice (v));
58- }
59-
60- dtensor __getitem__ (const NDIndex & v)
61- {
62- return dtensor (_tensor->__getitem__ (NDSlice (v)));
63- }
64-
65- dtensor __getitem__ (int64_t i)
66- {
67- return __getitem__ (NDIndex (1 ,i));
68- }
69-
70- dtensor __getitem__ (const std::vector<py::slice> & v)
71- {
72- return dtensor (_tensor->__getitem__ (NDSlice (v)));
73- }
74-
75- dtensor __getitem__ (const py::slice & s)
76- {
77- return dtensor (_tensor->__getitem__ (NDSlice (std::vector<py::slice>(1 , s))));
78- }
79-
80- void __setitem__ (const std::vector<py::slice> & v, const dtensor * ob)
81- {
82- // const dtensor * ob = b.cast<const dtensor*>();
83- _tensor->__setitem__ (NDSlice (v), ob->_tensor );
84- }
85-
86- std::string __repr__ () const
87- {
88- return _tensor->__repr__ ();
89- }
90-
91- // "__array_namespace__", # (self, /, *, api_version=None)
92- // "__dlpack__", # (self, /, *, stream=None)
93- // "__dlpack_device__", # (self, /)
94-
95- bool __bool__ ()
96- {
97- return _tensor->__bool__ ();
98- }
99-
100- double __float__ ()
101- {
102- return _tensor->__float__ ();
103- }
104-
105- int64_t __int__ ()
106- {
107- return _tensor->__int__ ();
108- }
109-
110- uint64_t __len__ ()
111- {
112- auto shp = _tensor->shape ();
113- return shp.empty () ? 1 : shp[0 ];
114- }
115-
116- ptr_type _tensor;
117- };
118-
119- dtensor create (const shape_type & shape, const char * op, const char * mod, py::args args, const py::kwargs& kwargs)
120- {
121- return dtensor (shape, op, mod, args, kwargs);
122- }
123-
124- dtensor ew_op (const dtensor & a, const char * op, const char * mod, py::args args, const py::kwargs& kwargs)
125- {
126- return dtensor (a._tensor ->_ew_op (op, mod, args, kwargs));
127- }
128-
129- dtensor ew_unary_op (const dtensor & a, const char * op, bool is_method)
130- {
131- return dtensor (a._tensor ->_ew_unary_op (op, is_method));
132- }
133-
134- dtensor ew_binary_op (const dtensor & a, const char * op, const py::object & b, bool is_method)
135- {
136- const dtensor * ob = nullptr ;
137- try {
138- ob = b.cast <const dtensor*>();
139- } catch (...) {
140- return dtensor (a._tensor ->_ew_binary_op (op, b, is_method));
141- }
142- return dtensor (a._tensor ->_ew_binary_op (op, ob->_tensor , is_method));
143- }
144-
145- dtensor & ew_binary_op_inplace (dtensor & a, const char * op, const py::object & b)
146- {
147- const dtensor * ob = nullptr ;
148- try {
149- ob = b.cast <const dtensor*>();
150- } catch (...) {
151- a._tensor ->_ew_binary_op_inplace (op, b);
152- }
153- a._tensor ->_ew_binary_op_inplace (op, ob->_tensor );
154- return a;
155- }
156-
157- dtensor reduce_op (const dtensor & a, const char * op, const py::kwargs & kwargs)
158- {
159- dim_vec_type dims;
160- if (kwargs.contains (" axis" )) {
161- auto ax = kwargs[" axis" ];
162- if (!ax.is_none ()) {
163- try {
164- auto a = ax.cast <dim_vec_type::value_type>();
165- dims.resize (1 );
166- dims[0 ] = a;
167- } catch (...) {
168- dims = ax.cast <dim_vec_type>();
169- }
170- }
171- }
172- return dtensor (a._tensor ->_reduce_op (op, dims));
173- }
174-
17521// ###################################################################
17622// ###################################################################
17723// ###################################################################
@@ -182,10 +28,11 @@ auto TypeDispatch(DType dt, Ts&&... args)
18228 switch (dt) {
18329 case DT_FLOAT64 :
18430 return OpDispatch<double >::op (std::forward<Ts>(args)...);
185- case DT_FLOAT32 :
186- return OpDispatch<float >::op (std::forward<Ts>(args)...);
31+ #if 0
18732 case DT_INT64:
18833 return OpDispatch<int64_t>::op(std::forward<Ts>(args)...);
34+ case DT_FLOAT32:
35+ return OpDispatch<float>::op(std::forward<Ts>(args)...);
18936 case DT_INT32:
19037 return OpDispatch<int32_t>::op(std::forward<Ts>(args)...);
19138 case DT_INT16:
@@ -196,6 +43,7 @@ auto TypeDispatch(DType dt, Ts&&... args)
19643 return OpDispatch<uint32_t>::op(std::forward<Ts>(args)...);
19744 case DT_UINT16:
19845 return OpDispatch<uint16_t>::op(std::forward<Ts>(args)...);
46+ #endif
19947 /* FIXME
20048 case DT_BOOL:
20149 return OpDispatch<bool>::op(std::forward<Ts>(args)...);
@@ -252,6 +100,14 @@ struct ReduceOp
252100 }
253101};
254102
103+ struct GetItem
104+ {
105+ static auto op (x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
106+ {
107+ return TypeDispatch<x::GetItem>(a->dtype (), a, NDSlice (v));
108+ }
109+ };
110+
255111rank_type myrank ()
256112{
257113 return theTransceiver->rank ();
@@ -274,6 +130,17 @@ PYBIND11_MODULE(_ddptensor, m) {
274130
275131 m.doc () = " A partitioned and distributed tensor" ;
276132
133+ def_enums (m);
134+
135+ py::enum_<DType>(m, " dtype" )
136+ .value (" float64" , DT_FLOAT64 )
137+ .value (" int64" , DT_INT64 )
138+ .value (" bool" , DT_BOOL )
139+ .export_values ();
140+
141+ m.def (" fini" , &fini);
142+ m.def (" myrank" , &myrank);
143+
277144 py::class_<Creator>(m, " Creator" )
278145 .def (" create_from_shape" , &Creator::create_from_shape)
279146 .def (" full" , &Creator::full);
@@ -291,26 +158,10 @@ PYBIND11_MODULE(_ddptensor, m) {
291158 .def (" op" , &ReduceOp::op);
292159
293160 py::class_<x::DPTensorBaseX, x::DPTensorBaseX::ptr_type>(m, " DPTensorX" )
294- .def (" __repr__" , &x::DPTensorBaseX::__repr__);
295-
296- def_enums (m);
297-
298- py::enum_<DType>(m, " dtype" )
299- .value (" float64" , DT_FLOAT64 )
300- .value (" int64" , DT_INT64 )
301- .value (" bool" , DT_BOOL )
302- .export_values ();
303-
304- m.def (" fini" , &fini);
305- m.def (" myrank" , &myrank);
306-
307- m.def (" create" , &create);
308- m.def (" ew_op" , &ew_op);
309- m.def (" ew_unary_op" , &ew_unary_op);
310- m.def (" ew_binary_op" , &ew_binary_op);
311- m.def (" ew_binary_op_inplace" , &ew_binary_op_inplace);
312- m.def (" reduce_op" , &reduce_op);
161+ .def (" __repr__" , &x::DPTensorBaseX::__repr__)
162+ .def (" __getitem__" , &GetItem::op);
313163
164+ #if 0
314165 py::class_<dtensor>(m, "dtensor")
315166 .def(py::init<const shape_type &, DType>())
316167 .def_property_readonly("dtype", &dtensor::dtype)
@@ -326,7 +177,7 @@ PYBIND11_MODULE(_ddptensor, m) {
326177 .def("__getitem__", py::overload_cast<int64_t>(&dtensor::__getitem__))
327178 .def("__setitem__", &dtensor::__setitem__)
328179 .def("get_slice", &dtensor::get_slice);
329-
180+ # endif
330181 // py::class_<dpdlpack>(m, "dpdlpack")
331182 // .def("__dlpack__", &dpdlpack.__dlpack__);
332183}
0 commit comments