[TF] Add TensorRT version to tf.sysconfig.get_build_info.
Add is_tensorrt_build and tensorrt_version to build_info. PiperOrigin-RevId: 341663396 Change-Id: I239d7d52c388767932716abd501061d31324ff10
This commit is contained in:
parent
9779f86969
commit
fa595eb8fa
@ -201,6 +201,7 @@ tensorflow/third_party/tensorrt/BUILD.tpl
|
|||||||
tensorflow/third_party/tensorrt/LICENSE
|
tensorflow/third_party/tensorrt/LICENSE
|
||||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||||
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
|
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
|
||||||
|
tensorflow/third_party/tensorrt/tensorrt/tensorrt_config.py.tpl
|
||||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||||
tensorflow/third_party/termcolor.BUILD
|
tensorflow/third_party/termcolor.BUILD
|
||||||
tensorflow/third_party/tf_toolchains.BUILD
|
tensorflow/third_party/tf_toolchains.BUILD
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
|
||||||
from tensorflow.python.platform import build_info
|
from tensorflow.python.platform import build_info
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -29,12 +30,15 @@ class BuildInfoTest(test.TestCase):
|
|||||||
test.is_built_with_rocm())
|
test.is_built_with_rocm())
|
||||||
self.assertEqual(build_info.build_info['is_cuda_build'],
|
self.assertEqual(build_info.build_info['is_cuda_build'],
|
||||||
test.is_built_with_cuda())
|
test.is_built_with_cuda())
|
||||||
|
self.assertEqual(build_info.build_info['is_tensorrt_build'],
|
||||||
|
is_tensorrt_enabled())
|
||||||
|
|
||||||
def testDeterministicOrder(self):
|
def testDeterministicOrder(self):
|
||||||
# The dict may contain other keys depending on the platform, but the ones
|
# The dict may contain other keys depending on the platform, but the ones
|
||||||
# it always contains should be in order.
|
# it always contains should be in order.
|
||||||
self.assertContainsSubsequence(build_info.build_info.keys(),
|
self.assertContainsSubsequence(
|
||||||
('is_cuda_build', 'is_rocm_build'))
|
build_info.build_info.keys(),
|
||||||
|
('is_cuda_build', 'is_rocm_build', 'is_tensorrt_build'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -2457,6 +2457,7 @@ def tf_py_build_info_genrule(name, out):
|
|||||||
" --key_value" +
|
" --key_value" +
|
||||||
" is_rocm_build=" + if_rocm("True", "False") +
|
" is_rocm_build=" + if_rocm("True", "False") +
|
||||||
" is_cuda_build=" + if_cuda("True", "False") +
|
" is_cuda_build=" + if_cuda("True", "False") +
|
||||||
|
" is_tensorrt_build=" + if_tensorrt("True", "False") +
|
||||||
if_windows(_dict_to_kv({
|
if_windows(_dict_to_kv({
|
||||||
"msvcp_dll_names": "msvcp140.dll,msvcp140_1.dll",
|
"msvcp_dll_names": "msvcp140.dll,msvcp140_1.dll",
|
||||||
}), "") + if_windows_cuda(_dict_to_kv({
|
}), "") + if_windows_cuda(_dict_to_kv({
|
||||||
|
@ -15,6 +15,7 @@ py_binary(
|
|||||||
tags = ["no-remote-exec"],
|
tags = ["no-remote-exec"],
|
||||||
deps = [
|
deps = [
|
||||||
"@local_config_cuda//cuda:cuda_config_py",
|
"@local_config_cuda//cuda:cuda_config_py",
|
||||||
|
"@local_config_tensorrt//:tensorrt_config_py",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -28,6 +28,12 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
cuda_config = None
|
cuda_config = None
|
||||||
|
|
||||||
|
# tensorrt.tensorrt is only valid in OSS
|
||||||
|
try:
|
||||||
|
from tensorrt.tensorrt import tensorrt_config # pylint: disable=g-import-not-at-top
|
||||||
|
except ImportError:
|
||||||
|
tensorrt_config = None
|
||||||
|
|
||||||
|
|
||||||
def write_build_info(filename, key_value_list):
|
def write_build_info(filename, key_value_list):
|
||||||
"""Writes a Python that describes the build.
|
"""Writes a Python that describes the build.
|
||||||
@ -43,6 +49,9 @@ def write_build_info(filename, key_value_list):
|
|||||||
if cuda_config:
|
if cuda_config:
|
||||||
build_info.update(cuda_config.config)
|
build_info.update(cuda_config.config)
|
||||||
|
|
||||||
|
if tensorrt_config:
|
||||||
|
build_info.update(tensorrt_config.config)
|
||||||
|
|
||||||
for arg in key_value_list:
|
for arg in key_value_list:
|
||||||
key, value = six.ensure_str(arg).split("=")
|
key, value = six.ensure_str(arg).split("=")
|
||||||
if value.lower() == "true":
|
if value.lower() == "true":
|
||||||
|
5
third_party/tensorrt/BUILD.tpl
vendored
5
third_party/tensorrt/BUILD.tpl
vendored
@ -40,4 +40,9 @@ bzl_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "tensorrt_config_py",
|
||||||
|
srcs = ["tensorrt/tensorrt_config.py"]
|
||||||
|
)
|
||||||
|
|
||||||
%{copy_rules}
|
%{copy_rules}
|
||||||
|
17
third_party/tensorrt/tensorrt/tensorrt_config.py.tpl
vendored
Normal file
17
third_party/tensorrt/tensorrt/tensorrt_config.py.tpl
vendored
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
config = %{tensorrt_config}
|
27
third_party/tensorrt/tensorrt_configure.bzl
vendored
27
third_party/tensorrt/tensorrt_configure.bzl
vendored
@ -79,6 +79,14 @@ def _create_dummy_repository(repository_ctx):
|
|||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set up tensorrt_config.py, which is used by gen_build_info to provide
|
||||||
|
# build environment info to the API
|
||||||
|
_tpl(
|
||||||
|
repository_ctx,
|
||||||
|
"tensorrt/tensorrt_config.py",
|
||||||
|
_py_tmpl_dict({}),
|
||||||
|
)
|
||||||
|
|
||||||
def enable_tensorrt(repository_ctx):
|
def enable_tensorrt(repository_ctx):
|
||||||
"""Returns whether to build with TensorRT support."""
|
"""Returns whether to build with TensorRT support."""
|
||||||
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
|
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
|
||||||
@ -93,6 +101,7 @@ def _create_local_tensorrt_repository(repository_ctx):
|
|||||||
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
|
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
|
||||||
"BUILD": _tpl_path(repository_ctx, "BUILD"),
|
"BUILD": _tpl_path(repository_ctx, "BUILD"),
|
||||||
"tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"),
|
"tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"),
|
||||||
|
"tensorrt/tensorrt_config.py": _tpl_path(repository_ctx, "tensorrt/tensorrt_config.py"),
|
||||||
}
|
}
|
||||||
|
|
||||||
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"])
|
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"])
|
||||||
@ -148,6 +157,19 @@ def _create_local_tensorrt_repository(repository_ctx):
|
|||||||
{"%{tensorrt_version}": trt_version},
|
{"%{tensorrt_version}": trt_version},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set up tensorrt_config.py, which is used by gen_build_info to provide
|
||||||
|
# build environment info to the API
|
||||||
|
repository_ctx.template(
|
||||||
|
"tensorrt/tensorrt_config.py",
|
||||||
|
tpl_paths["tensorrt/tensorrt_config.py"],
|
||||||
|
_py_tmpl_dict({
|
||||||
|
"tensorrt_version": trt_version,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _py_tmpl_dict(d):
|
||||||
|
return {"%{tensorrt_config}": str(d)}
|
||||||
|
|
||||||
def _tensorrt_configure_impl(repository_ctx):
|
def _tensorrt_configure_impl(repository_ctx):
|
||||||
"""Implementation of the tensorrt_configure repository rule."""
|
"""Implementation of the tensorrt_configure repository rule."""
|
||||||
|
|
||||||
@ -165,6 +187,11 @@ def _tensorrt_configure_impl(repository_ctx):
|
|||||||
config_repo_label(remote_config_repo, ":tensorrt/include/tensorrt_config.h"),
|
config_repo_label(remote_config_repo, ":tensorrt/include/tensorrt_config.h"),
|
||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
|
repository_ctx.template(
|
||||||
|
"tensorrt/tensorrt_config.py",
|
||||||
|
config_repo_label(remote_config_repo, ":tensorrt/tensorrt_config.py"),
|
||||||
|
{},
|
||||||
|
)
|
||||||
repository_ctx.template(
|
repository_ctx.template(
|
||||||
"LICENSE",
|
"LICENSE",
|
||||||
config_repo_label(remote_config_repo, ":LICENSE"),
|
config_repo_label(remote_config_repo, ":LICENSE"),
|
||||||
|
Loading…
Reference in New Issue
Block a user