diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 9cb9fb490bc..99f98564a88 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -1557,6 +1557,14 @@ class AssertTypeTest(test.TestCase): with self.assertRaisesRegexp(TypeError, "must be of type.*float32"): check_ops.assert_type(sparse_float16, dtypes.float32) + def test_raise_when_tf_type_is_not_dtype(self): + # Test case for GitHub issue: + # https://github.com/tensorflow/tensorflow/issues/45975 + value = constant_op.constant(0.0) + with self.assertRaisesRegexp( + TypeError, "Cannot convert.*to a TensorFlow DType"): + check_ops.assert_type(value, (dtypes.float32,)) + class AssertShapesTest(test.TestCase): diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index f920092fd7f..9b5ff4d8d9b 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1552,6 +1552,7 @@ def assert_type(tensor, tf_type, message=None, name=None): A `no_op` that does nothing. Type can be determined statically. """ message = message or '' + tf_type = dtypes.as_dtype(tf_type) with ops.name_scope(name, 'assert_type', [tensor]): if not isinstance(tensor, sparse_tensor.SparseTensor): tensor = ops.convert_to_tensor(tensor, name='tensor')