Merge pull request #46013 from yongtang:45975-tf.debugging.assert_type

PiperOrigin-RevId: 349874984
Change-Id: I6c7ad9ac9841e21d04b70afb7ed6fbb575086fd8
This commit is contained in:
TensorFlower Gardener 2021-01-02 15:52:22 -08:00
commit 7172b537f1
2 changed files with 9 additions and 0 deletions

View File

@ -1557,6 +1557,14 @@ class AssertTypeTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
check_ops.assert_type(sparse_float16, dtypes.float32)
def test_raise_when_tf_type_is_not_dtype(self):
# Test case for GitHub issue:
# https://github.com/tensorflow/tensorflow/issues/45975
value = constant_op.constant(0.0)
with self.assertRaisesRegexp(TypeError,
"Cannot convert.*to a TensorFlow DType"):
check_ops.assert_type(value, (dtypes.float32,))
class AssertShapesTest(test.TestCase):

View File

@ -1552,6 +1552,7 @@ def assert_type(tensor, tf_type, message=None, name=None):
A `no_op` that does nothing. Type can be determined statically.
"""
message = message or ''
tf_type = dtypes.as_dtype(tf_type)
with ops.name_scope(name, 'assert_type', [tensor]):
if not isinstance(tensor, sparse_tensor.SparseTensor):
tensor = ops.convert_to_tensor(tensor, name='tensor')