Refactor TF-TRT API:

- Add a new GraphConverter class to provide general support for offline
  GraphDef/SavedModel conversion. It also supports post-conversion calibration.
  This way we can make a backend specific converter by inheriting this class
  and overwrite the get_rewriter_config() method to provide backend specific
  rewriter options.
- Add a new TrtGraphConverter class (inherited from GraphConverter) for TRT
  conversion.

PiperOrigin-RevId: 232514321
This commit is contained in:
Guangda Lai 2019-02-05 10:39:23 -08:00 committed by TensorFlower Gardener
parent f722aee786
commit cc79252c0b
3 changed files with 465 additions and 228 deletions

View File

@ -259,7 +259,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Get config proto based on specific settings."""
conversion_params = self.GetConversionParams(run_params)
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
conversion_params.rewriter_config, conversion_params.max_batch_size,
conversion_params.max_workspace_size_bytes,
conversion_params.precision_mode,

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import six as _six
# pylint: disable=unused-import,line-too-long
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
@ -41,7 +42,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import builder
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
@ -60,6 +61,228 @@ def _to_string(s):
return s
class GraphConverter(object):
"""Base class for offline converters to optimize SavedModels/GraphDefs.
A `GraphConverter` object encapsulates the environment to convert (optimize) a
TensorFlow SavedModel or GraphDef.
To create a custom GraphConverter:
```python
class MyGraphConverter(GraphConverter):
...
def get_rewriter_config(self, rewriter_config_template=None):
my_rewriter_config = ...
return my_rewriter_config
```
Then to run the conversion without quantization calibration:
```python
my_converter = MyGraphConverter(input_saved_model_dir="my_dir")
converted_graph_def = my_converter.convert()
my_converter.save(output_saved_model_dir) # Optional
```
TODO(laigd): add calibration support.
"""
def __init__(self,
input_saved_model_dir=None,
input_saved_model_tags=None,
input_graph_def=None,
nodes_blacklist=None,
session_config=None):
"""Initialize the converter.
Args:
input_saved_model_dir: the directory to load the SavedModel which contains
the input graph to transforms. Used only when input_graph_def is None.
input_saved_model_tags: list of tags to load the SavedModel.
input_graph_def: a GraphDef object containing a model to be transformed.
If set to None, the graph will be read from the SavedModel loaded from
input_saved_model_dir.
nodes_blacklist: list of node names to prevent the converter from
touching. Only used when input_graph_def is not None.
session_config: the ConfigProto used to create a Session. It's also used
as a template to create a RewriterConfig for conversion. If not
specified, a default ConfigProto will be used.
Raises:
ValueError: if the combination of the parameters is invalid.
"""
if input_graph_def and input_saved_model_dir:
raise ValueError(
"Can only specify one of input_graph_def and input_saved_model_dir")
if not input_graph_def and not input_saved_model_dir:
raise ValueError("Must specify one of input_graph_def and "
"input_saved_model_dir")
self._input_graph_def = input_graph_def
self._nodes_blacklist = nodes_blacklist
self._input_saved_model_dir = input_saved_model_dir
self._converted = False
self._grappler_meta_graph_def = None
self._input_saved_model_tags = (
input_saved_model_tags or [tag_constants.SERVING])
self._session_config = session_config or config_pb2.ConfigProto()
def get_rewriter_config(self, rewriter_config_template=None):
"""Returns a RewriterConfig proto for TRT transformation.
Args:
rewriter_config_template: a template RewriterConfig proto used to create a
RewriterConfig for the conversion. The implementation should not modify
the template. If None, it will use a default one.
Returns:
A RewriterConfig proto which will be used to run the conversion using
Grappler.
"""
raise NotImplementedError("get_rewriter_config")
def _run_conversion(self):
"""Run Grappler's OptimizeGraph() tool to convert the graph."""
# Create custom ConfigProto for Grappler.
grappler_session_config = config_pb2.ConfigProto()
grappler_session_config.CopyFrom(self._session_config)
rewriter_config = None
if (grappler_session_config.HasField("graph_options") and
grappler_session_config.graph_options.HasField("rewrite_options")):
rewriter_config = grappler_session_config.graph_options.rewrite_options
custom_rewriter_config = self.get_rewriter_config(rewriter_config)
grappler_session_config.graph_options.rewrite_options.CopyFrom(
custom_rewriter_config)
# Run Grappler.
self._converted_graph_def = tf_optimizer.OptimizeGraph(
grappler_session_config,
self._grappler_meta_graph_def,
graph_id=b"tf_graph")
self._converted = True
def _convert_graph_def(self):
"""Convert the input GraphDef."""
graph = ops.Graph()
with graph.as_default():
importer.import_graph_def(self._input_graph_def, name="")
self._grappler_meta_graph_def = saver.export_meta_graph(
graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
if self._nodes_blacklist:
output_collection = meta_graph_pb2.CollectionDef()
output_list = output_collection.node_list.value
for i in self._nodes_blacklist:
if isinstance(i, ops.Tensor):
output_list.append(_to_bytes(i.name))
else:
output_list.append(_to_bytes(i))
# TODO(laigd): use another key as the self._nodes_blacklist are really
# not train_op.
self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
output_collection)
self._run_conversion()
def _convert_saved_model(self):
"""Convert the input SavedModel."""
graph = ops.Graph()
with session.Session(graph=graph, config=self._session_config) as sess:
input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
self._input_saved_model_dir)
def _gather_names(tensor_info):
"""Get the node names from a TensorInfo."""
return set([tensor_info[key].name.split(":")[0] for key in tensor_info])
# Get input and outputs from all SignatureDef.
output_node_names = set()
for key in input_meta_graph_def.signature_def:
signature_def = input_meta_graph_def.signature_def[key]
output_node_names.update(_gather_names(signature_def.inputs))
output_node_names.update(_gather_names(signature_def.outputs))
# Freeze the variables in the SavedModel graph and copy the frozen
# graph over.
frozen_graph_def = graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(add_shapes=True),
list(output_node_names))
self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
# Copy the collections that are not variables.
for key in input_meta_graph_def.collection_def:
# TODO(laigd): currently we use the collection key to filter out
# collections that depend on variable ops, but this may miss some
# other user-defined collections. A better way would be to use
# CollectionDef::NodeList for the filtering.
if key not in [
"variables", "local_variables", "model_variables",
"trainable_variables", "train_op", "table_initializer"
]:
self._grappler_meta_graph_def.collection_def[key].CopyFrom(
input_meta_graph_def.collection_def[key])
# Copy other information.
self._grappler_meta_graph_def.meta_info_def.CopyFrom(
input_meta_graph_def.meta_info_def)
for key in input_meta_graph_def.signature_def:
self._grappler_meta_graph_def.signature_def[key].CopyFrom(
input_meta_graph_def.signature_def[key])
# TODO(laigd): maybe add back AssetFileDef.
self._run_conversion()
def convert(self):
"""Run the conversion.
Returns:
The converted GraphDef.
"""
assert not self._converted
if self._input_graph_def:
self._convert_graph_def()
else:
self._convert_saved_model()
return self._converted_graph_def
def save(self, output_saved_model_dir):
"""Save the converted graph as a SavedModel.
Args:
output_saved_model_dir: construct a SavedModel using the converted
GraphDef and save it to the specified directory. This option only works
when the input graph is loaded from a SavedModel, i.e. when
input_saved_model_dir is specified and input_graph_def is None in
__init__().
Raises:
ValueError: if the input to the converter is a GraphDef instead of a
SavedModel.
"""
assert self._converted
if self._input_graph_def:
raise ValueError(
"Not able to save to a SavedModel since input is a GraphDef")
# Write the transformed graphdef as SavedModel.
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
with ops.Graph().as_default():
importer.import_graph_def(self._converted_graph_def, name="")
# We don't use any specific converter here.
with session.Session(config=self._session_config) as sess:
saved_model_builder.add_meta_graph_and_variables(
sess,
self._input_saved_model_tags,
signature_def_map=self._grappler_meta_graph_def.signature_def)
# Ignore other meta graphs from the input SavedModel.
saved_model_builder.save()
class TrtPrecisionMode(object):
FP32 = "FP32"
FP16 = "FP16"
@ -75,97 +298,231 @@ class TrtPrecisionMode(object):
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
def get_tensorrt_rewriter_config(
rewriter_config=None,
max_batch_size=1,
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
cached_engine_batches=None,
use_calibration=True):
"""Returns a RewriterConfig proto for TRT transformation.
class TrtGraphConverter(GraphConverter):
"""A GraphConverter for TRT transformation."""
Args:
rewriter_config: a template RewriterConfig proto used to create a
TRT-enabled RewriterConfig. If None, it will use a default one.
max_batch_size: max size for the input batch
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
engine can use at execution time. This corresponds to the 'workspaceSize'
parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph to
be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
If the number of cached engines is already at max but none of them can
serve the input, the TRTEngineOp will fall back to run the TF function
based on which the TRTEngineOp is created.
cached_engine_batches: a list of batch sizes used to create cached engines,
only used when is_dynamic_op is True. The length of the list should be <=
maximum_cached_engines, and the dynamic TRT op will use this list to
determine the batch sizes of the cached engines, instead of making the
decision on the fly. This is useful when we know the most common batch
size(s) the application is going to generate.
use_calibration: this argument is ignored if precision_mode is not INT8. If
set to True, a calibration graph will be created to calibrate the missing
ranges. The calibration graph must be converted to an inference graph
using calib_graph_to_infer_graph() after running calibration. if set to
False, quantization nodes will be expected for every tensor in the graph
(exlcuding those which will be fused). If a range is missing, an error
will occur. Please note that accuracy may be negatively affected if there
is a mismatch between which tensors TRT quantizes and which tensors were
trained with fake quantization.
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF_TRT_Calibration"
Returns:
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
@classmethod
def get_tensorrt_rewriter_config(
cls,
rewriter_config_template=None,
max_batch_size=1,
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
cached_engine_batches=None,
use_calibration=True):
"""Returns a RewriterConfig proto for TRT transformation.
Raises:
TypeError: if any of the parameters are of unexpected type.
ValueError: if any of the parameters are of unexpected value.
"""
if rewriter_config is not None and not isinstance(
rewriter_config, rewriter_config_pb2.RewriterConfig):
raise TypeError("rewriter_config should be a RewriterConfig proto.")
Args:
rewriter_config_template: a template RewriterConfig proto used to create a
TRT-enabled RewriterConfig. If None, it will use a default one.
max_batch_size: max size for the input batch
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
engine can use at execution time. This corresponds to the
'workspaceSize'
parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph
to be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the
TRT network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT
ops. If the number of cached engines is already at max but none of them
can serve the input, the TRTEngineOp will fall back to run the TF
function based on which the TRTEngineOp is created.
cached_engine_batches: a list of batch sizes used to create cached
engines, only used when is_dynamic_op is True. The length of the list
should be <= maximum_cached_engines, and the dynamic TRT op will use
this list to determine the batch sizes of the cached engines, instead of
making the decision on the fly. This is useful when we know the most
common batch size(s) the application is going to generate.
use_calibration: this argument is ignored if precision_mode is not INT8.
If set to True, a calibration graph will be created to calibrate the
missing ranges. The calibration graph must be converted to an inference
graph using calib_graph_to_infer_graph() after running calibration. if
set to False, quantization nodes will be expected for every tensor in
the graph (exlcuding those which will be fused). If a range is missing,
an error will occur. Please note that accuracy may be negatively
affected if there is a mismatch between which tensors TRT quantizes and
which tensors were trained with fake quantization.
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
if rewriter_config is None:
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
# need to run constant folding again.
rewriter_config_with_trt.optimizers.extend(
["constfold", "layout", "constfold"])
rewriter_config_with_trt.meta_optimizer_iterations = (
rewriter_config_pb2.RewriterConfig.ONE)
else:
rewriter_config_with_trt.CopyFrom(rewriter_config)
Returns:
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes():
raise ValueError(("precision mode '{}' is not supported."
"It should be one of {}").format(
precision_mode,
TrtPrecisionMode.supported_precision_modes))
Raises:
TypeError: if any of the parameters are of unexpected type.
ValueError: if any of the parameters are of unexpected value.
"""
if rewriter_config_template is not None and not isinstance(
rewriter_config_template, rewriter_config_pb2.RewriterConfig):
raise TypeError(
"rewriter_config_template should be a RewriterConfig proto.")
optimizer = rewriter_config_with_trt.custom_optimizers.add()
optimizer.name = "TensorRTOptimizer"
optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
optimizer.parameter_map["max_batch_size"].i = max_batch_size
optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
optimizer.parameter_map[
"max_workspace_size_bytes"].i = max_workspace_size_bytes
optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
if cached_engine_batches:
if not isinstance(cached_engine_batches, list):
raise TypeError("cached_engine_batches should be a list.")
if len(cached_engine_batches) > maximum_cached_engines:
raise ValueError("cached_engine_batches should not contain more than "
"maximum_cached_engines items.")
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
cached_engine_batches)
optimizer.parameter_map["use_calibration"].b = use_calibration
return rewriter_config_with_trt
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
if rewriter_config_template is None:
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
# need to run constant folding again.
rewriter_config_with_trt.optimizers.extend(
["constfold", "layout", "constfold"])
rewriter_config_with_trt.meta_optimizer_iterations = (
rewriter_config_pb2.RewriterConfig.ONE)
else:
rewriter_config_with_trt.CopyFrom(rewriter_config_template)
optimizer = rewriter_config_with_trt.custom_optimizers.add()
optimizer.name = "TensorRTOptimizer"
optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
optimizer.parameter_map["max_batch_size"].i = max_batch_size
optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
optimizer.parameter_map[
"max_workspace_size_bytes"].i = max_workspace_size_bytes
optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
if cached_engine_batches:
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
cached_engine_batches)
optimizer.parameter_map["use_calibration"].b = use_calibration
return rewriter_config_with_trt
def __init__(self,
input_saved_model_dir=None,
input_saved_model_tags=None,
input_graph_def=None,
nodes_blacklist=None,
session_config=None,
max_batch_size=1,
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
cached_engine_batches=None,
use_calibration=True):
"""Initialize the converter.
Args:
input_saved_model_dir: the directory to load the SavedModel which contains
the input graph to transforms. Used only when input_graph_def is None.
input_saved_model_tags: list of tags to load the SavedModel.
input_graph_def: a GraphDef object containing a model to be transformed.
If set to None, the graph will be read from the SavedModel loaded from
input_saved_model_dir.
nodes_blacklist: list of node names to prevent the converter from
touching. Only used when input_graph_def is not None.
session_config: the ConfigProto used to create a Session. It's also used
as a template to create a TRT-enabled ConfigProto for conversion. If not
specified, a default ConfigProto will be used.
max_batch_size: max size for the input batch.
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
engine can use at execution time. This corresponds to the
'workspaceSize'
parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph
to be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the
TRT network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT
ops. If the number of cached engines is already at max but none of them
can serve the input, the TRTEngineOp will fall back to run the TF
function based on which the TRTEngineOp is created.
cached_engine_batches: a list of batch sizes used to create cached
engines, only used when is_dynamic_op is True. The length of the list
should be <= maximum_cached_engines, and the dynamic TRT op will use
this list to determine the batch sizes of the cached engines, instead of
making the decision on the fly. This is useful when we know the most
common batch size(s) the application is going to generate.
use_calibration: this argument is ignored if precision_mode is not INT8.
If set to True, a calibration graph will be created to calibrate the
missing ranges. The calibration graph must be converted to an inference
graph using calib_graph_to_infer_graph() after running calibration. if
set to False, quantization nodes will be expected for every tensor in
the graph (exlcuding those which will be fused). If a range is missing,
an error will occur. Please note that accuracy may be negatively
affected if there is a mismatch between which tensors TRT quantizes and
which tensors were trained with fake quantization.
Raises:
ValueError: if the combination of the parameters is invalid.
RuntimeError: if the TensorRT library version is incompatible.
"""
super(TrtGraphConverter, self).__init__(
input_saved_model_dir=input_saved_model_dir,
input_saved_model_tags=input_saved_model_tags,
input_graph_def=input_graph_def,
nodes_blacklist=nodes_blacklist,
session_config=session_config)
# Check compatibility of TensorRT version.
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
version_mismatch = False
if loaded_version[0] < compiled_version[0]:
tf_logging.error(
"TensorRT version mismatch. Tensorflow was compiled against " +
"TensorRT %s but library loaded from environment is TensorRT %s" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])) +
". Please make sure that correct version of TensorRT " +
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
raise RuntimeError("Incompatible TensorRT library version")
for i in zip(loaded_version, compiled_version):
if i[0] != i[1]:
tf_logging.warn("TensorRT mismatch. Compiled against version " +
"%s, but loaded %s. Things may not work" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])))
version_mismatch = True
break
if not version_mismatch:
tf_logging.info("Running against TensorRT version %s" % ".".join(
[str(x) for x in loaded_version]))
# Check input arguments.
if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(
):
raise ValueError(("precision mode '{}' is not supported."
"It should be one of {}").format(
precision_mode,
TrtPrecisionMode.supported_precision_modes))
if cached_engine_batches:
if not isinstance(cached_engine_batches, list):
raise TypeError("cached_engine_batches should be a list.")
if len(cached_engine_batches) > maximum_cached_engines:
raise ValueError("cached_engine_batches should not contain more than "
"maximum_cached_engines items.")
# TODO(laigd):
# - Get rid of is_dynamic_op option, it should always be True, and it should
# accept N shapes as input.
# - Verify in int8 mode that maximum_cached_engines and
# cached_engine_batches are set appropriately.
# - If it fails to build the int8 engine it should return error.
self._max_batch_size = max_batch_size
self._max_workspace_size_bytes = max_workspace_size_bytes
self._precision_mode = precision_mode
self._minimum_segment_size = minimum_segment_size
self._is_dynamic_op = is_dynamic_op
self._maximum_cached_engines = maximum_cached_engines
self._cached_engine_batches = cached_engine_batches
self._use_calibration = use_calibration
def get_rewriter_config(self, rewriter_config_template=None):
return TrtGraphConverter.get_tensorrt_rewriter_config(
rewriter_config_template,
max_batch_size=self._max_batch_size,
max_workspace_size_bytes=self._max_workspace_size_bytes,
precision_mode=self._precision_mode,
minimum_segment_size=self._minimum_segment_size,
is_dynamic_op=self._is_dynamic_op,
maximum_cached_engines=self._maximum_cached_engines,
cached_engine_batches=self._cached_engine_batches,
use_calibration=self._use_calibration)
def create_inference_graph(
@ -253,144 +610,24 @@ def create_inference_graph(
ValueError: if the combination of the parameters is invalid.
RuntimeError: if the TensorRT library version is incompatible.
"""
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
version_mismatch = False
if loaded_version[0] < compiled_version[0]:
tf_logging.error(
"TensorRT version mismatch. Tensorflow was compiled against " +
"TensorRT %s but library loaded from environment is TensorRT %s" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])) +
". Please make sure that correct version of TensorRT " +
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
raise RuntimeError("Incompatible TensorRT library version")
for i in zip(loaded_version, compiled_version):
if i[0] != i[1]:
tf_logging.warn("TensorRT mismatch. Compiled against version " +
"%s, but loaded %s. Things may not work" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])))
version_mismatch = True
break
if not version_mismatch:
tf_logging.info("Running against TensorRT version %s" % ".".join(
[str(x) for x in loaded_version]))
if session_config is None:
session_config = config_pb2.ConfigProto()
if input_saved_model_tags is None:
input_saved_model_tags = [tag_constants.SERVING]
saved_model_loader = None
grappler_meta_graph_def = None
if input_graph_def is None:
# Read from SavedModel and freeze the graph if necessary.
if input_saved_model_dir is None:
raise ValueError("input_graph_def and input_saved_model_dir cannot be "
"both None")
with ops.Graph().as_default():
with session.Session(config=session_config) as sess:
saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir)
input_meta_graph_def = saved_model_loader.load(sess,
input_saved_model_tags)
output_node_names = set()
def _gather_names(tensor_info):
"""Get the node names from a TensorInfo."""
return set(
[tensor_info[key].name.split(":")[0] for key in tensor_info])
# Get input and outputs from all SignatureDef.
for key in input_meta_graph_def.signature_def:
signature_def = input_meta_graph_def.signature_def[key]
output_node_names.update(_gather_names(signature_def.inputs))
output_node_names.update(_gather_names(signature_def.outputs))
# Freeze the variables in the SavedModel graph and copy the frozen
# graph over.
frozen_graph_def = graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(add_shapes=True),
list(output_node_names))
grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
# Copy the collections that are not variables.
for key in input_meta_graph_def.collection_def:
# TODO(laigd): currently we use the collection key to filter out
# collections that depend on variable ops, but this may miss some
# other user-defined collections. A better way would be to use
# CollectionDef::NodeList for the filtering.
if key not in [
"variables", "local_variables", "model_variables",
"trainable_variables", "train_op", "table_initializer"
]:
grappler_meta_graph_def.collection_def[key].CopyFrom(
input_meta_graph_def.collection_def[key])
# Copy other information.
grappler_meta_graph_def.meta_info_def.CopyFrom(
input_meta_graph_def.meta_info_def)
for key in input_meta_graph_def.signature_def:
grappler_meta_graph_def.signature_def[key].CopyFrom(
input_meta_graph_def.signature_def[key])
# TODO(laigd): maybe add back AssetFileDef.
else:
if output_saved_model_dir is not None:
raise ValueError("output_saved_model_dir cannot be set when "
"input_graph_def is set")
# Create MetaGraphDef from input graph.
graph = ops.Graph()
with graph.as_default():
importer.import_graph_def(input_graph_def, name="")
grappler_meta_graph_def = saver.export_meta_graph(
graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
if outputs:
output_collection = meta_graph_pb2.CollectionDef()
output_list = output_collection.node_list.value
for i in outputs:
if isinstance(i, ops.Tensor):
output_list.append(_to_bytes(i.name))
else:
output_list.append(_to_bytes(i))
# TODO(laigd): use another key as the outputs are really not train_op.
grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
output_collection)
# Create TRT-enabled ConfigProto.
session_config_with_trt = config_pb2.ConfigProto()
session_config_with_trt.CopyFrom(session_config)
rewriter_config = None
if (session_config_with_trt.HasField("graph_options") and
session_config_with_trt.graph_options.HasField("rewrite_options")):
rewriter_config = session_config_with_trt.graph_options.rewrite_options
rewriter_config_with_trt = get_tensorrt_rewriter_config(
rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode,
minimum_segment_size, is_dynamic_op, maximum_cached_engines,
cached_engine_batches, use_calibration)
session_config_with_trt.graph_options.rewrite_options.CopyFrom(
rewriter_config_with_trt)
# Run Grappler.
transformed_graph_def = tf_optimizer.OptimizeGraph(
session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph")
# Optionally write the transformed graphdef as SavedModel.
if output_saved_model_dir is not None:
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
with ops.Graph().as_default():
importer.import_graph_def(transformed_graph_def, name="")
# We don't use TRT here.
with session.Session(config=session_config) as sess:
saved_model_builder.add_meta_graph_and_variables(
sess,
input_saved_model_tags,
signature_def_map=grappler_meta_graph_def.signature_def)
# Ignore other meta graphs from the input SavedModel.
saved_model_builder.save()
return transformed_graph_def
trt_converter = TrtGraphConverter(
input_saved_model_dir=input_saved_model_dir,
input_saved_model_tags=input_saved_model_tags,
input_graph_def=input_graph_def,
nodes_blacklist=outputs,
session_config=session_config,
max_batch_size=max_batch_size,
max_workspace_size_bytes=max_workspace_size_bytes,
precision_mode=precision_mode,
minimum_segment_size=minimum_segment_size,
is_dynamic_op=is_dynamic_op,
maximum_cached_engines=maximum_cached_engines,
cached_engine_batches=cached_engine_batches,
use_calibration=use_calibration)
converted_graph_def = trt_converter.convert()
if output_saved_model_dir:
trt_converter.save(output_saved_model_dir)
return converted_graph_def
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):

View File

@ -52,9 +52,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
_TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20
def testGetTensorrtRewriterConfig(self):
"""Test case for trt_convert.get_tensorrt_rewriter_config()."""
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
rewriter_config=None,
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
rewriter_config_template=None,
max_batch_size=128,
max_workspace_size_bytes=1234,
precision_mode="INT8",