Export tftrt python symbols.

PiperOrigin-RevId: 279160795
Change-Id: I84a5fed8888d06a877587cd9dd8d6b3eddbdca7d
This commit is contained in:
Guangda Lai 2019-11-07 13:55:01 -08:00 committed by TensorFlower Gardener
parent 352c88a315
commit 7eee3d7db6
10 changed files with 121 additions and 16 deletions

View File

@ -4165,6 +4165,7 @@ py_library(
":util",
":variable_scope",
":variables",
"//tensorflow/python/compiler",
"//tensorflow/python/eager:wrap_function",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/ops/linalg",

View File

@ -1,7 +1,7 @@
# Description:
# Python APIs for various Tensorflow backends.
load("//tensorflow:tensorflow.bzl", "if_not_windows")
load("//tensorflow:tensorflow.bzl", "if_windows")
package(
default_visibility = ["//visibility:public"],
@ -14,9 +14,10 @@ py_library(
name = "compiler",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = if_not_windows([
"//tensorflow/python/compiler/tensorrt:init_py",
]) + [
deps = if_windows(
["//tensorflow/python/compiler/tensorrt:trt_convert_windows"],
otherwise = ["//tensorflow/python/compiler/tensorrt:init_py"],
) + [
"//tensorflow/python/compiler/mlir",
"//tensorflow/python/compiler/xla:compiler_py",
],

View File

@ -3,16 +3,8 @@
# and provide TensorRT operators and converter package.
# APIs are meant to change over time.
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
)
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load(
"@local_config_tensorrt//:build_defs.bzl",
"if_tensorrt",
)
package(
default_visibility = ["//visibility:public"],
@ -58,6 +50,15 @@ py_library(
],
)
py_library(
name = "trt_convert_windows",
srcs = ["trt_convert_windows.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
],
)
py_library(
name = "tf_trt_integration_test_base",
srcs = ["test/tf_trt_integration_test_base.py"],

View File

@ -52,6 +52,7 @@ from tensorflow.python.training import saver
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# Lazily load the op, since it's not available in cpu-only builds. Importing
# this at top will cause tests that imports TF-TRT fail when they're built
@ -389,7 +390,8 @@ class TrtGraphConverter(object):
RuntimeError: if this class is used in TF 2.0.
"""
if context.executing_eagerly():
raise RuntimeError("Please use TrtGraphConverterV2 in TF 2.0.")
raise RuntimeError(
"Please use tf.experimental.tensorrt.Converter in TF 2.0.")
if input_graph_def and input_saved_model_dir:
raise ValueError(
@ -778,9 +780,12 @@ class _TRTEngineResource(tracking.TrackableResource):
max_cached_engines_count=self._maximum_cached_engines)
@tf_export("experimental.tensorrt.Converter", v1=[])
class TrtGraphConverterV2(object):
"""An offline converter for TF-TRT transformation for TF 2.0 SavedModels.
Currently this is not available on Windows platform.
Note that in V2, is_dynamic_op=False is not supported, meaning TRT engines
will be built only when the corresponding TRTEngineOp is executed. But we
still provide a way to avoid the cost of building TRT engines during inference
@ -793,7 +798,7 @@ class TrtGraphConverterV2(object):
```python
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
precision_mode='FP16')
converter = TrtGraphConverterV2(
converter = tf.experimental.tensorrt.Converter(
input_saved_model_dir="my_dir", conversion_params=params)
converter.convert()
converter.save(output_saved_model_dir)
@ -811,7 +816,7 @@ class TrtGraphConverterV2(object):
precision_mode='FP16',
# Set this to a large enough number so it can cache all the engines.
maximum_cached_engines=16)
converter = TrtGraphConverterV2(
converter = tf.experimental.tensorrt.Converter(
input_saved_model_dir="my_dir", conversion_params=params)
converter.convert()
@ -844,7 +849,7 @@ class TrtGraphConverterV2(object):
# Currently only one INT8 engine is supported in this mode.
maximum_cached_engines=1,
use_calibration=True)
converter = TrtGraphConverterV2(
converter = tf.experimental.tensorrt.Converter(
input_saved_model_dir="my_dir", conversion_params=params)
# Define a generator function that yields input data, and run INT8

View File

@ -0,0 +1,56 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Exposes the TRT conversion for Windows platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform
from tensorflow.python.util.tf_export import tf_export
if platform.system() != "Windows":
raise RuntimeError(
"This module is expected to be loaded only on Windows platform.")
@tf_export("experimental.tensorrt.Converter", v1=[])
class TrtConverterWindows(object):
"""An offline converter for TF-TRT transformation for TF 2.0 SavedModels.
Currently this is not available on Windows platform.
"""
def __init__(self,
input_saved_model_dir=None,
input_saved_model_tags=None,
input_saved_model_signature_key=None,
conversion_params=None):
"""Initialize the converter.
Args:
input_saved_model_dir: the directory to load the SavedModel which contains
the input graph to transforms. Used only when input_graph_def is None.
input_saved_model_tags: list of tags to load the SavedModel.
input_saved_model_signature_key: the key of the signature to optimize the
graph for.
conversion_params: a TrtConversionParams instance.
Raises:
NotImplementedError: TRT is not supported on Windows.
"""
raise NotImplementedError(
"TensorRT integration is not available on Windows.")

View File

@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import platform as _platform
import sys as _sys
from tensorflow.python import autograph
@ -109,6 +110,13 @@ from tensorflow.python.ops.variable_scope import *
from tensorflow.python.ops.variables import *
from tensorflow.python.ops.parallel_for.control_flow_ops import vectorized_map
# pylint: disable=g-import-not-at-top
if _platform.system() == "Windows":
from tensorflow.python.compiler.tensorrt import trt_convert_windows as trt
else:
from tensorflow.python.compiler.tensorrt import trt_convert as trt
# pylint: enable=g-import-not-at-top
# pylint: enable=wildcard-import
# pylint: enable=g-bad-import-order

View File

@ -24,6 +24,7 @@ TENSORFLOW_API_INIT_FILES = [
"dtypes/__init__.py",
"errors/__init__.py",
"experimental/__init__.py",
"experimental/tensorrt/__init__.py",
"feature_column/__init__.py",
"io/gfile/__init__.py",
"graph_util/__init__.py",

View File

@ -1,5 +1,9 @@
path: "tensorflow.experimental"
tf_module {
member {
name: "tensorrt"
mtype: "<type \'module\'>"
}
member_method {
name: "function_executor_type"
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"

View File

@ -0,0 +1,21 @@
path: "tensorflow.experimental.tensorrt.Converter"
tf_class {
is_instance: "<class \'tensorflow.python.compiler.tensorrt.trt_convert.TrtGraphConverterV2\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'input_saved_model_dir\', \'input_saved_model_tags\', \'input_saved_model_signature_key\', \'conversion_params\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"TrtConversionParams(rewriter_config_template=None, max_workspace_size_bytes=1073741824, precision_mode=\'FP32\', minimum_segment_size=3, is_dynamic_op=True, maximum_cached_engines=1, use_calibration=True, max_batch_size=1)\"], "
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "convert"
argspec: "args=[\'self\', \'calibration_input_fn\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "save"
argspec: "args=[\'self\', \'output_saved_model_dir\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.experimental.tensorrt"
tf_module {
member {
name: "Converter"
mtype: "<type \'type\'>"
}
}