From 2da8e8c60343d9e833f1c591c54ee2a27c240842 Mon Sep 17 00:00:00 2001 From: Tamara Norman Date: Wed, 9 Jan 2019 10:23:54 -0800 Subject: [PATCH] Change so that error message received if attempting to do + between a python literal and tensor with incompatible dtype to come from the op PiperOrigin-RevId: 228541386 --- tensorflow/python/ops/math_ops.py | 3 ++- tensorflow/python/ops/math_ops_test.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 5bccf5493f3..248d0925384 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -813,7 +813,8 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor): return func(x, y, name=name) elif not isinstance(y, sparse_tensor.SparseTensor): try: - y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y") + y = ops.convert_to_tensor_v2(y, dtype_hint=x.dtype.base_dtype, + name="y") except TypeError: # If the RHS is not a tensor, it might be a tensor aware object # that can implement the operator with knowledge of itself diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index b27cf7208c3..b4832e09c08 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -664,5 +665,22 @@ class NextAfterTest(test_util.TensorFlowTestCase): self.assertAllEqual(math_ops.nextafter(one, two) - one, eps_const) +class BinaryOpsTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testErrorReceivedIfDtypeMismatchFromOp(self): + if context.executing_eagerly(): + error = errors_impl.InvalidArgumentError + error_message = ( + r"cannot compute Add as input #0\(zero-based\) was expected to be a " + r"float tensor but is a int32 tensor \[Op:Add\] name: add/") + else: + error = TypeError + error_message = ("Input 'y' of 'Add' Op has type float32 that does not " + "match type int32 of argument 'x'.") + with self.assertRaisesRegexp(error, error_message): + a = array_ops.ones([1], dtype=dtypes.int32) + 1.0 + self.evaluate(a) + if __name__ == "__main__": googletest.main()