Export ConversionParams an API
This commit is contained in:
parent
e8565092ff
commit
3a378dbfd5
@ -111,6 +111,7 @@ class TrtPrecisionMode(object):
|
||||
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
||||
|
||||
|
||||
@tf_export("experimental.tensorrt.ConversionParams", v1=[])
|
||||
class TrtConversionParams(object):
|
||||
""" A class to encapsulate parameters that are used for TF-TRT conversion."""
|
||||
|
||||
@ -888,7 +889,7 @@ class TrtGraphConverterV2(object):
|
||||
1. FP32/FP16 precision
|
||||
|
||||
```python
|
||||
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
params = tf.experimental.tensorrt.ConversionParams(
|
||||
precision_mode='FP16')
|
||||
converter = tf.experimental.tensorrt.Converter(
|
||||
input_saved_model_dir="my_dir", conversion_params=params)
|
||||
@ -904,7 +905,7 @@ class TrtGraphConverterV2(object):
|
||||
2. FP32/FP16 precision with pre-built engines
|
||||
|
||||
```python
|
||||
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
params = tf.experimental.tensorrt.ConversionParams(
|
||||
precision_mode='FP16',
|
||||
# Set this to a large enough number so it can cache all the engines.
|
||||
maximum_cached_engines=16)
|
||||
@ -936,7 +937,7 @@ class TrtGraphConverterV2(object):
|
||||
3. INT8 precision and calibration with pre-built engines
|
||||
|
||||
```python
|
||||
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
params = tf.experimental.tensorrt.ConversionParams(
|
||||
precision_mode='INT8',
|
||||
# Currently only one INT8 engine is supported in this mode.
|
||||
maximum_cached_engines=1,
|
||||
@ -974,7 +975,7 @@ class TrtGraphConverterV2(object):
|
||||
input_saved_model_dir=None,
|
||||
input_saved_model_tags=None,
|
||||
input_saved_model_signature_key=None,
|
||||
conversion_params=DEFAULT_TRT_CONVERSION_PARAMS):
|
||||
conversion_params=TrtConversionParams()):
|
||||
"""Initialize the converter.
|
||||
|
||||
Args:
|
||||
|
@ -26,6 +26,66 @@ if platform.system() != "Windows":
|
||||
raise RuntimeError(
|
||||
"This module is expected to be loaded only on Windows platform.")
|
||||
|
||||
@tf_export("experimental.tensorrt.ConversionParams", v1=[])
|
||||
class TrtConversionParams(object):
|
||||
""" A class to encapsulate parameters that are used for TF-TRT conversion."""
|
||||
|
||||
def __init__(self,
|
||||
rewriter_config_template=None,
|
||||
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||
precision_mode=TrtPrecisionMode.FP32,
|
||||
minimum_segment_size=3,
|
||||
is_dynamic_op=True,
|
||||
maximum_cached_engines=1,
|
||||
use_calibration=True,
|
||||
max_batch_size=1,
|
||||
allow_build_at_runtime=True):
|
||||
"""Initialize TrtConversionParams.
|
||||
|
||||
Args:
|
||||
rewriter_config_template: a template RewriterConfig proto used to create a
|
||||
TRT-enabled RewriterConfig. If None, it will use a default one.
|
||||
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
|
||||
engine can use at execution time. This corresponds to the
|
||||
'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
|
||||
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
|
||||
minimum_segment_size: the minimum number of nodes required for a subgraph
|
||||
to be replaced by TRTEngineOp.
|
||||
is_dynamic_op: whether to generate dynamic TRT ops which will build the
|
||||
TRT network and engine at run time. i.e. Since TensorRT version < 6.0
|
||||
does not support dynamic dimensions other than the batch dimension,
|
||||
when the TensorFlow graph has a non-batch dimension of dynamic size,
|
||||
we would need to enable this option. This option should be set to True
|
||||
in TF 2.0.
|
||||
maximum_cached_engines: max number of cached TRT engines for dynamic TRT
|
||||
ops. Created TRT engines for a dynamic dimension are cached. This is
|
||||
the maximum number of engines that can be cached. If the number of
|
||||
cached engines is already at max but none of them supports the input
|
||||
shapes, the TRTEngineOp will fall back to run the original TF subgraph
|
||||
that corresponds to the TRTEngineOp.
|
||||
use_calibration: this argument is ignored if precision_mode is not INT8.
|
||||
If set to True, a calibration graph will be created to calibrate the
|
||||
missing ranges. The calibration graph must be converted to an inference
|
||||
graph by running calibration with calibrate(). If set to False,
|
||||
quantization nodes will be expected for every tensor in the graph
|
||||
(exlcuding those which will be fused). If a range is missing, an error
|
||||
will occur. Please note that accuracy may be negatively affected if
|
||||
there is a mismatch between which tensors TRT quantizes and which
|
||||
tensors were trained with fake quantization.
|
||||
max_batch_size: max size for the input batch. This parameter is only
|
||||
effective when is_dynamic_op=False which is not supported in TF 2.0.
|
||||
allow_build_at_runtime: whether to build TensorRT engines during runtime.
|
||||
If no TensorRT engine can be found in cache that can handle the given
|
||||
inputs during runtime, then a new TensorRT engine is built at runtime
|
||||
if allow_build_at_runtime=True, and otherwise native TF is used. This
|
||||
argument is only effective if is_dynamic_op=True.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: TRT is not supported on Windows.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"TensorRT integration is not available on Windows.")
|
||||
|
||||
|
||||
@tf_export("experimental.tensorrt.Converter", v1=[])
|
||||
class TrtConverterWindows(object):
|
||||
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.experimental.tensorrt"
|
||||
tf_module {
|
||||
member {
|
||||
name: "ConversionParams"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Converter"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user