Introduce TrtGraphConverterV2 for TF-TRT conversion in V2, and enhance the V2 unit test.

PiperOrigin-RevId: 244267523
This commit is contained in:
Guangda Lai 2019-04-18 15:09:57 -07:00 committed by TensorFlower Gardener
parent fb49f672cb
commit 96072813ec
7 changed files with 544 additions and 414 deletions

View File

@ -144,7 +144,7 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
).GetConversionParams(run_params)._replace( ).GetConversionParams(run_params)._replace(
# Disable layout optimizer, since it'll add Transpose(Const, Const) to # Disable layout optimizer, since it'll add Transpose(Const, Const) to
# the graph and breaks the conversion check. # the graph and breaks the conversion check.
rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase): class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase):

View File

@ -124,7 +124,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
maximum_cached_engines=1, maximum_cached_engines=1,
# Disable layout optimizer, since it will convert BiasAdd with NHWC # Disable layout optimizer, since it will convert BiasAdd with NHWC
# format to NCHW format under four dimentional input. # format to NCHW format under four dimentional input.
rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
def ExpectedEnginesToBuild(self, run_params): def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build.""" """Return the expected engines to build."""

View File

@ -85,7 +85,7 @@ class DynamicInputShapesTest(trt_test.TfTrtIntegrationTestBase):
maximum_cached_engines=10, maximum_cached_engines=10,
# Disable layout optimizer, since it will convert BiasAdd with NHWC # Disable layout optimizer, since it will convert BiasAdd with NHWC
# format to NCHW format under four dimentional input. # format to NCHW format under four dimentional input.
rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
def ExpectedEnginesToBuild(self, run_params): def ExpectedEnginesToBuild(self, run_params):
return ["TRTEngineOp_0"] return ["TRTEngineOp_0"]

View File

@ -65,7 +65,7 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase):
maximum_cached_engines=1, maximum_cached_engines=1,
# Disable layout optimizer, since it will convert BiasAdd with NHWC # Disable layout optimizer, since it will convert BiasAdd with NHWC
# format to NCHW format under four dimentional input. # format to NCHW format under four dimentional input.
rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
def ExpectedEnginesToBuild(self, run_params): def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build.""" """Return the expected engines to build."""

View File

