Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit f54b702

Browse files
committed
adding __float/bool/int___
1 parent b2a3ed4 commit f54b702

6 files changed

Lines changed: 132 additions & 26 deletions

File tree

ddptensor/array_api.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,13 @@
144144
"std", # (x, /, *, axis=None, correction=0.0, keepdims=False)
145145
"var", # (x, /, *, axis=None, correction=0.0, keepdims=False)
146146
]
147+
148+
unary_methods = [
149+
"__array_namespace__", # (self, /, *, api_version=None)
150+
"__bool__", # (self, /)
151+
"__dlpack__", # (self, /, *, stream=None)
152+
"__dlpack_device__", # (self, /)
153+
"__float__", # (self, /)
154+
"__int__", # (self, /)
155+
"__len__", # (self, /)
156+
]

ddptensor/ddptensor.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,6 @@
22
from ._ddptensor import float64, int64, fini
33
from . import array_api as api
44

5-
unary_methods = [
6-
# "__array_namespace__", # (self, /, *, api_version=None)
7-
"__bool__", # (self, /)
8-
# "__dlpack__", # (self, /, *, stream=None)
9-
# "__dlpack_device__", # (self, /)
10-
"__float__", # (self, /)
11-
"__int__", # (self, /)
12-
"__len__", # (self, /)
13-
]
14-
155
t_attributes = ["dtype", "shape", ] #"device", "ndim", "size", "T"]
166

177
#def try_except(func, *args, **kwargs):
@@ -27,7 +17,6 @@ def __init__(self, t):
2717
def __repr__(self):
2818
return self._t.__repr__()
2919

30-
3120
for method in api.ew_binary_methods:
3221
METHOD = method.upper()
3322
exec(
@@ -45,11 +34,11 @@ def __repr__(self):
4534
exec(
4635
f"{method} = lambda self: dtensor(_cdt.EWUnyOp.op(_cdt.{METHOD}, self._t))"
4736
)
48-
49-
# for method in unary_methods:
50-
# exec(
51-
# f"{method} = lambda self: self._t.{method}()"
52-
# )
37+
38+
for method in api.unary_methods:
39+
exec(
40+
f"{method} = lambda self: _cdt.UnyOp.{method}(self._t)"
41+
)
5342

5443
# for att in t_attributes:
5544
# exec(

src/ddptensor.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ struct EWUnyOp
9292
}
9393
};
9494

95+
struct UnyOp
96+
{
97+
static bool __bool__(x::DPTensorBaseX::ptr_type a)
98+
{
99+
return TypeDispatch<x::UnyOp>(a->dtype(), a, true);
100+
}
101+
102+
static double __float__(x::DPTensorBaseX::ptr_type a)
103+
{
104+
return TypeDispatch<x::UnyOp>(a->dtype(), a, double(1));
105+
}
106+
107+
static int64_t __int__(x::DPTensorBaseX::ptr_type a)
108+
{
109+
return TypeDispatch<x::UnyOp>(a->dtype(), a, int64_t(1));
110+
}
111+
};
112+
95113
struct ReduceOp
96114
{
97115
static auto op(ReduceOpId op, x::DPTensorBaseX::ptr_type a, const dim_vec_type & dim)
@@ -148,6 +166,11 @@ PYBIND11_MODULE(_ddptensor, m) {
148166
py::class_<EWUnyOp>(m, "EWUnyOp")
149167
.def("op", &EWUnyOp::op);
150168

169+
py::class_<UnyOp>(m, "UnyOp")
170+
.def("__bool__", &UnyOp::__bool__)
171+
.def("__float__", &UnyOp::__float__)
172+
.def("__int__", &UnyOp::__int__);
173+
151174
py::class_<IEWBinOp>(m, "IEWBinOp")
152175
.def("op", &IEWBinOp::op);
153176

src/include/ddptensor/PVSlice.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ class PVSlice
155155
return _slice;
156156
}
157157

158+
uint64_t size() const
159+
{
160+
return slice().size();
161+
}
162+
158163
#if 0
159164
NDSlice normalized_slice() const
160165
{

src/include/ddptensor/x.hpp

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,37 @@ namespace x
4747
template<typename T>
4848
class DPTensorX : public DPTensorBaseX
4949
{
50+
rank_type _owner;
5051
PVSlice _slice;
5152
xt::xstrided_slice_vector _lslice;
5253
std::shared_ptr<xt::xarray<T>> _xarray;
54+
T _replica = 0;
5355

5456
public:
5557
template<typename I>
56-
DPTensorX(PVSlice && slc, I && ax)
57-
: _slice(std::move(slc)),
58+
DPTensorX(PVSlice && slc, I && ax, rank_type owner=NOOWNER)
59+
: _owner(owner),
60+
_slice(std::move(slc)),
5861
_lslice(to_xt(_slice.local_slice_of_rank())),
5962
_xarray(std::make_shared<xt::xarray<T>>(std::forward<I>(ax)))
6063
{
6164
}
6265

6366
template<typename O>
64-
DPTensorX(const DPTensorX<O> & org, const NDSlice & slc)
65-
: _slice(org._slice, slc),
67+
DPTensorX(const DPTensorX<O> & org, const NDSlice & slc, rank_type owner=NOOWNER)
68+
: _owner(owner),
69+
_slice(org._slice, slc),
6670
_lslice(to_xt(_slice.local_slice_of_rank())),
6771
_xarray(org._xarray)
6872
{
69-
std::cerr << "slice: " << _slice.slice() << " lslice: " << _slice.local_slice_of_rank() << std::endl;
73+
if(owner == NOOWNER && slice().size() <= 1) {
74+
set_owner(org.slice().owner(slc));
75+
} else if(owner == REPLICATED) {
76+
_replica = *(xt::strided_view(xarray(), to_xt(slice().slice())).begin());
77+
}
78+
std::cerr << "slice: " << _slice.slice() << " sz " << _slice.size()
79+
<< " lslice: " << _slice.local_slice_of_rank() << " owner: " << _owner
80+
<< " val: " << _replica << std::endl;
7081
}
7182

7283
virtual std::string __repr__() const
@@ -82,6 +93,11 @@ namespace x
8293
return DTYPE<T>::value;
8394
}
8495

96+
virtual shape_type shape() const
97+
{
98+
return _slice.shape();
99+
}
100+
85101
xt::xarray<T> & xarray()
86102
{
87103
return *_xarray.get();
@@ -102,9 +118,41 @@ namespace x
102118
return _lslice;
103119
}
104120

105-
virtual shape_type shape() const
121+
bool has_owner() const
106122
{
107-
return _slice.shape();
123+
return _owner < _OWNER_END;
124+
}
125+
126+
void set_owner(rank_type o)
127+
{
128+
_owner = o;
129+
}
130+
131+
rank_type owner() const
132+
{
133+
return _owner;
134+
}
135+
136+
bool is_replicated() const
137+
{
138+
return _owner == REPLICATED;
139+
}
140+
141+
T replicate()
142+
{
143+
std::cerr << "is_replicated()=" << is_replicated() << " owner=" << owner() << " shape=" << to_string(shape()) << std::endl;
144+
if(is_replicated()) return _replica;
145+
if(has_owner() && _slice.size() == 1) {
146+
if(theTransceiver->rank() == owner()) {
147+
_replica = *(xt::strided_view(xarray(), lslice()).begin());
148+
std::cerr << "replica: " << _replica << std::endl;
149+
}
150+
theTransceiver->bcast(&_replica, sizeof(T), owner());
151+
set_owner(REPLICATED);
152+
} else {
153+
throw(std::runtime_error("Replication implemented for single element and single owner only."));
154+
}
155+
return _replica;
108156
}
109157
};
110158

@@ -460,6 +508,34 @@ namespace x
460508

461509
};
462510

