Export tftrt python symbols.

PiperOrigin-RevId: 279160795
Change-Id: I84a5fed8888d06a877587cd9dd8d6b3eddbdca7d
This commit is contained in:
Guangda Lai 2019-11-07 13:55:01 -08:00 committed by TensorFlower Gardener
parent 352c88a315
commit 7eee3d7db6
10 changed files with 121 additions and 16 deletions

View File

@ -4165,6 +4165,7 @@ py_library(
":util", ":util",
":variable_scope", ":variable_scope",
":variables", ":variables",
"//tensorflow/python/compiler",
"//tensorflow/python/eager:wrap_function", "//tensorflow/python/eager:wrap_function",
"//tensorflow/python/ops/distributions", "//tensorflow/python/ops/distributions",
"//tensorflow/python/ops/linalg", "//tensorflow/python/ops/linalg",

View File

@ -1,7 +1,7 @@
# Description: # Description:
# Python APIs for various Tensorflow backends. # Python APIs for various Tensorflow backends.
load("//tensorflow:tensorflow.bzl", "if_not_windows") load("//tensorflow:tensorflow.bzl", "if_windows")
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
@ -14,9 +14,10 @@ py_library(
name = "compiler", name = "compiler",
srcs = ["__init__.py"], srcs = ["__init__.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = if_not_windows([ deps = if_windows(
"//tensorflow/python/compiler/tensorrt:init_py", ["//tensorflow/python/compiler/tensorrt:trt_convert_windows"],
]) + [ otherwise = ["//tensorflow/python/compiler/tensorrt:init_py"],
) + [
"//tensorflow/python/compiler/mlir", "//tensorflow/python/compiler/mlir",
"//tensorflow/python/compiler/xla:compiler_py", "//tensorflow/python/compiler/xla:compiler_py",
], ],

View File

@ -3,16 +3,8 @@
# and provide TensorRT operators and converter package. # and provide TensorRT operators and converter package.
# APIs are meant to change over time. # APIs are meant to change over time.
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
)
load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load(
"@local_config_tensorrt//:build_defs.bzl",
"if_tensorrt",
)
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
@ -58,6 +50,15 @@ py_library(
], ],
) )
py_library(
name = "trt_convert_windows",
srcs = ["trt_convert_windows.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
],
)
py_library( py_library(
name = "tf_trt_integration_test_base", name = "tf_trt_integration_test_base",
srcs = ["test/tf_trt_integration_test_base.py"], srcs = ["test/tf_trt_integration_test_base.py"],

View File

@ -52,6 +52,7 @@ from tensorflow.python.training import saver
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# Lazily load the op, since it's not available in cpu-only builds. Importing # Lazily load the op, since it's not available in cpu-only builds. Importing
# this at top will cause tests that imports TF-TRT fail when they're built # this at top will cause tests that imports TF-TRT fail when they're built
@ -389,7 +390,8 @@ class TrtGraphConverter(object):
RuntimeError: if this class is used in TF 2.0. RuntimeError: if this class is used in TF 2.0.
""" """
if context.executing_eagerly(): if context.executing_eagerly():
raise RuntimeError("Please use TrtGraphConverterV2 in TF 2.0.") raise RuntimeError(
"Please use tf.experimental.tensorrt.Converter in TF 2.0.")
if input_graph_def and input_saved_model_dir: if input_graph_def and input_saved_model_dir:
raise ValueError( raise ValueError(
@ -778,9 +780,12 @@ class _TRTEngineResource(tracking.TrackableResource):
max_cached_engines_count=self._maximum_cached_engines) max_cached_engines_count=self._maximum_cached_engines)
@tf_export("experimental.tensorrt.Converter", v1=[])
class TrtGraphConverterV2(object): class TrtGraphConverterV2(object):
"""An offline converter for TF-TRT transformation for TF 2.0 SavedModels. """An offline converter for TF-TRT transformation for TF 2.0 SavedModels.
Currently this is not available on Windows platform.
Note that in V2, is_dynamic_op=False is not supported, meaning TRT engines Note that in V2, is_dynamic_op=False is not supported, meaning TRT engines
will be built only when the corresponding TRTEngineOp is executed. But we will be built only when the corresponding TRTEngineOp is executed. But we
still provide a way to avoid the cost of building TRT engines during inference still provide a way to avoid the cost of building TRT engines during inference
@ -793,7 +798,7 @@ class TrtGraphConverterV2(object):
```python ```python
params = DEFAULT_TRT_CONVERSION_PARAMS._replace( params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode='FP16') precision_mode='FP16')
converter = TrtGraphConverterV2( converter = tf.experimental.tensorrt.Converter(
input_saved_model_dir="my_dir", conversion_params=params) input_saved_model_dir="my_dir", conversion_params=params)
converter.convert() converter.convert()
converter.save(output_saved_model_dir) converter.save(output_saved_model_dir)
@ -811,7 +816,7 @@ class TrtGraphConverterV2(object):
precision_mode='FP16', precision_mode='FP16',
# Set this to a large enough number so it can cache all the engines. # Set this to a large enough number so it can cache all the engines.
maximum_cached_engines=16) maximum_cached_engines=16)
converter = TrtGraphConverterV2( converter = tf.experimental.tensorrt.Converter(
input_saved_model_dir="my_dir", conversion_params=params) input_saved_model_dir="my_dir", conversion_params=params)
converter.convert() converter.convert()
@ -844,7 +849,7 @@ class TrtGraphConverterV2(object):
# Currently only one INT8 engine is supported in this mode. # Currently only one INT8 engine is supported in this mode.
maximum_cached_engines=1, maximum_cached_engines=1,
use_calibration=True) use_calibration=True)
converter = TrtGraphConverterV2( converter = tf.experimental.tensorrt.Converter(
input_saved_model_dir="my_dir", conversion_params=params) input_saved_model_dir="my_dir", conversion_params=params)
# Define a generator function that yields input data, and run INT8 # Define a generator function that yields input data, and run INT8

