Export ConversionParams an API
This commit is contained in:
parent
e8565092ff
commit
3a378dbfd5
@ -111,6 +111,7 @@ class TrtPrecisionMode(object):
|
|||||||
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("experimental.tensorrt.ConversionParams", v1=[])
|
||||||
class TrtConversionParams(object):
|
class TrtConversionParams(object):
|
||||||
""" A class to encapsulate parameters that are used for TF-TRT conversion."""
|
""" A class to encapsulate parameters that are used for TF-TRT conversion."""
|
||||||
|
|
||||||
@ -888,7 +889,7 @@ class TrtGraphConverterV2(object):
|
|||||||
1. FP32/FP16 precision
|
1. FP32/FP16 precision
|
||||||
|
|
||||||
```python
|
```python
|
||||||
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
params = tf.experimental.tensorrt.ConversionParams(
|
||||||
precision_mode='FP16')
|
precision_mode='FP16')
|
||||||
converter = tf.experimental.tensorrt.Converter(
|
converter = tf.experimental.tensorrt.Converter(
|
||||||
input_saved_model_dir="my_dir", conversion_params=params)
|
input_saved_model_dir="my_dir", conversion_params=params)
|
||||||
@ -904,7 +905,7 @@ class TrtGraphConverterV2(object):
|
|||||||
2. FP32/FP16 precision with pre-built engines
|
2. FP32/FP16 precision with pre-built engines
|
||||||
|
|
||||||
```python
|
```python
|
||||||
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
params = tf.experimental.tensorrt.ConversionParams(
|
||||||
precision_mode='FP16',
|
precision_mode='FP16',
|
||||||
# Set this to a large enough number so it can cache all the engines.
|
# Set this to a large enough number so it can cache all the engines.
|
||||||
maximum_cached_engines=16)
|
maximum_cached_engines=16)
|
||||||
@ -936,7 +937,7 @@ class TrtGraphConverterV2(object):
|
|||||||
3. INT8 precision and calibration with pre-built engines
|
3. INT8 precision and calibration with pre-built engines
|
||||||
|
|
||||||
```python
|
```python
|
||||||
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
params = tf.experimental.tensorrt.ConversionParams(
|
||||||
precision_mode='INT8',
|
precision_mode='INT8',
|
||||||
# Currently only one INT8 engine is supported in this mode.
|
# Currently only one INT8 engine is supported in this mode.
|
||||||
maximum_cached_engines=1,
|
maximum_cached_engines=1,
|
||||||
@ -974,7 +975,7 @@ class TrtGraphConverterV2(object):
|
|||||||
input_saved_model_dir=None,
|
input_saved_model_dir=None,
|
||||||
input_saved_model_tags=None,
|
input_saved_model_tags=None,
|
||||||
input_saved_model_signature_key=None,
|
input_saved_model_signature_key=None,
|
||||||
conversion_params=DEFAULT_TRT_CONVERSION_PARAMS):
|
conversion_params=TrtConversionParams()):
|
||||||
"""Initialize the converter.
|
"""Initialize the converter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -26,6 +26,66 @@ if platform.system() != "Windows":
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"This module is expected to be loaded only on Windows platform.")
|
"This module is expected to be loaded only on Windows platform.")
|
||||||
|
|
||||||
|
@tf_export("experimental.tensorrt.ConversionParams", v1=[])
|
||||||
|
class TrtConversionParams(object):
|
||||||
|
""" A class to encapsulate parameters that are used for TF-TRT conversion."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
rewriter_config_template=None,
|
||||||
|
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||||
|
precision_mode=TrtPrecisionMode.FP32,
|
||||||
|
minimum_segment_size=3,
|
||||||
|
is_dynamic_op=True,
|
||||||
|
maximum_cached_engines=1,
|
||||||
|
use_calibration=True,
|
||||||
|
max_batch_size=1,
|
||||||
|
allow_build_at_runtime=True):
|
||||||
|
"""Initialize TrtConversionParams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewriter_config_template: a template RewriterConfig proto used to create a
|
||||||
|
TRT-enabled RewriterConfig. If None, it will use a default one.
|
||||||
|
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
|
||||||
|
engine can use at execution time. This corresponds to the
|
||||||
|
'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
|
||||||
|
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
|
||||||
|
minimum_segment_size: the minimum number of nodes required for a subgraph
|
||||||
|
to be replaced by TRTEngineOp.
|
||||||
|
is_dynamic_op: whether to generate dynamic TRT ops which will build the
|
||||||
|
TRT network and engine at run time. i.e. Since TensorRT version < 6.0
|
||||||
|
does not support dynamic dimensions other than the batch dimension,
|
||||||
|
when the TensorFlow graph has a non-batch dimension of dynamic size,
|
||||||
|
we would need to enable this option. This option should be set to True
|
||||||
|
in TF 2.0.
|
||||||
|
maximum_cached_engines: max number of cached TRT engines for dynamic TRT
|
||||||
|
ops. Created TRT engines for a dynamic dimension are cached. This is
|
||||||
|
the maximum number of engines that can be cached. If the number of
|
||||||
|
cached engines is already at max but none of them supports the input
|
||||||
|
shapes, the TRTEngineOp will fall back to run the original TF subgraph
|
||||||
|
that corresponds to the TRTEngineOp.
|
||||||
|
use_calibration: this argument is ignored if precision_mode is not INT8.
|
||||||
|
If set to True, a calibration graph will be created to calibrate the
|
||||||
|
missing ranges. The calibration graph must be converted to an inference
|
||||||
|
graph by running calibration with calibrate(). If set to False,
|
||||||
|
quantization nodes will be expected for every tensor in the graph
|
||||||
|
(exlcuding those which will be fused). If a range is missing, an error
|
||||||
|
will occur. Please note that accuracy may be negatively affected if
|
||||||
|
there is a mismatch between which tensors TRT quantizes and which
|
||||||
|
tensors were trained with fake quantization.
|
||||||
|
max_batch_size: max size for the input batch. This parameter is only
|
||||||
|
effective when is_dynamic_op=False which is not supported in TF 2.0.
|
||||||
|
allow_build_at_runtime: whether to build TensorRT engines during runtime.
|
||||||
|
If no TensorRT engine can be found in cache that can handle the given
|
||||||
|
inputs during runtime, then a new TensorRT engine is built at runtime
|
||||||
|
if allow_build_at_runtime=True, and otherwise native TF is used. This
|
||||||
|
argument is only effective if is_dynamic_op=True.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: TRT is not supported on Windows.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"TensorRT integration is not available on Windows.")
|
||||||
|
|
||||||
|
|
||||||
@tf_export("experimental.tensorrt.Converter", v1=[])
|
@tf_export("experimental.tensorrt.Converter", v1=[])
|
||||||
class TrtConverterWindows(object):
|
class TrtConverterWindows(object):
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
path: "tensorflow.experimental.tensorrt"
|
path: "tensorflow.experimental.tensorrt"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "ConversionParams"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "Converter"
|
name: "Converter"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user