Respect np scalars dtype when converting to tensors.
Before this change: Eager mode would always try to infer a dtype and convert it to int32 (since TF prefers that), but graph would use the numpy dtype directly. Eager would do this even if converting lists of scalars, but graph would downcast. After this change: Eager and graph behave the same. tf.convert_to_tensor(np.int64(1)).dtype == tf.int64 tf.convert_to_tensor([np.int64(1)]).dtype == tf.int32 PiperOrigin-RevId: 223823113
This commit is contained in:
parent
3360c72f32
commit
3b905b0e00
@ -220,6 +220,14 @@ TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
tensorflow::Safe_PyObjectPtr value_decrefer;
|
||||
if (PyArray_CheckAnyScalarExact(value)) {
|
||||
// Convert numpy scalars to numpy arrays.
|
||||
value = PyArray_FromScalar(value, nullptr);
|
||||
// The returned value needs to be DECREF'd, but the original value was
|
||||
// created in python code, and doesn't need to be DECREF'd.
|
||||
value_decrefer.reset(value);
|
||||
}
|
||||
if (PyArray_Check(value)) {
|
||||
int desired_np_dtype = -1;
|
||||
if (desired_dtype >= 0) {
|
||||
|
@ -95,6 +95,18 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
t = _create_tensor(values)
|
||||
self.assertAllEqual(values, t)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testNumpyDtypeSurvivesThroughTensorConversion(self):
|
||||
scalar_creators = [np.int32, np.int64, np.float32, np.float64]
|
||||
conversion_functions = [ops.convert_to_tensor, constant_op.constant]
|
||||
|
||||
for scalar_creator in scalar_creators:
|
||||
for conversion_function in conversion_functions:
|
||||
np_val = scalar_creator(3)
|
||||
tensor_val = conversion_function(np_val)
|
||||
self.assertEqual(tensor_val.numpy().dtype, np_val.dtype)
|
||||
self.assertEqual(tensor_val.numpy(), np_val)
|
||||
|
||||
def testNumpyValueWithCast(self):
|
||||
values = np.array([3.0], dtype=np.float32)
|
||||
t = _create_tensor(values, dtype=dtypes.float64)
|
||||
|
Loading…
Reference in New Issue
Block a user