View File

@ -0,0 +1,56 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Exposes the TRT conversion for Windows platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform
from tensorflow.python.util.tf_export import tf_export
if platform.system() != "Windows":
raise RuntimeError(
"This module is expected to be loaded only on Windows platform.")
@tf_export("experimental.tensorrt.Converter", v1=[])
class TrtConverterWindows(object):
"""An offline converter for TF-TRT transformation for TF 2.0 SavedModels.
Currently this is not available on Windows platform.
"""
def __init__(self,
input_saved_model_dir=None,
input_saved_model_tags=None,
input_saved_model_signature_key=None,
conversion_params=None):
"""Initialize the converter.
Args:
input_saved_model_dir: the directory to load the SavedModel which contains
the input graph to transforms. Used only when input_graph_def is None.
input_saved_model_tags: list of tags to load the SavedModel.
input_saved_model_signature_key: the key of the signature to optimize the
graph for.
conversion_params: a TrtConversionParams instance.
Raises:
NotImplementedError: TRT is not supported on Windows.
"""
raise NotImplementedError(
"TensorRT integration is not available on Windows.")

View File

@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import platform as _platform
import sys as _sys import sys as _sys
from tensorflow.python import autograph from tensorflow.python import autograph
@ -109,6 +110,13 @@ from tensorflow.python.ops.variable_scope import *
from tensorflow.python.ops.variables import * from tensorflow.python.ops.variables import *
from tensorflow.python.ops.parallel_for.control_flow_ops import vectorized_map from tensorflow.python.ops.parallel_for.control_flow_ops import vectorized_map
# pylint: disable=g-import-not-at-top
if _platform.system() == "Windows":
from tensorflow.python.compiler.tensorrt import trt_convert_windows as trt
else:
from tensorflow.python.compiler.tensorrt import trt_convert as trt
# pylint: enable=g-import-not-at-top
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order

View File

@ -24,6 +24,7 @@ TENSORFLOW_API_INIT_FILES = [
"dtypes/__init__.py", "dtypes/__init__.py",
"errors/__init__.py", "errors/__init__.py",
"experimental/__init__.py", "experimental/__init__.py",
"experimental/tensorrt/__init__.py",
"feature_column/__init__.py", "feature_column/__init__.py",
"io/gfile/__init__.py", "io/gfile/__init__.py",
"graph_util/__init__.py", "graph_util/__init__.py",

View File

@ -1,5 +1,9 @@
path: "tensorflow.experimental" path: "tensorflow.experimental"
tf_module { tf_module {
member {
name: "tensorrt"
mtype: "<type \'module\'>"
}
member_method { member_method {
name: "function_executor_type" name: "function_executor_type"
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"

View File

@ -0,0 +1,21 @@
path: "tensorflow.experimental.tensorrt.Converter"
tf_class {
is_instance: "<class \'tensorflow.python.compiler.tensorrt.trt_convert.TrtGraphConverterV2\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'input_saved_model_dir\', \'input_saved_model_tags\', \'input_saved_model_signature_key\', \'conversion_params\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"TrtConversionParams(rewriter_config_template=None, max_workspace_size_bytes=1073741824, precision_mode=\'FP32\', minimum_segment_size=3, is_dynamic_op=True, maximum_cached_engines=1, use_calibration=True, max_batch_size=1)\"], "
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "convert"
argspec: "args=[\'self\', \'calibration_input_fn\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "save"
argspec: "args=[\'self\', \'output_saved_model_dir\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.experimental.tensorrt"
tf_module {
member {
name: "Converter"
mtype: "<type \'type\'>"
}
}