Dynamically load TensorRT library. To achieve that, we:
1. link all TF-TRT targets against :tensorrt_stub 2. link TF-TRT ops/kernels to core 3. remove the tftrt.so shared library and corresponding loader (trt_ops.py) PiperOrigin-RevId: 253644955
This commit is contained in:
parent
fbe16a982d
commit
408949d3e1
@ -24,6 +24,12 @@ load(
|
||||
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
|
||||
# Placeholder for Google-internal load statements.
|
||||
|
||||
# NOTE: we always assume that if_static returns "otherwise" list in open source.
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"if_static",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -45,6 +51,15 @@ cc_library(
|
||||
]),
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "tensorrt_lib",
|
||||
actual = if_static(
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
":tensorrt_stub",
|
||||
),
|
||||
visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "tensorrt_test_cc",
|
||||
size = "small",
|
||||
@ -60,8 +75,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
] + if_tensorrt([
|
||||
":tensorrt_lib",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
)
|
||||
|
||||
@ -86,9 +101,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:stream_executor_headers_lib",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -96,7 +109,7 @@ cc_library(
|
||||
name = "trt_engine_resource_op_kernels",
|
||||
srcs = ["kernels/trt_engine_resource_ops.cc"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:private"],
|
||||
visibility = ["//tensorflow/core:__subpackages__"],
|
||||
deps = [
|
||||
":trt_allocator",
|
||||
":trt_engine_instance_proto_cc",
|
||||
@ -110,9 +123,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -144,21 +155,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
name = "python/ops/libtftrt.so",
|
||||
copts = tf_copts(is_external = True),
|
||||
linkopts = ["-lm"],
|
||||
deps = [
|
||||
":trt_op_kernels",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "trt_engine_op_test",
|
||||
size = "small",
|
||||
@ -211,9 +207,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
@ -226,18 +220,6 @@ tf_gen_op_wrapper_py(
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "trt_ops_loader",
|
||||
srcs = ["python/ops/trt_ops.py"],
|
||||
dso = [
|
||||
"python/ops/libtftrt.so",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
kernels = [
|
||||
":trt_op_kernels",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":trt_ops",
|
||||
@ -268,9 +250,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
@ -281,9 +261,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -352,9 +330,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/grappler/clusters:virtual_cluster",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -387,9 +363,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
@ -423,8 +397,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
] + if_tensorrt([
|
||||
":tensorrt_lib",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
)
|
||||
|
||||
@ -477,9 +451,7 @@ tf_cuda_library(
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -505,9 +477,7 @@ cc_library(
|
||||
srcs = ["utils/py_utils.cc"],
|
||||
hdrs = ["utils/py_utils.h"],
|
||||
copts = tf_copts(),
|
||||
deps = if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
deps = if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
tf_py_wrap_cc(
|
||||
|
@ -1,75 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Exposes the Python wrapper of TRTEngineOp."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import platform
|
||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
_tf_trt_so = None
|
||||
_module_lock = threading.Lock()
|
||||
|
||||
|
||||
def load_trt_ops():
|
||||
"""Load TF-TRT op libraries so if it hasn't been loaded already."""
|
||||
global _tf_trt_so
|
||||
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
if platform.system() == "Windows":
|
||||
raise RuntimeError("Windows platforms are not supported")
|
||||
|
||||
with _module_lock:
|
||||
if _tf_trt_so:
|
||||
return
|
||||
|
||||
try:
|
||||
# pylint: disable=g-import-not-at-top,unused-variable
|
||||
# This will call register_op_list() in
|
||||
# tensorflow/python/framework/op_def_registry.py, but it doesn't register
|
||||
# the op or the op kernel in C++ runtime.
|
||||
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
|
||||
# pylint: enable=g-import-not-at-top,unused-variable
|
||||
except ImportError as e:
|
||||
print("**** Failed to import TF-TRT ops. This is because the binary was "
|
||||
"not built with CUDA or TensorRT enabled. ****")
|
||||
raise e
|
||||
|
||||
try:
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.platform import resource_loader
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
# Loading the shared object will cause registration of the op and the op
|
||||
# kernel if we link TF-TRT dynamically.
|
||||
_tf_trt_so = load_library.load_op_library(
|
||||
resource_loader.get_path_to_datafile("libtftrt.so"))
|
||||
except errors.NotFoundError as e:
|
||||
no_trt_message = (
|
||||
"**** Failed to initialize TensorRT. This is either because the "
|
||||
"TensorRT installation path is not in LD_LIBRARY_PATH, or because "
|
||||
"you do not have it installed. If not installed, please go to "
|
||||
"https://developer.nvidia.com/tensorrt to download and install "
|
||||
"TensorRT ****")
|
||||
print(no_trt_message)
|
||||
raise e
|
@ -170,6 +170,7 @@ load(
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
@ -1438,6 +1439,9 @@ cc_library(
|
||||
] + if_mkl([
|
||||
":mkl_array_ops_op_lib",
|
||||
":mkl_nn_ops_op_lib",
|
||||
]) + if_tensorrt([
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_ops_op_lib",
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_op_libs",
|
||||
]) + tf_additional_cloud_op_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -1631,6 +1635,9 @@ cc_library(
|
||||
"//tensorflow/core/grappler/optimizers:gpu_swapping_ops",
|
||||
]) + if_nccl([
|
||||
"//tensorflow/core/kernels:nccl_kernels",
|
||||
]) + if_tensorrt([
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -15,6 +15,10 @@ load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl",
|
||||
)
|
||||
load(
|
||||
"@local_config_tensorrt//:build_defs.bzl",
|
||||
"if_tensorrt",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
@ -43,7 +47,7 @@ cc_library(
|
||||
name = "excluded_ops_lib",
|
||||
srcs = ["excluded_ops.cc"],
|
||||
hdrs = ["excluded_ops.h"],
|
||||
copts = if_mkl(["-DINTEL_MKL=1"]),
|
||||
copts = if_mkl(["-DINTEL_MKL=1"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -40,8 +40,14 @@ const std::unordered_set<std::string>* GetExcludedOps() {
|
||||
"QuantizedMatMulWithBias"
|
||||
"QuantizedMatMulWithBiasAndRelu"
|
||||
"QuantizedMatMulWithBiasAndReluAndRequantize",
|
||||
|
||||
#endif // INTEL_MKL
|
||||
#ifdef GOOGLE_TENSORRT
|
||||
"CreateTRTEngineCache",
|
||||
"PopulateTRTEngineCache",
|
||||
"DumpTRTEngineCache",
|
||||
"GetCalibrationDataOp",
|
||||
"TRTEngineOp",
|
||||
#endif // GOOGLE_TENSORRT
|
||||
});
|
||||
return excluded_ops;
|
||||
}
|
||||
|
@ -20,13 +20,12 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import platform
|
||||
import tempfile
|
||||
|
||||
import six as _six
|
||||
|
||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_linked_tensorrt_version
|
||||
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_loaded_tensorrt_version
|
||||
from tensorflow.compiler.tf2tensorrt import wrap_py_utils
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
@ -53,15 +52,6 @@ from tensorflow.python.training import saver
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
|
||||
# Import TRT library. This is fine since we don't import TF-TRT in
|
||||
# tensorflow/python/compiler/__init__.py, and `import tensorflow` won't trigger
|
||||
# importing of TF-TRT. Note that TF-TRT is still included in GPU build since
|
||||
# tensorflow/python/BUILD depends on it.
|
||||
#
|
||||
# We need this import so that when users import this module, they can execute a
|
||||
# TRT-converted graph without calling any of the methods in this module.
|
||||
trt_ops.load_trt_ops()
|
||||
|
||||
# Lazily load the op, since it's not available in cpu-only builds. Importing
|
||||
# this at top will cause tests that imports TF-TRT fail when they're built
|
||||
# and run without CUDA/GPU.
|
||||
@ -69,6 +59,18 @@ gen_trt_ops = LazyLoader(
|
||||
"gen_trt_ops", globals(),
|
||||
"tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
|
||||
|
||||
# Register TRT ops in python, so that when users import this module they can
|
||||
# execute a TRT-converted graph without calling any of the methods in this
|
||||
# module.
|
||||
if wrap_py_utils.is_tensorrt_enabled():
|
||||
if platform.system() == "Windows":
|
||||
raise RuntimeError("Windows platform is not supported")
|
||||
|
||||
# This will call register_op_list() in
|
||||
# tensorflow/python/framework/op_def_registry.py, but it doesn't register
|
||||
# the op or the op kernel in C++ runtime.
|
||||
gen_trt_ops.trt_engine_op # pylint: disable=pointless-statement
|
||||
|
||||
|
||||
def _to_bytes(s):
|
||||
"""Encode s if it is a sequence of chars."""
|
||||
@ -567,8 +569,8 @@ def _check_trt_version_compatibility():
|
||||
Raises:
|
||||
RuntimeError: if the TensorRT library version is incompatible.
|
||||
"""
|
||||
compiled_version = get_linked_tensorrt_version()
|
||||
loaded_version = get_loaded_tensorrt_version()
|
||||
compiled_version = wrap_py_utils.get_linked_tensorrt_version()
|
||||
loaded_version = wrap_py_utils.get_loaded_tensorrt_version()
|
||||
tf_logging.info("Linked TensorRT version: %s" % str(compiled_version))
|
||||
tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
|
||||
version_mismatch = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user