[TF] Add TensorRT version to tf.sysconfig.get_build_info.
Add is_tensorrt_build and tensorrt_version to build_info. PiperOrigin-RevId: 341663396 Change-Id: I239d7d52c388767932716abd501061d31324ff10
This commit is contained in:
parent
9779f86969
commit
fa595eb8fa
@ -201,6 +201,7 @@ tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt/tensorrt_config.py.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/tf_toolchains.BUILD
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
|
||||
from tensorflow.python.platform import build_info
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -29,12 +30,15 @@ class BuildInfoTest(test.TestCase):
|
||||
test.is_built_with_rocm())
|
||||
self.assertEqual(build_info.build_info['is_cuda_build'],
|
||||
test.is_built_with_cuda())
|
||||
self.assertEqual(build_info.build_info['is_tensorrt_build'],
|
||||
is_tensorrt_enabled())
|
||||
|
||||
def testDeterministicOrder(self):
|
||||
# The dict may contain other keys depending on the platform, but the ones
|
||||
# it always contains should be in order.
|
||||
self.assertContainsSubsequence(build_info.build_info.keys(),
|
||||
('is_cuda_build', 'is_rocm_build'))
|
||||
self.assertContainsSubsequence(
|
||||
build_info.build_info.keys(),
|
||||
('is_cuda_build', 'is_rocm_build', 'is_tensorrt_build'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -2457,6 +2457,7 @@ def tf_py_build_info_genrule(name, out):
|
||||
" --key_value" +
|
||||
" is_rocm_build=" + if_rocm("True", "False") +
|
||||
" is_cuda_build=" + if_cuda("True", "False") +
|
||||
" is_tensorrt_build=" + if_tensorrt("True", "False") +
|
||||
if_windows(_dict_to_kv({
|
||||
"msvcp_dll_names": "msvcp140.dll,msvcp140_1.dll",
|
||||
}), "") + if_windows_cuda(_dict_to_kv({
|
||||
|
@ -15,6 +15,7 @@ py_binary(
|
||||
tags = ["no-remote-exec"],
|
||||
deps = [
|
||||
"@local_config_cuda//cuda:cuda_config_py",
|
||||
"@local_config_tensorrt//:tensorrt_config_py",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -28,6 +28,12 @@ try:
|
||||
except ImportError:
|
||||
cuda_config = None
|
||||
|
||||
# tensorrt.tensorrt is only valid in OSS
|
||||
try:
|
||||
from tensorrt.tensorrt import tensorrt_config # pylint: disable=g-import-not-at-top
|
||||
except ImportError:
|
||||
tensorrt_config = None
|
||||
|
||||
|
||||
def write_build_info(filename, key_value_list):
|
||||
"""Writes a Python that describes the build.
|
||||
@ -43,6 +49,9 @@ def write_build_info(filename, key_value_list):
|
||||
if cuda_config:
|
||||
build_info.update(cuda_config.config)
|
||||
|
||||
if tensorrt_config:
|
||||
build_info.update(tensorrt_config.config)
|
||||
|
||||
for arg in key_value_list:
|
||||
key, value = six.ensure_str(arg).split("=")
|
||||
if value.lower() == "true":
|
||||
|
5
third_party/tensorrt/BUILD.tpl
vendored
5
third_party/tensorrt/BUILD.tpl
vendored
@ -40,4 +40,9 @@ bzl_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tensorrt_config_py",
|
||||
srcs = ["tensorrt/tensorrt_config.py"]
|
||||
)
|
||||
|
||||
%{copy_rules}
|
||||
|
17
third_party/tensorrt/tensorrt/tensorrt_config.py.tpl
vendored
Normal file
17
third_party/tensorrt/tensorrt/tensorrt_config.py.tpl
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
config = %{tensorrt_config}
|
27
third_party/tensorrt/tensorrt_configure.bzl
vendored
27
third_party/tensorrt/tensorrt_configure.bzl
vendored
@ -79,6 +79,14 @@ def _create_dummy_repository(repository_ctx):
|
||||
{},
|
||||
)
|
||||
|
||||
# Set up tensorrt_config.py, which is used by gen_build_info to provide
|
||||
# build environment info to the API
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"tensorrt/tensorrt_config.py",
|
||||
_py_tmpl_dict({}),
|
||||
)
|
||||
|
||||
def enable_tensorrt(repository_ctx):
|
||||
"""Returns whether to build with TensorRT support."""
|
||||
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
|
||||
@ -93,6 +101,7 @@ def _create_local_tensorrt_repository(repository_ctx):
|
||||
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
|
||||
"BUILD": _tpl_path(repository_ctx, "BUILD"),
|
||||
"tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"),
|
||||
"tensorrt/tensorrt_config.py": _tpl_path(repository_ctx, "tensorrt/tensorrt_config.py"),
|
||||
}
|
||||
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"])
|
||||
@ -148,6 +157,19 @@ def _create_local_tensorrt_repository(repository_ctx):
|
||||
{"%{tensorrt_version}": trt_version},
|
||||
)
|
||||
|
||||
# Set up tensorrt_config.py, which is used by gen_build_info to provide
|
||||
# build environment info to the API
|
||||
repository_ctx.template(
|
||||
"tensorrt/tensorrt_config.py",
|
||||
tpl_paths["tensorrt/tensorrt_config.py"],
|
||||
_py_tmpl_dict({
|
||||
"tensorrt_version": trt_version,
|
||||
}),
|
||||
)
|
||||
|
||||
def _py_tmpl_dict(d):
|
||||
return {"%{tensorrt_config}": str(d)}
|
||||
|
||||
def _tensorrt_configure_impl(repository_ctx):
|
||||
"""Implementation of the tensorrt_configure repository rule."""
|
||||
|
||||
@ -165,6 +187,11 @@ def _tensorrt_configure_impl(repository_ctx):
|
||||
config_repo_label(remote_config_repo, ":tensorrt/include/tensorrt_config.h"),
|
||||
{},
|
||||
)
|
||||
repository_ctx.template(
|
||||
"tensorrt/tensorrt_config.py",
|
||||
config_repo_label(remote_config_repo, ":tensorrt/tensorrt_config.py"),
|
||||
{},
|
||||
)
|
||||
repository_ctx.template(
|
||||
"LICENSE",
|
||||
config_repo_label(remote_config_repo, ":LICENSE"),
|
||||
|
Loading…
Reference in New Issue
Block a user