From 9532bbbf994df5de1fa6550f3cf9f4dc08fcd640 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Thu, 24 May 2018 16:53:33 -0700 Subject: [PATCH] When converting a numpy float64 to an EagerTensor, always ensure that it becomes a float64 tensor. Earlier py_seq_tensor would fall back to a float32 if not explicitly requesting a float64 (which would not happen if we had no other information). PiperOrigin-RevId: 197977260 --- tensorflow/python/eager/tensor_test.py | 5 +++++ tensorflow/python/lib/core/py_seq_tensor.cc | 13 ++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index b044b302316..626a4eb1eee 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -292,6 +292,11 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): def testUnicode(self): self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf") + def testFloatTensor(self): + self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype) + self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype) + self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype) + def testSliceDimOutOfRange(self): t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32) t2 = _create_tensor([1, 2], dtype=dtypes.int32) diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 32ea737a990..386be35ba2f 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -51,6 +51,10 @@ bool IsPyInt(PyObject* obj) { #endif } +bool IsPyDouble(PyObject* obj) { + return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type. +} + bool IsPyFloat(PyObject* obj) { return PyFloat_Check(obj) || PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types @@ -113,8 +117,10 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { "Attempted to convert an invalid sequence to a Tensor."); } } - } else if (IsPyFloat(obj)) { + } else if (IsPyDouble(obj)) { *dtype = DT_DOUBLE; + } else if (IsPyFloat(obj)) { + *dtype = DT_FLOAT; } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) { // Have to test for bool before int, since IsInt(True/False) == true. *dtype = DT_BOOL; @@ -433,7 +439,7 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { break; } switch (infer_dtype) { - case DT_DOUBLE: + case DT_FLOAT: // TODO(josh11b): Handle mixed floats and complex numbers? if (requested_dtype == DT_INVALID) { // TensorFlow uses float32s to represent floating point numbers @@ -446,7 +452,8 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { // final type. RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); } - + case DT_DOUBLE: + RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); case DT_INT64: if (requested_dtype == DT_INVALID) { const char* error = ConvertInt32(obj, shape, ret);