@@ -429,5 +429,57 @@ namespace x
429429#pragma GCC diagnostic pop
430430
431431 };
432+
433+
434+ template <typename T>
435+ class ReduceOp
436+ {
437+ public:
438+ using ptr_type = DPTensorBaseX::ptr_type;
439+
440+ #pragma GCC diagnostic ignored "-Wswitch"
441+
442+ template <typename X>
443+ static ptr_type dist_reduce (ReduceOpId rop, const PVSlice & slice, const dim_vec_type & dims, X && x)
444+ {
445+ xt::xarray<typename X::value_type> a = x;
446+ auto new_shape = reduce_shape (slice.shape (), dims);
447+ if (slice.need_reduce (dims)) {
448+ auto len = VPROD (new_shape);
449+ theTransceiver->reduce_all (a.data (), DTYPE <typename X::value_type>::value, len, rop);
450+ }
451+ return std::make_shared<DPTensorX<typename X::value_type>>(new_shape, a);
452+ }
453+
454+ static ptr_type op (ReduceOpId rop, const ptr_type & a_ptr, const dim_vec_type & dims)
455+ {
456+ auto const _a = dynamic_cast <DPTensorX<T>*>(a_ptr.get ());
457+ if (!_a )
458+ throw std::runtime_error (" Invalid array object: could not dynamically cast" );
459+ auto const & a = _a->xarray ();
460+
461+ switch (rop) {
462+ case MEAN :
463+ return dist_reduce (rop, _a->slice (), dims, xt::mean (a, dims));
464+ case PROD :
465+ return dist_reduce (rop, _a->slice (), dims, xt::prod (a, dims));
466+ case SUM :
467+ return dist_reduce (rop, _a->slice (), dims, xt::sum (a, dims));
468+ case STD :
469+ return dist_reduce (rop, _a->slice (), dims, xt::stddev (a, dims));
470+ case VAR :
471+ return dist_reduce (rop, _a->slice (), dims, xt::variance (a, dims));
472+ case MAX :
473+ case MIN :
474+ throw std::runtime_error (" Reduction operation not implemented" );
475+ default :
476+ throw std::runtime_error (" Unknown reduction operation" );
477+ }
478+ }
479+
480+ #pragma GCC diagnostic pop
481+
482+ };
483+
432484
433485} // namespace x
0 commit comments