Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 3146a1e

Browse files
committed
adding spmd.get_slice, more attributes, comments, cleanup
1 parent 8100d4f commit 3146a1e

12 files changed

Lines changed: 282 additions & 745 deletions

File tree

ddptensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .ddptensor import float64, int64, fini, dtensor
33
from os import getenv
44
from . import array_api as api
5+
from . import spmd
56

67
#__impl_str = getenv("DDPNP_ARRAY", 'numpy')
78
#exec(f"import {__impl_str} as __impl")

ddptensor/array_api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,5 +152,20 @@
152152
"__dlpack_device__", # (self, /)
153153
"__float__", # (self, /)
154154
"__int__", # (self, /)
155+
"__index__",
155156
"__len__", # (self, /)
156157
]
158+
159+
misc_methods = [
160+
"__getitem__",
161+
"__setitem__",
162+
]
163+
164+
attributes = [
165+
"dtype",
166+
"shape",
167+
"device",
168+
"ndim",
169+
"size",
170+
"T"
171+
]

ddptensor/ddptensor.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from ._ddptensor import float64, int64, fini
33
from . import array_api as api
44

5-
t_attributes = ["dtype", "shape", ] #"device", "ndim", "size", "T"]
6-
75
#def try_except(func, *args, **kwargs):
86
# try:
97
# return func(*args, **kwargs)
@@ -37,19 +35,16 @@ def __repr__(self):
3735

3836
for method in api.unary_methods:
3937
exec(
40-
f"{method} = lambda self: _cdt.UnyOp.{method}(self._t)"
38+
f"{method} = lambda self: self._t.{method}()"
4139
)
4240

43-
# for att in t_attributes:
44-
# exec(
45-
# f"{att} = property(lambda self: self._t.{att})"
46-
# )
41+
for att in api.attributes:
42+
exec(
43+
f"{att} = property(lambda self: self._t.{att})"
44+
)
4745

4846
def __getitem__(self, *args):
4947
return dtensor(self._t.__getitem__(*args))
5048

5149
def __setitem__(self, key, value):
5250
x = self._t.__setitem__(key, value._t) # if isinstance(value, dtensor) else value)
53-
54-
# def get_slice(self, *args):
55-
# return self._t.get_slice(*args)

ddptensor/spmd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import _ddptensor as _cdt
2+
3+
def get_slice(self, *args):
4+
return _cdt._get_slice(self._t, *args)

scripts/code_gen.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import array_api as api
22

