@@ -47,26 +47,37 @@ namespace x
4747 template <typename T>
4848 class DPTensorX : public DPTensorBaseX
4949 {
50+ rank_type _owner;
5051 PVSlice _slice;
5152 xt::xstrided_slice_vector _lslice;
5253 std::shared_ptr<xt::xarray<T>> _xarray;
54+ T _replica = 0 ;
5355
5456 public:
5557 template <typename I>
56- DPTensorX (PVSlice && slc, I && ax)
57- : _slice(std::move(slc)),
58+ DPTensorX (PVSlice && slc, I && ax, rank_type owner=NOOWNER )
59+ : _owner(owner),
60+ _slice (std::move(slc)),
5861 _lslice(to_xt(_slice.local_slice_of_rank())),
5962 _xarray(std::make_shared<xt::xarray<T>>(std::forward<I>(ax)))
6063 {
6164 }
6265
6366 template <typename O>
64- DPTensorX (const DPTensorX<O> & org, const NDSlice & slc)
65- : _slice(org._slice, slc),
67+ DPTensorX (const DPTensorX<O> & org, const NDSlice & slc, rank_type owner=NOOWNER )
68+ : _owner(owner),
69+ _slice(org._slice, slc),
6670 _lslice(to_xt(_slice.local_slice_of_rank())),
6771 _xarray(org._xarray)
6872 {
69- std::cerr << " slice: " << _slice.slice () << " lslice: " << _slice.local_slice_of_rank () << std::endl;
73+ if (owner == NOOWNER && slice ().size () <= 1 ) {
74+ set_owner (org.slice ().owner (slc));
75+ } else if (owner == REPLICATED ) {
76+ _replica = *(xt::strided_view (xarray (), to_xt (slice ().slice ())).begin ());
77+ }
78+ std::cerr << " slice: " << _slice.slice () << " sz " << _slice.size ()
79+ << " lslice: " << _slice.local_slice_of_rank () << " owner: " << _owner
80+ << " val: " << _replica << std::endl;
7081 }
7182
7283 virtual std::string __repr__ () const
@@ -82,6 +93,11 @@ namespace x
8293 return DTYPE <T>::value;
8394 }
8495
96+ virtual shape_type shape () const
97+ {
98+ return _slice.shape ();
99+ }
100+
85101 xt::xarray<T> & xarray ()
86102 {
87103 return *_xarray.get ();
@@ -102,9 +118,41 @@ namespace x
102118 return _lslice;
103119 }
104120
105- virtual shape_type shape () const
121+ bool has_owner () const
106122 {
107- return _slice.shape ();
123+ return _owner < _OWNER_END;
124+ }
125+
126+ void set_owner (rank_type o)
127+ {
128+ _owner = o;
129+ }
130+
131+ rank_type owner () const
132+ {
133+ return _owner;
134+ }
135+
136+ bool is_replicated () const
137+ {
138+ return _owner == REPLICATED ;
139+ }
140+
141+ T replicate ()
142+ {
143+ std::cerr << " is_replicated()=" << is_replicated () << " owner=" << owner () << " shape=" << to_string (shape ()) << std::endl;
144+ if (is_replicated ()) return _replica;
145+ if (has_owner () && _slice.size () == 1 ) {
146+ if (theTransceiver->rank () == owner ()) {
147+ _replica = *(xt::strided_view (xarray (), lslice ()).begin ());
148+ std::cerr << " replica: " << _replica << std::endl;
149+ }
150+ theTransceiver->bcast (&_replica, sizeof (T), owner ());
151+ set_owner (REPLICATED );
152+ } else {
153+ throw (std::runtime_error (" Replication implemented for single element and single owner only." ));
154+ }
155+ return _replica;
108156 }
109157 };
110158
@@ -460,6 +508,34 @@ namespace x
460508
461509 };
462510
511+ template <typename T>
512+ class UnyOp
513+ {
514+ public:
515+ using ptr_type = DPTensorBaseX::ptr_type;
516+
517+ template <typename N>
518+ static N __type__ (const ptr_type & a_ptr)
519+ {
520+ auto const _a = dynamic_cast <DPTensorX<T>*>(a_ptr.get ());
521+ if (!_a )
522+ throw std::runtime_error (" Invalid array object: could not dynamically cast" );
523+ T v = _a->replicate ();
524+ return static_cast <N>(v);
525+ }
526+ static bool op (const ptr_type & a_ptr, bool )
527+ {
528+ return __type__<bool >(a_ptr);
529+ }
530+ static double op (const ptr_type & a_ptr, double )
531+ {
532+ return __type__<double >(a_ptr);
533+ }
534+ static int64_t op (const ptr_type & a_ptr, int64_t )
535+ {
536+ return __type__<int64_t >(a_ptr);
537+ }
538+ };
463539
464540 template <typename T>
465541 class ReduceOp
@@ -474,11 +550,13 @@ namespace x
474550 {
475551 xt::xarray<typename X::value_type> a = x;
476552 auto new_shape = reduce_shape (slice.shape (), dims);
553+ rank_type owner = NOOWNER ;
477554 if (slice.need_reduce (dims)) {
478555 auto len = VPROD (new_shape);
479556 theTransceiver->reduce_all (a.data (), DTYPE <typename X::value_type>::value, len, rop);
557+ owner = REPLICATED ;
480558 }
481- return std::make_shared<DPTensorX<typename X::value_type>>(new_shape, a);
559+ return std::make_shared<DPTensorX<typename X::value_type>>(new_shape, a, owner );
482560 }
483561
484562 static ptr_type op (ReduceOpId rop, const ptr_type & a_ptr, const dim_vec_type & dims)
@@ -529,5 +607,5 @@ namespace x
529607 return std::make_shared<DPTensorX<T>>(*_a, slice);
530608 }
531609 };
532-
610+
533611} // namespace x
0 commit comments