|
| 1 | +#include "ddptensor/Operations.hpp" |
| 2 | +#include "ddptensor/x.hpp" |
| 3 | + |
| 4 | +namespace x { |
| 5 | + |
| 6 | + template<typename T> |
| 7 | + class EWBinOp |
| 8 | + { |
| 9 | + public: |
| 10 | + using ptr_type = DPTensorBaseX::ptr_type; |
| 11 | + |
| 12 | +#pragma GCC diagnostic ignored "-Wswitch" |
| 13 | + |
| 14 | + template<typename A, typename B, typename U = T, std::enable_if_t<std::is_floating_point<U>::value, bool> = true> |
| 15 | + static ptr_type integral_op(EWBinOpId iop, const DPTensorX<T> & tx, A && a, B && b) |
| 16 | + { |
| 17 | + throw std::runtime_error("Illegal or unknown inplace elementwise binary operation"); |
| 18 | + } |
| 19 | + |
| 20 | + template<typename A, typename B, typename U = T, std::enable_if_t<std::is_integral<U>::value, bool> = true> |
| 21 | + static ptr_type integral_op(EWBinOpId iop, const DPTensorX<T> & tx, A && a, B && b) |
| 22 | + { |
| 23 | + switch(iop) { |
| 24 | + case __AND__: |
| 25 | + case BITWISE_AND: |
| 26 | + return operatorx<T>::mk_tx_(tx, a & b); |
| 27 | + case __RAND__: |
| 28 | + return operatorx<T>::mk_tx_(tx, b & a); |
| 29 | + case __LSHIFT__: |
| 30 | + case BITWISE_LEFT_SHIFT: |
| 31 | + return operatorx<T>::mk_tx_(tx, a << b); |
| 32 | + case __MOD__: |
| 33 | + case REMAINDER: |
| 34 | + return operatorx<T>::mk_tx_(tx, a % b); |
| 35 | + case __OR__: |
| 36 | + case BITWISE_OR: |
| 37 | + return operatorx<T>::mk_tx_(tx, a | b); |
| 38 | + case __ROR__: |
| 39 | + return operatorx<T>::mk_tx_(tx, b | a); |
| 40 | + case __RSHIFT__: |
| 41 | + case BITWISE_RIGHT_SHIFT: |
| 42 | + return operatorx<T>::mk_tx_(tx, a >> b); |
| 43 | + case __XOR__: |
| 44 | + case BITWISE_XOR: |
| 45 | + return operatorx<T>::mk_tx_(tx, a ^ b); |
| 46 | + case __RXOR__: |
| 47 | + return operatorx<T>::mk_tx_(tx, b ^ a); |
| 48 | + case __RLSHIFT__: |
| 49 | + return operatorx<T>::mk_tx_(tx, b << a); |
| 50 | + case __RMOD__: |
| 51 | + return operatorx<T>::mk_tx_(tx, b % a); |
| 52 | + case __RRSHIFT__: |
| 53 | + return operatorx<T>::mk_tx_(tx, b >> a); |
| 54 | + default: |
| 55 | + throw std::runtime_error("Unknown elementwise binary operation"); |
| 56 | + } |
| 57 | + } |
| 58 | + |
| 59 | + static ptr_type op(EWBinOpId bop, const ptr_type & a_ptr, const ptr_type & b_ptr) |
| 60 | + { |
| 61 | + const auto _a = dynamic_cast<DPTensorX<T>*>(a_ptr.get()); |
| 62 | + const auto _b = dynamic_cast<DPTensorX<T>*>(b_ptr.get()); |
| 63 | + if(!_a || !_b) |
| 64 | + throw std::runtime_error("Invalid array object: could not dynamically cast"); |
| 65 | + const auto & a = xt::strided_view(_a->xarray(), _a->lslice()); |
| 66 | + const auto & b = xt::strided_view(_b->xarray(), _b->lslice()); |
| 67 | + |
| 68 | + switch(bop) { |
| 69 | + case __ADD__: |
| 70 | + case ADD: |
| 71 | + return operatorx<T>::mk_tx_(*_a, a + b); |
| 72 | + case __RADD__: |
| 73 | + return operatorx<T>::mk_tx_(*_a, b + a); |
| 74 | + case ATAN2: |
| 75 | + return operatorx<T>::mk_tx_(*_a, xt::atan2(a, b)); |
| 76 | + case __EQ__: |
| 77 | + case EQUAL: |
| 78 | + return operatorx<T>::mk_tx_(*_a, xt::equal(a, b)); |
| 79 | + case __FLOORDIV__: |
| 80 | + case FLOOR_DIVIDE: |
| 81 | + return operatorx<T>::mk_tx_(*_a, xt::floor(a / b)); |
| 82 | + case __GE__: |
| 83 | + case GREATER_EQUAL: |
| 84 | + return operatorx<T>::mk_tx_(*_a, a >= b); |
| 85 | + case __GT__: |
| 86 | + case GREATER: |
| 87 | + return operatorx<T>::mk_tx_(*_a, a > b); |
| 88 | + case __LE__: |
| 89 | + case LESS_EQUAL: |
| 90 | + return operatorx<T>::mk_tx_(*_a, a <= b); |
| 91 | + case __LT__: |
| 92 | + case LESS: |
| 93 | + return operatorx<T>::mk_tx_(*_a, a < b); |
| 94 | + case __MUL__: |
| 95 | + case MULTIPLY: |
| 96 | + return operatorx<T>::mk_tx_(*_a, a * b); |
| 97 | + case __RMUL__: |
| 98 | + return operatorx<T>::mk_tx_(*_a, b * a); |
| 99 | + case __NE__: |
| 100 | + case NOT_EQUAL: |
| 101 | + return operatorx<T>::mk_tx_(*_a, xt::not_equal(a, b)); |
| 102 | + case __SUB__: |
| 103 | + case SUBTRACT: |
| 104 | + return operatorx<T>::mk_tx_(*_a, a - b); |
| 105 | + case __TRUEDIV__: |
| 106 | + case DIVIDE: |
| 107 | + return operatorx<T>::mk_tx_(*_a, a / b); |
| 108 | + case __RFLOORDIV__: |
| 109 | + return operatorx<T>::mk_tx_(*_a, xt::floor(b / a)); |
| 110 | + case __RSUB__: |
| 111 | + return operatorx<T>::mk_tx_(*_a, b - a); |
| 112 | + case __RTRUEDIV__: |
| 113 | + return operatorx<T>::mk_tx_(*_a, b / a); |
| 114 | + case __MATMUL__: |
| 115 | + case __POW__: |
| 116 | + case POW: |
| 117 | + case __RPOW__: |
| 118 | + case LOGADDEXP: |
| 119 | + case LOGICAL_AND: |
| 120 | + case LOGICAL_OR: |
| 121 | + case LOGICAL_XOR: |
| 122 | + // FIXME |
| 123 | + throw std::runtime_error("Binary operation not implemented"); |
| 124 | + } |
| 125 | + return integral_op(bop, *_a, a, b); |
| 126 | + } |
| 127 | + |
| 128 | +#pragma GCC diagnostic pop |
| 129 | + |
| 130 | + }; |
| 131 | +} // namespace x |
| 132 | + |
| 133 | +tensor_i::ptr_type EWBinOp::op(EWBinOpId op, x::DPTensorBaseX::ptr_type a, x::DPTensorBaseX::ptr_type b) |
| 134 | +{ |
| 135 | + return TypeDispatch<x::EWBinOp>(a->dtype(), op, a, b); |
| 136 | +} |
0 commit comments