@@ -29,6 +29,8 @@ class DDPTensorImpl : public tensor_i
2929 PVSlice _slice;
3030 void * _allocated;
3131 void * _aligned;
32+ intptr_t * _sizes;
33+ intptr_t * _strides;
3234 uint64_t _offset;
3335 DTypeId _dtype;
3436
@@ -42,13 +44,17 @@ class DDPTensorImpl : public tensor_i
4244 : _owner(owner),
4345 _slice (shape_type(rank ? rank : 1 , rank ? sizes[0 ] : 1 ), static_cast<int>(owner==REPLICATED ? NOSPLIT : 0 )),
4446 _allocated(allocated),
45- _aligned(nullptr ),
47+ _aligned(aligned),
48+ _sizes(new intptr_t [rank]),
49+ _strides(new intptr_t [rank]),
4650 _offset(offset),
4751 _dtype(dtype)
4852 {
4953 assert (rank <= 1 );
5054 assert (rank == 0 || strides[0 ] == 1 );
51- dispatch (_dtype, aligned, [this ](auto * ptr) { this ->_aligned = ptr + this ->_offset ; });
55+
56+ memcpy (_sizes, sizes, rank*sizeof (intptr_t ));
57+ memcpy (_strides, strides, rank*sizeof (intptr_t ));
5258 }
5359
5460 DDPTensorImpl (DTypeId dtype, const shape_type & shp, rank_type owner=NOOWNER )
@@ -60,23 +66,38 @@ class DDPTensorImpl : public tensor_i
6066 _dtype(dtype)
6167 {
6268 alloc ();
69+
70+ intptr_t stride = 1 ;
71+ auto rank = shp.size ();
72+ for (auto i=0 ; i<rank; ++i) {
73+ _sizes[i] = shp[i];
74+ _strides[rank-i-1 ] = stride;
75+ stride *= shp[i];
76+ }
6377 }
6478
6579 void alloc ()
6680 {
6781 auto esz = sizeof_dtype (_dtype);
6882 _allocated = new (std::align_val_t (esz)) char [esz*_slice.size ()];
6983 _aligned = _allocated;
84+ auto rank = _slice.ndims ();
85+ _sizes = new intptr_t [rank];
86+ _strides = new intptr_t [rank];
7087 _offset = 0 ;
7188 }
7289
7390 ~DDPTensorImpl ()
7491 {
92+ delete [] _sizes;
93+ delete [] _strides;
7594 }
7695
7796 void * data ()
7897 {
79- return _aligned;
98+ void * ret;
99+ dispatch (_dtype, _aligned, [this , &ret](auto * ptr) { ret = ptr + this ->_offset ; });
100+ return ret;
80101 }
81102
82103 bool is_sliced () const
@@ -90,7 +111,8 @@ class DDPTensorImpl : public tensor_i
90111 const auto sz = _slice.size ();
91112 std::ostringstream oss;
92113
93- dispatch (_dtype, _aligned, [sz, &oss](auto * ptr) {
114+ dispatch (_dtype, _aligned, [this , sz, &oss](auto * ptr) {
115+ ptr += this ->_offset ;
94116 for (auto i=0 ; i<sz; ++i) {
95117 oss << ptr[i] << " " ;
96118 }
@@ -127,7 +149,7 @@ class DDPTensorImpl : public tensor_i
127149 throw (std::runtime_error (" Cast to scalar bool: tensor is not replicated" ));
128150
129151 bool res;
130- dispatch (_dtype, _aligned, [&res](auto * ptr) { res = static_cast <bool >(* ptr); });
152+ dispatch (_dtype, _aligned, [this , &res](auto * ptr) { res = static_cast <bool >(ptr[ this -> _offset ] ); });
131153 return res;
132154 }
133155
@@ -137,7 +159,7 @@ class DDPTensorImpl : public tensor_i
137159 throw (std::runtime_error (" Cast to scalar float: tensor is not replicated" ));
138160
139161 double res;
140- dispatch (_dtype, _aligned, [&res](auto * ptr) { res = static_cast <double >(* ptr); });
162+ dispatch (_dtype, _aligned, [this , &res](auto * ptr) { res = static_cast <double >(ptr[ this -> _offset ] ); });
141163 return res;
142164 }
143165
@@ -147,7 +169,7 @@ class DDPTensorImpl : public tensor_i
147169 throw (std::runtime_error (" Cast to scalar int: tensor is not replicated" ));
148170
149171 float res;
150- dispatch (_dtype, _aligned, [&res](auto * ptr) { res = static_cast <float >(* ptr); });
172+ dispatch (_dtype, _aligned, [this , &res](auto * ptr) { res = static_cast <float >(ptr[ this -> _offset ] ); });
151173 return res;
152174 }
153175
@@ -198,7 +220,18 @@ class DDPTensorImpl : public tensor_i
198220 auto sz = _slice.size ()*item_size ();
199221 buff.resize (pos + sz);
200222 void * out = buff.data () + pos;
201- memcpy (out, _aligned, sz);
223+ dispatch (_dtype, _aligned, [this , sz, out](auto * ptr) { memcpy (out, ptr + this ->_offset , sz); });
224+ }
225+
226+ virtual uint64_t store_memref (intptr_t * buff, int rank)
227+ {
228+ assert (rank == _slice.ndims () || (_slice.ndims () == 1 && _slice.size () == 1 ));
229+ buff[0 ] = reinterpret_cast <intptr_t >(_allocated);
230+ buff[1 ] = reinterpret_cast <intptr_t >(_aligned);
231+ buff[2 ] = static_cast <intptr_t >(_offset);
232+ memcpy (buff+3 , _sizes, rank*sizeof (intptr_t ));
233+ memcpy (buff+3 +rank, _strides, rank*sizeof (intptr_t ));
234+ return 3 + 2 *rank;
202235 }
203236};
204237
0 commit comments