Simplify trt_convert_test
This commit is contained in:
parent
4961524653
commit
61de424923
@ -30,7 +30,6 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import wrap_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import graph_util
|
||||
@ -361,15 +360,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
converter.convert()
|
||||
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
@def_function.function
|
||||
def wrapper_converted_func(*args, **kwargs):
|
||||
return converter._converted_func(*args, **kwargs)
|
||||
converted_func = wrapper_converted_func
|
||||
self.assertIsInstance(converted_func, def_function.Function)
|
||||
converted_concrete_func = converted_func.get_concrete_function(
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
||||
self._CheckTrtOps(converted_concrete_func)
|
||||
self._CheckTrtOps(converter._converted_func)
|
||||
|
||||
# Save the converted model without any TRT engine cache.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
@ -443,15 +434,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
@def_function.function
|
||||
def wrapper_converted_func(*args, **kwargs):
|
||||
return converter._converted_func(*args, **kwargs)
|
||||
converted_func = wrapper_converted_func
|
||||
self.assertIsInstance(converted_func, def_function.Function)
|
||||
converted_concrete_func = converted_func.get_concrete_function(
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32),
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
||||
self._CheckTrtOps(converted_concrete_func, _CheckFn)
|
||||
self._CheckTrtOps(converter._converted_func, _CheckFn)
|
||||
|
||||
# Save the converted model with the statically-built engine inlined.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
|
Loading…
Reference in New Issue
Block a user