3-
print("""// SPDX-License-Identifier: BSD-3-Clause
3+
print("""// Auto-generated file
4+
// #######################################################
5+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
6+
// !! DO NOT EDIT, USE ../scripts/code_gen.py TO UPDATE !!
7+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
8+
// #######################################################
9+
// SPDX-License-Identifier: BSD-3-Clause
410
#pragma once
511
#include <pybind11/pybind11.h>
612
#include <pybind11/stl.h>

src/ddptensor.cpp

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22

33
/*
44
A Distributed Data-Parallel Tensor for Python, following the array API.
5-
We have a 3-level hierachy
6-
1. tensor_i: the abstract interface, not bound to types (like numpy)
7-
2. dtensor_impl: a typed template layer with the actual functionality
8-
3. dtensor: the PYthon API delegating to untyped tensor_i
9-
We use pybind11.
5+
6+
XTensor handles the actual functionality on each process.
7+
pybind11 handles the bridge to Python.
8+
9+
We bridge dynamic dtypes of the Python array through dynamic type dispatch (TypeDispatch).
10+
This means the compiler will instantiate the full functionality for all elements types.
11+
Within kernels we dispatch the operation type by enum values (see x.hpp).
12+
tensor_i is an abstract class to hide the element type which of the actual tensor.
13+
The concrete tensor implementation (DPTensorX, x.hpp) requires the element type
14+
as a template parameter.
1015
*/
1116

1217
#include <pybind11/pybind11.h>
@@ -18,17 +23,17 @@ using namespace pybind11::literals; // to bring _a
1823
#include "ddptensor/MPIMediator.hpp"
1924
#include "ddptensor/x.hpp"
2025

21-
// ###################################################################
22-
// ###################################################################
23-
// ###################################################################
24-
26+
// Dependent on dt, dispatch arguments to a operation class.
27+
// The operation must
28+
// * be a template class accepting the element type as argument
29+
// * implement one or more "op" methods matching the given arguments (args)
30+
// All arguments other than dt are opaquely passed to the operation.
2531
template<template<typename OD> class OpDispatch, typename... Ts>
2632
auto TypeDispatch(DType dt, Ts&&... args)
2733
{
2834
switch(dt) {
2935
case DT_FLOAT64:
3036
return OpDispatch<double>::op(std::forward<Ts>(args)...);
31-
#if 0
3237
case DT_INT64:
3338
return OpDispatch<int64_t>::op(std::forward<Ts>(args)...);
3439
case DT_FLOAT32:
@@ -43,7 +48,6 @@ auto TypeDispatch(DType dt, Ts&&... args)
4348
return OpDispatch<uint32_t>::op(std::forward<Ts>(args)...);
4449
case DT_UINT16:
4550
return OpDispatch<uint16_t>::op(std::forward<Ts>(args)...);
46-
#endif
4751
/* FIXME
4852
case DT_BOOL:
4953
return OpDispatch<bool>::op(std::forward<Ts>(args)...);
@@ -53,6 +57,9 @@ auto TypeDispatch(DType dt, Ts&&... args)
5357
}
5458
}
5559

60+
// #########################################################################
61+
// The following classes are wrappers bridging pybind11 defs to TypeDispatch
62+
5663
struct Creator
5764
{
5865

@@ -67,7 +74,7 @@ struct Creator
6774
return TypeDispatch<x::Creator>(dtype, op, std::forward<shape_type>(shape), std::forward<py::object>(val));
6875
}
6976
};
70-
#if 0
77+
7178
struct IEWBinOp
7279
{
7380
static auto op(IEWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
@@ -92,43 +99,29 @@ struct EWUnyOp
9299
}
93100
};
94101

95-
struct UnyOp
96-
{
97-
static bool __bool__(x::DPTensorBaseX::ptr_type a)
98-
{
99-
return TypeDispatch<x::UnyOp>(a->dtype(), a, true);
100-
}
101-
102-
static double __float__(x::DPTensorBaseX::ptr_type a)
103-
{
104-
return TypeDispatch<x::UnyOp>(a->dtype(), a, double(1));
105-
}
106-
107-
static int64_t __int__(x::DPTensorBaseX::ptr_type a)
108-
{
109-
return TypeDispatch<x::UnyOp>(a->dtype(), a, int64_t(1));
110-
}
111-
};
112-
113102
struct ReduceOp
114103
{
115104
static auto op(ReduceOpId op, x::DPTensorBaseX::ptr_type a, const dim_vec_type & dim)
116105
{
117106
return TypeDispatch<x::ReduceOp>(a->dtype(), op, a, dim);
118107
}
119108
};
120-
#endif
121109

122110
struct GetItem
123111
{
124-
static auto op(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
112+
static auto __getitem__(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
125113
{
126114
return TypeDispatch<x::GetItem>(a->dtype(), a, NDSlice(v));
127115
}
116+
static auto get_slice(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v)
117+
{
118+
return TypeDispatch<x::SPMD>(a->dtype(), a, NDSlice(v));
119+
}
128120
};
121+
129122
struct SetItem
130123
{
131-
static auto op(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v, x::DPTensorBaseX::ptr_type b)
124+
static auto __setitem__(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v, x::DPTensorBaseX::ptr_type b)
132125
{
133126
return TypeDispatch<x::SetItem>(a->dtype(), a, NDSlice(v), b);
134127
}
@@ -142,6 +135,7 @@ rank_type myrank()
142135
Transceiver * theTransceiver = nullptr;
143136
Mediator * theMediator = nullptr;
144137

138+
// users currently need to call fini to make MPI terminate gracefully
145139
void fini()
146140
{
147141
delete theMediator;
@@ -150,6 +144,8 @@ void fini()
150144
theTransceiver = nullptr;
151145
}
152146

147+
// #########################################################################
148+
// Finally our Python module
153149
PYBIND11_MODULE(_ddptensor, m) {
154150
theTransceiver = new MPITransceiver();
155151
theMediator = new MPIMediator();
@@ -164,21 +160,17 @@ PYBIND11_MODULE(_ddptensor, m) {
164160
.value("bool", DT_BOOL)
165161
.export_values();
166162

167-
m.def("fini", &fini);
168-
m.def("myrank", &myrank);
163+
m.def("fini", &fini)
164+
.def("myrank", &myrank)
165+
.def("_get_slice", &GetItem::get_slice);
169166

170167
py::class_<Creator>(m, "Creator")
171168
.def("create_from_shape", &Creator::create_from_shape)
172169
.def("full", &Creator::full);
173-
#if 0
170+
174171
py::class_<EWUnyOp>(m, "EWUnyOp")
175172
.def("op", &EWUnyOp::op);
176173

177-
py::class_<UnyOp>(m, "UnyOp")
178-
.def("__bool__", &UnyOp::__bool__)
179-
.def("__float__", &UnyOp::__float__)
180-
.def("__int__", &UnyOp::__int__);
181-
182174
py::class_<IEWBinOp>(m, "IEWBinOp")
183175
.def("op", &IEWBinOp::op);
184176

@@ -187,12 +179,20 @@ PYBIND11_MODULE(_ddptensor, m) {
187179

188180
py::class_<ReduceOp>(m, "ReduceOp")
189181
.def("op", &ReduceOp::op);
190-
#endif
191182

192183
py::class_<x::DPTensorBaseX, x::DPTensorBaseX::ptr_type>(m, "DPTensorX")
184+
.def_property_readonly("dtype", &x::DPTensorBaseX::dtype)
185+
.def_property_readonly("shape", &x::DPTensorBaseX::shape)
186+
.def_property_readonly("size", &x::DPTensorBaseX::size)
187+
.def_property_readonly("ndim", &x::DPTensorBaseX::ndim)
188+
.def("__bool__", &x::DPTensorBaseX::__bool__)
189+
.def("__float__", &x::DPTensorBaseX::__float__)
190+
.def("__int__", &x::DPTensorBaseX::__int__)
191+
.def("__index__", &x::DPTensorBaseX::__int__)
192+
.def("__len__", &x::DPTensorBaseX::__len__)
193193
.def("__repr__", &x::DPTensorBaseX::__repr__)
194-
.def("__getitem__", &GetItem::op)
195-
.def("__setitem__", &SetItem::op);
194+
.def("__getitem__", &GetItem::__getitem__)
195+
.def("__setitem__", &SetItem::__setitem__);
196196

197197
#if 0
198198
py::class_<dtensor>(m, "dtensor")

0 commit comments

Comments
 (0)