When converting a numpy float64 to an EagerTensor, always ensure that it

becomes a float64 tensor.

Earlier py_seq_tensor would fall back to a float32 if not explicitly requesting
a float64 (which would not happen if we had no other information).

PiperOrigin-RevId: 197977260
This commit is contained in:
Akshay Modi 2018-05-24 16:53:33 -07:00 committed by TensorFlower Gardener
parent b9b93e90eb
commit 9532bbbf99
2 changed files with 15 additions and 3 deletions

View File

@ -292,6 +292,11 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
def testUnicode(self): def testUnicode(self):
self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf") self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf")
def testFloatTensor(self):
self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype)
self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype)
self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype)
def testSliceDimOutOfRange(self): def testSliceDimOutOfRange(self):
t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32) t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
t2 = _create_tensor([1, 2], dtype=dtypes.int32) t2 = _create_tensor([1, 2], dtype=dtypes.int32)

View File

@ -51,6 +51,10 @@ bool IsPyInt(PyObject* obj) {
#endif #endif
} }
bool IsPyDouble(PyObject* obj) {
return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type.
}
bool IsPyFloat(PyObject* obj) { bool IsPyFloat(PyObject* obj) {
return PyFloat_Check(obj) || return PyFloat_Check(obj) ||
PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types
@ -113,8 +117,10 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
"Attempted to convert an invalid sequence to a Tensor."); "Attempted to convert an invalid sequence to a Tensor.");
} }
} }
} else if (IsPyFloat(obj)) { } else if (IsPyDouble(obj)) {
*dtype = DT_DOUBLE; *dtype = DT_DOUBLE;
} else if (IsPyFloat(obj)) {
*dtype = DT_FLOAT;
} else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) { } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) {
// Have to test for bool before int, since IsInt(True/False) == true. // Have to test for bool before int, since IsInt(True/False) == true.
*dtype = DT_BOOL; *dtype = DT_BOOL;
@ -433,7 +439,7 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
break; break;
} }
switch (infer_dtype) { switch (infer_dtype) {
case DT_DOUBLE: case DT_FLOAT:
// TODO(josh11b): Handle mixed floats and complex numbers? // TODO(josh11b): Handle mixed floats and complex numbers?
if (requested_dtype == DT_INVALID) { if (requested_dtype == DT_INVALID) {
// TensorFlow uses float32s to represent floating point numbers // TensorFlow uses float32s to represent floating point numbers
@ -446,7 +452,8 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
// final type. // final type.
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
} }
case DT_DOUBLE:
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
case DT_INT64: case DT_INT64:
if (requested_dtype == DT_INVALID) { if (requested_dtype == DT_INVALID) {
const char* error = ConvertInt32(obj, shape, ret); const char* error = ConvertInt32(obj, shape, ret);