Change so that error message received if attempting to do + between a python literal and tensor with incompatible dtype to come from the op
PiperOrigin-RevId: 228541386
This commit is contained in:
parent
1ee193a256
commit
2da8e8c603
@ -813,7 +813,8 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
|
||||
return func(x, y, name=name)
|
||||
elif not isinstance(y, sparse_tensor.SparseTensor):
|
||||
try:
|
||||
y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
|
||||
y = ops.convert_to_tensor_v2(y, dtype_hint=x.dtype.base_dtype,
|
||||
name="y")
|
||||
except TypeError:
|
||||
# If the RHS is not a tensor, it might be a tensor aware object
|
||||
# that can implement the operator with knowledge of itself
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -664,5 +665,22 @@ class NextAfterTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(math_ops.nextafter(one, two) - one, eps_const)
|
||||
|
||||
|
||||
class BinaryOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testErrorReceivedIfDtypeMismatchFromOp(self):
|
||||
if context.executing_eagerly():
|
||||
error = errors_impl.InvalidArgumentError
|
||||
error_message = (
|
||||
r"cannot compute Add as input #0\(zero-based\) was expected to be a "
|
||||
r"float tensor but is a int32 tensor \[Op:Add\] name: add/")
|
||||
else:
|
||||
error = TypeError
|
||||
error_message = ("Input 'y' of 'Add' Op has type float32 that does not "
|
||||
"match type int32 of argument 'x'.")
|
||||
with self.assertRaisesRegexp(error, error_message):
|
||||
a = array_ops.ones([1], dtype=dtypes.int32) + 1.0
|
||||
self.evaluate(a)
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user