diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index bc8ca98cf3d..a1a55acb19d 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -30,7 +30,6 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.compiler.tensorrt import trt_convert from tensorflow.python.eager import def_function -from tensorflow.python.eager import wrap_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import graph_util @@ -361,15 +360,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): converter.convert() # Verify the converted GraphDef and ConcreteFunction. - @def_function.function - def wrapper_converted_func(*args, **kwargs): - return converter._converted_func(*args, **kwargs) - converted_func = wrapper_converted_func - self.assertIsInstance(converted_func, def_function.Function) - converted_concrete_func = converted_func.get_concrete_function( - tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)) - self._CheckTrtOps(converted_concrete_func) + self._CheckTrtOps(converter._converted_func) # Save the converted model without any TRT engine cache. output_saved_model_dir = self.mkdtemp() @@ -443,15 +434,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self.assertTrue(len(node.attr["serialized_segment"].s), node.name) # Verify the converted GraphDef and ConcreteFunction. - @def_function.function - def wrapper_converted_func(*args, **kwargs): - return converter._converted_func(*args, **kwargs) - converted_func = wrapper_converted_func - self.assertIsInstance(converted_func, def_function.Function) - converted_concrete_func = converted_func.get_concrete_function( - tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)) - self._CheckTrtOps(converted_concrete_func, _CheckFn) + self._CheckTrtOps(converter._converted_func, _CheckFn) # Save the converted model with the statically-built engine inlined. output_saved_model_dir = self.mkdtemp()