@@ -23,17 +23,34 @@ TensorBase convert_numpy_to_tensor_base(pybind11::array_t<T> py_buf)
2323 return static_cast <unsigned int >(dim);
2424 }
2525 );
26- warp_type (warp_type (typeid (T)));
27- return TensorBase (typeid (T), shape_vec, info.ptr );
26+ return TensorBase (warp_type (warp_type (typeid (T))), shape_vec, info.ptr );
2827}
2928
3029pybind11::dtype get_py_type (const std::type_info& info)
3130{
31+ if (info == typeid (std::int8_t ))
32+ return pybind11::dtype::of<std::int8_t >();
33+ if (info == typeid (std::int16_t ))
34+ return pybind11::dtype::of<std::int16_t >();
35+ if (info == typeid (std::int32_t ))
36+ return pybind11::dtype::of<std::int32_t >();
37+ if (info == typeid (std::int64_t ))
38+ return pybind11::dtype::of<std::int64_t >();
39+ if (info == typeid (std::uint8_t ))
40+ return pybind11::dtype::of<std::uint8_t >();
41+ if (info == typeid (std::uint16_t ))
42+ return pybind11::dtype::of<std::uint16_t >();
43+ if (info == typeid (std::uint32_t ))
44+ return pybind11::dtype::of<std::uint32_t >();
45+ if (info == typeid (std::uint64_t ))
46+ return pybind11::dtype::of<std::uint64_t >();
3247 if (info == typeid (bool ))
3348 return pybind11::dtype::of<bool >();
3449 if (info == typeid (float ))
3550 return pybind11::dtype::of<float >();
36- throw std::exception ();
51+ if (info == typeid (double ))
52+ return pybind11::dtype::of<double >();
53+ throw std::runtime_error (" no dtype" );
3754}
3855
3956pybind11::array convert_tensor_to_numpy (const Tensor& self)
@@ -125,6 +142,11 @@ pybind11::tuple tensor_shape(const Tensor& self)
125142 return pybind11::cast (std::vector (self.get_buffer ().shape ()));
126143}
127144
145+ DataType tensor_type (const Tensor& self)
146+ {
147+ return warp_type (self.get_buffer ().type ());
148+ }
149+
128150Tensor tensor_copying (const Tensor& self)
129151{
130152 return self;
@@ -205,6 +227,7 @@ PYBIND11_MODULE(tensor2, m)
205227 .def (" condition" , &condition)
206228 .def (" numpy" , &convert_tensor_to_numpy)
207229 .def (" shape" , &tensor_shape)
230+ .def (" dtype" , &tensor_type)
208231 .def (" __getitem__" , &python_index)
209232 .def (" __getitem__" , &python_slice)
210233 .def (" __getitem__" , &python_tuple_slice)
0 commit comments