@@ -33,6 +33,7 @@ namespace x
3333 virtual ~DPTensorBaseX () {};
3434 virtual std::string __repr__ () const = 0;
3535 virtual DType dtype () const = 0;
36+ virtual shape_type shape () const = 0;
3637 };
3738
3839 template <typename T>
@@ -65,6 +66,16 @@ namespace x
6566 {
6667 return _xarray;
6768 }
69+
70+ const PVSlice & slice () const
71+ {
72+ return _slice;
73+ }
74+
75+ virtual shape_type shape () const
76+ {
77+ return _slice.shape ();
78+ }
6879 };
6980
7081 template <typename T>
@@ -147,7 +158,7 @@ namespace x
147158 static void op (IEWBinOpId iop, ptr_type a_ptr, const ptr_type & b_ptr)
148159 {
149160 auto _a = dynamic_cast <DPTensorX<T>*>(a_ptr.get ());
150- auto _b = dynamic_cast <DPTensorX<T>*>(b_ptr.get ());
161+ auto const _b = dynamic_cast <DPTensorX<T>*>(b_ptr.get ());
151162 if (!_a || !_b)
152163 throw std::runtime_error (" Invalid array object: could not dynamically cast" );
153164 auto & a = _a->xarray ();
@@ -178,6 +189,117 @@ namespace x
178189 integral_iop (iop, a, b);
179190 }
180191
192+ #pragma GCC diagnostic pop
193+
194+ };
195+
196+
197+ template <typename T>
198+ class EWBinOp
199+ {
200+ public:
201+ using ptr_type = DPTensorBaseX::ptr_type;
202+
203+ template <typename X>
204+ static ptr_type mk_tx (const DPTensorBaseX & tx, X && x)
205+ {
206+ return std::make_shared<DPTensorX<typename X::value_type>>(tx.shape (), x);
207+ }
208+
209+ #pragma GCC diagnostic ignored "-Wswitch"
210+
211+ template <typename A, typename B, typename U = T, std::enable_if_t <std::is_floating_point<U>::value, bool > = true >
212+ static ptr_type integral_op (EWBinOpId iop, const DPTensorX<T> & tx, A && a, B && b)
213+ {
214+ throw std::runtime_error (" Illegal or unknown inplace elementwise binary operation" );
215+ }
216+
217+ template <typename A, typename B, typename U = T, std::enable_if_t <std::is_integral<U>::value, bool > = true >
218+ static ptr_type integral_op (EWBinOpId iop, const DPTensorBaseX & tx, A && a, B && b)
219+ {
220+ switch (iop) {
221+ case AND :
222+ case RAND :
223+ return mk_tx (tx, a & b);
224+ case LSHIFT :
225+ return mk_tx (tx, a << b);
226+ case MOD :
227+ return mk_tx (tx, a % b);
228+ case OR :
229+ case ROR :
230+ return mk_tx (tx, a | b);
231+ case RSHIFT :
232+ return mk_tx (tx, a >> b);
233+ case XOR :
234+ case RXOR :
235+ return mk_tx (tx, a ^ b);
236+ case RLSHIFT :
237+ return mk_tx (tx, b << a);
238+ case RMOD :
239+ return mk_tx (tx, b % a);
240+ case RRSHIFT :
241+ return mk_tx (tx, b >> a);
242+ default :
243+ throw std::runtime_error (" Unknown elementwise binary operation" );
244+ }
245+ }
246+
247+ static ptr_type op (EWBinOpId bop, const ptr_type & a_ptr, const ptr_type & b_ptr)
248+ {
249+ auto _a = dynamic_cast <DPTensorX<T>*>(a_ptr.get ());
250+ auto const _b = dynamic_cast <DPTensorX<T>*>(b_ptr.get ());
251+ if (!_a || !_b)
252+ throw std::runtime_error (" Invalid array object: could not dynamically cast" );
253+ auto & a = _a->xarray ();
254+ auto const & b = _b->xarray ();
255+
256+ switch (bop) {
257+ case ADD :
258+ case RADD :
259+ return mk_tx (*_a, a + b);
260+ case EQ :
261+ return mk_tx (*_a, xt::equal (a, b));
262+ case FLOORDIV :
263+ return mk_tx (*_a, xt::floor (a / b));
264+ case GE :
265+ return mk_tx (*_a, a >= b);
266+ case GT :
267+ return mk_tx (*_a, a > b);
268+ case LE :
269+ return mk_tx (*_a, a <= b);
270+ case LT :
271+ return mk_tx (*_a, a < b);
272+ /* FIXME
273+ case MATMUL:
274+ return mk_tx(*_a, );
275+ */
276+ case MUL :
277+ case RMUL :
278+ return mk_tx (*_a, a * b);
279+ case NE :
280+ return mk_tx (*_a, xt::not_equal (a, b));
281+ /* FIXME
282+ case POW:
283+ return mk_tx(*_a, );
284+ */
285+ case SUB :
286+ return mk_tx (*_a, a - b);
287+ case TRUEDIV :
288+ return mk_tx (*_a, a / b);
289+ case RFLOORDIV :
290+ return mk_tx (*_a, xt::floor (b / a));
291+ /* FIXME
292+ case RPOW:
293+ return mk_tx(*_a, );
294+ */
295+ case RSUB :
296+ return mk_tx (*_a, b - a);
297+ case RTRUEDIV :
298+ return mk_tx (*_a, b / a);
299+ }
300+ return integral_op (bop, *_a, a, b);
301+ }
302+
181303#pragma GCC diagnostic pop
182304
183305 };
0 commit comments