When converting a numpy float64 to an EagerTensor, always ensure that it
becomes a float64 tensor. Earlier py_seq_tensor would fall back to a float32 if not explicitly requesting a float64 (which would not happen if we had no other information). PiperOrigin-RevId: 197977260
This commit is contained in:
parent
b9b93e90eb
commit
9532bbbf99
@ -292,6 +292,11 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
def testUnicode(self):
|
||||
self.assertEqual(constant_op.constant(u"asdf").numpy(), b"asdf")
|
||||
|
||||
def testFloatTensor(self):
|
||||
self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype)
|
||||
self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype)
|
||||
self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype)
|
||||
|
||||
def testSliceDimOutOfRange(self):
|
||||
t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
|
||||
t2 = _create_tensor([1, 2], dtype=dtypes.int32)
|
||||
|
@ -51,6 +51,10 @@ bool IsPyInt(PyObject* obj) {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsPyDouble(PyObject* obj) {
|
||||
return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type.
|
||||
}
|
||||
|
||||
bool IsPyFloat(PyObject* obj) {
|
||||
return PyFloat_Check(obj) ||
|
||||
PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types
|
||||
@ -113,8 +117,10 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
|
||||
"Attempted to convert an invalid sequence to a Tensor.");
|
||||
}
|
||||
}
|
||||
} else if (IsPyFloat(obj)) {
|
||||
} else if (IsPyDouble(obj)) {
|
||||
*dtype = DT_DOUBLE;
|
||||
} else if (IsPyFloat(obj)) {
|
||||
*dtype = DT_FLOAT;
|
||||
} else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) {
|
||||
// Have to test for bool before int, since IsInt(True/False) == true.
|
||||
*dtype = DT_BOOL;
|
||||
@ -433,7 +439,7 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
|
||||
break;
|
||||
}
|
||||
switch (infer_dtype) {
|
||||
case DT_DOUBLE:
|
||||
case DT_FLOAT:
|
||||
// TODO(josh11b): Handle mixed floats and complex numbers?
|
||||
if (requested_dtype == DT_INVALID) {
|
||||
// TensorFlow uses float32s to represent floating point numbers
|
||||
@ -446,7 +452,8 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
|
||||
// final type.
|
||||
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
|
||||
}
|
||||
|
||||
case DT_DOUBLE:
|
||||
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
|
||||
case DT_INT64:
|
||||
if (requested_dtype == DT_INVALID) {
|
||||
const char* error = ConvertInt32(obj, shape, ret);
|
||||
|
Loading…
Reference in New Issue
Block a user