1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/lib/core/py_seq_tensor.h" |
17 | |
18 | #include "tensorflow/core/framework/tensor.h" |
19 | #include "tensorflow/core/framework/tensor_shape.h" |
20 | #include "tensorflow/core/framework/types.h" |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/lib/core/stringpiece.h" |
23 | #include "tensorflow/core/lib/strings/str_util.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | #include "tensorflow/python/lib/core/numpy.h" |
26 | #include "tensorflow/python/lib/core/py_util.h" |
27 | #include "tensorflow/python/lib/core/safe_ptr.h" |
28 | |
29 | namespace tensorflow { |
30 | namespace { |
31 | |
32 | inline bool PyIsInstance(PyObject* obj, PyTypeObject* t) { |
33 | return PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(t)); |
34 | } |
35 | |
36 | inline PyObject* PyType(PyObject* obj) { |
37 | return reinterpret_cast<PyObject*>(obj->ob_type); |
38 | } |
39 | |
40 | bool IsPyString(PyObject* obj) { |
41 | return PyBytes_Check(obj) || PyUnicode_Check(obj); |
42 | } |
43 | |
44 | bool IsPyInt(PyObject* obj) { |
45 | #if PY_MAJOR_VERSION >= 3 |
46 | return PyLong_Check(obj) || |
47 | PyIsInstance(obj, &PyIntegerArrType_Type); // NumPy integers |
48 | #else |
49 | return PyInt_Check(obj) || PyLong_Check(obj) || |
50 | PyIsInstance(obj, &PyIntegerArrType_Type); // NumPy integers |
51 | #endif |
52 | } |
53 | |
54 | bool IsPyFloat(PyObject* obj) { |
55 | return PyFloat_Check(obj) || |
56 | PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types |
57 | } |
58 | |
59 | // Converts Python object `c` that should hold a Python string into a |
60 | // C++ string in *out. Returns nullptr on success, or a message on error. |
61 | // Defined below, but forward declared here for use in PyRepr. |
62 | const char* ConvertOneString(PyObject* v, string* out); |
63 | |
64 | string PyRepr(PyObject* obj) { |
65 | if (obj == nullptr) { |
66 | return "<null>" ; |
67 | } |
68 | Safe_PyObjectPtr repr_obj = make_safe(PyObject_Repr(obj)); |
69 | if (repr_obj) { |
70 | string repr_str; |
71 | if (ConvertOneString(repr_obj.get(), &repr_str) == nullptr) { |
72 | return repr_str; |
73 | } |
74 | } |
75 | return "<error computing repr()>" ; |
76 | } |
77 | |
78 | bool IsPyDimension(PyObject* obj) { |
79 | const char* tp_name = obj->ob_type->tp_name; |
80 | if (strcmp(tp_name, "Dimension" ) != 0) return false; |
81 | bool ret = str_util::EndsWith( |
82 | PyRepr(PyType(obj)), |
83 | "tensorflow.python.framework.tensor_shape.Dimension'>" ); |
84 | return ret; |
85 | } |
86 | |
87 | Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { |
88 | std::vector<Safe_PyObjectPtr> refs_to_clean; |
89 | while (true) { |
90 | // We test strings first, in case a string is considered a sequence. |
91 | if (IsPyString(obj)) { |
92 | *dtype = DT_STRING; |
93 | } else if (PySequence_Check(obj)) { |
94 | auto length = PySequence_Length(obj); |
95 | if (length > 0) { |
96 | shape->AddDim(length); |
97 | obj = PySequence_GetItem(obj, 0); |
98 | refs_to_clean.push_back(make_safe(obj)); |
99 | continue; |
100 | } else if (length == 0) { |
101 | shape->AddDim(length); |
102 | *dtype = DT_INVALID; // Invalid dtype for empty tensors. |
103 | } else { |
104 | // The sequence does not have a valid length (PySequence_Length < 0). |
105 | if (PyErr_Occurred()) { |
106 | // PySequence_Length failed and set an exception. Fetch the message |
107 | // and convert it to a failed status. |
108 | return errors::InvalidArgument(PyExceptionFetch()); |
109 | } else { |
110 | // This is almost certainly dead code: PySequence_Length failed but |
111 | // did not set an exception. |
112 | return errors::InvalidArgument( |
113 | "Attempted to convert an invalid sequence to a Tensor." ); |
114 | } |
115 | } |
116 | } else if (IsPyFloat(obj)) { |
117 | *dtype = DT_DOUBLE; |
118 | } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) { |
119 | // Have to test for bool before int, since IsInt(True/False) == true. |
120 | *dtype = DT_BOOL; |
121 | } else if (IsPyInt(obj)) { |
122 | *dtype = DT_INT64; |
123 | } else if (IsPyDimension(obj)) { |
124 | *dtype = DT_INT64; |
125 | } else if (PyComplex_Check(obj) || |
126 | PyIsInstance(obj, &PyComplexFloatingArrType_Type)) { // NumPy |
127 | *dtype = DT_COMPLEX128; |
128 | } else { |
129 | return errors::InvalidArgument("Attempt to convert a value (" , |
130 | PyRepr(obj), |
131 | ") with an unsupported type (" , |
132 | PyRepr(PyType(obj)), ") to a Tensor." ); |
133 | } |
134 | return Status::OK(); |
135 | } |
136 | } |
137 | |
138 | // Error messages |
139 | |
140 | const char ErrorConverting[] = |
141 | "Error while converting Python sequence to Tensor." ; |
142 | const char ErrorRectangular[] = |
143 | "Can't convert non-rectangular Python sequence to Tensor." ; |
144 | const char ErrorMixedTypes[] = |
145 | "Can't convert Python sequence with mixed types to Tensor." ; |
146 | const char ErrorOutOfRange[] = |
147 | "Can't convert Python sequence with out-of-range integer to Tensor." ; |
148 | const char ErrorOutOfRangeDouble[] = |
149 | "Can't convert Python sequence with a value out of range for a " |
150 | "double-precision float." ; |
151 | const char ErrorConvertingUnicodeString[] = |
152 | "Error converting unicode string while converting Python sequence to " |
153 | "Tensor." ; |
154 | const char ErrorFoundInt64[] = |
155 | "Can't convert Python sequence with out-of-range integer to int32 Tensor." ; |
156 | const char ErrorFoundFloat[] = |
157 | "Can't convert Python sequence with floating point values to integer " |
158 | "Tensor." ; |
159 | |
160 | // Template for defining a function for recursively convering obj into |
161 | // an array of TYPE using the conversion function CONVERT. |
162 | // Note that these helper functions require shape.dims() >= 1. |
163 | |
164 | #define DEFINE_HELPER(FUNCTION, TYPE, TYPE_ENUM, CONVERT) \ |
165 | const char* FUNCTION##Helper(PyObject* obj, const TensorShape& shape, \ |
166 | TYPE** buf) { \ |
167 | if (TF_PREDICT_FALSE(obj == nullptr)) { \ |
168 | return ErrorConverting; \ |
169 | } \ |
170 | if (shape.dims() > 1) { \ |
171 | /* Iterate over outer dim, and recursively convert each element. */ \ |
172 | const int64 s = shape.dim_size(0); \ |
173 | Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \ |
174 | if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \ |
175 | return ErrorRectangular; \ |
176 | } \ |
177 | TensorShape rest = shape; \ |
178 | rest.RemoveDim(0); \ |
179 | for (int64 i = 0; i < s; ++i) { \ |
180 | const char* error = FUNCTION##Helper( \ |
181 | PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf); \ |
182 | if (TF_PREDICT_FALSE(error != nullptr)) return error; \ |
183 | } \ |
184 | } else { \ |
185 | Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \ |
186 | if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular; \ |
187 | const int64 s = shape.dim_size(0); \ |
188 | if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \ |
189 | return ErrorRectangular; \ |
190 | } \ |
191 | PyObject** l = PySequence_Fast_ITEMS(seq.get()); \ |
192 | for (int64 i = 0; i < s; ++i) { \ |
193 | const char* error = CONVERT(l[i], *buf); \ |
194 | if (TF_PREDICT_FALSE(error != nullptr)) return error; \ |
195 | ++*buf; \ |
196 | } \ |
197 | } \ |
198 | return nullptr; \ |
199 | } \ |
200 | const char* FUNCTION(PyObject* obj, const TensorShape& shape, \ |
201 | Tensor* dest) { \ |
202 | /* TODO(josh11b): Allocator & attributes? */ \ |
203 | Tensor result(TYPE_ENUM, shape); \ |
204 | if (shape.dims() == 0) { /* Scalar case */ \ |
205 | TYPE value; \ |
206 | const char* error = CONVERT(obj, &value); \ |
207 | if (error != nullptr) return error; \ |
208 | result.scalar<TYPE>()() = value; \ |
209 | } else { \ |
210 | TYPE* buf = result.flat<TYPE>().data(); \ |
211 | const char* error = FUNCTION##Helper(obj, shape, &buf); \ |
212 | if (error != nullptr) return error; \ |
213 | } \ |
214 | *dest = result; \ |
215 | return nullptr; \ |
216 | } |
217 | |
218 | // Int support |
219 | |
220 | const char* ConvertOneInt64(PyObject* v, int64* out) { |
221 | #if PY_MAJOR_VERSION < 3 |
222 | if (TF_PREDICT_TRUE(PyInt_Check(v))) { |
223 | *out = PyInt_AS_LONG(v); |
224 | return nullptr; |
225 | } |
226 | #endif |
227 | if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) { |
228 | int overflow = 0; |
229 | // Have to use LongLong for 64 bits, since long is 32 bits on Windows. |
230 | *out = PyLong_AsLongLongAndOverflow(v, &overflow); |
231 | if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange; |
232 | return nullptr; |
233 | } |
234 | if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers |
235 | #if PY_MAJOR_VERSION < 3 |
236 | Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v)); |
237 | #else |
238 | Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v)); |
239 | #endif |
240 | return ConvertOneInt64(as_int.get(), out); |
241 | } |
242 | if (IsPyFloat(v)) return ErrorFoundFloat; |
243 | return ErrorMixedTypes; |
244 | } |
245 | |
246 | DEFINE_HELPER(ConvertInt64, int64, DT_INT64, ConvertOneInt64); |
247 | |
248 | const char* ConvertOneInt32(PyObject* v, int32* out) { |
249 | int64 i; |
250 | #if PY_MAJOR_VERSION < 3 |
251 | if (TF_PREDICT_TRUE(PyInt_Check(v))) { |
252 | i = PyInt_AS_LONG(v); |
253 | } else |
254 | #endif |
255 | if (PyLong_Check(v) || IsPyDimension(v)) { |
256 | int overflow = 0; |
257 | // Have to use LongLong for 64 bits, since long is 32 bits on Windows. |
258 | i = PyLong_AsLongLongAndOverflow(v, &overflow); |
259 | if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange; |
260 | } else if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers |
261 | #if PY_MAJOR_VERSION < 3 |
262 | Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v)); |
263 | #else |
264 | Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v)); |
265 | #endif |
266 | return ConvertOneInt32(as_int.get(), out); |
267 | } else if (IsPyFloat(v)) { |
268 | return ErrorFoundFloat; |
269 | } else { |
270 | return ErrorMixedTypes; |
271 | } |
272 | *out = static_cast<uint32>(static_cast<uint64>(i)); |
273 | // Check for 32-bit overflow. |
274 | if (TF_PREDICT_FALSE(i != *out)) return ErrorFoundInt64; |
275 | return nullptr; |
276 | } |
277 | |
278 | DEFINE_HELPER(ConvertInt32, int32, DT_INT32, ConvertOneInt32); |
279 | |
280 | // Floating-point support |
281 | |
282 | template <class T> |
283 | const char* ConvertOneFloat(PyObject* v, T* out) { |
284 | if (TF_PREDICT_TRUE(PyFloat_Check(v))) { |
285 | *out = PyFloat_AS_DOUBLE(v); |
286 | return nullptr; |
287 | } |
288 | #if PY_MAJOR_VERSION < 3 |
289 | if (PyInt_Check(v)) { |
290 | *out = PyInt_AS_LONG(v); |
291 | return nullptr; |
292 | } |
293 | #endif |
294 | if (PyLong_Check(v)) { |
295 | *out = PyLong_AsDouble(v); |
296 | if (PyErr_Occurred()) return ErrorOutOfRangeDouble; |
297 | return nullptr; |
298 | } |
299 | if (PyIsInstance(v, &PyFloatingArrType_Type)) { // NumPy float types |
300 | Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v)); |
301 | return ConvertOneFloat<T>(as_float.get(), out); |
302 | } |
303 | if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers |
304 | #if PY_MAJOR_VERSION < 3 |
305 | Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v)); |
306 | #else |
307 | Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v)); |
308 | #endif |
309 | return ConvertOneFloat<T>(as_int.get(), out); |
310 | } |
311 | return ErrorMixedTypes; |
312 | } |
313 | |
314 | DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat<double>); |
315 | DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat<float>); |
316 | |
317 | // String support |
318 | |
319 | const char* ConvertOneString(PyObject* v, string* out) { |
320 | if (PyBytes_Check(v)) { |
321 | out->assign(PyBytes_AS_STRING(v), PyBytes_GET_SIZE(v)); |
322 | return nullptr; |
323 | } |
324 | if (PyUnicode_Check(v)) { |
325 | #if PY_MAJOR_VERSION >= 3 |
326 | Py_ssize_t size; |
327 | const char* str = PyUnicode_AsUTF8AndSize(v, &size); |
328 | if (str == nullptr) return ErrorConvertingUnicodeString; |
329 | out->assign(str, size); |
330 | return nullptr; |
331 | #else |
332 | PyObject* py_str = PyUnicode_AsUTF8String(v); |
333 | if (py_str == nullptr) return ErrorConvertingUnicodeString; |
334 | out->assign(PyBytes_AS_STRING(py_str), PyBytes_GET_SIZE(py_str)); |
335 | Py_DECREF(py_str); |
336 | return nullptr; |
337 | #endif |
338 | } |
339 | return ErrorMixedTypes; |
340 | } |
341 | |
342 | DEFINE_HELPER(ConvertString, string, DT_STRING, ConvertOneString); |
343 | |
344 | // Complex support |
345 | |
346 | const char* ConvertOneComplex(PyObject* v, complex128* out) { |
347 | if (PyComplex_Check(v)) { |
348 | *out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v)); |
349 | return nullptr; |
350 | } else if (PyIsInstance(v, &PyComplexFloatingArrType_Type)) { // NumPy |
351 | auto as_complex = PyComplex_AsCComplex(v); |
352 | *out = complex128(as_complex.real, as_complex.imag); |
353 | return nullptr; |
354 | } |
355 | return ErrorMixedTypes; |
356 | } |
357 | |
358 | DEFINE_HELPER(ConvertComplex, complex128, DT_COMPLEX128, ConvertOneComplex); |
359 | |
360 | // Bool support |
361 | |
362 | const char* ConvertOneBool(PyObject* v, bool* out) { |
363 | if (v == Py_True) { |
364 | *out = true; |
365 | } else if (v == Py_False) { |
366 | *out = false; |
367 | } else if (PyIsInstance(v, &PyBoolArrType_Type)) { // NumPy |
368 | *out = PyObject_IsTrue(v); |
369 | } else { |
370 | return ErrorMixedTypes; |
371 | } |
372 | return nullptr; |
373 | } |
374 | |
375 | DEFINE_HELPER(ConvertBool, bool, DT_BOOL, ConvertOneBool); |
376 | |
377 | #undef DEFINE_HELPER |
378 | |
379 | } // namespace |
380 | |
381 | #define RETURN_STRING_AS_STATUS(...) \ |
382 | do { \ |
383 | const char* _error = (__VA_ARGS__); \ |
384 | if (TF_PREDICT_TRUE(_error == nullptr)) return Status::OK(); \ |
385 | return errors::InvalidArgument(_error); \ |
386 | } while (0) |
387 | |
388 | Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { |
389 | DataType infer_dtype; |
390 | TensorShape shape; |
391 | TF_RETURN_IF_ERROR(InferShapeAndType(obj, &shape, &infer_dtype)); |
392 | DataType requested_dtype = DT_INVALID; |
393 | if (dtype != Py_None) { |
394 | int32 dtype_as_int = -1; |
395 | if (ConvertOneInt32(dtype, &dtype_as_int) == nullptr) { |
396 | requested_dtype = static_cast<DataType>(dtype_as_int); |
397 | } |
398 | } |
399 | // NOTE(josh11b): If don't successfully convert to the requested type, |
400 | // we just try instead to create a tensor of the inferred type and |
401 | // let the caller convert it to the requested type using a cast |
402 | // operation. |
403 | switch (requested_dtype) { |
404 | case DT_FLOAT: |
405 | if (ConvertFloat(obj, shape, ret) == nullptr) return Status::OK(); |
406 | break; |
407 | |
408 | case DT_DOUBLE: |
409 | if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK(); |
410 | break; |
411 | |
412 | case DT_INT64: |
413 | if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK(); |
414 | break; |
415 | |
416 | case DT_INT32: |
417 | if (ConvertInt32(obj, shape, ret) == nullptr) return Status::OK(); |
418 | break; |
419 | |
420 | case DT_COMPLEX128: |
421 | if (ConvertComplex(obj, shape, ret) == nullptr) return Status::OK(); |
422 | break; |
423 | |
424 | case DT_STRING: |
425 | if (ConvertString(obj, shape, ret) == nullptr) return Status::OK(); |
426 | break; |
427 | |
428 | case DT_BOOL: |
429 | if (ConvertBool(obj, shape, ret) == nullptr) return Status::OK(); |
430 | break; |
431 | |
432 | default: |
433 | break; |
434 | } |
435 | switch (infer_dtype) { |
436 | case DT_DOUBLE: |
437 | // TODO(josh11b): Handle mixed floats and complex numbers? |
438 | if (requested_dtype == DT_INVALID) { |
439 | // TensorFlow uses float32s to represent floating point numbers |
440 | // by default (for space and speed over using doubles). |
441 | RETURN_STRING_AS_STATUS(ConvertFloat(obj, shape, ret)); |
442 | } else { |
443 | // We are going to do a cast to the user's requested dtype |
444 | // after this. We use doubles for this intermediate result so |
445 | // we don't lose precision that might be representable in the |
446 | // final type. |
447 | RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); |
448 | } |
449 | |
450 | case DT_INT64: |
451 | if (requested_dtype == DT_INVALID) { |
452 | const char* error = ConvertInt32(obj, shape, ret); |
453 | if (error == ErrorFoundInt64) { |
454 | error = ConvertInt64(obj, shape, ret); |
455 | } |
456 | if (error == ErrorFoundFloat) { |
457 | error = ConvertFloat(obj, shape, ret); |
458 | } |
459 | // TODO(josh11b): May also want to fall back to using doubles if |
460 | // error == ErrorOutOfRange? |
461 | RETURN_STRING_AS_STATUS(error); |
462 | } else { |
463 | const char* error = ConvertInt64(obj, shape, ret); |
464 | if (error == ErrorFoundFloat) { |
465 | error = ConvertDouble(obj, shape, ret); |
466 | } |
467 | RETURN_STRING_AS_STATUS(error); |
468 | } |
469 | |
470 | case DT_STRING: |
471 | RETURN_STRING_AS_STATUS(ConvertString(obj, shape, ret)); |
472 | |
473 | case DT_COMPLEX128: |
474 | RETURN_STRING_AS_STATUS(ConvertComplex(obj, shape, ret)); |
475 | |
476 | case DT_BOOL: |
477 | RETURN_STRING_AS_STATUS(ConvertBool(obj, shape, ret)); |
478 | |
479 | case DT_INVALID: // Only occurs for empty tensors. |
480 | *ret = Tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, |
481 | shape); |
482 | return Status::OK(); |
483 | |
484 | default: |
485 | return errors::Unimplemented("Missing Python -> Tensor conversion for " , |
486 | DataTypeString(infer_dtype)); |
487 | } |
488 | |
489 | return Status::OK(); |
490 | } |
491 | |
492 | } // namespace tensorflow |
493 | |