[TF] Add TensorRT version to tf.sysconfig.get_build_info.

Add is_tensorrt_build and tensorrt_version to build_info.

PiperOrigin-RevId: 341663396
Change-Id: I239d7d52c388767932716abd501061d31324ff10
This commit is contained in:
Bixia Zheng 2020-11-10 11:38:31 -08:00 committed by TensorFlower Gardener
parent 9779f86969
commit fa595eb8fa
8 changed files with 67 additions and 2 deletions

View File

@ -201,6 +201,7 @@ tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
tensorflow/third_party/tensorrt/tensorrt/tensorrt_config.py.tpl
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/tf_toolchains.BUILD

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled
from tensorflow.python.platform import build_info
from tensorflow.python.platform import test
@ -29,12 +30,15 @@ class BuildInfoTest(test.TestCase):
test.is_built_with_rocm())
self.assertEqual(build_info.build_info['is_cuda_build'],
test.is_built_with_cuda())
self.assertEqual(build_info.build_info['is_tensorrt_build'],
is_tensorrt_enabled())
def testDeterministicOrder(self):
# The dict may contain other keys depending on the platform, but the ones
# it always contains should be in order.
self.assertContainsSubsequence(build_info.build_info.keys(),
('is_cuda_build', 'is_rocm_build'))
self.assertContainsSubsequence(
build_info.build_info.keys(),
('is_cuda_build', 'is_rocm_build', 'is_tensorrt_build'))
if __name__ == '__main__':

View File

@ -2457,6 +2457,7 @@ def tf_py_build_info_genrule(name, out):
" --key_value" +
" is_rocm_build=" + if_rocm("True", "False") +
" is_cuda_build=" + if_cuda("True", "False") +
" is_tensorrt_build=" + if_tensorrt("True", "False") +
if_windows(_dict_to_kv({
"msvcp_dll_names": "msvcp140.dll,msvcp140_1.dll",
}), "") + if_windows_cuda(_dict_to_kv({

View File

@ -15,6 +15,7 @@ py_binary(
tags = ["no-remote-exec"],
deps = [
"@local_config_cuda//cuda:cuda_config_py",
"@local_config_tensorrt//:tensorrt_config_py",
"@six_archive//:six",
],
)

View File

@ -28,6 +28,12 @@ try:
except ImportError:
cuda_config = None
# tensorrt.tensorrt is only valid in OSS
try:
from tensorrt.tensorrt import tensorrt_config # pylint: disable=g-import-not-at-top
except ImportError:
tensorrt_config = None
def write_build_info(filename, key_value_list):
"""Writes a Python that describes the build.
@ -43,6 +49,9 @@ def write_build_info(filename, key_value_list):
if cuda_config:
build_info.update(cuda_config.config)
if tensorrt_config:
build_info.update(tensorrt_config.config)
for arg in key_value_list:
key, value = six.ensure_str(arg).split("=")
if value.lower() == "true":

View File

@ -40,4 +40,9 @@ bzl_library(
],
)
py_library(
name = "tensorrt_config_py",
srcs = ["tensorrt/tensorrt_config.py"]
)
%{copy_rules}

View File

@ -0,0 +1,17 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
config = %{tensorrt_config}

View File

@ -79,6 +79,14 @@ def _create_dummy_repository(repository_ctx):
{},
)
# Set up tensorrt_config.py, which is used by gen_build_info to provide
# build environment info to the API
_tpl(
repository_ctx,
"tensorrt/tensorrt_config.py",
_py_tmpl_dict({}),
)
def enable_tensorrt(repository_ctx):
"""Returns whether to build with TensorRT support."""
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
@ -93,6 +101,7 @@ def _create_local_tensorrt_repository(repository_ctx):
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
"BUILD": _tpl_path(repository_ctx, "BUILD"),
"tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"),
"tensorrt/tensorrt_config.py": _tpl_path(repository_ctx, "tensorrt/tensorrt_config.py"),
}
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"])
@ -148,6 +157,19 @@ def _create_local_tensorrt_repository(repository_ctx):
{"%{tensorrt_version}": trt_version},
)
# Set up tensorrt_config.py, which is used by gen_build_info to provide
# build environment info to the API
repository_ctx.template(
"tensorrt/tensorrt_config.py",
tpl_paths["tensorrt/tensorrt_config.py"],
_py_tmpl_dict({
"tensorrt_version": trt_version,
}),
)
def _py_tmpl_dict(d):
return {"%{tensorrt_config}": str(d)}
def _tensorrt_configure_impl(repository_ctx):
"""Implementation of the tensorrt_configure repository rule."""
@ -165,6 +187,11 @@ def _tensorrt_configure_impl(repository_ctx):
config_repo_label(remote_config_repo, ":tensorrt/include/tensorrt_config.h"),
{},
)
repository_ctx.template(
"tensorrt/tensorrt_config.py",
config_repo_label(remote_config_repo, ":tensorrt/tensorrt_config.py"),
{},
)
repository_ctx.template(
"LICENSE",
config_repo_label(remote_config_repo, ":LICENSE"),