Introduce TrtGraphConverterV2 for TF-TRT conversion in V2, and enhance the V2 unit test.
PiperOrigin-RevId: 244267523
This commit is contained in:
parent
fb49f672cb
commit
96072813ec
@ -144,7 +144,7 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
).GetConversionParams(run_params)._replace(
|
).GetConversionParams(run_params)._replace(
|
||||||
# Disable layout optimizer, since it'll add Transpose(Const, Const) to
|
# Disable layout optimizer, since it'll add Transpose(Const, Const) to
|
||||||
# the graph and breaks the conversion check.
|
# the graph and breaks the conversion check.
|
||||||
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
|
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
|
||||||
|
|
||||||
|
|
||||||
class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase):
|
class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase):
|
||||||
|
@ -124,7 +124,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
maximum_cached_engines=1,
|
maximum_cached_engines=1,
|
||||||
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
||||||
# format to NCHW format under four dimentional input.
|
# format to NCHW format under four dimentional input.
|
||||||
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
|
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
|
||||||
|
|
||||||
def ExpectedEnginesToBuild(self, run_params):
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
"""Return the expected engines to build."""
|
"""Return the expected engines to build."""
|
||||||
|
@ -85,7 +85,7 @@ class DynamicInputShapesTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
maximum_cached_engines=10,
|
maximum_cached_engines=10,
|
||||||
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
||||||
# format to NCHW format under four dimentional input.
|
# format to NCHW format under four dimentional input.
|
||||||
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
|
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
|
||||||
|
|
||||||
def ExpectedEnginesToBuild(self, run_params):
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
return ["TRTEngineOp_0"]
|
return ["TRTEngineOp_0"]
|
||||||
|
@ -65,7 +65,7 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase):
|
|||||||
maximum_cached_engines=1,
|
maximum_cached_engines=1,
|
||||||
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
||||||
# format to NCHW format under four dimentional input.
|
# format to NCHW format under four dimentional input.
|
||||||
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
|
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
|
||||||
|
|
||||||
def ExpectedEnginesToBuild(self, run_params):
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
"""Return the expected engines to build."""
|
"""Return the expected engines to build."""
|
||||||
|
@ -56,12 +56,6 @@ RunParams = namedtuple("RunParams", [
|
|||||||
"use_calibration"
|
"use_calibration"
|
||||||
])
|
])
|
||||||
|
|
||||||
ConversionParams = namedtuple("ConversionParams", [
|
|
||||||
"max_batch_size", "max_workspace_size_bytes", "precision_mode",
|
|
||||||
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
|
|
||||||
"cached_engine_batches", "rewriter_config", "use_calibration"
|
|
||||||
])
|
|
||||||
|
|
||||||
PRECISION_MODES = ["FP32", "FP16", "INT8"]
|
PRECISION_MODES = ["FP32", "FP16", "INT8"]
|
||||||
|
|
||||||
|
|
||||||
@ -163,7 +157,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def GetConversionParams(self, run_params):
|
def GetConversionParams(self, run_params):
|
||||||
"""Return a ConversionParams for test."""
|
"""Return a TrtConversionParams for test."""
|
||||||
batch_list = []
|
batch_list = []
|
||||||
for dims_list in self._GetParamsCached().input_dims:
|
for dims_list in self._GetParamsCached().input_dims:
|
||||||
assert dims_list
|
assert dims_list
|
||||||
@ -171,19 +165,22 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
input_batches = [dims[0] for dims in dims_list]
|
input_batches = [dims[0] for dims in dims_list]
|
||||||
assert max(input_batches) == min(input_batches)
|
assert max(input_batches) == min(input_batches)
|
||||||
batch_list.append(input_batches[0])
|
batch_list.append(input_batches[0])
|
||||||
return ConversionParams(
|
conversion_params = trt_convert.TrtConversionParams(
|
||||||
# We use the minimum of all the batch sizes, so when multiple different
|
# We use the minimum of all the batch sizes, so when multiple different
|
||||||
# input shapes are provided it'll always create new engines in the
|
# input shapes are provided it'll always create new engines in the
|
||||||
# cache, and we can therefore test the cache behavior.
|
# cache, and we can therefore test the cache behavior.
|
||||||
max_batch_size=min(batch_list),
|
rewriter_config_template=None,
|
||||||
max_workspace_size_bytes=1 << 25,
|
max_workspace_size_bytes=1 << 25,
|
||||||
precision_mode=run_params.precision_mode,
|
precision_mode=run_params.precision_mode,
|
||||||
minimum_segment_size=2,
|
minimum_segment_size=2,
|
||||||
is_dynamic_op=run_params.dynamic_engine,
|
is_dynamic_op=run_params.dynamic_engine,
|
||||||
maximum_cached_engines=1,
|
maximum_cached_engines=1,
|
||||||
cached_engine_batches=None,
|
use_calibration=run_params.use_calibration,
|
||||||
rewriter_config=None,
|
use_function_backup=False,
|
||||||
use_calibration=run_params.use_calibration)
|
max_batch_size=min(batch_list),
|
||||||
|
cached_engine_batches=None)
|
||||||
|
return conversion_params._replace(
|
||||||
|
use_function_backup=IsQuantizationWithCalibration(conversion_params))
|
||||||
|
|
||||||
def ShouldRunTest(self, run_params):
|
def ShouldRunTest(self, run_params):
|
||||||
"""Whether to run the test."""
|
"""Whether to run the test."""
|
||||||
@ -218,24 +215,13 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
"""Get config proto based on specific settings."""
|
"""Get config proto based on specific settings."""
|
||||||
conversion_params = self.GetConversionParams(run_params)
|
conversion_params = self.GetConversionParams(run_params)
|
||||||
if graph_state == GraphState.INFERENCE and run_params.use_optimizer:
|
if graph_state == GraphState.INFERENCE and run_params.use_optimizer:
|
||||||
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
|
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(conversion_params)
|
||||||
conversion_params.rewriter_config,
|
|
||||||
conversion_params.max_batch_size,
|
|
||||||
conversion_params.max_workspace_size_bytes,
|
|
||||||
conversion_params.precision_mode,
|
|
||||||
conversion_params.minimum_segment_size,
|
|
||||||
conversion_params.is_dynamic_op,
|
|
||||||
conversion_params.maximum_cached_engines,
|
|
||||||
conversion_params.cached_engine_batches,
|
|
||||||
conversion_params.use_calibration,
|
|
||||||
use_function_backup=IsQuantizationWithCalibration(conversion_params))
|
|
||||||
|
|
||||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
|
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
|
||||||
else:
|
else:
|
||||||
graph_options = config_pb2.GraphOptions()
|
graph_options = config_pb2.GraphOptions()
|
||||||
if conversion_params.rewriter_config is not None:
|
if conversion_params.rewriter_config_template is not None:
|
||||||
graph_options.rewrite_options.CopyFrom(
|
graph_options.rewrite_options.CopyFrom(
|
||||||
conversion_params.rewriter_config)
|
conversion_params.rewriter_config_template)
|
||||||
|
|
||||||
config = config_pb2.ConfigProto(
|
config = config_pb2.ConfigProto(
|
||||||
gpu_options=self._GetGPUOptions(), graph_options=graph_options)
|
gpu_options=self._GetGPUOptions(), graph_options=graph_options)
|
||||||
@ -310,7 +296,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
maximum_cached_engines=conversion_params.maximum_cached_engines,
|
maximum_cached_engines=conversion_params.maximum_cached_engines,
|
||||||
cached_engine_batches=conversion_params.cached_engine_batches,
|
cached_engine_batches=conversion_params.cached_engine_batches,
|
||||||
use_calibration=conversion_params.use_calibration,
|
use_calibration=conversion_params.use_calibration,
|
||||||
use_function_backup=IsQuantizationWithCalibration(conversion_params))
|
use_function_backup=conversion_params.use_function_backup)
|
||||||
return converter
|
return converter
|
||||||
|
|
||||||
def _GetCalibratedInferGraph(self, run_params, gdef, inputs_data):
|
def _GetCalibratedInferGraph(self, run_params, gdef, inputs_data):
|
||||||
|
@ -18,7 +18,10 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
import six as _six
|
import six as _six
|
||||||
|
|
||||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_linked_tensorrt_version
|
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_linked_tensorrt_version
|
||||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_loaded_tensorrt_version
|
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_loaded_tensorrt_version
|
||||||
@ -80,7 +83,7 @@ class GraphConverter(object):
|
|||||||
class MyGraphConverter(GraphConverter):
|
class MyGraphConverter(GraphConverter):
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_rewriter_config(self, rewriter_config_template=None):
|
def get_rewriter_config(self):
|
||||||
my_rewriter_config = ...
|
my_rewriter_config = ...
|
||||||
return my_rewriter_config
|
return my_rewriter_config
|
||||||
```
|
```
|
||||||
@ -129,7 +132,7 @@ class GraphConverter(object):
|
|||||||
If set to None, the graph will be read from the SavedModel loaded from
|
If set to None, the graph will be read from the SavedModel loaded from
|
||||||
input_saved_model_dir.
|
input_saved_model_dir.
|
||||||
nodes_blacklist: list of node names to prevent the converter from
|
nodes_blacklist: list of node names to prevent the converter from
|
||||||
touching. Only used when input_graph_def is not None.
|
touching.
|
||||||
session_config: the ConfigProto used to create a Session. It's also used
|
session_config: the ConfigProto used to create a Session. It's also used
|
||||||
as a template to create a RewriterConfig for conversion. If not
|
as a template to create a RewriterConfig for conversion. If not
|
||||||
specified, a default ConfigProto will be used.
|
specified, a default ConfigProto will be used.
|
||||||
@ -137,21 +140,15 @@ class GraphConverter(object):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if the combination of the parameters is invalid.
|
ValueError: if the combination of the parameters is invalid.
|
||||||
"""
|
"""
|
||||||
if context.executing_eagerly():
|
if input_graph_def and input_saved_model_dir:
|
||||||
if input_graph_def or not input_saved_model_dir:
|
raise ValueError(
|
||||||
raise ValueError(
|
"Can only specify one of input_graph_def and input_saved_model_dir")
|
||||||
"TF 2.0 only supports conversion of SavedModel, please specify "
|
if not input_graph_def and not input_saved_model_dir:
|
||||||
"input_saved_model_dir as input.")
|
raise ValueError("Must specify one of input_graph_def and "
|
||||||
else:
|
"input_saved_model_dir")
|
||||||
if input_graph_def and input_saved_model_dir:
|
|
||||||
raise ValueError(
|
|
||||||
"Can only specify one of input_graph_def and input_saved_model_dir")
|
|
||||||
if not input_graph_def and not input_saved_model_dir:
|
|
||||||
raise ValueError("Must specify one of input_graph_def and "
|
|
||||||
"input_saved_model_dir")
|
|
||||||
|
|
||||||
self._input_graph_def = input_graph_def
|
self._input_graph_def = input_graph_def
|
||||||
self._nodes_blacklist = nodes_blacklist
|
self._nodes_blacklist = nodes_blacklist
|
||||||
|
|
||||||
self._input_saved_model_dir = input_saved_model_dir
|
self._input_saved_model_dir = input_saved_model_dir
|
||||||
self._converted = False
|
self._converted = False
|
||||||
@ -169,14 +166,9 @@ class GraphConverter(object):
|
|||||||
self._calibration_sess = None
|
self._calibration_sess = None
|
||||||
self._calibration_data_collected = False
|
self._calibration_data_collected = False
|
||||||
|
|
||||||
def get_rewriter_config(self, rewriter_config_template=None):
|
def get_rewriter_config(self):
|
||||||
"""Returns a RewriterConfig proto for TRT transformation.
|
"""Returns a RewriterConfig proto for TRT transformation.
|
||||||
|
|
||||||
Args:
|
|
||||||
rewriter_config_template: a template RewriterConfig proto used to create a
|
|
||||||
RewriterConfig for the conversion. The implementation should not modify
|
|
||||||
the template. If None, it will use a default one.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A RewriterConfig proto which will be used to run the conversion using
|
A RewriterConfig proto which will be used to run the conversion using
|
||||||
Grappler.
|
Grappler.
|
||||||
@ -188,11 +180,7 @@ class GraphConverter(object):
|
|||||||
# Create custom ConfigProto for Grappler.
|
# Create custom ConfigProto for Grappler.
|
||||||
grappler_session_config = config_pb2.ConfigProto()
|
grappler_session_config = config_pb2.ConfigProto()
|
||||||
grappler_session_config.CopyFrom(self._session_config)
|
grappler_session_config.CopyFrom(self._session_config)
|
||||||
rewriter_config = None
|
custom_rewriter_config = self.get_rewriter_config()
|
||||||
if (grappler_session_config.HasField("graph_options") and
|
|
||||||
grappler_session_config.graph_options.HasField("rewrite_options")):
|
|
||||||
rewriter_config = grappler_session_config.graph_options.rewrite_options
|
|
||||||
custom_rewriter_config = self.get_rewriter_config(rewriter_config)
|
|
||||||
grappler_session_config.graph_options.rewrite_options.CopyFrom(
|
grappler_session_config.graph_options.rewrite_options.CopyFrom(
|
||||||
custom_rewriter_config)
|
custom_rewriter_config)
|
||||||
|
|
||||||
@ -285,33 +273,6 @@ class GraphConverter(object):
|
|||||||
|
|
||||||
self._run_conversion()
|
self._run_conversion()
|
||||||
|
|
||||||
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
|
|
||||||
# use it here (b/124792963).
|
|
||||||
def _convert_saved_model_v2(self):
|
|
||||||
"""Convert the input SavedModel in 2.0 format."""
|
|
||||||
assert context.executing_eagerly()
|
|
||||||
|
|
||||||
self._saved_model = load.load(self._input_saved_model_dir,
|
|
||||||
self._input_saved_model_tags)
|
|
||||||
func = self._saved_model.signatures[self._input_saved_model_signature_key]
|
|
||||||
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
|
|
||||||
self._grappler_meta_graph_def = saver.export_meta_graph(
|
|
||||||
graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
|
|
||||||
|
|
||||||
# Add a collection 'train_op' so that Grappler knows the outputs.
|
|
||||||
fetch_collection = meta_graph_pb2.CollectionDef()
|
|
||||||
for array in frozen_func.inputs + frozen_func.outputs:
|
|
||||||
fetch_collection.node_list.value.append(array.name)
|
|
||||||
self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
|
|
||||||
fetch_collection)
|
|
||||||
|
|
||||||
# Run TRT optimizer in Grappler to convert the graph.
|
|
||||||
self._run_conversion()
|
|
||||||
self._converted_func = wrap_function.function_from_graph_def(
|
|
||||||
self._converted_graph_def,
|
|
||||||
[tensor.name for tensor in frozen_func.inputs],
|
|
||||||
[tensor.name for tensor in frozen_func.outputs])
|
|
||||||
|
|
||||||
def convert(self):
|
def convert(self):
|
||||||
"""Run the conversion.
|
"""Run the conversion.
|
||||||
|
|
||||||
@ -320,16 +281,11 @@ class GraphConverter(object):
|
|||||||
2.0+.
|
2.0+.
|
||||||
"""
|
"""
|
||||||
assert not self._converted
|
assert not self._converted
|
||||||
|
if self._input_graph_def:
|
||||||
if context.executing_eagerly():
|
self._convert_graph_def()
|
||||||
self._convert_saved_model_v2()
|
|
||||||
return self._converted_func
|
|
||||||
else:
|
else:
|
||||||
if self._input_graph_def:
|
self._convert_saved_model()
|
||||||
self._convert_graph_def()
|
return self._converted_graph_def
|
||||||
else:
|
|
||||||
self._convert_saved_model()
|
|
||||||
return self._converted_graph_def
|
|
||||||
|
|
||||||
def calibrate(self,
|
def calibrate(self,
|
||||||
fetch_names,
|
fetch_names,
|
||||||
@ -408,80 +364,71 @@ class GraphConverter(object):
|
|||||||
SavedModel.
|
SavedModel.
|
||||||
"""
|
"""
|
||||||
assert self._converted
|
assert self._converted
|
||||||
|
if self._input_graph_def:
|
||||||
|
raise ValueError(
|
||||||
|
"Not able to save to a SavedModel since input is a GraphDef")
|
||||||
|
|
||||||
if context.executing_eagerly():
|
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
|
||||||
# Rewrite the signature map using the optimized ConcreteFunction.
|
"""Restores collections that we need to keep."""
|
||||||
signatures = {
|
scope = ""
|
||||||
key: value for key, value in self._saved_model.signatures.items()
|
for key in collection_keys:
|
||||||
}
|
collection_def = src_meta_graph_def.collection_def[key]
|
||||||
signatures[self._input_saved_model_signature_key] = self._converted_func
|
kind = collection_def.WhichOneof("kind")
|
||||||
save.save(self._saved_model, output_saved_model_dir, signatures)
|
if kind is None:
|
||||||
else:
|
tf_logging.error(
|
||||||
if self._input_graph_def:
|
"Cannot identify data type for collection %s. Skipping.", key)
|
||||||
raise ValueError(
|
continue
|
||||||
"Not able to save to a SavedModel since input is a GraphDef")
|
from_proto = ops.get_from_proto_function(key)
|
||||||
|
if from_proto and kind == "bytes_list":
|
||||||
def _restore_collections(dest_graph, src_meta_graph_def, collections):
|
proto_type = ops.get_collection_proto_type(key)
|
||||||
"""Restores collections that we need to keep."""
|
# It is assumed that there are no Variables Keys in collections
|
||||||
scope = ""
|
for value in collection_def.bytes_list.value:
|
||||||
for key in collections:
|
proto = proto_type()
|
||||||
collection_def = src_meta_graph_def.collection_def[key]
|
proto.ParseFromString(value)
|
||||||
kind = collection_def.WhichOneof("kind")
|
try:
|
||||||
if kind is None:
|
new_value = from_proto(proto, import_scope=scope)
|
||||||
tf_logging.error(
|
except:
|
||||||
"Cannot identify data type for collection %s. Skipping.", key)
|
continue
|
||||||
continue
|
dest_graph.add_to_collection(key, new_value)
|
||||||
from_proto = ops.get_from_proto_function(key)
|
else:
|
||||||
if from_proto and kind == "bytes_list":
|
field = getattr(collection_def, kind)
|
||||||
proto_type = ops.get_collection_proto_type(key)
|
if kind == "node_list":
|
||||||
# It is assumed that there are no Variables Keys in collections
|
for value in field.value:
|
||||||
for value in collection_def.bytes_list.value:
|
name = ops.prepend_name_scope(value, scope)
|
||||||
proto = proto_type()
|
# Since the graph has been optimized, the node may no longer
|
||||||
proto.ParseFromString(value)
|
# exists
|
||||||
try:
|
try:
|
||||||
new_value = from_proto(proto, import_scope=scope)
|
col_op = dest_graph.as_graph_element(name)
|
||||||
except:
|
except (TypeError, ValueError, KeyError) as e:
|
||||||
continue
|
continue
|
||||||
dest_graph.add_to_collection(key, new_value)
|
dest_graph.add_to_collection(key, col_op)
|
||||||
|
elif kind == "int64_list":
|
||||||
|
# NOTE(opensource): This force conversion is to work around the
|
||||||
|
# fact that Python2 distinguishes between int and long, while
|
||||||
|
# Python3 has only int.
|
||||||
|
for value in field.value:
|
||||||
|
dest_graph.add_to_collection(key, int(value))
|
||||||
else:
|
else:
|
||||||
field = getattr(collection_def, kind)
|
for value in field.value:
|
||||||
if kind == "node_list":
|
dest_graph.add_to_collection(key,
|
||||||
for value in field.value:
|
ops.prepend_name_scope(value, scope))
|
||||||
name = ops.prepend_name_scope(value, scope)
|
|
||||||
# Since the graph has been optimized, the node may no longer
|
|
||||||
# exists
|
|
||||||
try:
|
|
||||||
col_op = dest_graph.as_graph_element(name)
|
|
||||||
except (TypeError, ValueError, KeyError) as e:
|
|
||||||
continue
|
|
||||||
dest_graph.add_to_collection(key, col_op)
|
|
||||||
elif kind == "int64_list":
|
|
||||||
# NOTE(opensource): This force conversion is to work around the
|
|
||||||
# fact that Python2 distinguishes between int and long, while
|
|
||||||
# Python3 has only int.
|
|
||||||
for value in field.value:
|
|
||||||
dest_graph.add_to_collection(key, int(value))
|
|
||||||
else:
|
|
||||||
for value in field.value:
|
|
||||||
dest_graph.add_to_collection(
|
|
||||||
key, ops.prepend_name_scope(value, scope))
|
|
||||||
|
|
||||||
# Write the transformed graphdef as SavedModel.
|
# Write the transformed graphdef as SavedModel.
|
||||||
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
|
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
importer.import_graph_def(self._converted_graph_def, name="")
|
importer.import_graph_def(self._converted_graph_def, name="")
|
||||||
_restore_collections(
|
_restore_collections(
|
||||||
ops.get_default_graph(), self._grappler_meta_graph_def,
|
ops.get_default_graph(), self._grappler_meta_graph_def,
|
||||||
self._collections_to_keep(
|
self._collections_to_keep(
|
||||||
self._grappler_meta_graph_def.collection_def))
|
self._grappler_meta_graph_def.collection_def))
|
||||||
# We don't use any specific converter here.
|
# We don't use any specific converter here.
|
||||||
with session.Session(config=self._session_config) as sess:
|
with session.Session(config=self._session_config) as sess:
|
||||||
saved_model_builder.add_meta_graph_and_variables(
|
saved_model_builder.add_meta_graph_and_variables(
|
||||||
sess,
|
sess,
|
||||||
self._input_saved_model_tags,
|
self._input_saved_model_tags,
|
||||||
signature_def_map=self._grappler_meta_graph_def.signature_def)
|
signature_def_map=self._grappler_meta_graph_def.signature_def)
|
||||||
# Ignore other meta graphs from the input SavedModel.
|
# Ignore other meta graphs from the input SavedModel.
|
||||||
saved_model_builder.save()
|
saved_model_builder.save()
|
||||||
|
|
||||||
|
|
||||||
class TrtPrecisionMode(object):
|
class TrtPrecisionMode(object):
|
||||||
@ -498,101 +445,202 @@ class TrtPrecisionMode(object):
|
|||||||
# so it can produce reasonable performance results with the default.
|
# so it can produce reasonable performance results with the default.
|
||||||
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
||||||
|
|
||||||
|
# TrtConversionParams encapsulates the parameters that are used for TF-TRT
|
||||||
|
# conversion.
|
||||||
|
TrtConversionParams = collections.namedtuple(
|
||||||
|
"TrtConversionParams",
|
||||||
|
[
|
||||||
|
|
||||||
|
# A template RewriterConfig proto used to create a TRT-enabled
|
||||||
|
# RewriterConfig. If None, it will use a default one.
|
||||||
|
"rewriter_config_template",
|
||||||
|
|
||||||
|
# The maximum GPU temporary memory which the TRT engine can use at
|
||||||
|
# execution time. This corresponds to the 'workspaceSize' parameter of
|
||||||
|
# nvinfer1::IBuilder::setMaxWorkspaceSize().
|
||||||
|
"max_workspace_size_bytes",
|
||||||
|
|
||||||
|
# One of TrtPrecisionMode.supported_precision_modes().
|
||||||
|
"precision_mode",
|
||||||
|
|
||||||
|
# The minimum number of nodes required for a subgraph to be replaced by
|
||||||
|
# TRTEngineOp.
|
||||||
|
"minimum_segment_size",
|
||||||
|
|
||||||
|
# Whether to generate dynamic TRT ops which will build the TRT network
|
||||||
|
# and engine at run time.
|
||||||
|
"is_dynamic_op",
|
||||||
|
|
||||||
|
# Max number of cached TRT engines in dynamic TRT ops. If the number of
|
||||||
|
# cached engines is already at max but none of them can serve the input,
|
||||||
|
# the TRTEngineOp will fall back to run the TF function based on which
|
||||||
|
# the TRTEngineOp is created.
|
||||||
|
"maximum_cached_engines",
|
||||||
|
|
||||||
|
# This argument is ignored if precision_mode is not INT8. If set to
|
||||||
|
# True, a calibration graph will be created to calibrate the missing
|
||||||
|
# ranges. The calibration graph must be converted to an inference graph
|
||||||
|
# by running calibration with calibrate(). If set to False, quantization
|
||||||
|
# nodes will be expected for every tensor in the graph (exlcuding those
|
||||||
|
# which will be fused). If a range is missing, an error will occur.
|
||||||
|
# Please note that accuracy may be negatively affected if there is a
|
||||||
|
# mismatch between which tensors TRT quantizes and which tensors were
|
||||||
|
# trained with fake quantization.
|
||||||
|
"use_calibration",
|
||||||
|
|
||||||
|
# If set to True, it will create a FunctionDef for each subgraph that is
|
||||||
|
# converted to TRT op, and if TRT ops fail to execute at runtime, it'll
|
||||||
|
# invoke that function as a fallback.
|
||||||
|
"use_function_backup",
|
||||||
|
|
||||||
|
# Max size for the input batch.
|
||||||
|
# This option is deprecated in TF 2.0.
|
||||||
|
"max_batch_size",
|
||||||
|
|
||||||
|
# A list of batch sizes used to create cached engines, only used when
|
||||||
|
# is_dynamic_op is True. The length of the list should be <=
|
||||||
|
# maximum_cached_engines, and the dynamic TRT op will use this list to
|
||||||
|
# determine the batch sizes of the cached engines, instead of making the
|
||||||
|
# decision on the fly. This is useful when we know the most common batch
|
||||||
|
# size(s) the application is going to generate.
|
||||||
|
# This option is deprecated in TF 2.0.
|
||||||
|
"cached_engine_batches",
|
||||||
|
])
|
||||||
|
|
||||||
|
DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams(
|
||||||
|
rewriter_config_template=None,
|
||||||
|
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||||
|
precision_mode=TrtPrecisionMode.FP32,
|
||||||
|
minimum_segment_size=3,
|
||||||
|
is_dynamic_op=False,
|
||||||
|
maximum_cached_engines=1,
|
||||||
|
use_calibration=True,
|
||||||
|
use_function_backup=True,
|
||||||
|
max_batch_size=1,
|
||||||
|
cached_engine_batches=None)
|
||||||
|
|
||||||
|
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration"
|
||||||
|
_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
|
||||||
|
_TRT_ENGINE_OP_NAME = "TRTEngineOp"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_conversion_params(conversion_params):
|
||||||
|
"""Validate the provided TrtConversionParams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversion_params: a TrtConversionParams instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if any of the parameters are of unexpected type.
|
||||||
|
ValueError: if any of the parameters are of unexpected value.
|
||||||
|
"""
|
||||||
|
supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
|
||||||
|
if conversion_params.precision_mode not in supported_precision_modes:
|
||||||
|
raise ValueError(
|
||||||
|
("precision mode '{}' is not supported."
|
||||||
|
"It should be one of {}").format(conversion_params.precision_mode,
|
||||||
|
supported_precision_modes))
|
||||||
|
if conversion_params.cached_engine_batches:
|
||||||
|
if not isinstance(conversion_params.cached_engine_batches, list):
|
||||||
|
raise TypeError("cached_engine_batches should be a list.")
|
||||||
|
if len(conversion_params.cached_engine_batches
|
||||||
|
) > conversion_params.maximum_cached_engines:
|
||||||
|
raise ValueError("cached_engine_batches should not contain more than "
|
||||||
|
"maximum_cached_engines items.")
|
||||||
|
|
||||||
|
|
||||||
|
def _check_trt_version_compatibility():
|
||||||
|
"""Check compatibility of TensorRT version.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: if the TensorRT library version is incompatible.
|
||||||
|
"""
|
||||||
|
compiled_version = get_linked_tensorrt_version()
|
||||||
|
loaded_version = get_loaded_tensorrt_version()
|
||||||
|
tf_logging.info("Linked TensorRT version: %s" % str(compiled_version))
|
||||||
|
tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
|
||||||
|
version_mismatch = False
|
||||||
|
if loaded_version[0] < compiled_version[0]:
|
||||||
|
tf_logging.error(
|
||||||
|
"TensorRT version mismatch. Tensorflow was compiled against " +
|
||||||
|
"TensorRT %s but library loaded from environment is TensorRT %s" %
|
||||||
|
(".".join([str(x) for x in compiled_version]),
|
||||||
|
".".join([str(x) for x in loaded_version])) +
|
||||||
|
". Please make sure that correct version of TensorRT " +
|
||||||
|
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
|
||||||
|
raise RuntimeError("Incompatible TensorRT library version")
|
||||||
|
for i in zip(loaded_version, compiled_version):
|
||||||
|
if i[0] != i[1]:
|
||||||
|
tf_logging.warn("TensorRT mismatch. Compiled against version " +
|
||||||
|
"%s, but loaded %s. Things may not work" %
|
||||||
|
(".".join([str(x) for x in compiled_version]),
|
||||||
|
".".join([str(x) for x in loaded_version])))
|
||||||
|
version_mismatch = True
|
||||||
|
break
|
||||||
|
if not version_mismatch:
|
||||||
|
tf_logging.info("Running against TensorRT version %s" %
|
||||||
|
".".join([str(x) for x in loaded_version]))
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensorrt_rewriter_config(
|
||||||
|
conversion_params=DEFAULT_TRT_CONVERSION_PARAMS):
|
||||||
|
"""Returns a RewriterConfig proto for TRT transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversion_params: a TrtConversionParams instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if any of the parameters are of unexpected type.
|
||||||
|
ValueError: if any of the parameters are of unexpected value.
|
||||||
|
"""
|
||||||
|
if conversion_params.rewriter_config_template is not None and not isinstance(
|
||||||
|
conversion_params.rewriter_config_template,
|
||||||
|
rewriter_config_pb2.RewriterConfig):
|
||||||
|
raise TypeError(
|
||||||
|
"rewriter_config_template should be a RewriterConfig proto.")
|
||||||
|
_check_conversion_params(conversion_params)
|
||||||
|
|
||||||
|
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
|
||||||
|
if conversion_params.rewriter_config_template is None:
|
||||||
|
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
|
||||||
|
# need to run constant folding again.
|
||||||
|
rewriter_config_with_trt.optimizers.extend(
|
||||||
|
["constfold", "layout", "constfold"])
|
||||||
|
rewriter_config_with_trt.meta_optimizer_iterations = (
|
||||||
|
rewriter_config_pb2.RewriterConfig.ONE)
|
||||||
|
else:
|
||||||
|
rewriter_config_with_trt.CopyFrom(
|
||||||
|
conversion_params.rewriter_config_template)
|
||||||
|
|
||||||
|
optimizer = rewriter_config_with_trt.custom_optimizers.add()
|
||||||
|
optimizer.name = "TensorRTOptimizer"
|
||||||
|
optimizer.parameter_map[
|
||||||
|
"minimum_segment_size"].i = conversion_params.minimum_segment_size
|
||||||
|
optimizer.parameter_map["max_batch_size"].i = conversion_params.max_batch_size
|
||||||
|
optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op
|
||||||
|
optimizer.parameter_map[
|
||||||
|
"max_workspace_size_bytes"].i = conversion_params.max_workspace_size_bytes
|
||||||
|
optimizer.parameter_map["precision_mode"].s = _to_bytes(
|
||||||
|
conversion_params.precision_mode)
|
||||||
|
optimizer.parameter_map[
|
||||||
|
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
|
||||||
|
if conversion_params.cached_engine_batches:
|
||||||
|
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
|
||||||
|
conversion_params.cached_engine_batches)
|
||||||
|
optimizer.parameter_map[
|
||||||
|
"use_calibration"].b = conversion_params.use_calibration
|
||||||
|
optimizer.parameter_map[
|
||||||
|
"use_function_backup"].b = conversion_params.use_function_backup
|
||||||
|
return rewriter_config_with_trt
|
||||||
|
|
||||||
|
|
||||||
class TrtGraphConverter(GraphConverter):
|
class TrtGraphConverter(GraphConverter):
|
||||||
"""A GraphConverter for TRT transformation."""
|
"""A GraphConverter for TRT transformation."""
|
||||||
|
|
||||||
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration"
|
# TODO(laigd): use TrtConversionParams here.
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_tensorrt_rewriter_config(
|
|
||||||
cls,
|
|
||||||
rewriter_config_template=None,
|
|
||||||
max_batch_size=1,
|
|
||||||
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
|
|
||||||
precision_mode=TrtPrecisionMode.FP32,
|
|
||||||
minimum_segment_size=3,
|
|
||||||
is_dynamic_op=False,
|
|
||||||
maximum_cached_engines=1,
|
|
||||||
cached_engine_batches=None,
|
|
||||||
use_calibration=True,
|
|
||||||
use_function_backup=True):
|
|
||||||
"""Returns a RewriterConfig proto for TRT transformation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rewriter_config_template: a template RewriterConfig proto used to create a
|
|
||||||
TRT-enabled RewriterConfig. If None, it will use a default one.
|
|
||||||
max_batch_size: max size for the input batch
|
|
||||||
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
|
|
||||||
engine can use at execution time. This corresponds to the
|
|
||||||
'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
|
|
||||||
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
|
|
||||||
minimum_segment_size: the minimum number of nodes required for a subgraph
|
|
||||||
to be replaced by TRTEngineOp.
|
|
||||||
is_dynamic_op: whether to generate dynamic TRT ops which will build the
|
|
||||||
TRT network and engine at run time.
|
|
||||||
maximum_cached_engines: max number of cached TRT engines in dynamic TRT
|
|
||||||
ops. If the number of cached engines is already at max but none of them
|
|
||||||
can serve the input, the TRTEngineOp will fall back to run the TF
|
|
||||||
function based on which the TRTEngineOp is created.
|
|
||||||
cached_engine_batches: a list of batch sizes used to create cached
|
|
||||||
engines, only used when is_dynamic_op is True. The length of the list
|
|
||||||
should be <= maximum_cached_engines, and the dynamic TRT op will use
|
|
||||||
this list to determine the batch sizes of the cached engines, instead of
|
|
||||||
making the decision on the fly. This is useful when we know the most
|
|
||||||
common batch size(s) the application is going to generate.
|
|
||||||
use_calibration: this argument is ignored if precision_mode is not INT8.
|
|
||||||
If set to True, a calibration graph will be created to calibrate the
|
|
||||||
missing ranges. The calibration graph must be converted to an inference
|
|
||||||
graph by running calibration with calibrate(). If set to False,
|
|
||||||
quantization nodes will be expected for every tensor in the graph
|
|
||||||
(exlcuding those which will be fused). If a range is missing, an error
|
|
||||||
will occur. Please note that accuracy may be negatively affected if
|
|
||||||
there is a mismatch between which tensors TRT quantizes and which
|
|
||||||
tensors were trained with fake quantization.
|
|
||||||
use_function_backup: if set to True, it will create a FunctionDef for each
|
|
||||||
subgraph that is converted to TRT op, and if TRT ops fail to execute at
|
|
||||||
runtime, it'll invoke that function as a fallback.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: if any of the parameters are of unexpected type.
|
|
||||||
ValueError: if any of the parameters are of unexpected value.
|
|
||||||
"""
|
|
||||||
if rewriter_config_template is not None and not isinstance(
|
|
||||||
rewriter_config_template, rewriter_config_pb2.RewriterConfig):
|
|
||||||
raise TypeError(
|
|
||||||
"rewriter_config_template should be a RewriterConfig proto.")
|
|
||||||
|
|
||||||
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
|
|
||||||
if rewriter_config_template is None:
|
|
||||||
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
|
|
||||||
# need to run constant folding again.
|
|
||||||
rewriter_config_with_trt.optimizers.extend(
|
|
||||||
["constfold", "layout", "constfold"])
|
|
||||||
rewriter_config_with_trt.meta_optimizer_iterations = (
|
|
||||||
rewriter_config_pb2.RewriterConfig.ONE)
|
|
||||||
else:
|
|
||||||
rewriter_config_with_trt.CopyFrom(rewriter_config_template)
|
|
||||||
|
|
||||||
optimizer = rewriter_config_with_trt.custom_optimizers.add()
|
|
||||||
optimizer.name = "TensorRTOptimizer"
|
|
||||||
optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
|
|
||||||
optimizer.parameter_map["max_batch_size"].i = max_batch_size
|
|
||||||
optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
|
|
||||||
optimizer.parameter_map[
|
|
||||||
"max_workspace_size_bytes"].i = max_workspace_size_bytes
|
|
||||||
optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
|
|
||||||
optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
|
|
||||||
if cached_engine_batches:
|
|
||||||
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
|
|
||||||
cached_engine_batches)
|
|
||||||
optimizer.parameter_map["use_calibration"].b = use_calibration
|
|
||||||
optimizer.parameter_map["use_function_backup"].b = use_function_backup
|
|
||||||
return rewriter_config_with_trt
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
input_saved_model_dir=None,
|
input_saved_model_dir=None,
|
||||||
input_saved_model_tags=None,
|
input_saved_model_tags=None,
|
||||||
@ -621,7 +669,7 @@ class TrtGraphConverter(GraphConverter):
|
|||||||
If set to None, the graph will be read from the SavedModel loaded from
|
If set to None, the graph will be read from the SavedModel loaded from
|
||||||
input_saved_model_dir.
|
input_saved_model_dir.
|
||||||
nodes_blacklist: list of node names to prevent the converter from
|
nodes_blacklist: list of node names to prevent the converter from
|
||||||
touching. Only used when input_graph_def is not None.
|
touching.
|
||||||
session_config: the ConfigProto used to create a Session. It's also used
|
session_config: the ConfigProto used to create a Session. It's also used
|
||||||
as a template to create a TRT-enabled ConfigProto for conversion. If not
|
as a template to create a TRT-enabled ConfigProto for conversion. If not
|
||||||
specified, a default ConfigProto will be used.
|
specified, a default ConfigProto will be used.
|
||||||
@ -659,7 +707,6 @@ class TrtGraphConverter(GraphConverter):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the combination of the parameters is invalid.
|
ValueError: if the combination of the parameters is invalid.
|
||||||
RuntimeError: if the TensorRT library version is incompatible.
|
|
||||||
"""
|
"""
|
||||||
super(TrtGraphConverter, self).__init__(
|
super(TrtGraphConverter, self).__init__(
|
||||||
input_saved_model_dir=input_saved_model_dir,
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
@ -668,54 +715,10 @@ class TrtGraphConverter(GraphConverter):
|
|||||||
input_graph_def=input_graph_def,
|
input_graph_def=input_graph_def,
|
||||||
nodes_blacklist=nodes_blacklist,
|
nodes_blacklist=nodes_blacklist,
|
||||||
session_config=session_config)
|
session_config=session_config)
|
||||||
|
_check_trt_version_compatibility()
|
||||||
# TODO(laigd): move all the validations below to
|
|
||||||
# get_tensorrt_rewriter_config().
|
|
||||||
# Check compatibility of TensorRT version.
|
|
||||||
compiled_version = get_linked_tensorrt_version()
|
|
||||||
loaded_version = get_loaded_tensorrt_version()
|
|
||||||
tf_logging.info("Linked TensorRT version: %s" % str(compiled_version))
|
|
||||||
tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
|
|
||||||
version_mismatch = False
|
|
||||||
if loaded_version[0] < compiled_version[0]:
|
|
||||||
tf_logging.error(
|
|
||||||
"TensorRT version mismatch. Tensorflow was compiled against " +
|
|
||||||
"TensorRT %s but library loaded from environment is TensorRT %s" %
|
|
||||||
(".".join([str(x) for x in compiled_version]),
|
|
||||||
".".join([str(x) for x in loaded_version])) +
|
|
||||||
". Please make sure that correct version of TensorRT " +
|
|
||||||
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
|
|
||||||
raise RuntimeError("Incompatible TensorRT library version")
|
|
||||||
for i in zip(loaded_version, compiled_version):
|
|
||||||
if i[0] != i[1]:
|
|
||||||
tf_logging.warn("TensorRT mismatch. Compiled against version " +
|
|
||||||
"%s, but loaded %s. Things may not work" %
|
|
||||||
(".".join([str(x) for x in compiled_version]),
|
|
||||||
".".join([str(x) for x in loaded_version])))
|
|
||||||
version_mismatch = True
|
|
||||||
break
|
|
||||||
if not version_mismatch:
|
|
||||||
tf_logging.info("Running against TensorRT version %s" %
|
|
||||||
".".join([str(x) for x in loaded_version]))
|
|
||||||
|
|
||||||
# Check input arguments.
|
|
||||||
supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
|
|
||||||
if precision_mode not in supported_precision_modes:
|
|
||||||
raise ValueError(
|
|
||||||
("precision mode '{}' is not supported."
|
|
||||||
"It should be one of {}").format(precision_mode,
|
|
||||||
supported_precision_modes))
|
|
||||||
|
|
||||||
if cached_engine_batches:
|
|
||||||
if not isinstance(cached_engine_batches, list):
|
|
||||||
raise TypeError("cached_engine_batches should be a list.")
|
|
||||||
if len(cached_engine_batches) > maximum_cached_engines:
|
|
||||||
raise ValueError("cached_engine_batches should not contain more than "
|
|
||||||
"maximum_cached_engines items.")
|
|
||||||
|
|
||||||
self._need_calibration = (
|
self._need_calibration = (
|
||||||
precision_mode == TrtPrecisionMode.INT8 and use_calibration)
|
precision_mode == TrtPrecisionMode.INT8 and use_calibration)
|
||||||
self._use_function_backup = use_function_backup
|
|
||||||
|
|
||||||
# TODO(laigd): consider provide a mechanism to remove the fallback path
|
# TODO(laigd): consider provide a mechanism to remove the fallback path
|
||||||
# after calibration is done.
|
# after calibration is done.
|
||||||
@ -724,31 +727,30 @@ class TrtGraphConverter(GraphConverter):
|
|||||||
"Calibration requires enabling fallback to TF function execution.")
|
"Calibration requires enabling fallback to TF function execution.")
|
||||||
|
|
||||||
# TODO(laigd):
|
# TODO(laigd):
|
||||||
# - Get rid of is_dynamic_op option, it should always be True, and it should
|
|
||||||
# accept N shapes as input.
|
|
||||||
# - Verify in int8 mode that maximum_cached_engines and
|
# - Verify in int8 mode that maximum_cached_engines and
|
||||||
# cached_engine_batches are set appropriately.
|
# cached_engine_batches are set appropriately.
|
||||||
# - If it fails to build the int8 engine it should return error.
|
# - If it fails to build the int8 engine it should return error.
|
||||||
self._max_batch_size = max_batch_size
|
rewriter_config_template = None
|
||||||
self._max_workspace_size_bytes = max_workspace_size_bytes
|
if (session_config and session_config.HasField("graph_options") and
|
||||||
self._precision_mode = precision_mode
|
session_config.graph_options.HasField("rewrite_options")):
|
||||||
self._minimum_segment_size = minimum_segment_size
|
rewriter_config_template = session_config.graph_options.rewrite_options
|
||||||
self._is_dynamic_op = is_dynamic_op
|
|
||||||
self._maximum_cached_engines = maximum_cached_engines
|
|
||||||
self._cached_engine_batches = cached_engine_batches
|
|
||||||
|
|
||||||
def get_rewriter_config(self, rewriter_config_template=None):
|
self._conversion_params = TrtConversionParams(
|
||||||
return TrtGraphConverter.get_tensorrt_rewriter_config(
|
rewriter_config_template=rewriter_config_template,
|
||||||
rewriter_config_template,
|
max_workspace_size_bytes=max_workspace_size_bytes,
|
||||||
max_batch_size=self._max_batch_size,
|
precision_mode=precision_mode,
|
||||||
max_workspace_size_bytes=self._max_workspace_size_bytes,
|
minimum_segment_size=minimum_segment_size,
|
||||||
precision_mode=self._precision_mode,
|
is_dynamic_op=is_dynamic_op,
|
||||||
minimum_segment_size=self._minimum_segment_size,
|
maximum_cached_engines=maximum_cached_engines,
|
||||||
is_dynamic_op=self._is_dynamic_op,
|
use_calibration=use_calibration,
|
||||||
maximum_cached_engines=self._maximum_cached_engines,
|
use_function_backup=use_function_backup,
|
||||||
cached_engine_batches=self._cached_engine_batches,
|
max_batch_size=max_batch_size,
|
||||||
use_calibration=self._need_calibration,
|
cached_engine_batches=cached_engine_batches)
|
||||||
use_function_backup=self._use_function_backup)
|
_check_conversion_params(self._conversion_params)
|
||||||
|
|
||||||
|
def get_rewriter_config(self):
|
||||||
|
return get_tensorrt_rewriter_config(
|
||||||
|
conversion_params=self._conversion_params)
|
||||||
|
|
||||||
def finalize_calibration(self):
|
def finalize_calibration(self):
|
||||||
assert self._need_calibration
|
assert self._need_calibration
|
||||||
@ -775,7 +777,7 @@ class TrtGraphConverter(GraphConverter):
|
|||||||
resource_name_input = array_ops.placeholder(dtypes.string)
|
resource_name_input = array_ops.placeholder(dtypes.string)
|
||||||
|
|
||||||
for node in self._converted_graph_def.node:
|
for node in self._converted_graph_def.node:
|
||||||
if node.op == "TRTEngineOp":
|
if node.op == _TRT_ENGINE_OP_NAME:
|
||||||
# Adds the get_serialized_resource_op for the device if not done
|
# Adds the get_serialized_resource_op for the device if not done
|
||||||
# before. We only add one such op for each device.
|
# before. We only add one such op for each device.
|
||||||
# TODO(laigd): What if the device is empty?????
|
# TODO(laigd): What if the device is empty?????
|
||||||
@ -791,11 +793,8 @@ class TrtGraphConverter(GraphConverter):
|
|||||||
calibration_result = self._calibration_sess.run(
|
calibration_result = self._calibration_sess.run(
|
||||||
device_to_get_resource_op_map[node.device],
|
device_to_get_resource_op_map[node.device],
|
||||||
feed_dict={
|
feed_dict={
|
||||||
container_input:
|
container_input: _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
|
||||||
TrtGraphConverter
|
resource_name_input: node.name
|
||||||
._TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
|
|
||||||
resource_name_input:
|
|
||||||
node.name
|
|
||||||
})
|
})
|
||||||
node.attr["calibration_data"].s = calibration_result
|
node.attr["calibration_data"].s = calibration_result
|
||||||
|
|
||||||
@ -806,9 +805,106 @@ class TrtGraphConverter(GraphConverter):
|
|||||||
"""Save the converted graph as a SavedModel."""
|
"""Save the converted graph as a SavedModel."""
|
||||||
if self._need_calibration:
|
if self._need_calibration:
|
||||||
assert self._calibration_data_collected
|
assert self._calibration_data_collected
|
||||||
|
|
||||||
super(TrtGraphConverter, self).save(output_saved_model_dir)
|
super(TrtGraphConverter, self).save(output_saved_model_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class TrtGraphConverterV2(object):
|
||||||
|
"""A converter for TF-TRT transformation for SavedModel in TF 2.0."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_saved_model_dir=None,
|
||||||
|
input_saved_model_tags=None,
|
||||||
|
input_saved_model_signature_key=None,
|
||||||
|
conversion_params=DEFAULT_TRT_CONVERSION_PARAMS):
|
||||||
|
"""Initialize the converter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_saved_model_dir: the directory to load the SavedModel which contains
|
||||||
|
the input graph to transforms. Used only when input_graph_def is None.
|
||||||
|
input_saved_model_tags: list of tags to load the SavedModel.
|
||||||
|
input_saved_model_signature_key: the key of the signature to optimize the
|
||||||
|
graph for.
|
||||||
|
conversion_params: a TrtConversionParams instance.
|
||||||
|
"""
|
||||||
|
assert context.executing_eagerly()
|
||||||
|
_check_trt_version_compatibility()
|
||||||
|
|
||||||
|
self._input_saved_model_dir = input_saved_model_dir
|
||||||
|
self._input_saved_model_tags = (
|
||||||
|
input_saved_model_tags or [tag_constants.SERVING])
|
||||||
|
self._input_saved_model_signature_key = (
|
||||||
|
input_saved_model_signature_key or
|
||||||
|
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||||
|
|
||||||
|
self._need_calibration = (
|
||||||
|
conversion_params.precision_mode == TrtPrecisionMode.INT8 and
|
||||||
|
conversion_params.use_calibration)
|
||||||
|
self._conversion_params = conversion_params
|
||||||
|
_check_conversion_params(self._conversion_params)
|
||||||
|
self._converted = False
|
||||||
|
|
||||||
|
def _run_conversion(self, meta_graph_def):
|
||||||
|
"""Run Grappler's OptimizeGraph() tool to convert the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
meta_graph_def: the MetaGraphDef instance to run the optimizations on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The optimized GraphDef.
|
||||||
|
"""
|
||||||
|
rewriter_config = get_tensorrt_rewriter_config(
|
||||||
|
conversion_params=self._conversion_params)
|
||||||
|
grappler_session_config = config_pb2.ConfigProto()
|
||||||
|
grappler_session_config.graph_options.rewrite_options.CopyFrom(
|
||||||
|
rewriter_config)
|
||||||
|
return tf_optimizer.OptimizeGraph(
|
||||||
|
grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
|
||||||
|
|
||||||
|
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
|
||||||
|
# use it here (b/124792963).
|
||||||
|
def convert(self):
|
||||||
|
"""Convert the input SavedModel in 2.0 format."""
|
||||||
|
assert not self._converted
|
||||||
|
self._saved_model = load.load(self._input_saved_model_dir,
|
||||||
|
self._input_saved_model_tags)
|
||||||
|
func = self._saved_model.signatures[self._input_saved_model_signature_key]
|
||||||
|
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
|
||||||
|
grappler_meta_graph_def = saver.export_meta_graph(
|
||||||
|
graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
|
||||||
|
|
||||||
|
# Add a collection 'train_op' so that Grappler knows the outputs.
|
||||||
|
fetch_collection = meta_graph_pb2.CollectionDef()
|
||||||
|
for array in frozen_func.inputs + frozen_func.outputs:
|
||||||
|
fetch_collection.node_list.value.append(array.name)
|
||||||
|
grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
|
||||||
|
fetch_collection)
|
||||||
|
|
||||||
|
# Run TRT optimizer in Grappler to convert the graph.
|
||||||
|
converted_graph_def = self._run_conversion(grappler_meta_graph_def)
|
||||||
|
self._converted_func = wrap_function.function_from_graph_def(
|
||||||
|
converted_graph_def, [tensor.name for tensor in frozen_func.inputs],
|
||||||
|
[tensor.name for tensor in frozen_func.outputs])
|
||||||
|
|
||||||
|
self._converted = True
|
||||||
|
return self._converted_func
|
||||||
|
|
||||||
|
def save(self, output_saved_model_dir):
|
||||||
|
"""Save the converted SavedModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_saved_model_dir: directory to saved the converted SavedModel.
|
||||||
|
"""
|
||||||
|
assert self._converted
|
||||||
|
# Rewrite the signature map using the optimized ConcreteFunction.
|
||||||
|
signatures = {
|
||||||
|
key: value for key, value in self._saved_model.signatures.items()
|
||||||
|
}
|
||||||
|
signatures[self._input_saved_model_signature_key] = self._converted_func
|
||||||
|
save.save(self._saved_model, output_saved_model_dir, signatures)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(laigd): use TrtConversionParams here.
|
||||||
def create_inference_graph(
|
def create_inference_graph(
|
||||||
input_graph_def,
|
input_graph_def,
|
||||||
outputs,
|
outputs,
|
||||||
|
@ -19,14 +19,17 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
|
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||||
from tensorflow.python.eager import context
|
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import graph_util
|
from tensorflow.python.framework import graph_util
|
||||||
@ -35,7 +38,6 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_nn_ops
|
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import builder
|
from tensorflow.python.saved_model import builder
|
||||||
@ -44,10 +46,11 @@ from tensorflow.python.saved_model import signature_constants
|
|||||||
from tensorflow.python.saved_model import signature_def_utils
|
from tensorflow.python.saved_model import signature_def_utils
|
||||||
from tensorflow.python.saved_model import tag_constants
|
from tensorflow.python.saved_model import tag_constants
|
||||||
from tensorflow.python.saved_model import utils
|
from tensorflow.python.saved_model import utils
|
||||||
from tensorflow.python.tools import saved_model_utils
|
|
||||||
from tensorflow.python.saved_model import load
|
from tensorflow.python.saved_model import load
|
||||||
from tensorflow.python.saved_model import save
|
from tensorflow.python.saved_model import save
|
||||||
|
from tensorflow.python.tools import saved_model_utils
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
_SAVED_MODEL_SIGNATURE_KEY = "mypredict"
|
_SAVED_MODEL_SIGNATURE_KEY = "mypredict"
|
||||||
|
|
||||||
@ -63,8 +66,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
|
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
|
||||||
if not is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
|
conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||||
rewriter_config_template=None,
|
|
||||||
max_batch_size=128,
|
max_batch_size=128,
|
||||||
max_workspace_size_bytes=1234,
|
max_workspace_size_bytes=1234,
|
||||||
precision_mode="INT8",
|
precision_mode="INT8",
|
||||||
@ -72,6 +74,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
is_dynamic_op=True,
|
is_dynamic_op=True,
|
||||||
maximum_cached_engines=2,
|
maximum_cached_engines=2,
|
||||||
cached_engine_batches=[1, 128])
|
cached_engine_batches=[1, 128])
|
||||||
|
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
|
||||||
|
conversion_params=conversion_params)
|
||||||
self.assertEqual(["constfold", "layout", "constfold"],
|
self.assertEqual(["constfold", "layout", "constfold"],
|
||||||
rewriter_cfg.optimizers)
|
rewriter_cfg.optimizers)
|
||||||
self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
|
self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
|
||||||
@ -106,7 +110,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
gpu_options=config_pb2.GPUOptions(allow_growth=True))
|
gpu_options=config_pb2.GPUOptions(allow_growth=True))
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def _GetGraph(self):
|
@classmethod
|
||||||
|
def _GetGraph(cls, inp, var):
|
||||||
"""Get the graph for testing."""
|
"""Get the graph for testing."""
|
||||||
# The graph computes (input+1)^2, it looks like:
|
# The graph computes (input+1)^2, it looks like:
|
||||||
#
|
#
|
||||||
@ -119,24 +124,42 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
# +
|
# +
|
||||||
# |
|
# |
|
||||||
# output (Identity)
|
# output (Identity)
|
||||||
|
add = inp + var
|
||||||
|
mul = inp * add
|
||||||
|
add = mul + add
|
||||||
|
out = array_ops.identity(add, name="output")
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _GetModelForV2(self):
|
||||||
|
|
||||||
|
class SimpleModel(tracking.AutoTrackable):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.v = None
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
|
||||||
|
])
|
||||||
|
def run(self, inp):
|
||||||
|
if self.v is None:
|
||||||
|
self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
|
||||||
|
return TrtConvertTest._GetGraph(inp, self.v)
|
||||||
|
|
||||||
|
return SimpleModel()
|
||||||
|
|
||||||
|
def _GetGraphForV1(self):
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with g.as_default():
|
with g.as_default():
|
||||||
with g.device("/GPU:0"):
|
with g.device("/GPU:0"):
|
||||||
inp = array_ops.placeholder(
|
inp = array_ops.placeholder(
|
||||||
dtype=dtypes.float32, shape=[None, 1, 1], name="input")
|
dtype=dtypes.float32, shape=[None, 1, 1], name="input")
|
||||||
var = variables.VariableV1([[[1.0]]],
|
var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
|
||||||
dtype=dtypes.float32,
|
out = TrtConvertTest._GetGraph(inp, var)
|
||||||
name="v1",
|
return g, var, inp, out
|
||||||
use_resource=False)
|
|
||||||
add = inp + var.value()
|
|
||||||
mul = inp * add
|
|
||||||
add = mul + add
|
|
||||||
out = array_ops.identity(add, name="output")
|
|
||||||
return g, var, inp, out
|
|
||||||
|
|
||||||
def _GetGraphDef(self):
|
def _GetGraphDef(self):
|
||||||
"""Get the graph def for testing."""
|
"""Get the graph def for testing."""
|
||||||
g, var, _, _ = self._GetGraph()
|
g, var, _, _ = self._GetGraphForV1()
|
||||||
with self.session(graph=g, config=self._GetConfigProto()) as sess:
|
with self.session(graph=g, config=self._GetConfigProto()) as sess:
|
||||||
sess.run(var.initializer)
|
sess.run(var.initializer)
|
||||||
graph_def = graph_util.convert_variables_to_constants(
|
graph_def = graph_util.convert_variables_to_constants(
|
||||||
@ -145,7 +168,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{
|
{
|
||||||
"v1": "Const",
|
"v1": "Const",
|
||||||
"v1/read": "Identity",
|
"add/ReadVariableOp": "Identity",
|
||||||
"input": "Placeholder",
|
"input": "Placeholder",
|
||||||
"add": "Add",
|
"add": "Add",
|
||||||
"mul": "Mul",
|
"mul": "Mul",
|
||||||
@ -156,7 +179,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def _WriteInputSavedModel(self, input_saved_model_dir):
|
def _WriteInputSavedModel(self, input_saved_model_dir):
|
||||||
"""Write the saved model as an input for testing."""
|
"""Write the saved model as an input for testing."""
|
||||||
g, var, inp, out = self._GetGraph()
|
g, var, inp, out = self._GetGraphForV1()
|
||||||
signature_def = signature_def_utils.build_signature_def(
|
signature_def = signature_def_utils.build_signature_def(
|
||||||
inputs={"myinput": utils.build_tensor_info(inp)},
|
inputs={"myinput": utils.build_tensor_info(inp)},
|
||||||
outputs={"myoutput": utils.build_tensor_info(out)},
|
outputs={"myoutput": utils.build_tensor_info(out)},
|
||||||
@ -183,7 +206,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
input_saved_model_dir=input_saved_model_dir,
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||||
input_graph_def=None if input_saved_model_dir else self._GetGraphDef(),
|
input_graph_def=None if input_saved_model_dir else self._GetGraphDef(),
|
||||||
nodes_blacklist=["output"],
|
nodes_blacklist=None if input_saved_model_dir else ["output"],
|
||||||
session_config=self._GetConfigProto(),
|
session_config=self._GetConfigProto(),
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
|
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||||
@ -193,28 +216,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
is_dynamic_op=is_dynamic_op,
|
is_dynamic_op=is_dynamic_op,
|
||||||
maximum_cached_engines=maximum_cached_engines,
|
maximum_cached_engines=maximum_cached_engines,
|
||||||
use_function_backup=use_function_backup)
|
use_function_backup=use_function_backup)
|
||||||
conversion_result = converter.convert()
|
output_graph_def = converter.convert()
|
||||||
|
|
||||||
if context.executing_eagerly():
|
if need_calibration:
|
||||||
output_graph_def = conversion_result.graph.as_graph_def()
|
|
||||||
else:
|
|
||||||
output_graph_def = conversion_result
|
|
||||||
|
|
||||||
if need_calibration:
|
class CalibrationData(object):
|
||||||
|
|
||||||
class CalibrationData(object):
|
def __init__(self):
|
||||||
|
self._data = 0
|
||||||
|
|
||||||
def __init__(self):
|
def next(self):
|
||||||
self._data = 0
|
self._data += 1
|
||||||
|
return {"input:0": [[[self._data]]]}
|
||||||
|
|
||||||
def next(self):
|
output_graph_def = converter.calibrate(
|
||||||
self._data += 1
|
fetch_names=["output:0"],
|
||||||
return {"input:0": [[[self._data]]]}
|
num_runs=10,
|
||||||
|
feed_dict_fn=CalibrationData().next)
|
||||||
output_graph_def = converter.calibrate(
|
|
||||||
fetch_names=["output:0"],
|
|
||||||
num_runs=10,
|
|
||||||
feed_dict_fn=CalibrationData().next)
|
|
||||||
|
|
||||||
if output_saved_model_dir is not None:
|
if output_saved_model_dir is not None:
|
||||||
converter.save(output_saved_model_dir=output_saved_model_dir)
|
converter.save(output_saved_model_dir=output_saved_model_dir)
|
||||||
@ -235,31 +253,19 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
graph_defs_to_verify = [output_graph_def]
|
graph_defs_to_verify = [output_graph_def]
|
||||||
|
|
||||||
if output_saved_model_dir:
|
if output_saved_model_dir:
|
||||||
if context.executing_eagerly():
|
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
|
||||||
root = load.load(output_saved_model_dir)
|
output_saved_model_dir, tag_constants.SERVING).graph_def
|
||||||
saved_model_graph_def = root.signatures[
|
self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
|
||||||
_SAVED_MODEL_SIGNATURE_KEY].graph.as_graph_def()
|
|
||||||
else:
|
|
||||||
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
|
|
||||||
output_saved_model_dir, tag_constants.SERVING).graph_def
|
|
||||||
self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
|
|
||||||
graph_defs_to_verify.append(saved_model_graph_def)
|
graph_defs_to_verify.append(saved_model_graph_def)
|
||||||
|
|
||||||
for graph_def in graph_defs_to_verify:
|
for graph_def in graph_defs_to_verify:
|
||||||
node_name_to_op = {node.name: node.op for node in graph_def.node}
|
node_name_to_op = {node.name: node.op for node in graph_def.node}
|
||||||
if context.executing_eagerly():
|
self.assertEqual(
|
||||||
# In V2 the actual graph could be inside a function.
|
{
|
||||||
for func in graph_def.library.function:
|
"input": "Placeholder",
|
||||||
node_name_to_op.update({node.name: node.op for node in func.node_def})
|
"TRTEngineOp_0": "TRTEngineOp",
|
||||||
self.assertIn("TRTEngineOp_0", node_name_to_op)
|
"output": "Identity"
|
||||||
self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"])
|
}, node_name_to_op)
|
||||||
else:
|
|
||||||
self.assertEqual(
|
|
||||||
{
|
|
||||||
"input": "Placeholder",
|
|
||||||
"TRTEngineOp_0": "TRTEngineOp",
|
|
||||||
"output": "Identity"
|
|
||||||
}, node_name_to_op)
|
|
||||||
|
|
||||||
if need_calibration:
|
if need_calibration:
|
||||||
trt_engine_nodes = [
|
trt_engine_nodes = [
|
||||||
@ -306,39 +312,81 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
if not is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO(laigd): we need to use ops like conv2d so Grappler can infer the
|
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||||
# shapes (at least rank) of the tensors, so we're able to build an TRT
|
|
||||||
# engine in dynamic mode. Currently shape information is not propagate from
|
|
||||||
# ConcreteFunction to GraphDef, need to investigate and fix it.
|
|
||||||
class SimpleModel(tracking.AutoTrackable):
|
|
||||||
|
|
||||||
def __init__(self):
|
# Create a model and save it.
|
||||||
self.v = None
|
input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||||
|
root = self._GetModelForV2()
|
||||||
@def_function.function(input_signature=[
|
expected_output = root.run(np_input)
|
||||||
tensor_spec.TensorSpec(shape=[None, 24, 24, 2], dtype=dtypes.float32)
|
|
||||||
])
|
|
||||||
def run(self, inp):
|
|
||||||
if self.v is None:
|
|
||||||
self.v = variables.Variable([[[[1., 0.5, 4., 6., 0.5, 1.],
|
|
||||||
[1., 0.5, 1., 1., 0.5, 1.]]]])
|
|
||||||
conv = gen_nn_ops.conv2d(
|
|
||||||
input=inp, filter=self.v, strides=[1, 2, 2, 1], padding="SAME")
|
|
||||||
identity = array_ops.identity(conv)
|
|
||||||
return identity
|
|
||||||
|
|
||||||
tmp_dir = self.get_temp_dir()
|
|
||||||
input_saved_model_dir = os.path.join(tmp_dir, "in_dir1_v2")
|
|
||||||
root = SimpleModel()
|
|
||||||
save.save(root, input_saved_model_dir,
|
save.save(root, input_saved_model_dir,
|
||||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||||
|
|
||||||
# Convert the SavedModel and verify the result.
|
# Run TRT conversion.
|
||||||
output_saved_model_dir = os.path.join(tmp_dir, "out_dir1_v2")
|
converter = trt_convert.TrtGraphConverterV2(
|
||||||
self._TestTrtGraphConverter(
|
|
||||||
input_saved_model_dir=input_saved_model_dir,
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
output_saved_model_dir=output_saved_model_dir,
|
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||||
is_dynamic_op=True)
|
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||||
|
precision_mode=trt_convert.TrtPrecisionMode.FP32,
|
||||||
|
is_dynamic_op=True,
|
||||||
|
maximum_cached_engines=2,
|
||||||
|
use_function_backup=False))
|
||||||
|
converted_concrete_func = converter.convert()
|
||||||
|
|
||||||
|
def _check_trt_ops(graph_def):
|
||||||
|
trt_op_names = [
|
||||||
|
node.name for node in graph_def.node if node.op == "TRTEngineOp"
|
||||||
|
]
|
||||||
|
for func in graph_def.library.function:
|
||||||
|
for node in func.node_def:
|
||||||
|
if node.op == "TRTEngineOp":
|
||||||
|
trt_op_names.append(node.name)
|
||||||
|
self.assertEqual(1, len(trt_op_names))
|
||||||
|
self.assertIn("TRTEngineOp_0", trt_op_names[0])
|
||||||
|
|
||||||
|
# Verify the converted GraphDef and ConcreteFunction.
|
||||||
|
self.assertIsInstance(converted_concrete_func, function.ConcreteFunction)
|
||||||
|
converted_graph_def = converted_concrete_func.graph.as_graph_def()
|
||||||
|
_check_trt_ops(converted_graph_def)
|
||||||
|
output_with_trt = converted_concrete_func(ops.convert_to_tensor(np_input))
|
||||||
|
self.assertEqual(1, len(output_with_trt))
|
||||||
|
self.assertAllClose(
|
||||||
|
expected_output, output_with_trt[0].numpy(), atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
# Run the converted ConcreteFunction as a function and make sure it works.
|
||||||
|
@def_function.function
|
||||||
|
def wrapper_func(*args, **kwargs):
|
||||||
|
return nest.flatten(converted_concrete_func(*args, **kwargs))
|
||||||
|
|
||||||
|
_check_trt_ops(
|
||||||
|
wrapper_func.get_concrete_function(
|
||||||
|
tensor_spec.TensorSpec(shape=[None, 1, 1],
|
||||||
|
dtype=dtypes.float32)).graph.as_graph_def())
|
||||||
|
output_with_trt = wrapper_func(np_input)
|
||||||
|
self.assertEqual(1, len(output_with_trt))
|
||||||
|
self.assertAllClose(
|
||||||
|
expected_output, output_with_trt[0], atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
# Save the converted model.
|
||||||
|
output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||||
|
converter.save(output_saved_model_dir)
|
||||||
|
|
||||||
|
# Load and verify the converted model.
|
||||||
|
#
|
||||||
|
# TODO(laigd): the name of then new input_signature of the
|
||||||
|
# `root_with_trt.run` function is empty string (originaly was None),
|
||||||
|
# investigate why.
|
||||||
|
root_with_trt = load.load(output_saved_model_dir)
|
||||||
|
# TODO(laigd): `root_with_trt.run` is still using the original graph without
|
||||||
|
# trt. Consider changing that.
|
||||||
|
# _check_trt_ops(
|
||||||
|
# root_with_trt.run.get_concrete_function().graph.as_graph_def())
|
||||||
|
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||||
|
_check_trt_ops(converted_signature.graph.as_graph_def())
|
||||||
|
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||||
|
# The output of running the converted signature is a dict due to
|
||||||
|
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||||
|
output_with_trt = output_with_trt[output_with_trt.keys()[0]]
|
||||||
|
self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
def _TestRun(self,
|
def _TestRun(self,
|
||||||
sess,
|
sess,
|
||||||
@ -363,7 +411,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
|
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{
|
{
|
||||||
"v1/read": "Const",
|
"add/ReadVariableOp": "Const",
|
||||||
"input": "Placeholder",
|
"input": "Placeholder",
|
||||||
"add": "Add",
|
"add": "Add",
|
||||||
"mul": "Mul",
|
"mul": "Mul",
|
||||||
|
Loading…
Reference in New Issue
Block a user