Refactor TF-TRT API:
- Add a new GraphConverter class to provide general support for offline GraphDef/SavedModel conversion. It also supports post-conversion calibration. This way we can make a backend specific converter by inheriting this class and overwrite the get_rewriter_config() method to provide backend specific rewriter options. - Add a new TrtGraphConverter class (inherited from GraphConverter) for TRT conversion. PiperOrigin-RevId: 232514321
This commit is contained in:
parent
f722aee786
commit
cc79252c0b
@ -259,7 +259,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
|||||||
"""Get config proto based on specific settings."""
|
"""Get config proto based on specific settings."""
|
||||||
conversion_params = self.GetConversionParams(run_params)
|
conversion_params = self.GetConversionParams(run_params)
|
||||||
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
|
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
|
||||||
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
|
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
|
||||||
conversion_params.rewriter_config, conversion_params.max_batch_size,
|
conversion_params.rewriter_config, conversion_params.max_batch_size,
|
||||||
conversion_params.max_workspace_size_bytes,
|
conversion_params.max_workspace_size_bytes,
|
||||||
conversion_params.precision_mode,
|
conversion_params.precision_mode,
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import six as _six
|
import six as _six
|
||||||
# pylint: disable=unused-import,line-too-long
|
# pylint: disable=unused-import,line-too-long
|
||||||
|
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
|
||||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
||||||
@ -41,7 +42,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.grappler import tf_optimizer
|
from tensorflow.python.grappler import tf_optimizer
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
from tensorflow.python.saved_model import builder
|
from tensorflow.python.saved_model import builder
|
||||||
from tensorflow.python.saved_model import loader_impl
|
from tensorflow.python.saved_model import loader
|
||||||
from tensorflow.python.saved_model import tag_constants
|
from tensorflow.python.saved_model import tag_constants
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
|
|
||||||
@ -60,6 +61,228 @@ def _to_string(s):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class GraphConverter(object):
|
||||||
|
"""Base class for offline converters to optimize SavedModels/GraphDefs.
|
||||||
|
|
||||||
|
A `GraphConverter` object encapsulates the environment to convert (optimize) a
|
||||||
|
TensorFlow SavedModel or GraphDef.
|
||||||
|
|
||||||
|
To create a custom GraphConverter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyGraphConverter(GraphConverter):
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_rewriter_config(self, rewriter_config_template=None):
|
||||||
|
my_rewriter_config = ...
|
||||||
|
return my_rewriter_config
|
||||||
|
```
|
||||||
|
|
||||||
|
Then to run the conversion without quantization calibration:
|
||||||
|
|
||||||
|
```python
|
||||||
|
my_converter = MyGraphConverter(input_saved_model_dir="my_dir")
|
||||||
|
converted_graph_def = my_converter.convert()
|
||||||
|
my_converter.save(output_saved_model_dir) # Optional
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO(laigd): add calibration support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_saved_model_dir=None,
|
||||||
|
input_saved_model_tags=None,
|
||||||
|
input_graph_def=None,
|
||||||
|
nodes_blacklist=None,
|
||||||
|
session_config=None):
|
||||||
|
"""Initialize the converter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_saved_model_dir: the directory to load the SavedModel which contains
|
||||||
|
the input graph to transforms. Used only when input_graph_def is None.
|
||||||
|
input_saved_model_tags: list of tags to load the SavedModel.
|
||||||
|
input_graph_def: a GraphDef object containing a model to be transformed.
|
||||||
|
If set to None, the graph will be read from the SavedModel loaded from
|
||||||
|
input_saved_model_dir.
|
||||||
|
nodes_blacklist: list of node names to prevent the converter from
|
||||||
|
touching. Only used when input_graph_def is not None.
|
||||||
|
session_config: the ConfigProto used to create a Session. It's also used
|
||||||
|
as a template to create a RewriterConfig for conversion. If not
|
||||||
|
specified, a default ConfigProto will be used.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the combination of the parameters is invalid.
|
||||||
|
"""
|
||||||
|
if input_graph_def and input_saved_model_dir:
|
||||||
|
raise ValueError(
|
||||||
|
"Can only specify one of input_graph_def and input_saved_model_dir")
|
||||||
|
if not input_graph_def and not input_saved_model_dir:
|
||||||
|
raise ValueError("Must specify one of input_graph_def and "
|
||||||
|
"input_saved_model_dir")
|
||||||
|
|
||||||
|
self._input_graph_def = input_graph_def
|
||||||
|
self._nodes_blacklist = nodes_blacklist
|
||||||
|
self._input_saved_model_dir = input_saved_model_dir
|
||||||
|
self._converted = False
|
||||||
|
self._grappler_meta_graph_def = None
|
||||||
|
|
||||||
|
self._input_saved_model_tags = (
|
||||||
|
input_saved_model_tags or [tag_constants.SERVING])
|
||||||
|
self._session_config = session_config or config_pb2.ConfigProto()
|
||||||
|
|
||||||
|
def get_rewriter_config(self, rewriter_config_template=None):
|
||||||
|
"""Returns a RewriterConfig proto for TRT transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewriter_config_template: a template RewriterConfig proto used to create a
|
||||||
|
RewriterConfig for the conversion. The implementation should not modify
|
||||||
|
the template. If None, it will use a default one.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A RewriterConfig proto which will be used to run the conversion using
|
||||||
|
Grappler.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("get_rewriter_config")
|
||||||
|
|
||||||
|
def _run_conversion(self):
|
||||||
|
"""Run Grappler's OptimizeGraph() tool to convert the graph."""
|
||||||
|
# Create custom ConfigProto for Grappler.
|
||||||
|
grappler_session_config = config_pb2.ConfigProto()
|
||||||
|
grappler_session_config.CopyFrom(self._session_config)
|
||||||
|
rewriter_config = None
|
||||||
|
if (grappler_session_config.HasField("graph_options") and
|
||||||
|
grappler_session_config.graph_options.HasField("rewrite_options")):
|
||||||
|
rewriter_config = grappler_session_config.graph_options.rewrite_options
|
||||||
|
custom_rewriter_config = self.get_rewriter_config(rewriter_config)
|
||||||
|
grappler_session_config.graph_options.rewrite_options.CopyFrom(
|
||||||
|
custom_rewriter_config)
|
||||||
|
|
||||||
|
# Run Grappler.
|
||||||
|
self._converted_graph_def = tf_optimizer.OptimizeGraph(
|
||||||
|
grappler_session_config,
|
||||||
|
self._grappler_meta_graph_def,
|
||||||
|
graph_id=b"tf_graph")
|
||||||
|
self._converted = True
|
||||||
|
|
||||||
|
def _convert_graph_def(self):
|
||||||
|
"""Convert the input GraphDef."""
|
||||||
|
graph = ops.Graph()
|
||||||
|
with graph.as_default():
|
||||||
|
importer.import_graph_def(self._input_graph_def, name="")
|
||||||
|
self._grappler_meta_graph_def = saver.export_meta_graph(
|
||||||
|
graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
|
||||||
|
if self._nodes_blacklist:
|
||||||
|
output_collection = meta_graph_pb2.CollectionDef()
|
||||||
|
output_list = output_collection.node_list.value
|
||||||
|
for i in self._nodes_blacklist:
|
||||||
|
if isinstance(i, ops.Tensor):
|
||||||
|
output_list.append(_to_bytes(i.name))
|
||||||
|
else:
|
||||||
|
output_list.append(_to_bytes(i))
|
||||||
|
# TODO(laigd): use another key as the self._nodes_blacklist are really
|
||||||
|
# not train_op.
|
||||||
|
self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
|
||||||
|
output_collection)
|
||||||
|
|
||||||
|
self._run_conversion()
|
||||||
|
|
||||||
|
def _convert_saved_model(self):
|
||||||
|
"""Convert the input SavedModel."""
|
||||||
|
graph = ops.Graph()
|
||||||
|
with session.Session(graph=graph, config=self._session_config) as sess:
|
||||||
|
input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
|
||||||
|
self._input_saved_model_dir)
|
||||||
|
|
||||||
|
def _gather_names(tensor_info):
|
||||||
|
"""Get the node names from a TensorInfo."""
|
||||||
|
return set([tensor_info[key].name.split(":")[0] for key in tensor_info])
|
||||||
|
|
||||||
|
# Get input and outputs from all SignatureDef.
|
||||||
|
output_node_names = set()
|
||||||
|
for key in input_meta_graph_def.signature_def:
|
||||||
|
signature_def = input_meta_graph_def.signature_def[key]
|
||||||
|
output_node_names.update(_gather_names(signature_def.inputs))
|
||||||
|
output_node_names.update(_gather_names(signature_def.outputs))
|
||||||
|
|
||||||
|
# Freeze the variables in the SavedModel graph and copy the frozen
|
||||||
|
# graph over.
|
||||||
|
frozen_graph_def = graph_util.convert_variables_to_constants(
|
||||||
|
sess, sess.graph.as_graph_def(add_shapes=True),
|
||||||
|
list(output_node_names))
|
||||||
|
self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
|
||||||
|
self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
|
||||||
|
|
||||||
|
# Copy the collections that are not variables.
|
||||||
|
for key in input_meta_graph_def.collection_def:
|
||||||
|
# TODO(laigd): currently we use the collection key to filter out
|
||||||
|
# collections that depend on variable ops, but this may miss some
|
||||||
|
# other user-defined collections. A better way would be to use
|
||||||
|
# CollectionDef::NodeList for the filtering.
|
||||||
|
if key not in [
|
||||||
|
"variables", "local_variables", "model_variables",
|
||||||
|
"trainable_variables", "train_op", "table_initializer"
|
||||||
|
]:
|
||||||
|
self._grappler_meta_graph_def.collection_def[key].CopyFrom(
|
||||||
|
input_meta_graph_def.collection_def[key])
|
||||||
|
|
||||||
|
# Copy other information.
|
||||||
|
self._grappler_meta_graph_def.meta_info_def.CopyFrom(
|
||||||
|
input_meta_graph_def.meta_info_def)
|
||||||
|
for key in input_meta_graph_def.signature_def:
|
||||||
|
self._grappler_meta_graph_def.signature_def[key].CopyFrom(
|
||||||
|
input_meta_graph_def.signature_def[key])
|
||||||
|
# TODO(laigd): maybe add back AssetFileDef.
|
||||||
|
|
||||||
|
self._run_conversion()
|
||||||
|
|
||||||
|
def convert(self):
|
||||||
|
"""Run the conversion.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The converted GraphDef.
|
||||||
|
"""
|
||||||
|
assert not self._converted
|
||||||
|
|
||||||
|
if self._input_graph_def:
|
||||||
|
self._convert_graph_def()
|
||||||
|
else:
|
||||||
|
self._convert_saved_model()
|
||||||
|
return self._converted_graph_def
|
||||||
|
|
||||||
|
def save(self, output_saved_model_dir):
|
||||||
|
"""Save the converted graph as a SavedModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_saved_model_dir: construct a SavedModel using the converted
|
||||||
|
GraphDef and save it to the specified directory. This option only works
|
||||||
|
when the input graph is loaded from a SavedModel, i.e. when
|
||||||
|
input_saved_model_dir is specified and input_graph_def is None in
|
||||||
|
__init__().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the input to the converter is a GraphDef instead of a
|
||||||
|
SavedModel.
|
||||||
|
"""
|
||||||
|
assert self._converted
|
||||||
|
|
||||||
|
if self._input_graph_def:
|
||||||
|
raise ValueError(
|
||||||
|
"Not able to save to a SavedModel since input is a GraphDef")
|
||||||
|
|
||||||
|
# Write the transformed graphdef as SavedModel.
|
||||||
|
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
importer.import_graph_def(self._converted_graph_def, name="")
|
||||||
|
# We don't use any specific converter here.
|
||||||
|
with session.Session(config=self._session_config) as sess:
|
||||||
|
saved_model_builder.add_meta_graph_and_variables(
|
||||||
|
sess,
|
||||||
|
self._input_saved_model_tags,
|
||||||
|
signature_def_map=self._grappler_meta_graph_def.signature_def)
|
||||||
|
# Ignore other meta graphs from the input SavedModel.
|
||||||
|
saved_model_builder.save()
|
||||||
|
|
||||||
|
|
||||||
class TrtPrecisionMode(object):
|
class TrtPrecisionMode(object):
|
||||||
FP32 = "FP32"
|
FP32 = "FP32"
|
||||||
FP16 = "FP16"
|
FP16 = "FP16"
|
||||||
@ -75,8 +298,15 @@ class TrtPrecisionMode(object):
|
|||||||
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
|
class TrtGraphConverter(GraphConverter):
|
||||||
|
"""A GraphConverter for TRT transformation."""
|
||||||
|
|
||||||
|
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF_TRT_Calibration"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def get_tensorrt_rewriter_config(
|
def get_tensorrt_rewriter_config(
|
||||||
rewriter_config=None,
|
cls,
|
||||||
|
rewriter_config_template=None,
|
||||||
max_batch_size=1,
|
max_batch_size=1,
|
||||||
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
|
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||||
precision_mode=TrtPrecisionMode.FP32,
|
precision_mode=TrtPrecisionMode.FP32,
|
||||||
@ -88,36 +318,37 @@ def get_tensorrt_rewriter_config(
|
|||||||
"""Returns a RewriterConfig proto for TRT transformation.
|
"""Returns a RewriterConfig proto for TRT transformation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rewriter_config: a template RewriterConfig proto used to create a
|
rewriter_config_template: a template RewriterConfig proto used to create a
|
||||||
TRT-enabled RewriterConfig. If None, it will use a default one.
|
TRT-enabled RewriterConfig. If None, it will use a default one.
|
||||||
max_batch_size: max size for the input batch
|
max_batch_size: max size for the input batch
|
||||||
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
|
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
|
||||||
engine can use at execution time. This corresponds to the 'workspaceSize'
|
engine can use at execution time. This corresponds to the
|
||||||
|
'workspaceSize'
|
||||||
parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
|
parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
|
||||||
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
|
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
|
||||||
minimum_segment_size: the minimum number of nodes required for a subgraph to
|
minimum_segment_size: the minimum number of nodes required for a subgraph
|
||||||
be replaced by TRTEngineOp.
|
to be replaced by TRTEngineOp.
|
||||||
is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
|
is_dynamic_op: whether to generate dynamic TRT ops which will build the
|
||||||
network and engine at run time.
|
TRT network and engine at run time.
|
||||||
maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
|
maximum_cached_engines: max number of cached TRT engines in dynamic TRT
|
||||||
If the number of cached engines is already at max but none of them can
|
ops. If the number of cached engines is already at max but none of them
|
||||||
serve the input, the TRTEngineOp will fall back to run the TF function
|
can serve the input, the TRTEngineOp will fall back to run the TF
|
||||||
based on which the TRTEngineOp is created.
|
function based on which the TRTEngineOp is created.
|
||||||
cached_engine_batches: a list of batch sizes used to create cached engines,
|
cached_engine_batches: a list of batch sizes used to create cached
|
||||||
only used when is_dynamic_op is True. The length of the list should be <=
|
engines, only used when is_dynamic_op is True. The length of the list
|
||||||
maximum_cached_engines, and the dynamic TRT op will use this list to
|
should be <= maximum_cached_engines, and the dynamic TRT op will use
|
||||||
determine the batch sizes of the cached engines, instead of making the
|
this list to determine the batch sizes of the cached engines, instead of
|
||||||
decision on the fly. This is useful when we know the most common batch
|
making the decision on the fly. This is useful when we know the most
|
||||||
size(s) the application is going to generate.
|
common batch size(s) the application is going to generate.
|
||||||
use_calibration: this argument is ignored if precision_mode is not INT8. If
|
use_calibration: this argument is ignored if precision_mode is not INT8.
|
||||||
set to True, a calibration graph will be created to calibrate the missing
|
If set to True, a calibration graph will be created to calibrate the
|
||||||
ranges. The calibration graph must be converted to an inference graph
|
missing ranges. The calibration graph must be converted to an inference
|
||||||
using calib_graph_to_infer_graph() after running calibration. if set to
|
graph using calib_graph_to_infer_graph() after running calibration. if
|
||||||
False, quantization nodes will be expected for every tensor in the graph
|
set to False, quantization nodes will be expected for every tensor in
|
||||||
(exlcuding those which will be fused). If a range is missing, an error
|
the graph (exlcuding those which will be fused). If a range is missing,
|
||||||
will occur. Please note that accuracy may be negatively affected if there
|
an error will occur. Please note that accuracy may be negatively
|
||||||
is a mismatch between which tensors TRT quantizes and which tensors were
|
affected if there is a mismatch between which tensors TRT quantizes and
|
||||||
trained with fake quantization.
|
which tensors were trained with fake quantization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
|
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
|
||||||
@ -126,12 +357,13 @@ def get_tensorrt_rewriter_config(
|
|||||||
TypeError: if any of the parameters are of unexpected type.
|
TypeError: if any of the parameters are of unexpected type.
|
||||||
ValueError: if any of the parameters are of unexpected value.
|
ValueError: if any of the parameters are of unexpected value.
|
||||||
"""
|
"""
|
||||||
if rewriter_config is not None and not isinstance(
|
if rewriter_config_template is not None and not isinstance(
|
||||||
rewriter_config, rewriter_config_pb2.RewriterConfig):
|
rewriter_config_template, rewriter_config_pb2.RewriterConfig):
|
||||||
raise TypeError("rewriter_config should be a RewriterConfig proto.")
|
raise TypeError(
|
||||||
|
"rewriter_config_template should be a RewriterConfig proto.")
|
||||||
|
|
||||||
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
|
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
|
||||||
if rewriter_config is None:
|
if rewriter_config_template is None:
|
||||||
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
|
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
|
||||||
# need to run constant folding again.
|
# need to run constant folding again.
|
||||||
rewriter_config_with_trt.optimizers.extend(
|
rewriter_config_with_trt.optimizers.extend(
|
||||||
@ -139,13 +371,7 @@ def get_tensorrt_rewriter_config(
|
|||||||
rewriter_config_with_trt.meta_optimizer_iterations = (
|
rewriter_config_with_trt.meta_optimizer_iterations = (
|
||||||
rewriter_config_pb2.RewriterConfig.ONE)
|
rewriter_config_pb2.RewriterConfig.ONE)
|
||||||
else:
|
else:
|
||||||
rewriter_config_with_trt.CopyFrom(rewriter_config)
|
rewriter_config_with_trt.CopyFrom(rewriter_config_template)
|
||||||
|
|
||||||
if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes():
|
|
||||||
raise ValueError(("precision mode '{}' is not supported."
|
|
||||||
"It should be one of {}").format(
|
|
||||||
precision_mode,
|
|
||||||
TrtPrecisionMode.supported_precision_modes))
|
|
||||||
|
|
||||||
optimizer = rewriter_config_with_trt.custom_optimizers.add()
|
optimizer = rewriter_config_with_trt.custom_optimizers.add()
|
||||||
optimizer.name = "TensorRTOptimizer"
|
optimizer.name = "TensorRTOptimizer"
|
||||||
@ -156,16 +382,147 @@ def get_tensorrt_rewriter_config(
|
|||||||
"max_workspace_size_bytes"].i = max_workspace_size_bytes
|
"max_workspace_size_bytes"].i = max_workspace_size_bytes
|
||||||
optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
|
optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
|
||||||
optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
|
optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
|
||||||
|
if cached_engine_batches:
|
||||||
|
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
|
||||||
|
cached_engine_batches)
|
||||||
|
optimizer.parameter_map["use_calibration"].b = use_calibration
|
||||||
|
return rewriter_config_with_trt
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_saved_model_dir=None,
|
||||||
|
input_saved_model_tags=None,
|
||||||
|
input_graph_def=None,
|
||||||
|
nodes_blacklist=None,
|
||||||
|
session_config=None,
|
||||||
|
max_batch_size=1,
|
||||||
|
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||||
|
precision_mode=TrtPrecisionMode.FP32,
|
||||||
|
minimum_segment_size=3,
|
||||||
|
is_dynamic_op=False,
|
||||||
|
maximum_cached_engines=1,
|
||||||
|
cached_engine_batches=None,
|
||||||
|
use_calibration=True):
|
||||||
|
"""Initialize the converter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_saved_model_dir: the directory to load the SavedModel which contains
|
||||||
|
the input graph to transforms. Used only when input_graph_def is None.
|
||||||
|
input_saved_model_tags: list of tags to load the SavedModel.
|
||||||
|
input_graph_def: a GraphDef object containing a model to be transformed.
|
||||||
|
If set to None, the graph will be read from the SavedModel loaded from
|
||||||
|
input_saved_model_dir.
|
||||||
|
nodes_blacklist: list of node names to prevent the converter from
|
||||||
|
touching. Only used when input_graph_def is not None.
|
||||||
|
session_config: the ConfigProto used to create a Session. It's also used
|
||||||
|
as a template to create a TRT-enabled ConfigProto for conversion. If not
|
||||||
|
specified, a default ConfigProto will be used.
|
||||||
|
max_batch_size: max size for the input batch.
|
||||||
|
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
|
||||||
|
engine can use at execution time. This corresponds to the
|
||||||
|
'workspaceSize'
|
||||||
|
parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
|
||||||
|
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
|
||||||
|
minimum_segment_size: the minimum number of nodes required for a subgraph
|
||||||
|
to be replaced by TRTEngineOp.
|
||||||
|
is_dynamic_op: whether to generate dynamic TRT ops which will build the
|
||||||
|
TRT network and engine at run time.
|
||||||
|
maximum_cached_engines: max number of cached TRT engines in dynamic TRT
|
||||||
|
ops. If the number of cached engines is already at max but none of them
|
||||||
|
can serve the input, the TRTEngineOp will fall back to run the TF
|
||||||
|
function based on which the TRTEngineOp is created.
|
||||||
|
cached_engine_batches: a list of batch sizes used to create cached
|
||||||
|
engines, only used when is_dynamic_op is True. The length of the list
|
||||||
|
should be <= maximum_cached_engines, and the dynamic TRT op will use
|
||||||
|
this list to determine the batch sizes of the cached engines, instead of
|
||||||
|
making the decision on the fly. This is useful when we know the most
|
||||||
|
common batch size(s) the application is going to generate.
|
||||||
|
use_calibration: this argument is ignored if precision_mode is not INT8.
|
||||||
|
If set to True, a calibration graph will be created to calibrate the
|
||||||
|
missing ranges. The calibration graph must be converted to an inference
|
||||||
|
graph using calib_graph_to_infer_graph() after running calibration. if
|
||||||
|
set to False, quantization nodes will be expected for every tensor in
|
||||||
|
the graph (exlcuding those which will be fused). If a range is missing,
|
||||||
|
an error will occur. Please note that accuracy may be negatively
|
||||||
|
affected if there is a mismatch between which tensors TRT quantizes and
|
||||||
|
which tensors were trained with fake quantization.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the combination of the parameters is invalid.
|
||||||
|
RuntimeError: if the TensorRT library version is incompatible.
|
||||||
|
"""
|
||||||
|
super(TrtGraphConverter, self).__init__(
|
||||||
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
|
input_saved_model_tags=input_saved_model_tags,
|
||||||
|
input_graph_def=input_graph_def,
|
||||||
|
nodes_blacklist=nodes_blacklist,
|
||||||
|
session_config=session_config)
|
||||||
|
|
||||||
|
# Check compatibility of TensorRT version.
|
||||||
|
compiled_version = get_linked_tensorrt_version()
|
||||||
|
loaded_version = get_loaded_tensorrt_version()
|
||||||
|
version_mismatch = False
|
||||||
|
if loaded_version[0] < compiled_version[0]:
|
||||||
|
tf_logging.error(
|
||||||
|
"TensorRT version mismatch. Tensorflow was compiled against " +
|
||||||
|
"TensorRT %s but library loaded from environment is TensorRT %s" %
|
||||||
|
(".".join([str(x) for x in compiled_version]),
|
||||||
|
".".join([str(x) for x in loaded_version])) +
|
||||||
|
". Please make sure that correct version of TensorRT " +
|
||||||
|
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
|
||||||
|
raise RuntimeError("Incompatible TensorRT library version")
|
||||||
|
for i in zip(loaded_version, compiled_version):
|
||||||
|
if i[0] != i[1]:
|
||||||
|
tf_logging.warn("TensorRT mismatch. Compiled against version " +
|
||||||
|
"%s, but loaded %s. Things may not work" %
|
||||||
|
(".".join([str(x) for x in compiled_version]),
|
||||||
|
".".join([str(x) for x in loaded_version])))
|
||||||
|
version_mismatch = True
|
||||||
|
break
|
||||||
|
if not version_mismatch:
|
||||||
|
tf_logging.info("Running against TensorRT version %s" % ".".join(
|
||||||
|
[str(x) for x in loaded_version]))
|
||||||
|
|
||||||
|
# Check input arguments.
|
||||||
|
if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(
|
||||||
|
):
|
||||||
|
raise ValueError(("precision mode '{}' is not supported."
|
||||||
|
"It should be one of {}").format(
|
||||||
|
precision_mode,
|
||||||
|
TrtPrecisionMode.supported_precision_modes))
|
||||||
|
|
||||||
if cached_engine_batches:
|
if cached_engine_batches:
|
||||||
if not isinstance(cached_engine_batches, list):
|
if not isinstance(cached_engine_batches, list):
|
||||||
raise TypeError("cached_engine_batches should be a list.")
|
raise TypeError("cached_engine_batches should be a list.")
|
||||||
if len(cached_engine_batches) > maximum_cached_engines:
|
if len(cached_engine_batches) > maximum_cached_engines:
|
||||||
raise ValueError("cached_engine_batches should not contain more than "
|
raise ValueError("cached_engine_batches should not contain more than "
|
||||||
"maximum_cached_engines items.")
|
"maximum_cached_engines items.")
|
||||||
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
|
|
||||||
cached_engine_batches)
|
# TODO(laigd):
|
||||||
optimizer.parameter_map["use_calibration"].b = use_calibration
|
# - Get rid of is_dynamic_op option, it should always be True, and it should
|
||||||
return rewriter_config_with_trt
|
# accept N shapes as input.
|
||||||
|
# - Verify in int8 mode that maximum_cached_engines and
|
||||||
|
# cached_engine_batches are set appropriately.
|
||||||
|
# - If it fails to build the int8 engine it should return error.
|
||||||
|
self._max_batch_size = max_batch_size
|
||||||
|
self._max_workspace_size_bytes = max_workspace_size_bytes
|
||||||
|
self._precision_mode = precision_mode
|
||||||
|
self._minimum_segment_size = minimum_segment_size
|
||||||
|
self._is_dynamic_op = is_dynamic_op
|
||||||
|
self._maximum_cached_engines = maximum_cached_engines
|
||||||
|
self._cached_engine_batches = cached_engine_batches
|
||||||
|
self._use_calibration = use_calibration
|
||||||
|
|
||||||
|
def get_rewriter_config(self, rewriter_config_template=None):
|
||||||
|
return TrtGraphConverter.get_tensorrt_rewriter_config(
|
||||||
|
rewriter_config_template,
|
||||||
|
max_batch_size=self._max_batch_size,
|
||||||
|
max_workspace_size_bytes=self._max_workspace_size_bytes,
|
||||||
|
precision_mode=self._precision_mode,
|
||||||
|
minimum_segment_size=self._minimum_segment_size,
|
||||||
|
is_dynamic_op=self._is_dynamic_op,
|
||||||
|
maximum_cached_engines=self._maximum_cached_engines,
|
||||||
|
cached_engine_batches=self._cached_engine_batches,
|
||||||
|
use_calibration=self._use_calibration)
|
||||||
|
|
||||||
|
|
||||||
def create_inference_graph(
|
def create_inference_graph(
|
||||||
@ -253,144 +610,24 @@ def create_inference_graph(
|
|||||||
ValueError: if the combination of the parameters is invalid.
|
ValueError: if the combination of the parameters is invalid.
|
||||||
RuntimeError: if the TensorRT library version is incompatible.
|
RuntimeError: if the TensorRT library version is incompatible.
|
||||||
"""
|
"""
|
||||||
compiled_version = get_linked_tensorrt_version()
|
trt_converter = TrtGraphConverter(
|
||||||
loaded_version = get_loaded_tensorrt_version()
|
input_saved_model_dir=input_saved_model_dir,
|
||||||
version_mismatch = False
|
input_saved_model_tags=input_saved_model_tags,
|
||||||
if loaded_version[0] < compiled_version[0]:
|
input_graph_def=input_graph_def,
|
||||||
tf_logging.error(
|
nodes_blacklist=outputs,
|
||||||
"TensorRT version mismatch. Tensorflow was compiled against " +
|
session_config=session_config,
|
||||||
"TensorRT %s but library loaded from environment is TensorRT %s" %
|
max_batch_size=max_batch_size,
|
||||||
(".".join([str(x) for x in compiled_version]),
|
max_workspace_size_bytes=max_workspace_size_bytes,
|
||||||
".".join([str(x) for x in loaded_version])) +
|
precision_mode=precision_mode,
|
||||||
". Please make sure that correct version of TensorRT " +
|
minimum_segment_size=minimum_segment_size,
|
||||||
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
|
is_dynamic_op=is_dynamic_op,
|
||||||
raise RuntimeError("Incompatible TensorRT library version")
|
maximum_cached_engines=maximum_cached_engines,
|
||||||
for i in zip(loaded_version, compiled_version):
|
cached_engine_batches=cached_engine_batches,
|
||||||
if i[0] != i[1]:
|
use_calibration=use_calibration)
|
||||||
tf_logging.warn("TensorRT mismatch. Compiled against version " +
|
converted_graph_def = trt_converter.convert()
|
||||||
"%s, but loaded %s. Things may not work" %
|
if output_saved_model_dir:
|
||||||
(".".join([str(x) for x in compiled_version]),
|
trt_converter.save(output_saved_model_dir)
|
||||||
".".join([str(x) for x in loaded_version])))
|
return converted_graph_def
|
||||||
version_mismatch = True
|
|
||||||
break
|
|
||||||
if not version_mismatch:
|
|
||||||
tf_logging.info("Running against TensorRT version %s" % ".".join(
|
|
||||||
[str(x) for x in loaded_version]))
|
|
||||||
|
|
||||||
if session_config is None:
|
|
||||||
session_config = config_pb2.ConfigProto()
|
|
||||||
|
|
||||||
if input_saved_model_tags is None:
|
|
||||||
input_saved_model_tags = [tag_constants.SERVING]
|
|
||||||
saved_model_loader = None
|
|
||||||
grappler_meta_graph_def = None
|
|
||||||
|
|
||||||
if input_graph_def is None:
|
|
||||||
# Read from SavedModel and freeze the graph if necessary.
|
|
||||||
if input_saved_model_dir is None:
|
|
||||||
raise ValueError("input_graph_def and input_saved_model_dir cannot be "
|
|
||||||
"both None")
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
with session.Session(config=session_config) as sess:
|
|
||||||
saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir)
|
|
||||||
input_meta_graph_def = saved_model_loader.load(sess,
|
|
||||||
input_saved_model_tags)
|
|
||||||
output_node_names = set()
|
|
||||||
|
|
||||||
def _gather_names(tensor_info):
|
|
||||||
"""Get the node names from a TensorInfo."""
|
|
||||||
return set(
|
|
||||||
[tensor_info[key].name.split(":")[0] for key in tensor_info])
|
|
||||||
|
|
||||||
# Get input and outputs from all SignatureDef.
|
|
||||||
for key in input_meta_graph_def.signature_def:
|
|
||||||
signature_def = input_meta_graph_def.signature_def[key]
|
|
||||||
output_node_names.update(_gather_names(signature_def.inputs))
|
|
||||||
output_node_names.update(_gather_names(signature_def.outputs))
|
|
||||||
|
|
||||||
# Freeze the variables in the SavedModel graph and copy the frozen
|
|
||||||
# graph over.
|
|
||||||
frozen_graph_def = graph_util.convert_variables_to_constants(
|
|
||||||
sess, sess.graph.as_graph_def(add_shapes=True),
|
|
||||||
list(output_node_names))
|
|
||||||
grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
|
|
||||||
grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
|
|
||||||
|
|
||||||
# Copy the collections that are not variables.
|
|
||||||
for key in input_meta_graph_def.collection_def:
|
|
||||||
# TODO(laigd): currently we use the collection key to filter out
|
|
||||||
# collections that depend on variable ops, but this may miss some
|
|
||||||
# other user-defined collections. A better way would be to use
|
|
||||||
# CollectionDef::NodeList for the filtering.
|
|
||||||
if key not in [
|
|
||||||
"variables", "local_variables", "model_variables",
|
|
||||||
"trainable_variables", "train_op", "table_initializer"
|
|
||||||
]:
|
|
||||||
grappler_meta_graph_def.collection_def[key].CopyFrom(
|
|
||||||
input_meta_graph_def.collection_def[key])
|
|
||||||
|
|
||||||
# Copy other information.
|
|
||||||
grappler_meta_graph_def.meta_info_def.CopyFrom(
|
|
||||||
input_meta_graph_def.meta_info_def)
|
|
||||||
for key in input_meta_graph_def.signature_def:
|
|
||||||
grappler_meta_graph_def.signature_def[key].CopyFrom(
|
|
||||||
input_meta_graph_def.signature_def[key])
|
|
||||||
# TODO(laigd): maybe add back AssetFileDef.
|
|
||||||
else:
|
|
||||||
if output_saved_model_dir is not None:
|
|
||||||
raise ValueError("output_saved_model_dir cannot be set when "
|
|
||||||
"input_graph_def is set")
|
|
||||||
# Create MetaGraphDef from input graph.
|
|
||||||
graph = ops.Graph()
|
|
||||||
with graph.as_default():
|
|
||||||
importer.import_graph_def(input_graph_def, name="")
|
|
||||||
grappler_meta_graph_def = saver.export_meta_graph(
|
|
||||||
graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
|
|
||||||
if outputs:
|
|
||||||
output_collection = meta_graph_pb2.CollectionDef()
|
|
||||||
output_list = output_collection.node_list.value
|
|
||||||
for i in outputs:
|
|
||||||
if isinstance(i, ops.Tensor):
|
|
||||||
output_list.append(_to_bytes(i.name))
|
|
||||||
else:
|
|
||||||
output_list.append(_to_bytes(i))
|
|
||||||
# TODO(laigd): use another key as the outputs are really not train_op.
|
|
||||||
grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
|
|
||||||
output_collection)
|
|
||||||
|
|
||||||
# Create TRT-enabled ConfigProto.
|
|
||||||
session_config_with_trt = config_pb2.ConfigProto()
|
|
||||||
session_config_with_trt.CopyFrom(session_config)
|
|
||||||
rewriter_config = None
|
|
||||||
if (session_config_with_trt.HasField("graph_options") and
|
|
||||||
session_config_with_trt.graph_options.HasField("rewrite_options")):
|
|
||||||
rewriter_config = session_config_with_trt.graph_options.rewrite_options
|
|
||||||
rewriter_config_with_trt = get_tensorrt_rewriter_config(
|
|
||||||
rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode,
|
|
||||||
minimum_segment_size, is_dynamic_op, maximum_cached_engines,
|
|
||||||
cached_engine_batches, use_calibration)
|
|
||||||
session_config_with_trt.graph_options.rewrite_options.CopyFrom(
|
|
||||||
rewriter_config_with_trt)
|
|
||||||
|
|
||||||
# Run Grappler.
|
|
||||||
transformed_graph_def = tf_optimizer.OptimizeGraph(
|
|
||||||
session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph")
|
|
||||||
|
|
||||||
# Optionally write the transformed graphdef as SavedModel.
|
|
||||||
if output_saved_model_dir is not None:
|
|
||||||
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
|
|
||||||
with ops.Graph().as_default():
|
|
||||||
importer.import_graph_def(transformed_graph_def, name="")
|
|
||||||
# We don't use TRT here.
|
|
||||||
with session.Session(config=session_config) as sess:
|
|
||||||
saved_model_builder.add_meta_graph_and_variables(
|
|
||||||
sess,
|
|
||||||
input_saved_model_tags,
|
|
||||||
signature_def_map=grappler_meta_graph_def.signature_def)
|
|
||||||
# Ignore other meta graphs from the input SavedModel.
|
|
||||||
saved_model_builder.save()
|
|
||||||
|
|
||||||
return transformed_graph_def
|
|
||||||
|
|
||||||
|
|
||||||
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
|
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
|
||||||
|
@ -52,9 +52,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
_TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20
|
_TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20
|
||||||
|
|
||||||
def testGetTensorrtRewriterConfig(self):
|
def testGetTensorrtRewriterConfig(self):
|
||||||
"""Test case for trt_convert.get_tensorrt_rewriter_config()."""
|
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
|
||||||
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
|
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
|
||||||
rewriter_config=None,
|
rewriter_config_template=None,
|
||||||
max_batch_size=128,
|
max_batch_size=128,
|
||||||
max_workspace_size_bytes=1234,
|
max_workspace_size_bytes=1234,
|
||||||
precision_mode="INT8",
|
precision_mode="INT8",
|
||||||
|
Loading…
Reference in New Issue
Block a user