1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/python/lib/core/py_seq_tensor.h"
17
18#include "tensorflow/core/framework/tensor.h"
19#include "tensorflow/core/framework/tensor_shape.h"
20#include "tensorflow/core/framework/types.h"
21#include "tensorflow/core/lib/core/errors.h"
22#include "tensorflow/core/lib/core/stringpiece.h"
23#include "tensorflow/core/lib/strings/str_util.h"
24#include "tensorflow/core/platform/types.h"
25#include "tensorflow/python/lib/core/numpy.h"
26#include "tensorflow/python/lib/core/py_util.h"
27#include "tensorflow/python/lib/core/safe_ptr.h"
28
29namespace tensorflow {
30namespace {
31
32inline bool PyIsInstance(PyObject* obj, PyTypeObject* t) {
33 return PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(t));
34}
35
36inline PyObject* PyType(PyObject* obj) {
37 return reinterpret_cast<PyObject*>(obj->ob_type);
38}
39
40bool IsPyString(PyObject* obj) {
41 return PyBytes_Check(obj) || PyUnicode_Check(obj);
42}
43
44bool IsPyInt(PyObject* obj) {
45#if PY_MAJOR_VERSION >= 3
46 return PyLong_Check(obj) ||
47 PyIsInstance(obj, &PyIntegerArrType_Type); // NumPy integers
48#else
49 return PyInt_Check(obj) || PyLong_Check(obj) ||
50 PyIsInstance(obj, &PyIntegerArrType_Type); // NumPy integers
51#endif
52}
53
54bool IsPyFloat(PyObject* obj) {
55 return PyFloat_Check(obj) ||
56 PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types
57}
58
59// Converts Python object `c` that should hold a Python string into a
60// C++ string in *out. Returns nullptr on success, or a message on error.
61// Defined below, but forward declared here for use in PyRepr.
62const char* ConvertOneString(PyObject* v, string* out);
63
64string PyRepr(PyObject* obj) {
65 if (obj == nullptr) {
66 return "<null>";
67 }
68 Safe_PyObjectPtr repr_obj = make_safe(PyObject_Repr(obj));
69 if (repr_obj) {
70 string repr_str;
71 if (ConvertOneString(repr_obj.get(), &repr_str) == nullptr) {
72 return repr_str;
73 }
74 }
75 return "<error computing repr()>";
76}
77
78bool IsPyDimension(PyObject* obj) {
79 const char* tp_name = obj->ob_type->tp_name;
80 if (strcmp(tp_name, "Dimension") != 0) return false;
81 bool ret = str_util::EndsWith(
82 PyRepr(PyType(obj)),
83 "tensorflow.python.framework.tensor_shape.Dimension'>");
84 return ret;
85}
86
87Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
88 std::vector<Safe_PyObjectPtr> refs_to_clean;
89 while (true) {
90 // We test strings first, in case a string is considered a sequence.
91 if (IsPyString(obj)) {
92 *dtype = DT_STRING;
93 } else if (PySequence_Check(obj)) {
94 auto length = PySequence_Length(obj);
95 if (length > 0) {
96 shape->AddDim(length);
97 obj = PySequence_GetItem(obj, 0);
98 refs_to_clean.push_back(make_safe(obj));
99 continue;
100 } else if (length == 0) {
101 shape->AddDim(length);
102 *dtype = DT_INVALID; // Invalid dtype for empty tensors.
103 } else {
104 // The sequence does not have a valid length (PySequence_Length < 0).
105 if (PyErr_Occurred()) {
106 // PySequence_Length failed and set an exception. Fetch the message
107 // and convert it to a failed status.
108 return errors::InvalidArgument(PyExceptionFetch());
109 } else {
110 // This is almost certainly dead code: PySequence_Length failed but
111 // did not set an exception.
112 return errors::InvalidArgument(
113 "Attempted to convert an invalid sequence to a Tensor.");
114 }
115 }
116 } else if (IsPyFloat(obj)) {
117 *dtype = DT_DOUBLE;
118 } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) {
119 // Have to test for bool before int, since IsInt(True/False) == true.
120 *dtype = DT_BOOL;
121 } else if (IsPyInt(obj)) {
122 *dtype = DT_INT64;
123 } else if (IsPyDimension(obj)) {
124 *dtype = DT_INT64;
125 } else if (PyComplex_Check(obj) ||
126 PyIsInstance(obj, &PyComplexFloatingArrType_Type)) { // NumPy
127 *dtype = DT_COMPLEX128;
128 } else {
129 return errors::InvalidArgument("Attempt to convert a value (",
130 PyRepr(obj),
131 ") with an unsupported type (",
132 PyRepr(PyType(obj)), ") to a Tensor.");
133 }
134 return Status::OK();
135 }
136}
137
138// Error messages
139
140const char ErrorConverting[] =
141 "Error while converting Python sequence to Tensor.";
142const char ErrorRectangular[] =
143 "Can't convert non-rectangular Python sequence to Tensor.";
144const char ErrorMixedTypes[] =
145 "Can't convert Python sequence with mixed types to Tensor.";
146const char ErrorOutOfRange[] =
147 "Can't convert Python sequence with out-of-range integer to Tensor.";
148const char ErrorOutOfRangeDouble[] =
149 "Can't convert Python sequence with a value out of range for a "
150 "double-precision float.";
151const char ErrorConvertingUnicodeString[] =
152 "Error converting unicode string while converting Python sequence to "
153 "Tensor.";
154const char ErrorFoundInt64[] =
155 "Can't convert Python sequence with out-of-range integer to int32 Tensor.";
156const char ErrorFoundFloat[] =
157 "Can't convert Python sequence with floating point values to integer "
158 "Tensor.";
159
160// Template for defining a function for recursively convering obj into
161// an array of TYPE using the conversion function CONVERT.
162// Note that these helper functions require shape.dims() >= 1.
163
164#define DEFINE_HELPER(FUNCTION, TYPE, TYPE_ENUM, CONVERT) \
165 const char* FUNCTION##Helper(PyObject* obj, const TensorShape& shape, \
166 TYPE** buf) { \
167 if (TF_PREDICT_FALSE(obj == nullptr)) { \
168 return ErrorConverting; \
169 } \
170 if (shape.dims() > 1) { \
171 /* Iterate over outer dim, and recursively convert each element. */ \
172 const int64 s = shape.dim_size(0); \
173 Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \
174 if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \
175 return ErrorRectangular; \
176 } \
177 TensorShape rest = shape; \
178 rest.RemoveDim(0); \
179 for (int64 i = 0; i < s; ++i) { \
180 const char* error = FUNCTION##Helper( \
181 PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf); \
182 if (TF_PREDICT_FALSE(error != nullptr)) return error; \
183 } \
184 } else { \
185 Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \
186 if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular; \
187 const int64 s = shape.dim_size(0); \
188 if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \
189 return ErrorRectangular; \
190 } \
191 PyObject** l = PySequence_Fast_ITEMS(seq.get()); \
192 for (int64 i = 0; i < s; ++i) { \
193 const char* error = CONVERT(l[i], *buf); \
194 if (TF_PREDICT_FALSE(error != nullptr)) return error; \
195 ++*buf; \
196 } \
197 } \
198 return nullptr; \
199 } \
200 const char* FUNCTION(PyObject* obj, const TensorShape& shape, \
201 Tensor* dest) { \
202 /* TODO(josh11b): Allocator & attributes? */ \
203 Tensor result(TYPE_ENUM, shape); \
204 if (shape.dims() == 0) { /* Scalar case */ \
205 TYPE value; \
206 const char* error = CONVERT(obj, &value); \
207 if (error != nullptr) return error; \
208 result.scalar<TYPE>()() = value; \
209 } else { \
210 TYPE* buf = result.flat<TYPE>().data(); \
211 const char* error = FUNCTION##Helper(obj, shape, &buf); \
212 if (error != nullptr) return error; \
213 } \
214 *dest = result; \
215 return nullptr; \
216 }
217
218// Int support
219
220const char* ConvertOneInt64(PyObject* v, int64* out) {
221#if PY_MAJOR_VERSION < 3
222 if (TF_PREDICT_TRUE(PyInt_Check(v))) {
223 *out = PyInt_AS_LONG(v);
224 return nullptr;
225 }
226#endif
227 if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) {
228 int overflow = 0;
229 // Have to use LongLong for 64 bits, since long is 32 bits on Windows.
230 *out = PyLong_AsLongLongAndOverflow(v, &overflow);
231 if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange;
232 return nullptr;
233 }
234 if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
235#if PY_MAJOR_VERSION < 3
236 Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
237#else
238 Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
239#endif
240 return ConvertOneInt64(as_int.get(), out);
241 }
242 if (IsPyFloat(v)) return ErrorFoundFloat;
243 return ErrorMixedTypes;
244}
245
246DEFINE_HELPER(ConvertInt64, int64, DT_INT64, ConvertOneInt64);
247
248const char* ConvertOneInt32(PyObject* v, int32* out) {
249 int64 i;
250#if PY_MAJOR_VERSION < 3
251 if (TF_PREDICT_TRUE(PyInt_Check(v))) {
252 i = PyInt_AS_LONG(v);
253 } else
254#endif
255 if (PyLong_Check(v) || IsPyDimension(v)) {
256 int overflow = 0;
257 // Have to use LongLong for 64 bits, since long is 32 bits on Windows.
258 i = PyLong_AsLongLongAndOverflow(v, &overflow);
259 if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange;
260 } else if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
261#if PY_MAJOR_VERSION < 3
262 Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
263#else
264 Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
265#endif
266 return ConvertOneInt32(as_int.get(), out);
267 } else if (IsPyFloat(v)) {
268 return ErrorFoundFloat;
269 } else {
270 return ErrorMixedTypes;
271 }
272 *out = static_cast<uint32>(static_cast<uint64>(i));
273 // Check for 32-bit overflow.
274 if (TF_PREDICT_FALSE(i != *out)) return ErrorFoundInt64;
275 return nullptr;
276}
277
278DEFINE_HELPER(ConvertInt32, int32, DT_INT32, ConvertOneInt32);
279
280// Floating-point support
281
282template <class T>
283const char* ConvertOneFloat(PyObject* v, T* out) {
284 if (TF_PREDICT_TRUE(PyFloat_Check(v))) {
285 *out = PyFloat_AS_DOUBLE(v);
286 return nullptr;
287 }
288#if PY_MAJOR_VERSION < 3
289 if (PyInt_Check(v)) {
290 *out = PyInt_AS_LONG(v);
291 return nullptr;
292 }
293#endif
294 if (PyLong_Check(v)) {
295 *out = PyLong_AsDouble(v);
296 if (PyErr_Occurred()) return ErrorOutOfRangeDouble;
297 return nullptr;
298 }
299 if (PyIsInstance(v, &PyFloatingArrType_Type)) { // NumPy float types
300 Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
301 return ConvertOneFloat<T>(as_float.get(), out);
302 }
303 if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
304#if PY_MAJOR_VERSION < 3
305 Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
306#else
307 Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
308#endif
309 return ConvertOneFloat<T>(as_int.get(), out);
310 }
311 return ErrorMixedTypes;
312}
313
314DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat<double>);
315DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat<float>);
316
317// String support
318
319const char* ConvertOneString(PyObject* v, string* out) {
320 if (PyBytes_Check(v)) {
321 out->assign(PyBytes_AS_STRING(v), PyBytes_GET_SIZE(v));
322 return nullptr;
323 }
324 if (PyUnicode_Check(v)) {
325#if PY_MAJOR_VERSION >= 3
326 Py_ssize_t size;
327 const char* str = PyUnicode_AsUTF8AndSize(v, &size);
328 if (str == nullptr) return ErrorConvertingUnicodeString;
329 out->assign(str, size);
330 return nullptr;
331#else
332 PyObject* py_str = PyUnicode_AsUTF8String(v);
333 if (py_str == nullptr) return ErrorConvertingUnicodeString;
334 out->assign(PyBytes_AS_STRING(py_str), PyBytes_GET_SIZE(py_str));
335 Py_DECREF(py_str);
336 return nullptr;
337#endif
338 }
339 return ErrorMixedTypes;
340}
341
342DEFINE_HELPER(ConvertString, string, DT_STRING, ConvertOneString);
343
344// Complex support
345
346const char* ConvertOneComplex(PyObject* v, complex128* out) {
347 if (PyComplex_Check(v)) {
348 *out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v));
349 return nullptr;
350 } else if (PyIsInstance(v, &PyComplexFloatingArrType_Type)) { // NumPy
351 auto as_complex = PyComplex_AsCComplex(v);
352 *out = complex128(as_complex.real, as_complex.imag);
353 return nullptr;
354 }
355 return ErrorMixedTypes;
356}
357
358DEFINE_HELPER(ConvertComplex, complex128, DT_COMPLEX128, ConvertOneComplex);
359
360// Bool support
361
362const char* ConvertOneBool(PyObject* v, bool* out) {
363 if (v == Py_True) {
364 *out = true;
365 } else if (v == Py_False) {
366 *out = false;
367 } else if (PyIsInstance(v, &PyBoolArrType_Type)) { // NumPy
368 *out = PyObject_IsTrue(v);
369 } else {
370 return ErrorMixedTypes;
371 }
372 return nullptr;
373}
374
375DEFINE_HELPER(ConvertBool, bool, DT_BOOL, ConvertOneBool);
376
377#undef DEFINE_HELPER
378
379} // namespace
380
381#define RETURN_STRING_AS_STATUS(...) \
382 do { \
383 const char* _error = (__VA_ARGS__); \
384 if (TF_PREDICT_TRUE(_error == nullptr)) return Status::OK(); \
385 return errors::InvalidArgument(_error); \
386 } while (0)
387
388Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
389 DataType infer_dtype;
390 TensorShape shape;
391 TF_RETURN_IF_ERROR(InferShapeAndType(obj, &shape, &infer_dtype));
392 DataType requested_dtype = DT_INVALID;
393 if (dtype != Py_None) {
394 int32 dtype_as_int = -1;
395 if (ConvertOneInt32(dtype, &dtype_as_int) == nullptr) {
396 requested_dtype = static_cast<DataType>(dtype_as_int);
397 }
398 }
399 // NOTE(josh11b): If don't successfully convert to the requested type,
400 // we just try instead to create a tensor of the inferred type and
401 // let the caller convert it to the requested type using a cast
402 // operation.
403 switch (requested_dtype) {
404 case DT_FLOAT:
405 if (ConvertFloat(obj, shape, ret) == nullptr) return Status::OK();
406 break;
407
408 case DT_DOUBLE:
409 if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK();
410 break;
411
412 case DT_INT64:
413 if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK();
414 break;
415
416 case DT_INT32:
417 if (ConvertInt32(obj, shape, ret) == nullptr) return Status::OK();
418 break;
419
420 case DT_COMPLEX128:
421 if (ConvertComplex(obj, shape, ret) == nullptr) return Status::OK();
422 break;
423
424 case DT_STRING:
425 if (ConvertString(obj, shape, ret) == nullptr) return Status::OK();
426 break;
427
428 case DT_BOOL:
429 if (ConvertBool(obj, shape, ret) == nullptr) return Status::OK();
430 break;
431
432 default:
433 break;
434 }
435 switch (infer_dtype) {
436 case DT_DOUBLE:
437 // TODO(josh11b): Handle mixed floats and complex numbers?
438 if (requested_dtype == DT_INVALID) {
439 // TensorFlow uses float32s to represent floating point numbers
440 // by default (for space and speed over using doubles).
441 RETURN_STRING_AS_STATUS(ConvertFloat(obj, shape, ret));
442 } else {
443 // We are going to do a cast to the user's requested dtype
444 // after this. We use doubles for this intermediate result so
445 // we don't lose precision that might be representable in the
446 // final type.
447 RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
448 }
449
450 case DT_INT64:
451 if (requested_dtype == DT_INVALID) {
452 const char* error = ConvertInt32(obj, shape, ret);
453 if (error == ErrorFoundInt64) {
454 error = ConvertInt64(obj, shape, ret);
455 }
456 if (error == ErrorFoundFloat) {
457 error = ConvertFloat(obj, shape, ret);
458 }
459 // TODO(josh11b): May also want to fall back to using doubles if
460 // error == ErrorOutOfRange?
461 RETURN_STRING_AS_STATUS(error);
462 } else {
463 const char* error = ConvertInt64(obj, shape, ret);
464 if (error == ErrorFoundFloat) {
465 error = ConvertDouble(obj, shape, ret);
466 }
467 RETURN_STRING_AS_STATUS(error);
468 }
469
470 case DT_STRING:
471 RETURN_STRING_AS_STATUS(ConvertString(obj, shape, ret));
472
473 case DT_COMPLEX128:
474 RETURN_STRING_AS_STATUS(ConvertComplex(obj, shape, ret));
475
476 case DT_BOOL:
477 RETURN_STRING_AS_STATUS(ConvertBool(obj, shape, ret));
478
479 case DT_INVALID: // Only occurs for empty tensors.
480 *ret = Tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
481 shape);
482 return Status::OK();
483
484 default:
485 return errors::Unimplemented("Missing Python -> Tensor conversion for ",
486 DataTypeString(infer_dtype));
487 }
488
489 return Status::OK();
490}
491
492} // namespace tensorflow
493