Adjust error message of tf.debugging.assert_type

This PR tries to address the issue raised in 45975 where
the error message of tf.debugging.assert_type could be
misleading when the tf_type is not passed with a DType.

This PR adds additional check so that tf_type arg can be guarded
if non-DType value (e.g., list, tuple etc) is passed.

This PR fixes 45975.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-12-28 13:53:11 +00:00
parent 8fe56f0807
commit 73c5ff37f1
2 changed files with 9 additions and 0 deletions

View File

@ -1557,6 +1557,14 @@ class AssertTypeTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
check_ops.assert_type(sparse_float16, dtypes.float32)
def test_raise_when_tf_type_is_not_dtype(self):
# Test case for GitHub issue:
# https://github.com/tensorflow/tensorflow/issues/45975
value = constant_op.constant(0.0)
with self.assertRaisesRegexp(
TypeError, "Cannot convert.*to a TensorFlow DType"):
check_ops.assert_type(value, (dtypes.float32,))
class AssertShapesTest(test.TestCase):

View File

@ -1552,6 +1552,7 @@ def assert_type(tensor, tf_type, message=None, name=None):
A `no_op` that does nothing. Type can be determined statically.
"""
message = message or ''
tf_type = dtypes.as_dtype(tf_type)
with ops.name_scope(name, 'assert_type', [tensor]):
if not isinstance(tensor, sparse_tensor.SparseTensor):
tensor = ops.convert_to_tensor(tensor, name='tensor')