Change so that error message received if attempting to do + between a python literal and tensor with incompatible dtype to come from the op

PiperOrigin-RevId: 228541386
This commit is contained in:
Tamara Norman 2019-01-09 10:23:54 -08:00 committed by TensorFlower Gardener
parent 1ee193a256
commit 2da8e8c603
2 changed files with 20 additions and 1 deletions

View File

@ -813,7 +813,8 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
return func(x, y, name=name)
elif not isinstance(y, sparse_tensor.SparseTensor):
try:
y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
y = ops.convert_to_tensor_v2(y, dtype_hint=x.dtype.base_dtype,
name="y")
except TypeError:
# If the RHS is not a tensor, it might be a tensor aware object
# that can implement the operator with knowledge of itself

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@ -664,5 +665,22 @@ class NextAfterTest(test_util.TensorFlowTestCase):
self.assertAllEqual(math_ops.nextafter(one, two) - one, eps_const)
class BinaryOpsTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes
def testErrorReceivedIfDtypeMismatchFromOp(self):
if context.executing_eagerly():
error = errors_impl.InvalidArgumentError
error_message = (
r"cannot compute Add as input #0\(zero-based\) was expected to be a "
r"float tensor but is a int32 tensor \[Op:Add\] name: add/")
else:
error = TypeError
error_message = ("Input 'y' of 'Add' Op has type float32 that does not "
"match type int32 of argument 'x'.")
with self.assertRaisesRegexp(error, error_message):
a = array_ops.ones([1], dtype=dtypes.int32) + 1.0
self.evaluate(a)
if __name__ == "__main__":
googletest.main()