Export tftrt python symbols.
PiperOrigin-RevId: 279160795 Change-Id: I84a5fed8888d06a877587cd9dd8d6b3eddbdca7d
This commit is contained in:
parent
352c88a315
commit
7eee3d7db6
@ -4165,6 +4165,7 @@ py_library(
|
||||
":util",
|
||||
":variable_scope",
|
||||
":variables",
|
||||
"//tensorflow/python/compiler",
|
||||
"//tensorflow/python/eager:wrap_function",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Description:
|
||||
# Python APIs for various Tensorflow backends.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||
load("//tensorflow:tensorflow.bzl", "if_windows")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
@ -14,9 +14,10 @@ py_library(
|
||||
name = "compiler",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = if_not_windows([
|
||||
"//tensorflow/python/compiler/tensorrt:init_py",
|
||||
]) + [
|
||||
deps = if_windows(
|
||||
["//tensorflow/python/compiler/tensorrt:trt_convert_windows"],
|
||||
otherwise = ["//tensorflow/python/compiler/tensorrt:init_py"],
|
||||
) + [
|
||||
"//tensorflow/python/compiler/mlir",
|
||||
"//tensorflow/python/compiler/xla:compiler_py",
|
||||
],
|
||||
|
@ -3,16 +3,8 @@
|
||||
# and provide TensorRT operators and converter package.
|
||||
# APIs are meant to change over time.
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_copts",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
load(
|
||||
"@local_config_tensorrt//:build_defs.bzl",
|
||||
"if_tensorrt",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
@ -58,6 +50,15 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "trt_convert_windows",
|
||||
srcs = ["trt_convert_windows.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tf_trt_integration_test_base",
|
||||
srcs = ["test/tf_trt_integration_test_base.py"],
|
||||
|
@ -52,6 +52,7 @@ from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# Lazily load the op, since it's not available in cpu-only builds. Importing
|
||||
# this at top will cause tests that imports TF-TRT fail when they're built
|
||||
@ -389,7 +390,8 @@ class TrtGraphConverter(object):
|
||||
RuntimeError: if this class is used in TF 2.0.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("Please use TrtGraphConverterV2 in TF 2.0.")
|
||||
raise RuntimeError(
|
||||
"Please use tf.experimental.tensorrt.Converter in TF 2.0.")
|
||||
|
||||
if input_graph_def and input_saved_model_dir:
|
||||
raise ValueError(
|
||||
@ -778,9 +780,12 @@ class _TRTEngineResource(tracking.TrackableResource):
|
||||
max_cached_engines_count=self._maximum_cached_engines)
|
||||
|
||||
|
||||
@tf_export("experimental.tensorrt.Converter", v1=[])
|
||||
class TrtGraphConverterV2(object):
|
||||
"""An offline converter for TF-TRT transformation for TF 2.0 SavedModels.
|
||||
|
||||
Currently this is not available on Windows platform.
|
||||
|
||||
Note that in V2, is_dynamic_op=False is not supported, meaning TRT engines
|
||||
will be built only when the corresponding TRTEngineOp is executed. But we
|
||||
still provide a way to avoid the cost of building TRT engines during inference
|
||||
@ -793,7 +798,7 @@ class TrtGraphConverterV2(object):
|
||||
```python
|
||||
params = DEFAULT_TRT_CONVERSION_PARAMS._replace(
|
||||
precision_mode='FP16')
|
||||
converter = TrtGraphConverterV2(
|
||||
converter = tf.experimental.tensorrt.Converter(
|
||||
input_saved_model_dir="my_dir", conversion_params=params)
|
||||
converter.convert()
|
||||
converter.save(output_saved_model_dir)
|
||||
@ -811,7 +816,7 @@ class TrtGraphConverterV2(object):
|
||||
precision_mode='FP16',
|
||||
# Set this to a large enough number so it can cache all the engines.
|
||||
maximum_cached_engines=16)
|
||||
converter = TrtGraphConverterV2(
|
||||
converter = tf.experimental.tensorrt.Converter(
|
||||
input_saved_model_dir="my_dir", conversion_params=params)
|
||||
converter.convert()
|
||||
|
||||
@ -844,7 +849,7 @@ class TrtGraphConverterV2(object):
|
||||
# Currently only one INT8 engine is supported in this mode.
|
||||
maximum_cached_engines=1,
|
||||
use_calibration=True)
|
||||
converter = TrtGraphConverterV2(
|
||||
converter = tf.experimental.tensorrt.Converter(
|
||||
input_saved_model_dir="my_dir", conversion_params=params)
|
||||
|
||||
# Define a generator function that yields input data, and run INT8
|
||||
|
56
tensorflow/python/compiler/tensorrt/trt_convert_windows.py
Normal file
56
tensorflow/python/compiler/tensorrt/trt_convert_windows.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Exposes the TRT conversion for Windows platform."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import platform
|
||||
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
if platform.system() != "Windows":
|
||||
raise RuntimeError(
|
||||
"This module is expected to be loaded only on Windows platform.")
|
||||
|
||||
|
||||
@tf_export("experimental.tensorrt.Converter", v1=[])
|
||||
class TrtConverterWindows(object):
|
||||
"""An offline converter for TF-TRT transformation for TF 2.0 SavedModels.
|
||||
|
||||
Currently this is not available on Windows platform.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_saved_model_dir=None,
|
||||
input_saved_model_tags=None,
|
||||
input_saved_model_signature_key=None,
|
||||
conversion_params=None):
|
||||
"""Initialize the converter.
|
||||
|
||||
Args:
|
||||
input_saved_model_dir: the directory to load the SavedModel which contains
|
||||
the input graph to transforms. Used only when input_graph_def is None.
|
||||
input_saved_model_tags: list of tags to load the SavedModel.
|
||||
input_saved_model_signature_key: the key of the signature to optimize the
|
||||
graph for.
|
||||
conversion_params: a TrtConversionParams instance.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: TRT is not supported on Windows.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"TensorRT integration is not available on Windows.")
|
@ -20,6 +20,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import platform as _platform
|
||||
import sys as _sys
|
||||
|
||||
from tensorflow.python import autograph
|
||||
@ -109,6 +110,13 @@ from tensorflow.python.ops.variable_scope import *
|
||||
from tensorflow.python.ops.variables import *
|
||||
from tensorflow.python.ops.parallel_for.control_flow_ops import vectorized_map
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if _platform.system() == "Windows":
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert_windows as trt
|
||||
else:
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert as trt
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
# pylint: enable=wildcard-import
|
||||
# pylint: enable=g-bad-import-order
|
||||
|
||||
|
@ -24,6 +24,7 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"dtypes/__init__.py",
|
||||
"errors/__init__.py",
|
||||
"experimental/__init__.py",
|
||||
"experimental/tensorrt/__init__.py",
|
||||
"feature_column/__init__.py",
|
||||
"io/gfile/__init__.py",
|
||||
"graph_util/__init__.py",
|
||||
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "tensorrt"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "function_executor_type"
|
||||
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -0,0 +1,21 @@
|
||||
path: "tensorflow.experimental.tensorrt.Converter"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.compiler.tensorrt.trt_convert.TrtGraphConverterV2\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'input_saved_model_dir\', \'input_saved_model_tags\', \'input_saved_model_signature_key\', \'conversion_params\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"TrtConversionParams(rewriter_config_template=None, max_workspace_size_bytes=1073741824, precision_mode=\'FP32\', minimum_segment_size=3, is_dynamic_op=True, maximum_cached_engines=1, use_calibration=True, max_batch_size=1)\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'input_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "convert"
|
||||
argspec: "args=[\'self\', \'calibration_input_fn\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'output_saved_model_dir\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.experimental.tensorrt"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Converter"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user