Merge pull request #46013 from yongtang:45975-tf.debugging.assert_type
PiperOrigin-RevId: 349874984 Change-Id: I6c7ad9ac9841e21d04b70afb7ed6fbb575086fd8
This commit is contained in:
commit
7172b537f1
@ -1557,6 +1557,14 @@ class AssertTypeTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
|
||||
check_ops.assert_type(sparse_float16, dtypes.float32)
|
||||
|
||||
def test_raise_when_tf_type_is_not_dtype(self):
|
||||
# Test case for GitHub issue:
|
||||
# https://github.com/tensorflow/tensorflow/issues/45975
|
||||
value = constant_op.constant(0.0)
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
"Cannot convert.*to a TensorFlow DType"):
|
||||
check_ops.assert_type(value, (dtypes.float32,))
|
||||
|
||||
|
||||
class AssertShapesTest(test.TestCase):
|
||||
|
||||
|
@ -1552,6 +1552,7 @@ def assert_type(tensor, tf_type, message=None, name=None):
|
||||
A `no_op` that does nothing. Type can be determined statically.
|
||||
"""
|
||||
message = message or ''
|
||||
tf_type = dtypes.as_dtype(tf_type)
|
||||
with ops.name_scope(name, 'assert_type', [tensor]):
|
||||
if not isinstance(tensor, sparse_tensor.SparseTensor):
|
||||
tensor = ops.convert_to_tensor(tensor, name='tensor')
|
||||
|
Loading…
Reference in New Issue
Block a user