@@ -37,7 +37,7 @@ namespace pybind11
3737 {
3838 static auto get (handle src)
3939 {
40- return array_t <T, array::f_style>::ensure (src);
40+ return array_t <T, array::f_style | array::forcecast >::ensure (src);
4141 }
4242 };
4343
@@ -51,8 +51,7 @@ namespace pybind11
5151 {
5252 static auto get (handle src)
5353 {
54- auto buf = xtensor_get_buffer<T, L>::get (src);
55- return buf;
54+ return xtensor_get_buffer<T, L>::get (src);
5655 }
5756 };
5857
@@ -61,11 +60,7 @@ namespace pybind11
6160 {
6261 static auto get (handle src)
6362 {
64- auto buf = xtensor_get_buffer<T, L>::get (src);
65- if (buf.ndim () != N) {
66- return false ;
67- }
68- return buf;
63+ return xtensor_get_buffer<T, L>::get (src);
6964 }
7065 };
7166
@@ -98,6 +93,27 @@ namespace pybind11
9893 };
9994
10095
96+ template <class T >
97+ struct xtensor_verify
98+ {
99+ template <class B >
100+ static bool get (const B& buf)
101+ {
102+ return true ;
103+ }
104+ };
105+
106+ template <class T , std::size_t N, xt::layout_type L>
107+ struct xtensor_verify <xt::xtensor<T, N, L>>
108+ {
109+ template <class B >
110+ static bool get (const B& buf)
111+ {
112+ return buf.ndim () == N;
113+ }
114+ };
115+
116+
101117 // Casts a strided expression type to numpy array.If given a base,
102118 // the numpy array references the src data, otherwise it'll make a copy.
103119 // The writeable attributes lets you specify writeable flag for the array.
@@ -192,11 +208,14 @@ namespace pybind11
192208 if (!buf) {
193209 return false ;
194210 }
211+ if (!xtensor_verify<Type>::get (buf)) {
212+ return false ;
213+ }
195214
196215 std::vector<size_t > shape (buf.ndim ());
197216 std::copy (buf.shape (), buf.shape () + buf.ndim (), shape.begin ());
198- value = Type (shape);
199- std::copy (buf.data (), buf.data () + buf.size (), value.begin ());
217+ value = Type::from_shape (shape);
218+ std::copy (buf.data (), buf.data () + buf.size (), value.data ());
200219
201220 return true ;
202221 }
0 commit comments