Introduce TrtGraphConverterV2 for TF-TRT conversion in V2, and enhance the V2 unit test.
PiperOrigin-RevId: 244267523
This commit is contained in:
parent
fb49f672cb
commit
96072813ec
@ -144,7 +144,7 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
|
||||
).GetConversionParams(run_params)._replace(
|
||||
# Disable layout optimizer, since it'll add Transpose(Const, Const) to
|
||||
# the graph and breaks the conversion check.
|
||||
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
|
||||
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
|
||||
|
||||
|
||||
class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase):
|
||||
|
@ -124,7 +124,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
|
||||
maximum_cached_engines=1,
|
||||
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
||||
# format to NCHW format under four dimentional input.
|
||||
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
|
||||
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
|
||||
|
||||
def ExpectedEnginesToBuild(self, run_params):
|
||||
"""Return the expected engines to build."""
|
||||
|
@ -85,7 +85,7 @@ class DynamicInputShapesTest(trt_test.TfTrtIntegrationTestBase):
|
||||
maximum_cached_engines=10,
|
||||
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
||||
# format to NCHW format under four dimentional input.
|
||||
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
|
||||
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
|
||||
|
||||
def ExpectedEnginesToBuild(self, run_params):
|
||||
return ["TRTEngineOp_0"]
|
||||
|
@ -65,7 +65,7 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase):
|
||||
maximum_cached_engines=1,
|
||||
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
||||
# format to NCHW format under four dimentional input.
|
||||
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
|
||||
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())
|
||||
|
||||
def ExpectedEnginesToBuild(self, run_params):
|
||||
"""Return the expected engines to build."""
|
||||
|
@ -56,12 +56,6 @@ RunParams = namedtuple("RunParams", [
|
||||
"use_calibration"
|
||||
])
|
||||
|
||||
ConversionParams = namedtuple("ConversionParams", [
|
||||
"max_batch_size", "max_workspace_size_bytes", "precision_mode",
|
||||
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
|
||||
"cached_engine_batches", "rewriter_config", "use_calibration"
|
||||
])
|
||||
|
||||
PRECISION_MODES = ["FP32", "FP16", "INT8"]
|
||||
|
||||
|
||||
@ -163,7 +157,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
raise NotImplementedError()
|
||||
|
||||
def GetConversionParams(self, run_params):
|
||||
"""Return a ConversionParams for test."""
|
||||
"""Return a TrtConversionParams for test."""
|
||||
batch_list = []
|
||||
for dims_list in self._GetParamsCached().input_dims:
|
||||
assert dims_list
|
||||
@ -171,19 +165,22 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
input_batches = [dims[0] for dims in dims_list]
|
||||
assert max(input_batches) == min(input_batches)
|
||||
batch_list.append(input_batches[0])
|
||||
return ConversionParams(
|
||||
conversion_params = trt_convert.TrtConversionParams(
|
||||
# We use the minimum of all the batch sizes, so when multiple different
|
||||
# input shapes are provided it'll always create new engines in the
|
||||
# cache, and we can therefore test the cache behavior.
|
||||
max_batch_size=min(batch_list),
|
||||
rewriter_config_template=None,
|
||||
max_workspace_size_bytes=1 << 25,
|
||||
precision_mode=run_params.precision_mode,
|
||||
minimum_segment_size=2,
|
||||
is_dynamic_op=run_params.dynamic_engine,
|
||||
maximum_cached_engines=1,
|
||||
cached_engine_batches=None,
|
||||
rewriter_config=None,
|
||||
use_calibration=run_params.use_calibration)
|
||||
use_calibration=run_params.use_calibration,
|
||||
use_function_backup=False,
|
||||
max_batch_size=min(batch_list),
|
||||
cached_engine_batches=None)
|
||||
return conversion_params._replace(
|
||||
use_function_backup=IsQuantizationWithCalibration(conversion_params))
|
||||
|
||||
def ShouldRunTest(self, run_params):
|
||||
"""Whether to run the test."""
|
||||
@ -218,24 +215,13 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
"""Get config proto based on specific settings."""
|
||||
conversion_params = self.GetConversionParams(run_params)
|
||||
if graph_state == GraphState.INFERENCE and run_params.use_optimizer:
|
||||
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
|
||||
conversion_params.rewriter_config,
|
||||
conversion_params.max_batch_size,
|
||||
conversion_params.max_workspace_size_bytes,
|
||||
conversion_params.precision_mode,
|
||||
conversion_params.minimum_segment_size,
|
||||
conversion_params.is_dynamic_op,
|
||||
conversion_params.maximum_cached_engines,
|
||||
conversion_params.cached_engine_batches,
|
||||
conversion_params.use_calibration,
|
||||
use_function_backup=IsQuantizationWithCalibration(conversion_params))
|
||||
|
||||
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(conversion_params)
|
||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
|
||||
else:
|
||||
graph_options = config_pb2.GraphOptions()
|
||||
if conversion_params.rewriter_config is not None:
|
||||
if conversion_params.rewriter_config_template is not None:
|
||||
graph_options.rewrite_options.CopyFrom(
|
||||
conversion_params.rewriter_config)
|
||||
conversion_params.rewriter_config_template)
|
||||
|
||||
config = config_pb2.ConfigProto(
|
||||
gpu_options=self._GetGPUOptions(), graph_options=graph_options)
|
||||
@ -310,7 +296,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
maximum_cached_engines=conversion_params.maximum_cached_engines,
|
||||
cached_engine_batches=conversion_params.cached_engine_batches,
|
||||
use_calibration=conversion_params.use_calibration,
|
||||
use_function_backup=IsQuantizationWithCalibration(conversion_params))
|
||||
use_function_backup=conversion_params.use_function_backup)
|
||||
return converter
|
||||
|
||||
def _GetCalibratedInferGraph(self, run_params, gdef, inputs_data):
|
||||
|
@ -18,7 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import six as _six
|
||||
|
||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_linked_tensorrt_version
|
||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_loaded_tensorrt_version
|
||||
@ -80,7 +83,7 @@ class GraphConverter(object):
|
||||
class MyGraphConverter(GraphConverter):
|
||||
...
|
||||
|
||||
def get_rewriter_config(self, rewriter_config_template=None):
|
||||
def get_rewriter_config(self):
|
||||
my_rewriter_config = ...
|
||||
return my_rewriter_config
|
||||
```
|
||||
@ -129,7 +132,7 @@ class GraphConverter(object):
|
||||
If set to None, the graph will be read from the SavedModel loaded from
|
||||
input_saved_model_dir.
|
||||
nodes_blacklist: list of node names to prevent the converter from
|
||||
touching. Only used when input_graph_def is not None.
|
||||
touching.
|
||||
session_config: the ConfigProto used to create a Session. It's also used
|
||||
as a template to create a RewriterConfig for conversion. If not
|
||||
specified, a default ConfigProto will be used.
|
||||
@ -137,21 +140,15 @@ class GraphConverter(object):
|
||||
Raises:
|
||||
ValueError: if the combination of the parameters is invalid.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
if input_graph_def or not input_saved_model_dir:
|
||||
raise ValueError(
|
||||
"TF 2.0 only supports conversion of SavedModel, please specify "
|
||||
"input_saved_model_dir as input.")
|
||||
else:
|
||||
if input_graph_def and input_saved_model_dir:
|
||||
raise ValueError(
|
||||
"Can only specify one of input_graph_def and input_saved_model_dir")
|
||||
if not input_graph_def and not input_saved_model_dir:
|
||||
raise ValueError("Must specify one of input_graph_def and "
|
||||
"input_saved_model_dir")
|
||||
if input_graph_def and input_saved_model_dir:
|
||||
raise ValueError(
|
||||
"Can only specify one of input_graph_def and input_saved_model_dir")
|
||||
if not input_graph_def and not input_saved_model_dir:
|
||||
raise ValueError("Must specify one of input_graph_def and "
|
||||
"input_saved_model_dir")
|
||||
|
||||
self._input_graph_def = input_graph_def
|
||||
self._nodes_blacklist = nodes_blacklist
|
||||
self._input_graph_def = input_graph_def
|
||||
self._nodes_blacklist = nodes_blacklist
|
||||
|
||||
self._input_saved_model_dir = input_saved_model_dir
|
||||
self._converted = False
|
||||
@ -169,14 +166,9 @@ class GraphConverter(object):
|
||||
self._calibration_sess = None
|
||||
self._calibration_data_collected = False
|
||||
|
||||
def get_rewriter_config(self, rewriter_config_template=None):
|
||||
def get_rewriter_config(self):
|
||||
"""Returns a RewriterConfig proto for TRT transformation.
|
||||
|
||||
Args:
|
||||
rewriter_config_template: a template RewriterConfig proto used to create a
|
||||
RewriterConfig for the conversion. The implementation should not modify
|
||||
the template. If None, it will use a default one.
|
||||
|
||||
Returns:
|
||||
A RewriterConfig proto which will be used to run the conversion using
|
||||
Grappler.
|
||||
@ -188,11 +180,7 @@ class GraphConverter(object):
|
||||
# Create custom ConfigProto for Grappler.
|
||||
grappler_session_config = config_pb2.ConfigProto()
|
||||
grappler_session_config.CopyFrom(self._session_config)
|
||||
rewriter_config = None
|
||||
if (grappler_session_config.HasField("graph_options") and
|
||||
grappler_session_config.graph_options.HasField("rewrite_options")):
|
||||
rewriter_config = grappler_session_config.graph_options.rewrite_options
|
||||
custom_rewriter_config = self.get_rewriter_config(rewriter_config)
|
||||
custom_rewriter_config = self.get_rewriter_config()
|
||||
grappler_session_config.graph_options.rewrite_options.CopyFrom(
|
||||
custom_rewriter_config)
|
||||
|
||||
@ -285,33 +273,6 @@ class GraphConverter(object):
|
||||
|
||||
self._run_conversion()
|
||||
|
||||
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
|
||||
# use it here (b/124792963).
|
||||
def _convert_saved_model_v2(self):
|
||||
"""Convert the input SavedModel in 2.0 format."""
|
||||
assert context.executing_eagerly()
|
||||
|
||||
self._saved_model = load.load(self._input_saved_model_dir,
|
||||
self._input_saved_model_tags)
|
||||
func = self._saved_model.signatures[self._input_saved_model_signature_key]
|
||||
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
|
||||
self._grappler_meta_graph_def = saver.export_meta_graph(
|
||||
graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
|
||||
|
||||
# Add a collection 'train_op' so that Grappler knows the outputs.
|
||||
fetch_collection = meta_graph_pb2.CollectionDef()
|
||||
for array in frozen_func.inputs + frozen_func.outputs:
|
||||
fetch_collection.node_list.value.append(array.name)
|
||||
self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
|
||||
fetch_collection)
|
||||
|
||||
# Run TRT optimizer in Grappler to convert the graph.
|
||||
self._run_conversion()
|
||||
self._converted_func = wrap_function.function_from_graph_def(
|
||||
self._converted_graph_def,
|
||||
[tensor.name for tensor in frozen_func.inputs],
|
||||
[tensor.name for tensor in frozen_func.outputs])
|
||||
|
||||
def convert(self):
|
||||
"""Run the conversion.
|
||||
|
||||
@ -320,16 +281,11 @@ class GraphConverter(object):
|
||||
2.0+.
|
||||
"""
|
||||
assert not self._converted
|
||||
|
||||
if context.executing_eagerly():
|
||||
self._convert_saved_model_v2()
|
||||
return self._converted_func
|
||||
if self._input_graph_def:
|
||||
self._convert_graph_def()
|
||||
else:
|
||||
if self._input_graph_def:
|
||||
self._convert_graph_def()
|
||||
else:
|
||||
self._convert_saved_model()
|
||||
return self._converted_graph_def
|
||||
self._convert_saved_model()
|
||||
return self._converted_graph_def
|
||||
|
||||
def calibrate(self,
|
||||
fetch_names,
|
||||
@ -408,80 +364,71 @@ class GraphConverter(object):
|
||||
SavedModel.
|
||||
"""
|
||||
assert self._converted
|
||||
if self._input_graph_def:
|
||||
raise ValueError(
|
||||
"Not able to save to a SavedModel since input is a GraphDef")
|
||||
|
||||
if context.executing_eagerly():
|
||||
# Rewrite the signature map using the optimized ConcreteFunction.
|
||||
signatures = {
|
||||
key: value for key, value in self._saved_model.signatures.items()
|
||||
}
|
||||
signatures[self._input_saved_model_signature_key] = self._converted_func
|
||||
save.save(self._saved_model, output_saved_model_dir, signatures)
|
||||
else:
|
||||
if self._input_graph_def:
|
||||
raise ValueError(
|
||||
"Not able to save to a SavedModel since input is a GraphDef")
|
||||
|
||||
def _restore_collections(dest_graph, src_meta_graph_def, collections):
|
||||
"""Restores collections that we need to keep."""
|
||||
scope = ""
|
||||
for key in collections:
|
||||
collection_def = src_meta_graph_def.collection_def[key]
|
||||
kind = collection_def.WhichOneof("kind")
|
||||
if kind is None:
|
||||
tf_logging.error(
|
||||
"Cannot identify data type for collection %s. Skipping.", key)
|
||||
continue
|
||||
from_proto = ops.get_from_proto_function(key)
|
||||
if from_proto and kind == "bytes_list":
|
||||
proto_type = ops.get_collection_proto_type(key)
|
||||
# It is assumed that there are no Variables Keys in collections
|
||||
for value in collection_def.bytes_list.value:
|
||||
proto = proto_type()
|
||||
proto.ParseFromString(value)
|
||||
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
|
||||
"""Restores collections that we need to keep."""
|
||||
scope = ""
|
||||
for key in collection_keys:
|
||||
collection_def = src_meta_graph_def.collection_def[key]
|
||||
kind = collection_def.WhichOneof("kind")
|
||||
if kind is None:
|
||||
tf_logging.error(
|
||||
"Cannot identify data type for collection %s. Skipping.", key)
|
||||
continue
|
||||
from_proto = ops.get_from_proto_function(key)
|
||||
if from_proto and kind == "bytes_list":
|
||||
proto_type = ops.get_collection_proto_type(key)
|
||||
# It is assumed that there are no Variables Keys in collections
|
||||
for value in collection_def.bytes_list.value:
|
||||
proto = proto_type()
|
||||
proto.ParseFromString(value)
|
||||
try:
|
||||
new_value = from_proto(proto, import_scope=scope)
|
||||
except:
|
||||
continue
|
||||
dest_graph.add_to_collection(key, new_value)
|
||||
else:
|
||||
field = getattr(collection_def, kind)
|
||||
if kind == "node_list":
|
||||
for value in field.value:
|
||||
name = ops.prepend_name_scope(value, scope)
|
||||
# Since the graph has been optimized, the node may no longer
|
||||
# exists
|
||||
try:
|
||||
new_value = from_proto(proto, import_scope=scope)
|
||||
except:
|
||||
col_op = dest_graph.as_graph_element(name)
|
||||
except (TypeError, ValueError, KeyError) as e:
|
||||
continue
|
||||
dest_graph.add_to_collection(key, new_value)
|
||||
dest_graph.add_to_collection(key, col_op)
|
||||
elif kind == "int64_list":
|
||||
# NOTE(opensource): This force conversion is to work around the
|
||||
# fact that Python2 distinguishes between int and long, while
|
||||
# Python3 has only int.
|
||||
for value in field.value:
|
||||
dest_graph.add_to_collection(key, int(value))
|
||||
else:
|
||||
field = getattr(collection_def, kind)
|
||||
if kind == "node_list":
|
||||
for value in field.value:
|
||||
name = ops.prepend_name_scope(value, scope)
|
||||
# Since the graph has been optimized, the node may no longer
|
||||
# exists
|
||||
try:
|
||||
col_op = dest_graph.as_graph_element(name)
|
||||
except (TypeError, ValueError, KeyError) as e:
|
||||
continue
|
||||
dest_graph.add_to_collection(key, col_op)
|
||||
elif kind == "int64_list":
|
||||
# NOTE(opensource): This force conversion is to work around the
|
||||
# fact that Python2 distinguishes between int and long, while
|
||||
# Python3 has only int.
|
||||
for value in field.value:
|
||||
dest_graph.add_to_collection(key, int(value))
|
||||
else:
|
||||
for value in field.value:
|
||||
dest_graph.add_to_collection(
|
||||
key, ops.prepend_name_scope(value, scope))
|
||||
for value in field.value:
|
||||
dest_graph.add_to_collection(key,
|
||||
ops.prepend_name_scope(value, scope))
|
||||
|
||||
# Write the transformed graphdef as SavedModel.
|
||||
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
|
||||
with ops.Graph().as_default():
|
||||
importer.import_graph_def(self._converted_graph_def, name="")
|
||||
_restore_collections(
|
||||
ops.get_default_graph(), self._grappler_meta_graph_def,
|
||||
self._collections_to_keep(
|
||||
self._grappler_meta_graph_def.collection_def))
|
||||
# We don't use any specific converter here.
|
||||
with session.Session(config=self._session_config) as sess:
|
||||
saved_model_builder.add_meta_graph_and_variables(
|
||||
sess,
|
||||
self._input_saved_model_tags,
|
||||
signature_def_map=self._grappler_meta_graph_def.signature_def)
|
||||
# Ignore other meta graphs from the input SavedModel.
|
||||
saved_model_builder.save()
|
||||
# Write the transformed graphdef as SavedModel.
|
||||
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
|
||||
with ops.Graph().as_default():
|
||||
importer.import_graph_def(self._converted_graph_def, name="")
|
||||
_restore_collections(
|
||||
ops.get_default_graph(), self._grappler_meta_graph_def,
|
||||
self._collections_to_keep(
|
||||
self._grappler_meta_graph_def.collection_def))
|
||||
# We don't use any specific converter here.
|
||||
with session.Session(config=self._session_config) as sess:
|
||||
saved_model_builder.add_meta_graph_and_variables(
|
||||
sess,
|
||||
self._input_saved_model_tags,
|
||||
signature_def_map=self._grappler_meta_graph_def.signature_def)
|
||||
# Ignore other meta graphs from the input SavedModel.
|
||||
saved_model_builder.save()
|
||||
|
||||
|
||||
class TrtPrecisionMode(object):
|
||||
@ -498,101 +445,202 @@ class TrtPrecisionMode(object):
|
||||
# so it can produce reasonable performance results with the default.
|
||||
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
||||
|
||||
# TrtConversionParams encapsulates the parameters that are used for TF-TRT
|
||||
# conversion.
|
||||
TrtConversionParams = collections.namedtuple(
|
||||
"TrtConversionParams",
|
||||
[
|
||||
|
||||
# A template RewriterConfig proto used to create a TRT-enabled
|
||||
# RewriterConfig. If None, it will use a default one.
|
||||
"rewriter_config_template",
|
||||
|
||||
# The maximum GPU temporary memory which the TRT engine can use at
|
||||
# execution time. This corresponds to the 'workspaceSize' parameter of
|
||||
# nvinfer1::IBuilder::setMaxWorkspaceSize().
|
||||
"max_workspace_size_bytes",
|
||||
|
||||
# One of TrtPrecisionMode.supported_precision_modes().
|
||||
"precision_mode",
|
||||
|
||||
# The minimum number of nodes required for a subgraph to be replaced by
|
||||
# TRTEngineOp.
|
||||
"minimum_segment_size",
|
||||
|
||||
# Whether to generate dynamic TRT ops which will build the TRT network
|
||||
# and engine at run time.
|
||||
"is_dynamic_op",
|
||||
|
||||
# Max number of cached TRT engines in dynamic TRT ops. If the number of
|
||||
# cached engines is already at max but none of them can serve the input,
|
||||
# the TRTEngineOp will fall back to run the TF function based on which
|
||||
# the TRTEngineOp is created.
|
||||
"maximum_cached_engines",
|
||||
|
||||
# This argument is ignored if precision_mode is not INT8. If set to
|
||||
# True, a calibration graph will be created to calibrate the missing
|
||||
# ranges. The calibration graph must be converted to an inference graph
|
||||
# by running calibration with calibrate(). If set to False, quantization
|
||||
# nodes will be expected for every tensor in the graph (exlcuding those
|
||||
# which will be fused). If a range is missing, an error will occur.
|
||||
# Please note that accuracy may be negatively affected if there is a
|
||||
# mismatch between which tensors TRT quantizes and which tensors were
|
||||
# trained with fake quantization.
|
||||
"use_calibration",
|
||||
|
||||
# If set to True, it will create a FunctionDef for each subgraph that is
|
||||
# converted to TRT op, and if TRT ops fail to execute at runtime, it'll
|
||||
# invoke that function as a fallback.
|
||||
"use_function_backup",
|
||||
|
||||
# Max size for the input batch.
|
||||
# This option is deprecated in TF 2.0.
|
||||
"max_batch_size",
|
||||
|
||||
# A list of batch sizes used to create cached engines, only used when
|
||||
# is_dynamic_op is True. The length of the list should be <=
|
||||
# maximum_cached_engines, and the dynamic TRT op will use this list to
|
||||
# determine the batch sizes of the cached engines, instead of making the
|
||||
# decision on the fly. This is useful when we know the most common batch
|
||||
# size(s) the application is going to generate.
|
||||
# This option is deprecated in TF 2.0.
|
||||
"cached_engine_batches",
|
||||
])
|
||||
|
||||
DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams(
|
||||
rewriter_config_template=None,
|
||||
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||
precision_mode=TrtPrecisionMode.FP32,
|
||||
minimum_segment_size=3,
|
||||
is_dynamic_op=False,
|
||||
maximum_cached_engines=1,
|
||||
use_calibration=True,
|
||||
use_function_backup=True,
|
||||
max_batch_size=1,
|
||||
cached_engine_batches=None)
|
||||
|
||||
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration"
|
||||
_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache"
|
||||
_TRT_ENGINE_OP_NAME = "TRTEngineOp"
|
||||
|
||||
|
||||
def _check_conversion_params(conversion_params):
|
||||
"""Validate the provided TrtConversionParams.
|
||||
|
||||
Args:
|
||||
conversion_params: a TrtConversionParams instance.
|
||||
|
||||
Raises:
|
||||
TypeError: if any of the parameters are of unexpected type.
|
||||
ValueError: if any of the parameters are of unexpected value.
|
||||
"""
|
||||
supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
|
||||
if conversion_params.precision_mode not in supported_precision_modes:
|
||||
raise ValueError(
|
||||
("precision mode '{}' is not supported."
|
||||
"It should be one of {}").format(conversion_params.precision_mode,
|
||||
supported_precision_modes))
|
||||
if conversion_params.cached_engine_batches:
|
||||
if not isinstance(conversion_params.cached_engine_batches, list):
|
||||
raise TypeError("cached_engine_batches should be a list.")
|
||||
if len(conversion_params.cached_engine_batches
|
||||
) > conversion_params.maximum_cached_engines:
|
||||
raise ValueError("cached_engine_batches should not contain more than "
|
||||
"maximum_cached_engines items.")
|
||||
|
||||
|
||||
def _check_trt_version_compatibility():
|
||||
"""Check compatibility of TensorRT version.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the TensorRT library version is incompatible.
|
||||
"""
|
||||
compiled_version = get_linked_tensorrt_version()
|
||||
loaded_version = get_loaded_tensorrt_version()
|
||||
tf_logging.info("Linked TensorRT version: %s" % str(compiled_version))
|
||||
tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
|
||||
version_mismatch = False
|
||||
if loaded_version[0] < compiled_version[0]:
|
||||
tf_logging.error(
|
||||
"TensorRT version mismatch. Tensorflow was compiled against " +
|
||||
"TensorRT %s but library loaded from environment is TensorRT %s" %
|
||||
(".".join([str(x) for x in compiled_version]),
|
||||
".".join([str(x) for x in loaded_version])) +
|
||||
". Please make sure that correct version of TensorRT " +
|
||||
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
|
||||
raise RuntimeError("Incompatible TensorRT library version")
|
||||
for i in zip(loaded_version, compiled_version):
|
||||
if i[0] != i[1]:
|
||||
tf_logging.warn("TensorRT mismatch. Compiled against version " +
|
||||
"%s, but loaded %s. Things may not work" %
|
||||
(".".join([str(x) for x in compiled_version]),
|
||||
".".join([str(x) for x in loaded_version])))
|
||||
version_mismatch = True
|
||||
break
|
||||
if not version_mismatch:
|
||||
tf_logging.info("Running against TensorRT version %s" %
|
||||
".".join([str(x) for x in loaded_version]))
|
||||
|
||||
|
||||
def get_tensorrt_rewriter_config(
|
||||
conversion_params=DEFAULT_TRT_CONVERSION_PARAMS):
|
||||
"""Returns a RewriterConfig proto for TRT transformation.
|
||||
|
||||
Args:
|
||||
conversion_params: a TrtConversionParams instance.
|
||||
|
||||
Returns:
|
||||
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
|
||||
|
||||
Raises:
|
||||
TypeError: if any of the parameters are of unexpected type.
|
||||
ValueError: if any of the parameters are of unexpected value.
|
||||
"""
|
||||
if conversion_params.rewriter_config_template is not None and not isinstance(
|
||||
conversion_params.rewriter_config_template,
|
||||
rewriter_config_pb2.RewriterConfig):
|
||||
raise TypeError(
|
||||
"rewriter_config_template should be a RewriterConfig proto.")
|
||||
_check_conversion_params(conversion_params)
|
||||
|
||||
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
|
||||
if conversion_params.rewriter_config_template is None:
|
||||
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
|
||||
# need to run constant folding again.
|
||||
rewriter_config_with_trt.optimizers.extend(
|
||||
["constfold", "layout", "constfold"])
|
||||
rewriter_config_with_trt.meta_optimizer_iterations = (
|
||||
rewriter_config_pb2.RewriterConfig.ONE)
|
||||
else:
|
||||
rewriter_config_with_trt.CopyFrom(
|
||||
conversion_params.rewriter_config_template)
|
||||
|
||||
optimizer = rewriter_config_with_trt.custom_optimizers.add()
|
||||
optimizer.name = "TensorRTOptimizer"
|
||||
optimizer.parameter_map[
|
||||
"minimum_segment_size"].i = conversion_params.minimum_segment_size
|
||||
optimizer.parameter_map["max_batch_size"].i = conversion_params.max_batch_size
|
||||
optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op
|
||||
optimizer.parameter_map[
|
||||
"max_workspace_size_bytes"].i = conversion_params.max_workspace_size_bytes
|
||||
optimizer.parameter_map["precision_mode"].s = _to_bytes(
|
||||
conversion_params.precision_mode)
|
||||
optimizer.parameter_map[
|
||||
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
|
||||
if conversion_params.cached_engine_batches:
|
||||
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
|
||||
conversion_params.cached_engine_batches)
|
||||
optimizer.parameter_map[
|
||||
"use_calibration"].b = conversion_params.use_calibration
|
||||
optimizer.parameter_map[
|
||||
"use_function_backup"].b = conversion_params.use_function_backup
|
||||
return rewriter_config_with_trt
|
||||
|
||||
|
||||
class TrtGraphConverter(GraphConverter):
|
||||
"""A GraphConverter for TRT transformation."""
|
||||
|
||||
_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration"
|
||||
|
||||
@classmethod
|
||||
def get_tensorrt_rewriter_config(
|
||||
cls,
|
||||
rewriter_config_template=None,
|
||||
max_batch_size=1,
|
||||
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||
precision_mode=TrtPrecisionMode.FP32,
|
||||
minimum_segment_size=3,
|
||||
is_dynamic_op=False,
|
||||
maximum_cached_engines=1,
|
||||
cached_engine_batches=None,
|
||||
use_calibration=True,
|
||||
use_function_backup=True):
|
||||
"""Returns a RewriterConfig proto for TRT transformation.
|
||||
|
||||
Args:
|
||||
rewriter_config_template: a template RewriterConfig proto used to create a
|
||||
TRT-enabled RewriterConfig. If None, it will use a default one.
|
||||
max_batch_size: max size for the input batch
|
||||
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
|
||||
engine can use at execution time. This corresponds to the
|
||||
'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
|
||||
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
|
||||
minimum_segment_size: the minimum number of nodes required for a subgraph
|
||||
to be replaced by TRTEngineOp.
|
||||
is_dynamic_op: whether to generate dynamic TRT ops which will build the
|
||||
TRT network and engine at run time.
|
||||
maximum_cached_engines: max number of cached TRT engines in dynamic TRT
|
||||
ops. If the number of cached engines is already at max but none of them
|
||||
can serve the input, the TRTEngineOp will fall back to run the TF
|
||||
function based on which the TRTEngineOp is created.
|
||||
cached_engine_batches: a list of batch sizes used to create cached
|
||||
engines, only used when is_dynamic_op is True. The length of the list
|
||||
should be <= maximum_cached_engines, and the dynamic TRT op will use
|
||||
this list to determine the batch sizes of the cached engines, instead of
|
||||
making the decision on the fly. This is useful when we know the most
|
||||
common batch size(s) the application is going to generate.
|
||||
use_calibration: this argument is ignored if precision_mode is not INT8.
|
||||
If set to True, a calibration graph will be created to calibrate the
|
||||
missing ranges. The calibration graph must be converted to an inference
|
||||
graph by running calibration with calibrate(). If set to False,
|
||||
quantization nodes will be expected for every tensor in the graph
|
||||
(exlcuding those which will be fused). If a range is missing, an error
|
||||
will occur. Please note that accuracy may be negatively affected if
|
||||
there is a mismatch between which tensors TRT quantizes and which
|
||||
tensors were trained with fake quantization.
|
||||
use_function_backup: if set to True, it will create a FunctionDef for each
|
||||
subgraph that is converted to TRT op, and if TRT ops fail to execute at
|
||||
runtime, it'll invoke that function as a fallback.
|
||||
|
||||
Returns:
|
||||
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
|
||||
|
||||
Raises:
|
||||
TypeError: if any of the parameters are of unexpected type.
|
||||
ValueError: if any of the parameters are of unexpected value.
|
||||
"""
|
||||
if rewriter_config_template is not None and not isinstance(
|
||||
rewriter_config_template, rewriter_config_pb2.RewriterConfig):
|
||||
raise TypeError(
|
||||
"rewriter_config_template should be a RewriterConfig proto.")
|
||||
|
||||
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
|
||||
if rewriter_config_template is None:
|
||||
# Layout optimizer may add Const nodes followed by Reshape nodes, thus we
|
||||
# need to run constant folding again.
|
||||
rewriter_config_with_trt.optimizers.extend(
|
||||
["constfold", "layout", "constfold"])
|
||||
rewriter_config_with_trt.meta_optimizer_iterations = (
|
||||
rewriter_config_pb2.RewriterConfig.ONE)
|
||||
else:
|
||||
rewriter_config_with_trt.CopyFrom(rewriter_config_template)
|
||||
|
||||
optimizer = rewriter_config_with_trt.custom_optimizers.add()
|
||||
optimizer.name = "TensorRTOptimizer"
|
||||
optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
|
||||
optimizer.parameter_map["max_batch_size"].i = max_batch_size
|
||||
optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
|
||||
optimizer.parameter_map[
|
||||
"max_workspace_size_bytes"].i = max_workspace_size_bytes
|
||||
optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
|
||||
optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
|
||||
if cached_engine_batches:
|
||||
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
|
||||
cached_engine_batches)
|
||||
optimizer.parameter_map["use_calibration"].b = use_calibration
|
||||
optimizer.parameter_map["use_function_backup"].b = use_function_backup
|
||||
return rewriter_config_with_trt
|
||||
|
||||
# TODO(laigd): use TrtConversionParams here.
|
||||
def __init__(self,
|
||||
input_saved_model_dir=None,
|
||||
input_saved_model_tags=None,
|
||||
@ -621,7 +669,7 @@ class TrtGraphConverter(GraphConverter):
|
||||
If set to None, the graph will be read from the SavedModel loaded from
|
||||
input_saved_model_dir.
|
||||
nodes_blacklist: list of node names to prevent the converter from
|
||||
touching. Only used when input_graph_def is not None.
|
||||
touching.
|
||||
session_config: the ConfigProto used to create a Session. It's also used
|
||||
as a template to create a TRT-enabled ConfigProto for conversion. If not
|
||||
specified, a default ConfigProto will be used.
|
||||
@ -659,7 +707,6 @@ class TrtGraphConverter(GraphConverter):
|
||||
|
||||
Raises:
|
||||
ValueError: if the combination of the parameters is invalid.
|
||||
RuntimeError: if the TensorRT library version is incompatible.
|
||||
"""
|
||||
super(TrtGraphConverter, self).__init__(
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
@ -668,54 +715,10 @@ class TrtGraphConverter(GraphConverter):
|
||||
input_graph_def=input_graph_def,
|
||||
nodes_blacklist=nodes_blacklist,
|
||||
session_config=session_config)
|
||||
|
||||
# TODO(laigd): move all the validations below to
|
||||
# get_tensorrt_rewriter_config().
|
||||
# Check compatibility of TensorRT version.
|
||||
compiled_version = get_linked_tensorrt_version()
|
||||
loaded_version = get_loaded_tensorrt_version()
|
||||
tf_logging.info("Linked TensorRT version: %s" % str(compiled_version))
|
||||
tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
|
||||
version_mismatch = False
|
||||
if loaded_version[0] < compiled_version[0]:
|
||||
tf_logging.error(
|
||||
"TensorRT version mismatch. Tensorflow was compiled against " +
|
||||
"TensorRT %s but library loaded from environment is TensorRT %s" %
|
||||
(".".join([str(x) for x in compiled_version]),
|
||||
".".join([str(x) for x in loaded_version])) +
|
||||
". Please make sure that correct version of TensorRT " +
|
||||
"is available in the system and added to ldconfig or LD_LIBRARY_PATH")
|
||||
raise RuntimeError("Incompatible TensorRT library version")
|
||||
for i in zip(loaded_version, compiled_version):
|
||||
if i[0] != i[1]:
|
||||
tf_logging.warn("TensorRT mismatch. Compiled against version " +
|
||||
"%s, but loaded %s. Things may not work" %
|
||||
(".".join([str(x) for x in compiled_version]),
|
||||
".".join([str(x) for x in loaded_version])))
|
||||
version_mismatch = True
|
||||
break
|
||||
if not version_mismatch:
|
||||
tf_logging.info("Running against TensorRT version %s" %
|
||||
".".join([str(x) for x in loaded_version]))
|
||||
|
||||
# Check input arguments.
|
||||
supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
|
||||
if precision_mode not in supported_precision_modes:
|
||||
raise ValueError(
|
||||
("precision mode '{}' is not supported."
|
||||
"It should be one of {}").format(precision_mode,
|
||||
supported_precision_modes))
|
||||
|
||||
if cached_engine_batches:
|
||||
if not isinstance(cached_engine_batches, list):
|
||||
raise TypeError("cached_engine_batches should be a list.")
|
||||
if len(cached_engine_batches) > maximum_cached_engines:
|
||||
raise ValueError("cached_engine_batches should not contain more than "
|
||||
"maximum_cached_engines items.")
|
||||
_check_trt_version_compatibility()
|
||||
|
||||
self._need_calibration = (
|
||||
precision_mode == TrtPrecisionMode.INT8 and use_calibration)
|
||||
self._use_function_backup = use_function_backup
|
||||
|
||||
# TODO(laigd): consider provide a mechanism to remove the fallback path
|
||||
# after calibration is done.
|
||||
@ -724,31 +727,30 @@ class TrtGraphConverter(GraphConverter):
|
||||
"Calibration requires enabling fallback to TF function execution.")
|
||||
|
||||
# TODO(laigd):
|
||||
# - Get rid of is_dynamic_op option, it should always be True, and it should
|
||||
# accept N shapes as input.
|
||||
# - Verify in int8 mode that maximum_cached_engines and
|
||||
# cached_engine_batches are set appropriately.
|
||||
# - If it fails to build the int8 engine it should return error.
|
||||
self._max_batch_size = max_batch_size
|
||||
self._max_workspace_size_bytes = max_workspace_size_bytes
|
||||
self._precision_mode = precision_mode
|
||||
self._minimum_segment_size = minimum_segment_size
|
||||
self._is_dynamic_op = is_dynamic_op
|
||||
self._maximum_cached_engines = maximum_cached_engines
|
||||
self._cached_engine_batches = cached_engine_batches
|
||||
rewriter_config_template = None
|
||||
if (session_config and session_config.HasField("graph_options") and
|
||||
session_config.graph_options.HasField("rewrite_options")):
|
||||
rewriter_config_template = session_config.graph_options.rewrite_options
|
||||
|
||||
def get_rewriter_config(self, rewriter_config_template=None):
|
||||
return TrtGraphConverter.get_tensorrt_rewriter_config(
|
||||
rewriter_config_template,
|
||||
max_batch_size=self._max_batch_size,
|
||||
max_workspace_size_bytes=self._max_workspace_size_bytes,
|
||||
precision_mode=self._precision_mode,
|
||||
minimum_segment_size=self._minimum_segment_size,
|
||||
is_dynamic_op=self._is_dynamic_op,
|
||||
maximum_cached_engines=self._maximum_cached_engines,
|
||||
cached_engine_batches=self._cached_engine_batches,
|
||||
use_calibration=self._need_calibration,
|
||||
use_function_backup=self._use_function_backup)
|
||||
self._conversion_params = TrtConversionParams(
|
||||
rewriter_config_template=rewriter_config_template,
|
||||
max_workspace_size_bytes=max_workspace_size_bytes,
|
||||
precision_mode=precision_mode,
|
||||
minimum_segment_size=minimum_segment_size,
|
||||
is_dynamic_op=is_dynamic_op,
|
||||
maximum_cached_engines=maximum_cached_engines,
|
||||
use_calibration=use_calibration,
|
||||
use_function_backup=use_function_backup,
|
||||
max_batch_size=max_batch_size,
|
||||
cached_engine_batches=cached_engine_batches)
|
||||
_check_conversion_params(self._conversion_params)
|
||||
|
||||
def get_rewriter_config(self):
|
||||
return get_tensorrt_rewriter_config(
|
||||
conversion_params=self._conversion_params)
|
||||
|
||||
def finalize_calibration(self):
|
||||
assert self._need_calibration
|
||||
@ -775,7 +777,7 @@ class TrtGraphConverter(GraphConverter):
|
||||
resource_name_input = array_ops.placeholder(dtypes.string)
|
||||
|
||||
for node in self._converted_graph_def.node:
|
||||
if node.op == "TRTEngineOp":
|
||||
if node.op == _TRT_ENGINE_OP_NAME:
|
||||
# Adds the get_serialized_resource_op for the device if not done
|
||||
# before. We only add one such op for each device.
|
||||
# TODO(laigd): What if the device is empty?????
|
||||
@ -791,11 +793,8 @@ class TrtGraphConverter(GraphConverter):
|
||||
calibration_result = self._calibration_sess.run(
|
||||
device_to_get_resource_op_map[node.device],
|
||||
feed_dict={
|
||||
container_input:
|
||||
TrtGraphConverter
|
||||
._TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
|
||||
resource_name_input:
|
||||
node.name
|
||||
container_input: _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
|
||||
resource_name_input: node.name
|
||||
})
|
||||
node.attr["calibration_data"].s = calibration_result
|
||||
|
||||
@ -806,9 +805,106 @@ class TrtGraphConverter(GraphConverter):
|
||||
"""Save the converted graph as a SavedModel."""
|
||||
if self._need_calibration:
|
||||
assert self._calibration_data_collected
|
||||
|
||||
super(TrtGraphConverter, self).save(output_saved_model_dir)
|
||||
|
||||
|
||||
class TrtGraphConverterV2(object):
|
||||
"""A converter for TF-TRT transformation for SavedModel in TF 2.0."""
|
||||
|
||||
def __init__(self,
|
||||
input_saved_model_dir=None,
|
||||
input_saved_model_tags=None,
|
||||
input_saved_model_signature_key=None,
|
||||
conversion_params=DEFAULT_TRT_CONVERSION_PARAMS):
|
||||
"""Initialize the converter.
|
||||
|
||||
Args:
|
||||
input_saved_model_dir: the directory to load the SavedModel which contains
|
||||
the input graph to transforms. Used only when input_graph_def is None.
|
||||
input_saved_model_tags: list of tags to load the SavedModel.
|
||||
input_saved_model_signature_key: the key of the signature to optimize the
|
||||
graph for.
|
||||
conversion_params: a TrtConversionParams instance.
|
||||
"""
|
||||
assert context.executing_eagerly()
|
||||
_check_trt_version_compatibility()
|
||||
|
||||
self._input_saved_model_dir = input_saved_model_dir
|
||||
self._input_saved_model_tags = (
|
||||
input_saved_model_tags or [tag_constants.SERVING])
|
||||
self._input_saved_model_signature_key = (
|
||||
input_saved_model_signature_key or
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||
|
||||
self._need_calibration = (
|
||||
conversion_params.precision_mode == TrtPrecisionMode.INT8 and
|
||||
conversion_params.use_calibration)
|
||||
self._conversion_params = conversion_params
|
||||
_check_conversion_params(self._conversion_params)
|
||||
self._converted = False
|
||||
|
||||
def _run_conversion(self, meta_graph_def):
|
||||
"""Run Grappler's OptimizeGraph() tool to convert the graph.
|
||||
|
||||
Args:
|
||||
meta_graph_def: the MetaGraphDef instance to run the optimizations on.
|
||||
|
||||
Returns:
|
||||
The optimized GraphDef.
|
||||
"""
|
||||
rewriter_config = get_tensorrt_rewriter_config(
|
||||
conversion_params=self._conversion_params)
|
||||
grappler_session_config = config_pb2.ConfigProto()
|
||||
grappler_session_config.graph_options.rewrite_options.CopyFrom(
|
||||
rewriter_config)
|
||||
return tf_optimizer.OptimizeGraph(
|
||||
grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
|
||||
|
||||
# TODO(laigd): provide a utility function to optimize a ConcreteFunction and
|
||||
# use it here (b/124792963).
|
||||
def convert(self):
|
||||
"""Convert the input SavedModel in 2.0 format."""
|
||||
assert not self._converted
|
||||
self._saved_model = load.load(self._input_saved_model_dir,
|
||||
self._input_saved_model_tags)
|
||||
func = self._saved_model.signatures[self._input_saved_model_signature_key]
|
||||
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
|
||||
grappler_meta_graph_def = saver.export_meta_graph(
|
||||
graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
|
||||
|
||||
# Add a collection 'train_op' so that Grappler knows the outputs.
|
||||
fetch_collection = meta_graph_pb2.CollectionDef()
|
||||
for array in frozen_func.inputs + frozen_func.outputs:
|
||||
fetch_collection.node_list.value.append(array.name)
|
||||
grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
|
||||
fetch_collection)
|
||||
|
||||
# Run TRT optimizer in Grappler to convert the graph.
|
||||
converted_graph_def = self._run_conversion(grappler_meta_graph_def)
|
||||
self._converted_func = wrap_function.function_from_graph_def(
|
||||
converted_graph_def, [tensor.name for tensor in frozen_func.inputs],
|
||||
[tensor.name for tensor in frozen_func.outputs])
|
||||
|
||||
self._converted = True
|
||||
return self._converted_func
|
||||
|
||||
def save(self, output_saved_model_dir):
|
||||
"""Save the converted SavedModel.
|
||||
|
||||
Args:
|
||||
output_saved_model_dir: directory to saved the converted SavedModel.
|
||||
"""
|
||||
assert self._converted
|
||||
# Rewrite the signature map using the optimized ConcreteFunction.
|
||||
signatures = {
|
||||
key: value for key, value in self._saved_model.signatures.items()
|
||||
}
|
||||
signatures[self._input_saved_model_signature_key] = self._converted_func
|
||||
save.save(self._saved_model, output_saved_model_dir, signatures)
|
||||
|
||||
|
||||
# TODO(laigd): use TrtConversionParams here.
|
||||
def create_inference_graph(
|
||||
input_graph_def,
|
||||
outputs,
|
||||
|
@ -19,14 +19,17 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import graph_util
|
||||
@ -35,7 +38,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import builder
|
||||
@ -44,10 +46,11 @@ from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.saved_model import utils
|
||||
from tensorflow.python.tools import saved_model_utils
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.tools import saved_model_utils
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
_SAVED_MODEL_SIGNATURE_KEY = "mypredict"
|
||||
|
||||
@ -63,8 +66,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
|
||||
rewriter_config_template=None,
|
||||
conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
max_batch_size=128,
|
||||
max_workspace_size_bytes=1234,
|
||||
precision_mode="INT8",
|
||||
@ -72,6 +74,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
is_dynamic_op=True,
|
||||
maximum_cached_engines=2,
|
||||
cached_engine_batches=[1, 128])
|
||||
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
|
||||
conversion_params=conversion_params)
|
||||
self.assertEqual(["constfold", "layout", "constfold"],
|
||||
rewriter_cfg.optimizers)
|
||||
self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE,
|
||||
@ -106,7 +110,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
gpu_options=config_pb2.GPUOptions(allow_growth=True))
|
||||
return config
|
||||
|
||||
def _GetGraph(self):
|
||||
@classmethod
|
||||
def _GetGraph(cls, inp, var):
|
||||
"""Get the graph for testing."""
|
||||
# The graph computes (input+1)^2, it looks like:
|
||||
#
|
||||
@ -119,24 +124,42 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
# +
|
||||
# |
|
||||
# output (Identity)
|
||||
add = inp + var
|
||||
mul = inp * add
|
||||
add = mul + add
|
||||
out = array_ops.identity(add, name="output")
|
||||
return out
|
||||
|
||||
def _GetModelForV2(self):
|
||||
|
||||
class SimpleModel(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self):
|
||||
self.v = None
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32)
|
||||
])
|
||||
def run(self, inp):
|
||||
if self.v is None:
|
||||
self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32)
|
||||
return TrtConvertTest._GetGraph(inp, self.v)
|
||||
|
||||
return SimpleModel()
|
||||
|
||||
def _GetGraphForV1(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
with g.device("/GPU:0"):
|
||||
inp = array_ops.placeholder(
|
||||
dtype=dtypes.float32, shape=[None, 1, 1], name="input")
|
||||
var = variables.VariableV1([[[1.0]]],
|
||||
dtype=dtypes.float32,
|
||||
name="v1",
|
||||
use_resource=False)
|
||||
add = inp + var.value()
|
||||
mul = inp * add
|
||||
add = mul + add
|
||||
out = array_ops.identity(add, name="output")
|
||||
return g, var, inp, out
|
||||
var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
|
||||
out = TrtConvertTest._GetGraph(inp, var)
|
||||
return g, var, inp, out
|
||||
|
||||
def _GetGraphDef(self):
|
||||
"""Get the graph def for testing."""
|
||||
g, var, _, _ = self._GetGraph()
|
||||
g, var, _, _ = self._GetGraphForV1()
|
||||
with self.session(graph=g, config=self._GetConfigProto()) as sess:
|
||||
sess.run(var.initializer)
|
||||
graph_def = graph_util.convert_variables_to_constants(
|
||||
@ -145,7 +168,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(
|
||||
{
|
||||
"v1": "Const",
|
||||
"v1/read": "Identity",
|
||||
"add/ReadVariableOp": "Identity",
|
||||
"input": "Placeholder",
|
||||
"add": "Add",
|
||||
"mul": "Mul",
|
||||
@ -156,7 +179,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _WriteInputSavedModel(self, input_saved_model_dir):
|
||||
"""Write the saved model as an input for testing."""
|
||||
g, var, inp, out = self._GetGraph()
|
||||
g, var, inp, out = self._GetGraphForV1()
|
||||
signature_def = signature_def_utils.build_signature_def(
|
||||
inputs={"myinput": utils.build_tensor_info(inp)},
|
||||
outputs={"myoutput": utils.build_tensor_info(out)},
|
||||
@ -183,7 +206,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||
input_graph_def=None if input_saved_model_dir else self._GetGraphDef(),
|
||||
nodes_blacklist=["output"],
|
||||
nodes_blacklist=None if input_saved_model_dir else ["output"],
|
||||
session_config=self._GetConfigProto(),
|
||||
max_batch_size=max_batch_size,
|
||||
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||
@ -193,28 +216,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
is_dynamic_op=is_dynamic_op,
|
||||
maximum_cached_engines=maximum_cached_engines,
|
||||
use_function_backup=use_function_backup)
|
||||
conversion_result = converter.convert()
|
||||
output_graph_def = converter.convert()
|
||||
|
||||
if context.executing_eagerly():
|
||||
output_graph_def = conversion_result.graph.as_graph_def()
|
||||
else:
|
||||
output_graph_def = conversion_result
|
||||
if need_calibration:
|
||||
|
||||
if need_calibration:
|
||||
class CalibrationData(object):
|
||||
|
||||
class CalibrationData(object):
|
||||
def __init__(self):
|
||||
self._data = 0
|
||||
|
||||
def __init__(self):
|
||||
self._data = 0
|
||||
def next(self):
|
||||
self._data += 1
|
||||
return {"input:0": [[[self._data]]]}
|
||||
|
||||
def next(self):
|
||||
self._data += 1
|
||||
return {"input:0": [[[self._data]]]}
|
||||
|
||||
output_graph_def = converter.calibrate(
|
||||
fetch_names=["output:0"],
|
||||
num_runs=10,
|
||||
feed_dict_fn=CalibrationData().next)
|
||||
output_graph_def = converter.calibrate(
|
||||
fetch_names=["output:0"],
|
||||
num_runs=10,
|
||||
feed_dict_fn=CalibrationData().next)
|
||||
|
||||
if output_saved_model_dir is not None:
|
||||
converter.save(output_saved_model_dir=output_saved_model_dir)
|
||||
@ -235,31 +253,19 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
graph_defs_to_verify = [output_graph_def]
|
||||
|
||||
if output_saved_model_dir:
|
||||
if context.executing_eagerly():
|
||||
root = load.load(output_saved_model_dir)
|
||||
saved_model_graph_def = root.signatures[
|
||||
_SAVED_MODEL_SIGNATURE_KEY].graph.as_graph_def()
|
||||
else:
|
||||
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
|
||||
output_saved_model_dir, tag_constants.SERVING).graph_def
|
||||
self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
|
||||
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
|
||||
output_saved_model_dir, tag_constants.SERVING).graph_def
|
||||
self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
|
||||
graph_defs_to_verify.append(saved_model_graph_def)
|
||||
|
||||
for graph_def in graph_defs_to_verify:
|
||||
node_name_to_op = {node.name: node.op for node in graph_def.node}
|
||||
if context.executing_eagerly():
|
||||
# In V2 the actual graph could be inside a function.
|
||||
for func in graph_def.library.function:
|
||||
node_name_to_op.update({node.name: node.op for node in func.node_def})
|
||||
self.assertIn("TRTEngineOp_0", node_name_to_op)
|
||||
self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"])
|
||||
else:
|
||||
self.assertEqual(
|
||||
{
|
||||
"input": "Placeholder",
|
||||
"TRTEngineOp_0": "TRTEngineOp",
|
||||
"output": "Identity"
|
||||
}, node_name_to_op)
|
||||
self.assertEqual(
|
||||
{
|
||||
"input": "Placeholder",
|
||||
"TRTEngineOp_0": "TRTEngineOp",
|
||||
"output": "Identity"
|
||||
}, node_name_to_op)
|
||||
|
||||
if need_calibration:
|
||||
trt_engine_nodes = [
|
||||
@ -306,39 +312,81 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
# TODO(laigd): we need to use ops like conv2d so Grappler can infer the
|
||||
# shapes (at least rank) of the tensors, so we're able to build an TRT
|
||||
# engine in dynamic mode. Currently shape information is not propagate from
|
||||
# ConcreteFunction to GraphDef, need to investigate and fix it.
|
||||
class SimpleModel(tracking.AutoTrackable):
|
||||
np_input = np.random.random_sample([4, 1, 1]).astype(np.float32)
|
||||
|
||||
def __init__(self):
|
||||
self.v = None
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=[None, 24, 24, 2], dtype=dtypes.float32)
|
||||
])
|
||||
def run(self, inp):
|
||||
if self.v is None:
|
||||
self.v = variables.Variable([[[[1., 0.5, 4., 6., 0.5, 1.],
|
||||
[1., 0.5, 1., 1., 0.5, 1.]]]])
|
||||
conv = gen_nn_ops.conv2d(
|
||||
input=inp, filter=self.v, strides=[1, 2, 2, 1], padding="SAME")
|
||||
identity = array_ops.identity(conv)
|
||||
return identity
|
||||
|
||||
tmp_dir = self.get_temp_dir()
|
||||
input_saved_model_dir = os.path.join(tmp_dir, "in_dir1_v2")
|
||||
root = SimpleModel()
|
||||
# Create a model and save it.
|
||||
input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||
root = self._GetModelForV2()
|
||||
expected_output = root.run(np_input)
|
||||
save.save(root, input_saved_model_dir,
|
||||
{_SAVED_MODEL_SIGNATURE_KEY: root.run})
|
||||
|
||||
# Convert the SavedModel and verify the result.
|
||||
output_saved_model_dir = os.path.join(tmp_dir, "out_dir1_v2")
|
||||
self._TestTrtGraphConverter(
|
||||
# Run TRT conversion.
|
||||
converter = trt_convert.TrtGraphConverterV2(
|
||||
input_saved_model_dir=input_saved_model_dir,
|
||||
output_saved_model_dir=output_saved_model_dir,
|
||||
is_dynamic_op=True)
|
||||
input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY,
|
||||
conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
precision_mode=trt_convert.TrtPrecisionMode.FP32,
|
||||
is_dynamic_op=True,
|
||||
maximum_cached_engines=2,
|
||||
use_function_backup=False))
|
||||
converted_concrete_func = converter.convert()
|
||||
|
||||
def _check_trt_ops(graph_def):
|
||||
trt_op_names = [
|
||||
node.name for node in graph_def.node if node.op == "TRTEngineOp"
|
||||
]
|
||||
for func in graph_def.library.function:
|
||||
for node in func.node_def:
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
self.assertEqual(1, len(trt_op_names))
|
||||
self.assertIn("TRTEngineOp_0", trt_op_names[0])
|
||||
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
self.assertIsInstance(converted_concrete_func, function.ConcreteFunction)
|
||||
converted_graph_def = converted_concrete_func.graph.as_graph_def()
|
||||
_check_trt_ops(converted_graph_def)
|
||||
output_with_trt = converted_concrete_func(ops.convert_to_tensor(np_input))
|
||||
self.assertEqual(1, len(output_with_trt))
|
||||
self.assertAllClose(
|
||||
expected_output, output_with_trt[0].numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
# Run the converted ConcreteFunction as a function and make sure it works.
|
||||
@def_function.function
|
||||
def wrapper_func(*args, **kwargs):
|
||||
return nest.flatten(converted_concrete_func(*args, **kwargs))
|
||||
|
||||
_check_trt_ops(
|
||||
wrapper_func.get_concrete_function(
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1],
|
||||
dtype=dtypes.float32)).graph.as_graph_def())
|
||||
output_with_trt = wrapper_func(np_input)
|
||||
self.assertEqual(1, len(output_with_trt))
|
||||
self.assertAllClose(
|
||||
expected_output, output_with_trt[0], atol=1e-6, rtol=1e-6)
|
||||
|
||||
# Save the converted model.
|
||||
output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
|
||||
converter.save(output_saved_model_dir)
|
||||
|
||||
# Load and verify the converted model.
|
||||
#
|
||||
# TODO(laigd): the name of then new input_signature of the
|
||||
# `root_with_trt.run` function is empty string (originaly was None),
|
||||
# investigate why.
|
||||
root_with_trt = load.load(output_saved_model_dir)
|
||||
# TODO(laigd): `root_with_trt.run` is still using the original graph without
|
||||
# trt. Consider changing that.
|
||||
# _check_trt_ops(
|
||||
# root_with_trt.run.get_concrete_function().graph.as_graph_def())
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
_check_trt_ops(converted_signature.graph.as_graph_def())
|
||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
output_with_trt = output_with_trt[output_with_trt.keys()[0]]
|
||||
self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def _TestRun(self,
|
||||
sess,
|
||||
@ -363,7 +411,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
|
||||
self.assertEqual(
|
||||
{
|
||||
"v1/read": "Const",
|
||||
"add/ReadVariableOp": "Const",
|
||||
"input": "Placeholder",
|
||||
"add": "Add",
|
||||
"mul": "Mul",
|
||||
|
Loading…
Reference in New Issue
Block a user