511+
template<typename T>
512+
class UnyOp
513+
{
514+
public:
515+
using ptr_type = DPTensorBaseX::ptr_type;
516+
517+
template<typename N>
518+
static N __type__(const ptr_type & a_ptr)
519+
{
520+
auto const _a = dynamic_cast<DPTensorX<T>*>(a_ptr.get());
521+
if(!_a )
522+
throw std::runtime_error("Invalid array object: could not dynamically cast");
523+
T v = _a->replicate();
524+
return static_cast<N>(v);
525+
}
526+
static bool op(const ptr_type & a_ptr, bool)
527+
{
528+
return __type__<bool>(a_ptr);
529+
}
530+
static double op(const ptr_type & a_ptr, double)
531+
{
532+
return __type__<double>(a_ptr);
533+
}
534+
static int64_t op(const ptr_type & a_ptr, int64_t)
535+
{
536+
return __type__<int64_t>(a_ptr);
537+
}
538+
};
463539

464540
template<typename T>
465541
class ReduceOp
@@ -474,11 +550,13 @@ namespace x
474550
{
475551
xt::xarray<typename X::value_type> a = x;
476552
auto new_shape = reduce_shape(slice.shape(), dims);
553+
rank_type owner = NOOWNER;
477554
if(slice.need_reduce(dims)) {
478555
auto len = VPROD(new_shape);
479556
theTransceiver->reduce_all(a.data(), DTYPE<typename X::value_type>::value, len, rop);
557+
owner = REPLICATED;
480558
}
481-
return std::make_shared<DPTensorX<typename X::value_type>>(new_shape, a);
559+
return std::make_shared<DPTensorX<typename X::value_type>>(new_shape, a, owner);
482560
}
483561

484562
static ptr_type op(ReduceOpId rop, const ptr_type & a_ptr, const dim_vec_type & dims)
@@ -529,5 +607,5 @@ namespace x
529607
return std::make_shared<DPTensorX<T>>(*_a, slice);
530608
}
531609
};
532-
610+
533611
} // namespace x

test/test_x.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
print(a == b)
77
print(dt.sqrt(a))
88
print(dt.sum(a, [1]))
9-
print(a[0:1,0:1])
9+
print(a[0:1,0:1], float(a[0:1,0:1]), bool(a[0:1,0:1]), int(a[0:1,0:1]))
10+
print(a[0:2,0:2])
1011
dt.fini()

0 commit comments

Comments
 (0)