Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 8100d4f

Browse files
committed
adding __setitem__
1 parent f54b702 commit 8100d4f

8 files changed

Lines changed: 249 additions & 93 deletions

File tree

ddptensor/ddptensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def __repr__(self):
4848
def __getitem__(self, *args):
4949
return dtensor(self._t.__getitem__(*args))
5050

51-
# def __setitem__(self, key, value):
52-
# x = self._t.__setitem__(key, value._t if isinstance(value, dtensor) else value)
51+
def __setitem__(self, key, value):
52+
x = self._t.__setitem__(key, value._t) # if isinstance(value, dtensor) else value)
5353

5454
# def get_slice(self, *args):
5555
# return self._t.get_slice(*args)

src/MPIMediator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,20 @@ uint64_t MPIMediator::register_array(tensor_i::ptr_type ary)
5353
return s_last_id;
5454
}
5555

56-
void MPIMediator::pull(rank_type from, const tensor_i * ary, const NDSlice & slice, void * rbuff)
56+
void MPIMediator::pull(rank_type from, const tensor_i & ary, const NDSlice & slice, void * rbuff)
5757
{
5858
MPI_Comm comm = MPI_COMM_WORLD;
5959
MPI_Request request[2];
6060
MPI_Status status[2];
6161
Buffer buff;
6262

6363
bitsery::Serializer<OutputAdapter> ser{buff};
64-
uint64_t id = ary->id();
64+
uint64_t id = ary.id();
6565
ser.value8b(id);
6666
ser.object(slice);
6767
ser.adapter().flush();
6868

69-
auto sz = slice.size() * ary->item_size();
69+
auto sz = slice.size() * ary.item_size();
7070
std::cerr << "alsdkjf " << sz << " " << buff.size() << " " << rbuff << std::endl;
7171
MPI_Irecv(rbuff, sz, MPI_CHAR, from, PUSH_TAG, comm, &request[1]);
7272
MPI_Isend(buff.data(), buff.size(), MPI_CHAR, from, PULL_TAG, comm, &request[0]);

src/ddptensor.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct Creator
6767
return TypeDispatch<x::Creator>(dtype, op, std::forward<shape_type>(shape), std::forward<py::object>(val));
6868
}
6969
};
70-
70+
#if 0
7171
struct IEWBinOp
7272
{
7373
static auto op(IEWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b)
@@ -117,6 +117,7 @@ struct ReduceOp
117117
return TypeDispatch<x::ReduceOp>(a->dtype(), op, a, dim);
118118
}
119119
};
120+
#endif
120121

121122
struct GetItem
122123
{
@@ -125,6 +126,13 @@ struct GetItem
125126
return TypeDispatch<x::GetItem>(a->dtype(), a, NDSlice(v));
126127
}
127128
};
129+
struct SetItem
130+
{
131+
static auto op(x::DPTensorBaseX::ptr_type a, const std::vector<py::slice> & v, x::DPTensorBaseX::ptr_type b)
132+
{
133+
return TypeDispatch<x::SetItem>(a->dtype(), a, NDSlice(v), b);
134+
}
135+
};
128136

129137
rank_type myrank()
130138
{
@@ -162,7 +170,7 @@ PYBIND11_MODULE(_ddptensor, m) {
162170
py::class_<Creator>(m, "Creator")
163171
.def("create_from_shape", &Creator::create_from_shape)
164172
.def("full", &Creator::full);
165-
173+
#if 0
166174
py::class_<EWUnyOp>(m, "EWUnyOp")
167175
.def("op", &EWUnyOp::op);
168176

@@ -179,10 +187,12 @@ PYBIND11_MODULE(_ddptensor, m) {
179187

180188
py::class_<ReduceOp>(m, "ReduceOp")
181189
.def("op", &ReduceOp::op);
190+
#endif
182191

183192
py::class_<x::DPTensorBaseX, x::DPTensorBaseX::ptr_type>(m, "DPTensorX")
184193
.def("__repr__", &x::DPTensorBaseX::__repr__)
185-
.def("__getitem__", &GetItem::op);
194+
.def("__getitem__", &GetItem::op)
195+
.def("__setitem__", &SetItem::op);
186196

187197
#if 0
188198
py::class_<dtensor>(m, "dtensor")

src/include/ddptensor/MPIMediator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class MPIMediator : public Mediator
1313
MPIMediator();
1414
virtual ~MPIMediator();
1515
virtual uint64_t register_array(tensor_i::ptr_type ary);
16-
virtual void pull(rank_type from, const tensor_i * ary, const NDSlice & slice, void * buffer);
16+
virtual void pull(rank_type from, const tensor_i & ary, const NDSlice & slice, void * buffer);
1717

1818
protected:
1919
void listen();

src/include/ddptensor/Mediator.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#pragma once
44

55
#include <vector>
6-
#include "UtilsAndTypes.hpp"
76
#include "tensor_i.hpp"
87

98
class NDSlice;
@@ -13,7 +12,7 @@ class Mediator
1312
public:
1413
virtual ~Mediator() {}
1514
virtual uint64_t register_array(tensor_i::ptr_type ary) = 0;
16-
virtual void pull(rank_type from, const tensor_i * ary, const NDSlice & slice, void * buffer) = 0;
15+
virtual void pull(rank_type from, const tensor_i & ary, const NDSlice & slice, void * buffer) = 0;
1716
};
1817

1918
extern Mediator * theMediator;

src/include/ddptensor/PVSlice.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class PVSlice
172172
return _slice.map(slc);
173173
}
174174

175-
NDSlice slice_of_rank(rank_type rank) const
175+
NDSlice slice_of_rank(rank_type rank = theTransceiver->rank()) const
176176
{
177177
if(_base->split_dim() == NOSPLIT) {
178178
return rank == theTransceiver->rank() ? slice() : NDSlice();
@@ -279,4 +279,11 @@ class PVSlice
279279
{
280280
return iterator();
281281
}
282+
283+
friend std::ostream &operator<<(std::ostream &output, const PVSlice & slc) {
284+
output << "{slice=" << slc.slice()
285+
<< "base=" << to_string(slc._base->shape())
286+
<< "offset=" << slc._base->offset()<< "}";
287+
return output;
288+
}
282289
};

src/include/ddptensor/tensor_i.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ class tensor_i
2121
public:
2222
typedef std::shared_ptr<tensor_i> ptr_type;
2323

24+
virtual ~tensor_i() {};
25+
virtual std::string __repr__() const = 0;
26+
virtual DType dtype() const = 0;
27+
virtual shape_type shape() const = 0;
28+
virtual void bufferize(const NDSlice & slice, Buffer & buff) = 0;
29+
virtual int item_size() const = 0;
30+
virtual uint64_t id() const = 0;
31+
};
32+
#if 0
33+
2434
virtual ~tensor_i(){}
2535

2636
virtual const PVSlice & pvslice() = 0;
@@ -47,3 +57,4 @@ class tensor_i
4757

4858
virtual py::object get_slice(const NDSlice & slice) const = 0;
4959
};
60+
#endif

0 commit comments

Comments
 (0)