From 324c3394864b7d1aa37cf8ac9598f094d768d2be Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Feb 2019 11:57:44 -0800 Subject: [PATCH] Allows tf.convert_to_tensor to work with zero-dimensional numpy arrays PiperOrigin-RevId: 232717709 --- tensorflow/python/eager/tensor_test.py | 19 +++++++++++++++- tensorflow/python/lib/core/py_seq_tensor.cc | 25 +++++++++++++++++++-- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 0ee2ff68c20..0d8845bd96f 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -339,6 +339,24 @@ class TFETensorTest(test_util.TensorFlowTestCase): def testConvertToTensorAllowsOverflow(self): _ = ops.convert_to_tensor(123456789, dtype=dtypes.uint8) + @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorNumpyZeroDim(self): + for np_type, dtype in [(np.int32, dtypes.int32), + (np.half, dtypes.half), + (np.float32, dtypes.float32)]: + x = ops.convert_to_tensor([np.array(65, dtype=np_type), + np.array(16, dtype=np_type)]) + self.assertEqual(x.dtype, dtype) + self.assertAllEqual(x, [65, 16]) + + @test_util.assert_no_new_pyobjects_executing_eagerly + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorNumpyScalar(self): + x = ops.convert_to_tensor([np.asscalar(np.array(321, dtype=np.int)), + np.asscalar(np.array(16, dtype=np.int))]) + self.assertAllEqual(x, [321, 16]) + def testEagerTensorError(self): with self.assertRaisesRegexp( TypeError, @@ -347,7 +365,6 @@ class TFETensorTest(test_util.TensorFlowTestCase): _ = ops.convert_to_tensor(1., dtype=dtypes.int32) - class TFETensorUtilTest(test_util.TensorFlowTestCase): def testListOfThree(self): diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index f681cff6cff..6cdf7e7163a 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -64,6 +64,19 @@ bool IsPyFloat(PyObject* obj) { PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types } +// If the input is a zero dimensional PyArray return it converted to a scalar. +// Otherwise return the input and increment its reference count. +// Users must Py_DECREF the output of this method. +PyObject* ZeroDimArrayToScalar(PyObject* obj) { + if (PyArray_IsZeroDim(obj) && !PyArray_IsScalar(obj, Generic)) { + auto pyarray_obj = reinterpret_cast(obj); + obj = PyArray_ToScalar(PyArray_DATA(pyarray_obj), pyarray_obj); + } else { + Py_INCREF(obj); + } + return obj; +} + // Converts Python object `c` that should hold a Python string into a // C++ string in *out. Returns nullptr on success, or a message on error. // Defined below, but forward declared here for use in PyRepr. @@ -130,6 +143,10 @@ Status SampleElementFromSequence(PyObject* seq, PyObject** elem) { Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { std::vector refs_to_clean; while (true) { + // Convert any zero dimensional numpy arrays to scalars first of all. + // We also have to make sure a reference to the safe_obj is kept. + obj = ZeroDimArrayToScalar(obj); + refs_to_clean.push_back(make_safe(obj)); // We test strings first, in case a string is considered a sequence. if (IsPyString(obj)) { *dtype = DT_STRING; @@ -240,7 +257,9 @@ const char ErrorFoundFloat[] = } \ PyObject** l = PySequence_Fast_ITEMS(seq.get()); \ for (int64 i = 0; i < s; ++i) { \ - const char* error = CONVERT(l[i], *buf); \ + auto scalar = ZeroDimArrayToScalar(l[i]); \ + const char* error = CONVERT(scalar, *buf); \ + Py_DECREF(scalar); \ if (TF_PREDICT_FALSE(error != nullptr)) return error; \ ++*buf; \ } \ @@ -253,7 +272,9 @@ const char ErrorFoundFloat[] = Tensor result(TYPE_ENUM, shape); \ if (shape.dims() == 0) { /* Scalar case */ \ TYPE value; \ - const char* error = CONVERT(obj, &value); \ + auto scalar = ZeroDimArrayToScalar(obj); \ + const char* error = CONVERT(scalar, &value); \ + Py_DECREF(scalar); \ if (error != nullptr) return error; \ result.scalar()() = value; \ } else { \