Export ConversionParams an API

This commit is contained in:
Pooya Davoodi 2020-01-15 15:57:38 -08:00
parent e8565092ff
commit 3a378dbfd5
3 changed files with 69 additions and 4 deletions

View File

@ -111,6 +111,7 @@ class TrtPrecisionMode(object):
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
@tf_export("experimental.tensorrt.ConversionParams", v1=[])
class TrtConversionParams(object):
""" A class to encapsulate parameters that are used for TF-TRT conversion."""
@ -888,7 +889,7 @@ class TrtGraphConverterV2(object):
1. FP32/FP16 precision
```python
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
params = tf.experimental.tensorrt.ConversionParams(
precision_mode='FP16')
converter = tf.experimental.tensorrt.Converter(
input_saved_model_dir="my_dir", conversion_params=params)
@ -904,7 +905,7 @@ class TrtGraphConverterV2(object):
2. FP32/FP16 precision with pre-built engines
```python
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
params = tf.experimental.tensorrt.ConversionParams(
precision_mode='FP16',
# Set this to a large enough number so it can cache all the engines.
maximum_cached_engines=16)
@ -936,7 +937,7 @@ class TrtGraphConverterV2(object):
3. INT8 precision and calibration with pre-built engines
```python
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
params = tf.experimental.tensorrt.ConversionParams(
precision_mode='INT8',
# Currently only one INT8 engine is supported in this mode.
maximum_cached_engines=1,
@ -974,7 +975,7 @@ class TrtGraphConverterV2(object):
input_saved_model_dir=None,
input_saved_model_tags=None,
input_saved_model_signature_key=None,
conversion_params=DEFAULT_TRT_CONVERSION_PARAMS):
conversion_params=TrtConversionParams()):
"""Initialize the converter.
Args:

View File

@ -26,6 +26,66 @@ if platform.system() != "Windows":
raise RuntimeError(
"This module is expected to be loaded only on Windows platform.")
@tf_export("experimental.tensorrt.ConversionParams", v1=[])
class TrtConversionParams(object):
""" A class to encapsulate parameters that are used for TF-TRT conversion."""
def __init__(self,
rewriter_config_template=None,
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=True,
maximum_cached_engines=1,
use_calibration=True,
max_batch_size=1,
allow_build_at_runtime=True):
"""Initialize TrtConversionParams.
Args:
rewriter_config_template: a template RewriterConfig proto used to create a
TRT-enabled RewriterConfig. If None, it will use a default one.
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
engine can use at execution time. This corresponds to the
'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph
to be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the
TRT network and engine at run time. i.e. Since TensorRT version < 6.0
does not support dynamic dimensions other than the batch dimension,
when the TensorFlow graph has a non-batch dimension of dynamic size,
we would need to enable this option. This option should be set to True
in TF 2.0.
maximum_cached_engines: max number of cached TRT engines for dynamic TRT
ops. Created TRT engines for a dynamic dimension are cached. This is
the maximum number of engines that can be cached. If the number of
cached engines is already at max but none of them supports the input
shapes, the TRTEngineOp will fall back to run the original TF subgraph
that corresponds to the TRTEngineOp.
use_calibration: this argument is ignored if precision_mode is not INT8.
If set to True, a calibration graph will be created to calibrate the
missing ranges. The calibration graph must be converted to an inference
graph by running calibration with calibrate(). If set to False,
quantization nodes will be expected for every tensor in the graph
(exlcuding those which will be fused). If a range is missing, an error
will occur. Please note that accuracy may be negatively affected if
there is a mismatch between which tensors TRT quantizes and which
tensors were trained with fake quantization.
max_batch_size: max size for the input batch. This parameter is only
effective when is_dynamic_op=False which is not supported in TF 2.0.
allow_build_at_runtime: whether to build TensorRT engines during runtime.
If no TensorRT engine can be found in cache that can handle the given
inputs during runtime, then a new TensorRT engine is built at runtime
if allow_build_at_runtime=True, and otherwise native TF is used. This
argument is only effective if is_dynamic_op=True.
Raises:
NotImplementedError: TRT is not supported on Windows.
"""
raise NotImplementedError(
"TensorRT integration is not available on Windows.")
@tf_export("experimental.tensorrt.Converter", v1=[])
class TrtConverterWindows(object):

View File

@ -1,5 +1,9 @@
path: "tensorflow.experimental.tensorrt"
tf_module {
member {
name: "ConversionParams"
mtype: "<type \'type\'>"
}
member {
name: "Converter"
mtype: "<type \'type\'>"