Simplify trt_convert_test
This commit is contained in:
parent
4961524653
commit
61de424923
@ -30,7 +30,6 @@ from tensorflow.core.protobuf import config_pb2
|
|||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import wrap_function
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import graph_util
|
from tensorflow.python.framework import graph_util
|
||||||
@ -361,15 +360,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
converter.convert()
|
converter.convert()
|
||||||
|
|
||||||
# Verify the converted GraphDef and ConcreteFunction.
|
# Verify the converted GraphDef and ConcreteFunction.
|
||||||
@def_function.function
|
self._CheckTrtOps(converter._converted_func)
|
||||||
def wrapper_converted_func(*args, **kwargs):
|
|
||||||
return converter._converted_func(*args, **kwargs)
|
|
||||||
converted_func = wrapper_converted_func
|
|
||||||
self.assertIsInstance(converted_func, def_function.Function)
|
|
||||||
converted_concrete_func = converted_func.get_concrete_function(
|
|
||||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
|
|
||||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
|
||||||
self._CheckTrtOps(converted_concrete_func)
|
|
||||||
|
|
||||||
# Save the converted model without any TRT engine cache.
|
# Save the converted model without any TRT engine cache.
|
||||||
output_saved_model_dir = self.mkdtemp()
|
output_saved_model_dir = self.mkdtemp()
|
||||||
@ -443,15 +434,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||||
|
|
||||||
# Verify the converted GraphDef and ConcreteFunction.
|
# Verify the converted GraphDef and ConcreteFunction.
|
||||||
@def_function.function
|
self._CheckTrtOps(converter._converted_func, _CheckFn)
|
||||||
def wrapper_converted_func(*args, **kwargs):
|
|
||||||
return converter._converted_func(*args, **kwargs)
|
|
||||||
converted_func = wrapper_converted_func
|
|
||||||
self.assertIsInstance(converted_func, def_function.Function)
|
|
||||||
converted_concrete_func = converted_func.get_concrete_function(
|
|
||||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
|
|
||||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
|
||||||
self._CheckTrtOps(converted_concrete_func, _CheckFn)
|
|
||||||
|
|
||||||
# Save the converted model with the statically-built engine inlined.
|
# Save the converted model with the statically-built engine inlined.
|
||||||
output_saved_model_dir = self.mkdtemp()
|
output_saved_model_dir = self.mkdtemp()
|
||||||
|
Loading…
Reference in New Issue
Block a user