@ -56,12 +56,6 @@ RunParams = namedtuple("RunParams", [
"use_calibration" "use_calibration"
]) ])
ConversionParams = namedtuple("ConversionParams", [
"max_batch_size", "max_workspace_size_bytes", "precision_mode",
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
"cached_engine_batches", "rewriter_config", "use_calibration"
])
PRECISION_MODES = ["FP32", "FP16", "INT8"] PRECISION_MODES = ["FP32", "FP16", "INT8"]
@ -163,7 +157,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
raise NotImplementedError() raise NotImplementedError()
def GetConversionParams(self, run_params): def GetConversionParams(self, run_params):
"""Return a ConversionParams for test.""" """Return a TrtConversionParams for test."""
batch_list = [] batch_list = []
for dims_list in self._GetParamsCached().input_dims: for dims_list in self._GetParamsCached().input_dims:
assert dims_list assert dims_list
@ -171,19 +165,22 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
input_batches = [dims[0] for dims in dims_list] input_batches = [dims[0] for dims in dims_list]
assert max(input_batches) == min(input_batches) assert max(input_batches) == min(input_batches)
batch_list.append(input_batches[0]) batch_list.append(input_batches[0])
return ConversionParams( conversion_params = trt_convert.TrtConversionParams(
# We use the minimum of all the batch sizes, so when multiple different # We use the minimum of all the batch sizes, so when multiple different
# input shapes are provided it'll always create new engines in the # input shapes are provided it'll always create new engines in the
# cache, and we can therefore test the cache behavior. # cache, and we can therefore test the cache behavior.
max_batch_size=min(batch_list), rewriter_config_template=None,
max_workspace_size_bytes=1 << 25, max_workspace_size_bytes=1 << 25,
precision_mode=run_params.precision_mode, precision_mode=run_params.precision_mode,
minimum_segment_size=2, minimum_segment_size=2,
is_dynamic_op=run_params.dynamic_engine, is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1, maximum_cached_engines=1,
cached_engine_batches=None, use_calibration=run_params.use_calibration,
rewriter_config=None, use_function_backup=False,
use_calibration=run_params.use_calibration) max_batch_size=min(batch_list),
cached_engine_batches=None)
return conversion_params._replace(
use_function_backup=IsQuantizationWithCalibration(conversion_params))
def ShouldRunTest(self, run_params): def ShouldRunTest(self, run_params):
"""Whether to run the test.""" """Whether to run the test."""
@ -218,24 +215,13 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Get config proto based on specific settings.""" """Get config proto based on specific settings."""
conversion_params = self.GetConversionParams(run_params) conversion_params = self.GetConversionParams(run_params)
if graph_state == GraphState.INFERENCE and run_params.use_optimizer: if graph_state == GraphState.INFERENCE and run_params.use_optimizer:
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config( rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(conversion_params)
conversion_params.rewriter_config,
conversion_params.max_batch_size,
conversion_params.max_workspace_size_bytes,
conversion_params.precision_mode,
conversion_params.minimum_segment_size,
conversion_params.is_dynamic_op,
conversion_params.maximum_cached_engines,
conversion_params.cached_engine_batches,
conversion_params.use_calibration,
use_function_backup=IsQuantizationWithCalibration(conversion_params))
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else: else:
graph_options = config_pb2.GraphOptions() graph_options = config_pb2.GraphOptions()
if conversion_params.rewriter_config is not None: if conversion_params.rewriter_config_template is not None:
graph_options.rewrite_options.CopyFrom( graph_options.rewrite_options.CopyFrom(
conversion_params.rewriter_config) conversion_params.rewriter_config_template)
config = config_pb2.ConfigProto( config = config_pb2.ConfigProto(
gpu_options=self._GetGPUOptions(), graph_options=graph_options) gpu_options=self._GetGPUOptions(), graph_options=graph_options)
@ -310,7 +296,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
maximum_cached_engines=conversion_params.maximum_cached_engines, maximum_cached_engines=conversion_params.maximum_cached_engines,
cached_engine_batches=conversion_params.cached_engine_batches, cached_engine_batches=conversion_params.cached_engine_batches,
use_calibration=conversion_params.use_calibration, use_calibration=conversion_params.use_calibration,
use_function_backup=IsQuantizationWithCalibration(conversion_params)) use_function_backup=conversion_params.use_function_backup)
return converter return converter
def _GetCalibratedInferGraph(self, run_params, gdef, inputs_data): def _GetCalibratedInferGraph(self, run_params, gdef, inputs_data):

View File

@ -18,7 +18,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import six as _six import six as _six
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_linked_tensorrt_version from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_linked_tensorrt_version
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_loaded_tensorrt_version from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_loaded_tensorrt_version
@ -80,7 +83,7 @@ class GraphConverter(object):
class MyGraphConverter(GraphConverter): class MyGraphConverter(GraphConverter):
... ...
def get_rewriter_config(self, rewriter_config_template=None): def get_rewriter_config(self):
my_rewriter_config = ... my_rewriter_config = ...
return my_rewriter_config return my_rewriter_config
``` ```
@ -129,7 +132,7 @@ class GraphConverter(object):
If set to None, the graph will be read from the SavedModel loaded from If set to None, the graph will be read from the SavedModel loaded from
input_saved_model_dir. input_saved_model_dir.
nodes_blacklist: list of node names to prevent the converter from nodes_blacklist: list of node names to prevent the converter from
touching. Only used when input_graph_def is not None. touching.
session_config: the ConfigProto used to create a Session. It's also used session_config: the ConfigProto used to create a Session. It's also used
as a template to create a RewriterConfig for conversion. If not as a template to create a RewriterConfig for conversion. If not
specified, a default ConfigProto will be used. specified, a default ConfigProto will be used.
@ -137,21 +140,15 @@ class GraphConverter(object):
Raises: Raises:
ValueError: if the combination of the parameters is invalid. ValueError: if the combination of the parameters is invalid.
""" """
if context.executing_eagerly(): if input_graph_def and input_saved_model_dir:
if input_graph_def or not input_saved_model_dir: raise ValueError(
raise ValueError( "Can only specify one of input_graph_def and input_saved_model_dir")
"TF 2.0 only supports conversion of SavedModel, please specify " if not input_graph_def and not input_saved_model_dir:
"input_saved_model_dir as input.") raise ValueError("Must specify one of input_graph_def and "
else: "input_saved_model_dir")
if input_graph_def and input_saved_model_dir:
raise ValueError(
"Can only specify one of input_graph_def and input_saved_model_dir")
if not input_graph_def and not input_saved_model_dir:
raise ValueError("Must specify one of input_graph_def and "
"input_saved_model_dir")
self._input_graph_def = input_graph_def self._input_graph_def = input_graph_def
self._nodes_blacklist = nodes_blacklist self._nodes_blacklist = nodes_blacklist
self._input_saved_model_dir = input_saved_model_dir self._input_saved_model_dir = input_saved_model_dir
self._converted = False self._converted = False
@ -169,14 +166,9 @@ class GraphConverter(object):
self._calibration_sess = None self._calibration_sess = None
self._calibration_data_collected = False self._calibration_data_collected = False
def get_rewriter_config(self, rewriter_config_template=None): def get_rewriter_config(self):
"""Returns a RewriterConfig proto for TRT transformation. """Returns a RewriterConfig proto for TRT transformation.
Args:
rewriter_config_template: a template RewriterConfig proto used to create a
RewriterConfig for the conversion. The implementation should not modify
the template. If None, it will use a default one.
Returns: Returns:
A RewriterConfig proto which will be used to run the conversion using A RewriterConfig proto which will be used to run the conversion using
Grappler. Grappler.
@ -188,11 +180,7 @@ class GraphConverter(object):
# Create custom ConfigProto for Grappler. # Create custom ConfigProto for Grappler.
grappler_session_config = config_pb2.ConfigProto() grappler_session_config = config_pb2.ConfigProto()
grappler_session_config.CopyFrom(self._session_config) grappler_session_config.CopyFrom(self._session_config)
rewriter_config = None custom_rewriter_config = self.get_rewriter_config()
if (grappler_session_config.HasField("graph_options") and
grappler_session_config.graph_options.HasField("rewrite_options")):
rewriter_config = grappler_session_config.graph_options.rewrite_options
custom_rewriter_config = self.get_rewriter_config(rewriter_config)
grappler_session_config.graph_options.rewrite_options.CopyFrom( grappler_session_config.graph_options.rewrite_options.CopyFrom(
custom_rewriter_config) custom_rewriter_config)
@ -285,33 +273,6 @@ class GraphConverter(object):
self._run_conversion() self._run_conversion()
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
# use it here (b/124792963).
def _convert_saved_model_v2(self):
"""Convert the input SavedModel in 2.0 format."""
assert context.executing_eagerly()
self._saved_model = load.load(self._input_saved_model_dir,
self._input_saved_model_tags)
func = self._saved_model.signatures[self._input_saved_model_signature_key]
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
self._grappler_meta_graph_def = saver.export_meta_graph(
graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
# Add a collection 'train_op' so that Grappler knows the outputs.
fetch_collection = meta_graph_pb2.CollectionDef()
for array in frozen_func.inputs + frozen_func.outputs:
fetch_collection.node_list.value.append(array.name)
self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
fetch_collection)
# Run TRT optimizer in Grappler to convert the graph.
self._run_conversion()
self._converted_func = wrap_function.function_from_graph_def(
self._converted_graph_def,
[tensor.name for tensor in frozen_func.inputs],
[tensor.name for tensor in frozen_func.outputs])
def convert(self): def convert(self):
"""Run the conversion. """Run the conversion.
@ -320,16 +281,11 @@ class GraphConverter(object):
2.0+. 2.0+.
""" """
assert not self._converted assert not self._converted
if self._input_graph_def:
if context.executing_eagerly(): self._convert_graph_def()
self._convert_saved_model_v2()
return self._converted_func
else: else:
if self._input_graph_def: self._convert_saved_model()
self._convert_graph_def() return self._converted_graph_def
else:
self._convert_saved_model()
return self._converted_graph_def
def calibrate(self, def calibrate(self,
fetch_names, fetch_names,
@ -408,80 +364,71 @@ class GraphConverter(object):
SavedModel. SavedModel.
""" """
assert self._converted assert self._converted
if self._input_graph_def:
raise ValueError(
"Not able to save to a SavedModel since input is a GraphDef")
if context.executing_eagerly(): def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
# Rewrite the signature map using the optimized ConcreteFunction. """Restores collections that we need to keep."""
signatures = { scope = ""
key: value for key, value in self._saved_model.signatures.items() for key in collection_keys:
} collection_def = src_meta_graph_def.collection_def[key]
signatures[self._input_saved_model_signature_key] = self._converted_func kind = collection_def.WhichOneof("kind")
save.save(self._saved_model, output_saved_model_dir, signatures) if kind is None:
else: tf_logging.error(
if self._input_graph_def: "Cannot identify data type for collection %s. Skipping.", key)
raise ValueError( continue
"Not able to save to a SavedModel since input is a GraphDef") from_proto = ops.get_from_proto_function(key)
if from_proto and kind == "bytes_list":
def _restore_collections(dest_graph, src_meta_graph_def, collections): proto_type = ops.get_collection_proto_type(key)
"""Restores collections that we need to keep.""" # It is assumed that there are no Variables Keys in collections
scope = "" for value in collection_def.bytes_list.value:
for key in collections: proto = proto_type()
collection_def = src_meta_graph_def.collection_def[key] proto.ParseFromString(value)
kind = collection_def.WhichOneof("kind") try:
if kind is None: new_value = from_proto(proto, import_scope=scope)
tf_logging.error( except:
"Cannot identify data type for collection %s. Skipping.", key) continue
continue dest_graph.add_to_collection(key, new_value)
from_proto = ops.get_from_proto_function(key) else:
if from_proto and kind == "bytes_list": field = getattr(collection_def, kind)
proto_type = ops.get_collection_proto_type(key) if kind == "node_list":
# It is assumed that there are no Variables Keys in collections for value in field.value:
for value in collection_def.bytes_list.value: name = ops.prepend_name_scope(value, scope)
proto = proto_type() # Since the graph has been optimized, the node may no longer
proto.ParseFromString(value) # exists
try: try:
new_value = from_proto(proto, import_scope=scope) col_op = dest_graph.as_graph_element(name)
except: except (TypeError, ValueError, KeyError) as e:
continue continue
dest_graph.add_to_collection(key, new_value) dest_graph.add_to_collection(key, col_op)
elif kind == "int64_list":
# NOTE(opensource): This force conversion is to work around the
# fact that Python2 distinguishes between int and long, while
# Python3 has only int.
for value in field.value:
dest_graph.add_to_collection(key, int(value))
else: else:
field = getattr(collection_def, kind) for value in field.value:
if kind == "node_list": dest_graph.add_to_collection(key,
for value in field.value: ops.prepend_name_scope(value, scope))
name = ops.prepend_name_scope(value, scope)
# Since the graph has been optimized, the node may no longer
# exists
try:
col_op = dest_graph.as_graph_element(name)
except (TypeError, ValueError, KeyError) as e:
continue
dest_graph.add_to_collection(key, col_op)
elif kind == "int64_list":
# NOTE(opensource): This force conversion is to work around the
# fact that Python2 distinguishes between int and long, while
# Python3 has only int.
for value in field.value:
dest_graph.add_to_collection(key, int(value))
else:
for value in field.value:
dest_graph.add_to_collection(
key, ops.prepend_name_scope(value, scope))
# Write the transformed graphdef as SavedModel. # Write the transformed graphdef as SavedModel.
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
with ops.Graph().as_default(): with ops.Graph().as_default():
importer.import_graph_def(self._converted_graph_def, name="") importer.import_graph_def(self._converted_graph_def, name="")
_restore_collections( _restore_collections(
ops.get_default_graph(), self._grappler_meta_graph_def, ops.get_default_graph(), self._grappler_meta_graph_def,
self._collections_to_keep( self._collections_to_keep(
self._grappler_meta_graph_def.collection_def)) self._grappler_meta_graph_def.collection_def))
# We don't use any specific converter here. # We don't use any specific converter here.
with session.Session(config=self._session_config) as sess: with session.Session(config=self._session_config) as sess:
saved_model_builder.add_meta_graph_and_variables( saved_model_builder.add_meta_graph_and_variables(
sess, sess,
self._input_saved_model_tags, self._input_saved_model_tags,
signature_def_map=self._grappler_meta_graph_def.signature_def) signature_def_map=self._grappler_meta_graph_def.signature_def)
# Ignore other meta graphs from the input SavedModel. # Ignore other meta graphs from the input SavedModel.
saved_model_builder.save() saved_model_builder.save()
class TrtPrecisionMode(object): class TrtPrecisionMode(object):
@ -498,101 +445,202 @@ class TrtPrecisionMode(object):
# so it can produce reasonable performance results with the default. # so it can produce reasonable performance results with the default.
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
# TrtConversionParams encapsulates the parameters that are used for TF-TRT
# conversion.
TrtConversionParams = collections.namedtuple(
"TrtConversionParams",
[
# A template RewriterConfig proto used to create a TRT-enabled
# RewriterConfig. If None, it will use a default one.
"rewriter_config_template",
# The maximum GPU temporary memory which the TRT engine can use at
# execution time. This corresponds to the 'workspaceSize' parameter of
# nvinfer1::IBuilder::setMaxWorkspaceSize().
"max_workspace_size_bytes",
# One of TrtPrecisionMode.supported_precision_modes().
"precision_mode",
# The minimum number of nodes required for a subgraph to be replaced by
# TRTEngineOp.
"minimum_segment_size",
# Whether to generate dynamic TRT ops which will build the TRT network
# and engine at run time.
"is_dynamic_op",
# Max number of cached TRT engines in dynamic TRT ops. If the number of
# cached engines is already at max but none of them can serve the input,
# the TRTEngineOp will fall back to run the TF function based on which
# the TRTEngineOp is created.
"maximum_cached_engines",
# This argument is ignored if precision_mode is not INT8. If set to
# True, a calibration graph will be created to calibrate the missing
# ranges. The calibration graph must be converted to an inference graph
# by running calibration with calibrate(). If set to False, quantization
# nodes will be expected for every tensor in the graph (exlcuding those
# which will be fused). If a range is missing, an error will occur.
# Please note that accuracy may be negatively affected if there is a
# mismatch between which tensors TRT quantizes and which tensors were
# trained with fake quantization.
"use_calibration",
# If set to True, it will create a FunctionDef for each subgraph that is
# converted to TRT op, and if TRT ops fail to execute at runtime, it'll
# invoke that function as a fallback.
"use_function_backup",
# Max size for the input batch.
# This option is deprecated in TF 2.0.
"max_batch_size",
# A list of batch sizes used to create cached engines, only used when
# is_dynamic_op is True. The length of the list should be <=
# maximum_cached_engines, and the dynamic TRT op will use this list to
# determine the batch sizes of the cached engines, instead of making the
# decision on the fly. This is useful when we know the most common batch
# size(s) the application is going to generate.
# This option is deprecated in TF 2.0.
"cached_engine_batches",
])
DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams(
rewriter_config_template=None,
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
use_calibration=True,
use_function_backup=True,
max_batch_size=1,
cached_engine_batches=None)
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration"
_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
_TRT_ENGINE_OP_NAME = "TRTEngineOp"
def _check_conversion_params(conversion_params):
"""Validate the provided TrtConversionParams.
Args:
conversion_params: a TrtConversionParams instance.
Raises:
TypeError: if any of the parameters are of unexpected type.
ValueError: if any of the parameters are of unexpected value.
"""
supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
if conversion_params.precision_mode not in supported_precision_modes:
raise ValueError(
("precision mode '{}' is not supported."
"It should be one of {}").format(conversion_params.precision_mode,
supported_precision_modes))
if conversion_params.cached_engine_batches:
if not isinstance(conversion_params.cached_engine_batches, list):
raise TypeError("cached_engine_batches should be a list.")
if len(conversion_params.cached_engine_batches
) > conversion_params.maximum_cached_engines:
raise ValueError("cached_engine_batches should not contain more than "
"maximum_cached_engines items.")
def _check_trt_version_compatibility():
"""Check compatibility of TensorRT version.
Raises:
RuntimeError: if the TensorRT library version is incompatible.
"""
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
tf_logging.info("Linked TensorRT version: %s" % str(compiled_version))
tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
version_mismatch = False
if loaded_version[0] < compiled_version[0]:
tf_logging.error(
"TensorRT version mismatch. Tensorflow was compiled against " +
"TensorRT %s but library loaded from environment is TensorRT %s" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])) +
". Please make sure that correct version of TensorRT " +
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
raise RuntimeError("Incompatible TensorRT library version")
for i in zip(loaded_version, compiled_version):
if i[0] != i[1]:
tf_logging.warn("TensorRT mismatch. Compiled against version " +
"%s, but loaded %s. Things may not work" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])))
version_mismatch = True
break
if not version_mismatch:
tf_logging.info("Running against TensorRT version %s" %
".".join([str(x) for x in loaded_version]))
def get_tensorrt_rewriter_config(
conversion_params=DEFAULT_TRT_CONVERSION_PARAMS):
"""Returns a RewriterConfig proto for TRT transformation.
Args:
conversion_params: a TrtConversionParams instance.
Returns:
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
Raises:
TypeError: if any of the parameters are of unexpected type.
ValueError: if any of the parameters are of unexpected value.
"""
if conversion_params.rewriter_config_template is not None and not isinstance(
conversion_params.rewriter_config_template,
rewriter_config_pb2.RewriterConfig):
raise TypeError(
"rewriter_config_template should be a RewriterConfig proto.")
_check_conversion_params(conversion_params)
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
if conversion_params.rewriter_config_template is None:
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
# need to run constant folding again.
rewriter_config_with_trt.optimizers.extend(
["constfold", "layout", "constfold"])
rewriter_config_with_trt.meta_optimizer_iterations = (
rewriter_config_pb2.RewriterConfig.ONE)
else:
rewriter_config_with_trt.CopyFrom(
conversion_params.rewriter_config_template)
optimizer = rewriter_config_with_trt.custom_optimizers.add()
optimizer.name = "TensorRTOptimizer"
optimizer.parameter_map[
"minimum_segment_size"].i = conversion_params.minimum_segment_size
optimizer.parameter_map["max_batch_size"].i = conversion_params.max_batch_size
optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op
optimizer.parameter_map[
"max_workspace_size_bytes"].i = conversion_params.max_workspace_size_bytes
optimizer.parameter_map["precision_mode"].s = _to_bytes(
conversion_params.precision_mode)
optimizer.parameter_map[
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
if conversion_params.cached_engine_batches:
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
conversion_params.cached_engine_batches)
optimizer.parameter_map[
"use_calibration"].b = conversion_params.use_calibration
optimizer.parameter_map[
"use_function_backup"].b = conversion_params.use_function_backup
return rewriter_config_with_trt
class TrtGraphConverter(GraphConverter): class TrtGraphConverter(GraphConverter):
"""A GraphConverter for TRT transformation.""" """A GraphConverter for TRT transformation."""
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration" # TODO(laigd): use TrtConversionParams here.
@classmethod
def get_tensorrt_rewriter_config(
cls,
rewriter_config_template=None,
max_batch_size=1,
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
cached_engine_batches=None,
use_calibration=True,
use_function_backup=True):
"""Returns a RewriterConfig proto for TRT transformation.
Args:
rewriter_config_template: a template RewriterConfig proto used to create a
TRT-enabled RewriterConfig. If None, it will use a default one.
max_batch_size: max size for the input batch
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
engine can use at execution time. This corresponds to the
'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph
to be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the
TRT network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT
ops. If the number of cached engines is already at max but none of them
can serve the input, the TRTEngineOp will fall back to run the TF
function based on which the TRTEngineOp is created.
cached_engine_batches: a list of batch sizes used to create cached
engines, only used when is_dynamic_op is True. The length of the list
should be <= maximum_cached_engines, and the dynamic TRT op will use
this list to determine the batch sizes of the cached engines, instead of
making the decision on the fly. This is useful when we know the most
common batch size(s) the application is going to generate.
use_calibration: this argument is ignored if precision_mode is not INT8.
If set to True, a calibration graph will be created to calibrate the
missing ranges. The calibration graph must be converted to an inference
graph by running calibration with calibrate(). If set to False,
quantization nodes will be expected for every tensor in the graph
(exlcuding those which will be fused). If a range is missing, an error
will occur. Please note that accuracy may be negatively affected if
there is a mismatch between which tensors TRT quantizes and which
tensors were trained with fake quantization.
use_function_backup: if set to True, it will create a FunctionDef for each
subgraph that is converted to TRT op, and if TRT ops fail to execute at
runtime, it'll invoke that function as a fallback.
Returns:
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
Raises:
TypeError: if any of the parameters are of unexpected type.
ValueError: if any of the parameters are of unexpected value.
"""
if rewriter_config_template is not None and not isinstance(
rewriter_config_template, rewriter_config_pb2.RewriterConfig):
raise TypeError(
"rewriter_config_template should be a RewriterConfig proto.")
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
if rewriter_config_template is None:
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
# need to run constant folding again.
rewriter_config_with_trt.optimizers.extend(
["constfold", "layout", "constfold"])
rewriter_config_with_trt.meta_optimizer_iterations = (
rewriter_config_pb2.RewriterConfig.ONE)
else:
rewriter_config_with_trt.CopyFrom(rewriter_config_template)
optimizer = rewriter_config_with_trt.custom_optimizers.add()
optimizer.name = "TensorRTOptimizer"
optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
optimizer.parameter_map["max_batch_size"].i = max_batch_size
optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
optimizer.parameter_map[
"max_workspace_size_bytes"].i = max_workspace_size_bytes
optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
if cached_engine_batches:
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
cached_engine_batches)
optimizer.parameter_map["use_calibration"].b = use_calibration
optimizer.parameter_map["use_function_backup"].b = use_function_backup
return rewriter_config_with_trt
def __init__(self, def __init__(self,
input_saved_model_dir=None, input_saved_model_dir=None,
input_saved_model_tags=None, input_saved_model_tags=None,
@ -621,7 +669,7 @@ class TrtGraphConverter(GraphConverter):
If set to None, the graph will be read from the SavedModel loaded from If set to None, the graph will be read from the SavedModel loaded from
input_saved_model_dir. input_saved_model_dir.
nodes_blacklist: list of node names to prevent the converter from nodes_blacklist: list of node names to prevent the converter from
touching. Only used when input_graph_def is not None. touching.
session_config: the ConfigProto used to create a Session. It's also used session_config: the ConfigProto used to create a Session. It's also used
as a template to create a TRT-enabled ConfigProto for conversion. If not as a template to create a TRT-enabled ConfigProto for conversion. If not
specified, a default ConfigProto will be used. specified, a default ConfigProto will be used.
@ -659,7 +707,6 @@ class TrtGraphConverter(GraphConverter):
Raises: Raises:
ValueError: if the combination of the parameters is invalid. ValueError: if the combination of the parameters is invalid.
RuntimeError: if the TensorRT library version is incompatible.
""" """
super(TrtGraphConverter, self).__init__( super(TrtGraphConverter, self).__init__(
input_saved_model_dir=input_saved_model_dir, input_saved_model_dir=input_saved_model_dir,
@ -668,54 +715,10 @@ class TrtGraphConverter(GraphConverter):
input_graph_def=input_graph_def, input_graph_def=input_graph_def,
nodes_blacklist=nodes_blacklist, nodes_blacklist=nodes_blacklist,
session_config=session_config) session_config=session_config)
_check_trt_version_compatibility()
# TODO(laigd): move all the validations below to
# get_tensorrt_rewriter_config().
# Check compatibility of TensorRT version.
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
tf_logging.info("Linked TensorRT version: %s" % str(compiled_version))
tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
version_mismatch = False
if loaded_version[0] < compiled_version[0]:
tf_logging.error(
"TensorRT version mismatch. Tensorflow was compiled against " +
"TensorRT %s but library loaded from environment is TensorRT %s" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])) +
". Please make sure that correct version of TensorRT " +
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
raise RuntimeError("Incompatible TensorRT library version")
for i in zip(loaded_version, compiled_version):
if i[0] != i[1]:
tf_logging.warn("TensorRT mismatch. Compiled against version " +
"%s, but loaded %s. Things may not work" %
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])))
version_mismatch = True
break
if not version_mismatch:
tf_logging.info("Running against TensorRT version %s" %
".".join([str(x) for x in loaded_version]))
# Check input arguments.
supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
if precision_mode not in supported_precision_modes:
raise ValueError(
("precision mode '{}' is not supported."
"It should be one of {}").format(precision_mode,
supported_precision_modes))
if cached_engine_batches:
if not isinstance(cached_engine_batches, list):
raise TypeError("cached_engine_batches should be a list.")
if len(cached_engine_batches) > maximum_cached_engines:
raise ValueError("cached_engine_batches should not contain more than "
"maximum_cached_engines items.")
self._need_calibration = ( self._need_calibration = (
precision_mode == TrtPrecisionMode.INT8 and use_calibration) precision_mode == TrtPrecisionMode.INT8 and use_calibration)
self._use_function_backup = use_function_backup
# TODO(laigd): consider provide a mechanism to remove the fallback path # TODO(laigd): consider provide a mechanism to remove the fallback path
# after calibration is done. # after calibration is done.
@ -724,31 +727,30 @@ class TrtGraphConverter(GraphConverter):
"Calibration requires enabling fallback to TF function execution.") "Calibration requires enabling fallback to TF function execution.")
# TODO(laigd): # TODO(laigd):
# - Get rid of is_dynamic_op option, it should always be True, and it should
# accept N shapes as input.
# - Verify in int8 mode that maximum_cached_engines and # - Verify in int8 mode that maximum_cached_engines and
# cached_engine_batches are set appropriately. # cached_engine_batches are set appropriately.
# - If it fails to build the int8 engine it should return error. # - If it fails to build the int8 engine it should return error.
self._max_batch_size = max_batch_size rewriter_config_template = None
self._max_workspace_size_bytes = max_workspace_size_bytes if (session_config and session_config.HasField("graph_options") and
self._precision_mode = precision_mode session_config.graph_options.HasField("rewrite_options")):
self._minimum_segment_size = minimum_segment_size rewriter_config_template = session_config.graph_options.rewrite_options
self._is_dynamic_op = is_dynamic_op
self._maximum_cached_engines = maximum_cached_engines
self._cached_engine_batches = cached_engine_batches
def get_rewriter_config(self, rewriter_config_template=None): self._conversion_params = TrtConversionParams(
return TrtGraphConverter.get_tensorrt_rewriter_config( rewriter_config_template=rewriter_config_template,
rewriter_config_template, max_workspace_size_bytes=max_workspace_size_bytes,
max_batch_size=self._max_batch_size, precision_mode=precision_mode,
max_workspace_size_bytes=self._max_workspace_size_bytes, minimum_segment_size=minimum_segment_size,
precision_mode=self._precision_mode, is_dynamic_op=is_dynamic_op,
minimum_segment_size=self._minimum_segment_size, maximum_cached_engines=maximum_cached_engines,
is_dynamic_op=self._is_dynamic_op, use_calibration=use_calibration,
maximum_cached_engines=self._maximum_cached_engines, use_function_backup=use_function_backup,
cached_engine_batches=self._cached_engine_batches, max_batch_size=max_batch_size,
use_calibration=self._need_calibration, cached_engine_batches=cached_engine_batches)
use_function_backup=self._use_function_backup) _check_conversion_params(self._conversion_params)
def get_rewriter_config(self):
return get_tensorrt_rewriter_config(
conversion_params=self._conversion_params)
def finalize_calibration(self): def finalize_calibration(self):
assert self._need_calibration assert self._need_calibration
@ -775,7 +777,7 @@ class TrtGraphConverter(GraphConverter):
resource_name_input = array_ops.placeholder(dtypes.string) resource_name_input = array_ops.placeholder(dtypes.string)
for node in self._converted_graph_def.node: for node in self._converted_graph_def.node:
if node.op == "TRTEngineOp": if node.op == _TRT_ENGINE_OP_NAME:
# Adds the get_serialized_resource_op for the device if not done # Adds the get_serialized_resource_op for the device if not done
# before. We only add one such op for each device. # before. We only add one such op for each device.
# TODO(laigd): What if the device is empty????? # TODO(laigd): What if the device is empty?????
@ -791,11 +793,8 @@ class TrtGraphConverter(GraphConverter):
calibration_result = self._calibration_sess.run( calibration_result = self._calibration_sess.run(
device_to_get_resource_op_map[node.device], device_to_get_resource_op_map[node.device],
feed_dict={ feed_dict={
container_input: container_input: _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
TrtGraphConverter resource_name_input: node.name
._TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
resource_name_input:
node.name
}) })
node.attr["calibration_data"].s = calibration_result node.attr["calibration_data"].s = calibration_result
@ -806,9 +805,106 @@ class TrtGraphConverter(GraphConverter):
"""Save the converted graph as a SavedModel.""" """Save the converted graph as a SavedModel."""
if self._need_calibration: if self._need_calibration:
assert self._calibration_data_collected assert self._calibration_data_collected
super(TrtGraphConverter, self).save(output_saved_model_dir) super(TrtGraphConverter, self).save(output_saved_model_dir)
class TrtGraphConverterV2(object):
"""A converter for TF-TRT transformation for SavedModel in TF 2.0."""
def __init__(self,
input_saved_model_dir=None,
input_saved_model_tags=None,
input_saved_model_signature_key=None,
conversion_params=DEFAULT_TRT_CONVERSION_PARAMS):
"""Initialize the converter.
Args:
input_saved_model_dir: the directory to load the SavedModel which contains
the input graph to transforms. Used only when input_graph_def is None.
input_saved_model_tags: list of tags to load the SavedModel.
input_saved_model_signature_key: the key of the signature to optimize the
graph for.
conversion_params: a TrtConversionParams instance.
"""
assert context.executing_eagerly()
_check_trt_version_compatibility()
self._input_saved_model_dir = input_saved_model_dir
self._input_saved_model_tags = (
input_saved_model_tags or [tag_constants.SERVING])
self._input_saved_model_signature_key = (
input_saved_model_signature_key or
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
self._need_calibration = (
conversion_params.precision_mode == TrtPrecisionMode.INT8 and
conversion_params.use_calibration)
self._conversion_params = conversion_params
_check_conversion_params(self._conversion_params)
self._converted = False
def _run_conversion(self, meta_graph_def):
"""Run Grappler's OptimizeGraph() tool to convert the graph.
Args:
meta_graph_def: the MetaGraphDef instance to run the optimizations on.
Returns:
The optimized GraphDef.
"""
rewriter_config = get_tensorrt_rewriter_config(
conversion_params=self._conversion_params)
grappler_session_config = config_pb2.ConfigProto()
grappler_session_config.graph_options.rewrite_options.CopyFrom(
rewriter_config)
return tf_optimizer.OptimizeGraph(
grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
# use it here (b/124792963).
def convert(self):
"""Convert the input SavedModel in 2.0 format."""
assert not self._converted
self._saved_model = load.load(self._input_saved_model_dir,
self._input_saved_model_tags)
func = self._saved_model.signatures[self._input_saved_model_signature_key]
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
grappler_meta_graph_def = saver.export_meta_graph(
graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
# Add a collection 'train_op' so that Grappler knows the outputs.
fetch_collection = meta_graph_pb2.CollectionDef()
for array in frozen_func.inputs + frozen_func.outputs:
fetch_collection.node_list.value.append(array.name)
grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
fetch_collection)
# Run TRT optimizer in Grappler to convert the graph.
converted_graph_def = self._run_conversion(grappler_meta_graph_def)
self._converted_func = wrap_function.function_from_graph_def(
converted_graph_def, [tensor.name for tensor in frozen_func.inputs],
[tensor.name for tensor in frozen_func.outputs])
self._converted = True
return self._converted_func
def save(self, output_saved_model_dir):
"""Save the converted SavedModel.
Args:
output_saved_model_dir: directory to saved the converted SavedModel.
"""
assert self._converted
# Rewrite the signature map using the optimized ConcreteFunction.
signatures = {
key: value for key, value in self._saved_model.signatures.items()
}
signatures[self._input_saved_model_signature_key] = self._converted_func
save.save(self._saved_model, output_saved_model_dir, signatures)
# TODO(laigd): use TrtConversionParams here.
def create_inference_graph( def create_inference_graph(
input_graph_def, input_graph_def,
outputs, outputs,

View File

@ -19,14 +19,17 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import tempfile
import numpy as np
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.compiler.tensorrt import trt_convert from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
@ -35,7 +38,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import builder
@ -44,10 +46,11 @@ from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils from tensorflow.python.saved_model import utils
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.saved_model import load from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save from tensorflow.python.saved_model import save
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest
_SAVED_MODEL_SIGNATURE_KEY = "mypredict" _SAVED_MODEL_SIGNATURE_KEY = "mypredict"
@ -63,8 +66,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config().""" """Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
if not is_tensorrt_enabled(): if not is_tensorrt_enabled():
return return
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config( conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
rewriter_config_template=None,
max_batch_size=128, max_batch_size=128,
max_workspace_size_bytes=1234, max_workspace_size_bytes=1234,
precision_mode="INT8", precision_mode="INT8",
@ -72,6 +74,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
is_dynamic_op=True, is_dynamic_op=True,
maximum_cached_engines=2, maximum_cached_engines=2,
cached_engine_batches=[1, 128]) cached_engine_batches=[1, 128])
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
conversion_params=conversion_params)
self.assertEqual(["constfold", "layout", "constfold"], self.assertEqual(["constfold", "layout", "constfold"],
rewriter_cfg.optimizers) rewriter_cfg.optimizers)
self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE, self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
@ -106,7 +110,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
gpu_options=config_pb2.GPUOptions(allow_growth=True)) gpu_options=config_pb2.GPUOptions(allow_growth=True))
return config return config
def _GetGraph(self): @classmethod
def _GetGraph(cls, inp, var):
"""Get the graph for testing.""" """Get the graph for testing."""
# The graph computes (input+1)^2, it looks like: # The graph computes (input+1)^2, it looks like:
# #
@ -119,24 +124,42 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
# + # +
# | # |
# output (Identity) # output (Identity)
add = inp + var
mul = inp * add
add = mul + add
out = array_ops.identity(add, name="output")
return out
def _GetModelForV2(self):
class SimpleModel(tracking.AutoTrackable):
def __init__(self):
self.v = None
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
])
def run(self, inp):
if self.v is None:
self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
return TrtConvertTest._GetGraph(inp, self.v)
return SimpleModel()
def _GetGraphForV1(self):
g = ops.Graph() g = ops.Graph()
with g.as_default(): with g.as_default():
with g.device("/GPU:0"): with g.device("/GPU:0"):
inp = array_ops.placeholder( inp = array_ops.placeholder(
dtype=dtypes.float32, shape=[None, 1, 1], name="input") dtype=dtypes.float32, shape=[None, 1, 1], name="input")
var = variables.VariableV1([[[1.0]]], var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
dtype=dtypes.float32, out = TrtConvertTest._GetGraph(inp, var)
name="v1", return g, var, inp, out
use_resource=False)
add = inp + var.value()
mul = inp * add
add = mul + add
out = array_ops.identity(add, name="output")
return g, var, inp, out
def _GetGraphDef(self): def _GetGraphDef(self):
"""Get the graph def for testing.""" """Get the graph def for testing."""
g, var, _, _ = self._GetGraph() g, var, _, _ = self._GetGraphForV1()
with self.session(graph=g, config=self._GetConfigProto()) as sess: with self.session(graph=g, config=self._GetConfigProto()) as sess:
sess.run(var.initializer) sess.run(var.initializer)
graph_def = graph_util.convert_variables_to_constants( graph_def = graph_util.convert_variables_to_constants(
@ -145,7 +168,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
self.assertEqual( self.assertEqual(
{ {
"v1": "Const", "v1": "Const",
"v1/read": "Identity", "add/ReadVariableOp": "Identity",
"input": "Placeholder", "input": "Placeholder",
"add": "Add", "add": "Add",
"mul": "Mul", "mul": "Mul",
@ -156,7 +179,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
def _WriteInputSavedModel(self, input_saved_model_dir): def _WriteInputSavedModel(self, input_saved_model_dir):
"""Write the saved model as an input for testing.""" """Write the saved model as an input for testing."""
g, var, inp, out = self._GetGraph() g, var, inp, out = self._GetGraphForV1()
signature_def = signature_def_utils.build_signature_def( signature_def = signature_def_utils.build_signature_def(
inputs={"myinput": utils.build_tensor_info(inp)}, inputs={"myinput": utils.build_tensor_info(inp)},
outputs={"myoutput": utils.build_tensor_info(out)}, outputs={"myoutput": utils.build_tensor_info(out)},
@ -183,7 +206,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
input_saved_model_dir=input_saved_model_dir, input_saved_model_dir=input_saved_model_dir,
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
input_graph_def=None if input_saved_model_dir else self._GetGraphDef(), input_graph_def=None if input_saved_model_dir else self._GetGraphDef(),
nodes_blacklist=["output"], nodes_blacklist=None if input_saved_model_dir else ["output"],
session_config=self._GetConfigProto(), session_config=self._GetConfigProto(),
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES, max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
@ -193,28 +216,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
is_dynamic_op=is_dynamic_op, is_dynamic_op=is_dynamic_op,
maximum_cached_engines=maximum_cached_engines, maximum_cached_engines=maximum_cached_engines,
use_function_backup=use_function_backup) use_function_backup=use_function_backup)
conversion_result = converter.convert() output_graph_def = converter.convert()
if context.executing_eagerly(): if need_calibration:
output_graph_def = conversion_result.graph.as_graph_def()
else:
output_graph_def = conversion_result
if need_calibration: class CalibrationData(object):
class CalibrationData(object): def __init__(self):
self._data = 0
def __init__(self): def next(self):
self._data = 0 self._data += 1
return {"input:0": [[[self._data]]]}
def next(self): output_graph_def = converter.calibrate(
self._data += 1 fetch_names=["output:0"],
return {"input:0": [[[self._data]]]} num_runs=10,
feed_dict_fn=CalibrationData().next)
output_graph_def = converter.calibrate(
fetch_names=["output:0"],
num_runs=10,
feed_dict_fn=CalibrationData().next)
if output_saved_model_dir is not None: if output_saved_model_dir is not None:
converter.save(output_saved_model_dir=output_saved_model_dir) converter.save(output_saved_model_dir=output_saved_model_dir)
@ -235,31 +253,19 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
graph_defs_to_verify = [output_graph_def] graph_defs_to_verify = [output_graph_def]
if output_saved_model_dir: if output_saved_model_dir:
if context.executing_eagerly(): saved_model_graph_def = saved_model_utils.get_meta_graph_def(
root = load.load(output_saved_model_dir) output_saved_model_dir, tag_constants.SERVING).graph_def
saved_model_graph_def = root.signatures[ self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
_SAVED_MODEL_SIGNATURE_KEY].graph.as_graph_def()
else:
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
output_saved_model_dir, tag_constants.SERVING).graph_def
self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
graph_defs_to_verify.append(saved_model_graph_def) graph_defs_to_verify.append(saved_model_graph_def)
for graph_def in graph_defs_to_verify: for graph_def in graph_defs_to_verify:
node_name_to_op = {node.name: node.op for node in graph_def.node} node_name_to_op = {node.name: node.op for node in graph_def.node}
if context.executing_eagerly(): self.assertEqual(
# In V2 the actual graph could be inside a function. {
for func in graph_def.library.function: "input": "Placeholder",
node_name_to_op.update({node.name: node.op for node in func.node_def}) "TRTEngineOp_0": "TRTEngineOp",
self.assertIn("TRTEngineOp_0", node_name_to_op) "output": "Identity"
self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"]) }, node_name_to_op)
else:
self.assertEqual(
{
"input": "Placeholder",
"TRTEngineOp_0": "TRTEngineOp",
"output": "Identity"
}, node_name_to_op)
if need_calibration: if need_calibration:
trt_engine_nodes = [ trt_engine_nodes = [
@ -306,39 +312,81 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
if not is_tensorrt_enabled(): if not is_tensorrt_enabled():
return return
# TODO(laigd): we need to use ops like conv2d so Grappler can infer the np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
# shapes (at least rank) of the tensors, so we're able to build an TRT
# engine in dynamic mode. Currently shape information is not propagate from
# ConcreteFunction to GraphDef, need to investigate and fix it.
class SimpleModel(tracking.AutoTrackable):
def __init__(self): # Create a model and save it.
self.v = None input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
root = self._GetModelForV2()
@def_function.function(input_signature=[ expected_output = root.run(np_input)
tensor_spec.TensorSpec(shape=[None, 24, 24, 2], dtype=dtypes.float32)
])
def run(self, inp):
if self.v is None:
self.v = variables.Variable([[[[1., 0.5, 4., 6., 0.5, 1.],
[1., 0.5, 1., 1., 0.5, 1.]]]])
conv = gen_nn_ops.conv2d(
input=inp, filter=self.v, strides=[1, 2, 2, 1], padding="SAME")
identity = array_ops.identity(conv)
return identity
tmp_dir = self.get_temp_dir()
input_saved_model_dir = os.path.join(tmp_dir, "in_dir1_v2")
root = SimpleModel()
save.save(root, input_saved_model_dir, save.save(root, input_saved_model_dir,
{_SAVED_MODEL_SIGNATURE_KEY: root.run}) {_SAVED_MODEL_SIGNATURE_KEY: root.run})
# Convert the SavedModel and verify the result. # Run TRT conversion.
output_saved_model_dir = os.path.join(tmp_dir, "out_dir1_v2") converter = trt_convert.TrtGraphConverterV2(
self._TestTrtGraphConverter(
input_saved_model_dir=input_saved_model_dir, input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir, input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
is_dynamic_op=True) conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode=trt_convert.TrtPrecisionMode.FP32,
is_dynamic_op=True,
maximum_cached_engines=2,
use_function_backup=False))
converted_concrete_func = converter.convert()
def _check_trt_ops(graph_def):
trt_op_names = [
node.name for node in graph_def.node if node.op == "TRTEngineOp"
]
for func in graph_def.library.function:
for node in func.node_def:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
self.assertEqual(1, len(trt_op_names))
self.assertIn("TRTEngineOp_0", trt_op_names[0])
# Verify the converted GraphDef and ConcreteFunction.
self.assertIsInstance(converted_concrete_func, function.ConcreteFunction)
converted_graph_def = converted_concrete_func.graph.as_graph_def()
_check_trt_ops(converted_graph_def)
output_with_trt = converted_concrete_func(ops.convert_to_tensor(np_input))
self.assertEqual(1, len(output_with_trt))
self.assertAllClose(
expected_output, output_with_trt[0].numpy(), atol=1e-6, rtol=1e-6)
# Run the converted ConcreteFunction as a function and make sure it works.
@def_function.function
def wrapper_func(*args, **kwargs):
return nest.flatten(converted_concrete_func(*args, **kwargs))
_check_trt_ops(
wrapper_func.get_concrete_function(
tensor_spec.TensorSpec(shape=[None, 1, 1],
dtype=dtypes.float32)).graph.as_graph_def())
output_with_trt = wrapper_func(np_input)
self.assertEqual(1, len(output_with_trt))
self.assertAllClose(
expected_output, output_with_trt[0], atol=1e-6, rtol=1e-6)
# Save the converted model.
output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
converter.save(output_saved_model_dir)
# Load and verify the converted model.
#
# TODO(laigd): the name of then new input_signature of the
# `root_with_trt.run` function is empty string (originaly was None),
# investigate why.
root_with_trt = load.load(output_saved_model_dir)
# TODO(laigd): `root_with_trt.run` is still using the original graph without
# trt. Consider changing that.
# _check_trt_ops(
# root_with_trt.run.get_concrete_function().graph.as_graph_def())
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
_check_trt_ops(converted_signature.graph.as_graph_def())
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
# The output of running the converted signature is a dict due to
# compatibility reasons with V1 SavedModel signature mechanism.
output_with_trt = output_with_trt[output_with_trt.keys()[0]]
self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6)
def _TestRun(self, def _TestRun(self,
sess, sess,
@ -363,7 +411,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
node_name_to_op = {node.name: node.op for node in output_graph_def.node} node_name_to_op = {node.name: node.op for node in output_graph_def.node}
self.assertEqual( self.assertEqual(
{ {
"v1/read": "Const", "add/ReadVariableOp": "Const",
"input": "Placeholder", "input": "Placeholder",
"add": "Add", "add": "Add",
"mul": "Mul", "mul": "Mul",