22
33/*
44 A Distributed Data-Parallel Tensor for Python, following the array API.
5- We have a 3-level hierachy
6- 1. tensor_i: the abstract interface, not bound to types (like numpy)
7- 2. dtensor_impl: a typed template layer with the actual functionality
8- 3. dtensor: the PYthon API delegating to untyped tensor_i
9- We use pybind11.
5+
6+ XTensor handles the actual functionality on each process.
7+ pybind11 handles the bridge to Python.
8+
9+ We bridge dynamic dtypes of the Python array through dynamic type dispatch (TypeDispatch).
10+ This means the compiler will instantiate the full functionality for all elements types.
11+ Within kernels we dispatch the operation type by enum values (see x.hpp).
12+ tensor_i is an abstract class to hide the element type which of the actual tensor.
13+ The concrete tensor implementation (DPTensorX, x.hpp) requires the element type
14+ as a template parameter.
1015 */
1116
1217#include < pybind11/pybind11.h>
@@ -18,17 +23,17 @@ using namespace pybind11::literals; // to bring _a
1823#include " ddptensor/MPIMediator.hpp"
1924#include " ddptensor/x.hpp"
2025
21- // ###################################################################
22- // ###################################################################
23- // ###################################################################
24-
26+ // Dependent on dt, dispatch arguments to a operation class.
27+ // The operation must
28+ // * be a template class accepting the element type as argument
29+ // * implement one or more "op" methods matching the given arguments (args)
30+ // All arguments other than dt are opaquely passed to the operation.
2531template <template <typename OD > class OpDispatch , typename ... Ts>
2632auto TypeDispatch (DType dt, Ts&&... args)
2733{
2834 switch (dt) {
2935 case DT_FLOAT64 :
3036 return OpDispatch<double >::op (std::forward<Ts>(args)...);
31- #if 0
3237 case DT_INT64 :
3338 return OpDispatch<int64_t >::op (std::forward<Ts>(args)...);
3439 case DT_FLOAT32 :
@@ -43,7 +48,6 @@ auto TypeDispatch(DType dt, Ts&&... args)
4348 return OpDispatch<uint32_t >::op (std::forward<Ts>(args)...);
4449 case DT_UINT16 :
4550 return OpDispatch<uint16_t >::op (std::forward<Ts>(args)...);
46- #endif
4751 /* FIXME
4852 case DT_BOOL:
4953 return OpDispatch<bool>::op(std::forward<Ts>(args)...);
@@ -53,6 +57,9 @@ auto TypeDispatch(DType dt, Ts&&... args)
5357 }
5458}
5559
60+ // #########################################################################
61+ // The following classes are wrappers bridging pybind11 defs to TypeDispatch
62+
5663struct Creator
5764{
5865
@@ -67,7 +74,7 @@ struct Creator
6774 return TypeDispatch<x::Creator>(dtype, op, std::forward<shape_type>(shape), std::forward<py::object>(val));
6875 }
6976};
70- # if 0
77+
7178struct IEWBinOp
7279{
7380 static auto op (IEWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
@@ -92,43 +99,29 @@ struct EWUnyOp
9299 }
93100};
94101
95- struct UnyOp
96- {
97- static bool __bool__(x::DPTensorBaseX::ptr_type a)
98- {
99- return TypeDispatch<x::UnyOp>(a->dtype(), a, true);
100- }
101-
102- static double __float__(x::DPTensorBaseX::ptr_type a)
103- {
104- return TypeDispatch<x::UnyOp>(a->dtype(), a, double(1));
105- }
106-
107- static int64_t __int__(x::DPTensorBaseX::ptr_type a)
108- {
109- return TypeDispatch<x::UnyOp>(a->dtype(), a, int64_t(1));
110- }
111- };
112-
113102struct ReduceOp
114103{
115104 static auto op (ReduceOpId op, x::DPTensorBaseX::ptr_type a, const dim_vec_type & dim)
116105 {
117106 return TypeDispatch<x::ReduceOp>(a->dtype (), op, a, dim);
118107 }
119108};
120- #endif
121109
122110struct GetItem
123111{
124- static auto op (x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
112+ static auto __getitem__ (x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
125113 {
126114 return TypeDispatch<x::GetItem>(a->dtype (), a, NDSlice (v));
127115 }
116+ static auto get_slice (x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
117+ {
118+ return TypeDispatch<x::SPMD >(a->dtype (), a, NDSlice (v));
119+ }
128120};
121+
129122struct SetItem
130123{
131- static auto op (x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v, x::DPTensorBaseX::ptr_type b)
124+ static auto __setitem__ (x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v, x::DPTensorBaseX::ptr_type b)
132125 {
133126 return TypeDispatch<x::SetItem>(a->dtype (), a, NDSlice (v), b);
134127 }
@@ -142,6 +135,7 @@ rank_type myrank()
142135Transceiver * theTransceiver = nullptr ;
143136Mediator * theMediator = nullptr ;
144137
138+ // users currently need to call fini to make MPI terminate gracefully
145139void fini ()
146140{
147141 delete theMediator;
@@ -150,6 +144,8 @@ void fini()
150144 theTransceiver = nullptr ;
151145}
152146
147+ // #########################################################################
148+ // Finally our Python module
153149PYBIND11_MODULE (_ddptensor, m) {
154150 theTransceiver = new MPITransceiver ();
155151 theMediator = new MPIMediator ();
@@ -164,21 +160,17 @@ PYBIND11_MODULE(_ddptensor, m) {
164160 .value (" bool" , DT_BOOL )
165161 .export_values ();
166162
167- m.def (" fini" , &fini);
168- m.def (" myrank" , &myrank);
163+ m.def (" fini" , &fini)
164+ .def (" myrank" , &myrank)
165+ .def (" _get_slice" , &GetItem::get_slice);
169166
170167 py::class_<Creator>(m, " Creator" )
171168 .def (" create_from_shape" , &Creator::create_from_shape)
172169 .def (" full" , &Creator::full);
173- # if 0
170+
174171 py::class_<EWUnyOp>(m, " EWUnyOp" )
175172 .def (" op" , &EWUnyOp::op);
176173
177- py::class_<UnyOp>(m, "UnyOp")
178- .def("__bool__", &UnyOp::__bool__)
179- .def("__float__", &UnyOp::__float__)
180- .def("__int__", &UnyOp::__int__);
181-
182174 py::class_<IEWBinOp>(m, " IEWBinOp" )
183175 .def (" op" , &IEWBinOp::op);
184176
@@ -187,12 +179,20 @@ PYBIND11_MODULE(_ddptensor, m) {
187179
188180 py::class_<ReduceOp>(m, " ReduceOp" )
189181 .def (" op" , &ReduceOp::op);
190- #endif
191182
192183 py::class_<x::DPTensorBaseX, x::DPTensorBaseX::ptr_type>(m, " DPTensorX" )
184+ .def_property_readonly (" dtype" , &x::DPTensorBaseX::dtype)
185+ .def_property_readonly (" shape" , &x::DPTensorBaseX::shape)
186+ .def_property_readonly (" size" , &x::DPTensorBaseX::size)
187+ .def_property_readonly (" ndim" , &x::DPTensorBaseX::ndim)
188+ .def (" __bool__" , &x::DPTensorBaseX::__bool__)
189+ .def (" __float__" , &x::DPTensorBaseX::__float__)
190+ .def (" __int__" , &x::DPTensorBaseX::__int__)
191+ .def (" __index__" , &x::DPTensorBaseX::__int__)
192+ .def (" __len__" , &x::DPTensorBaseX::__len__)
193193 .def (" __repr__" , &x::DPTensorBaseX::__repr__)
194- .def (" __getitem__" , &GetItem::op )
195- .def (" __setitem__" , &SetItem::op );
194+ .def (" __getitem__" , &GetItem::__getitem__ )
195+ .def (" __setitem__" , &SetItem::__setitem__ );
196196
197197#if 0
198198 py::class_<dtensor>(m, "dtensor")
0 commit comments