Simplify trt_convert_test

This commit is contained in:
Guangda Lai 2019-08-16 23:57:33 -07:00
parent 4961524653
commit 61de424923

View File

@ -30,7 +30,6 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.eager import def_function
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_util
@ -361,15 +360,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
converter.convert()
# Verify the converted GraphDef and ConcreteFunction.
@def_function.function
def wrapper_converted_func(*args, **kwargs):
return converter._converted_func(*args, **kwargs)
converted_func = wrapper_converted_func
self.assertIsInstance(converted_func, def_function.Function)
converted_concrete_func = converted_func.get_concrete_function(
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
self._CheckTrtOps(converted_concrete_func)
self._CheckTrtOps(converter._converted_func)
# Save the converted model without any TRT engine cache.
output_saved_model_dir = self.mkdtemp()
@ -443,15 +434,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
# Verify the converted GraphDef and ConcreteFunction.
@def_function.function
def wrapper_converted_func(*args, **kwargs):
return converter._converted_func(*args, **kwargs)
converted_func = wrapper_converted_func
self.assertIsInstance(converted_func, def_function.Function)
converted_concrete_func = converted_func.get_concrete_function(
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
self._CheckTrtOps(converted_concrete_func, _CheckFn)
self._CheckTrtOps(converter._converted_func, _CheckFn)
# Save the converted model with the statically-built engine inlined.
output_saved_model_dir = self.mkdtemp()