tensorrt_configure: fix quadratic runtime due to label resolution

PiperOrigin-RevId: 293788006
Change-Id: Icdbcb244fdbf51736ef9a0343124196a52f3b1f3
This commit is contained in:
A. Unique TensorFlower 2020-02-07 04:22:02 -08:00 committed by TensorFlower Gardener
parent 543a87593e
commit a3fddc818e

View File

@ -45,10 +45,13 @@ def _get_tensorrt_headers(tensorrt_version):
return _TF_TENSORRT_HEADERS_V6
return _TF_TENSORRT_HEADERS
def _tpl_path(repository_ctx, filename):
return repository_ctx.path(Label("//third_party/tensorrt:%s.tpl" % filename))
def _tpl(repository_ctx, tpl, substitutions):
repository_ctx.template(
tpl,
Label("//third_party/tensorrt:%s.tpl" % tpl),
_tpl_path(repository_ctx, tpl),
substitutions,
)
@ -71,12 +74,6 @@ def enable_tensorrt(repository_ctx):
def _tensorrt_configure_impl(repository_ctx):
"""Implementation of the tensorrt_configure repository rule."""
# Resolve all labels before doing any real work. Resolving causes the
# function to be restarted with all previous state being lost. This
# can easily lead to a O(n^2) runtime in the number of labels.
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ:
# Forward to the pre-configured remote repository.
remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
@ -109,6 +106,17 @@ def _tensorrt_configure_impl(repository_ctx):
_create_dummy_repository(repository_ctx)
return
# Resolve all labels before doing any real work. Resolving causes the
# function to be restarted with all previous state being lost. This
# can easily lead to a O(n^2) runtime in the number of labels.
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
tpl_paths = {
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
"BUILD": _tpl_path(repository_ctx, "BUILD"),
"tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"),
}
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"])
trt_version = config["tensorrt_version"]
cpu_value = get_cpu_value(repository_ctx)
@ -134,18 +142,26 @@ def _tensorrt_configure_impl(repository_ctx):
]
# Set up config file.
_tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"})
repository_ctx.template(
"build_defs.bzl",
tpl_paths["build_defs.bzl"],
{"%{if_tensorrt}": "if_true"},
)
# Set up BUILD file.
_tpl(repository_ctx, "BUILD", {
"%{copy_rules}": "\n".join(copy_rules),
})
repository_ctx.template(
"BUILD",
tpl_paths["BUILD"],
{"%{copy_rules}": "\n".join(copy_rules)},
)
# Set up tensorrt_config.h, which is used by
# tensorflow/stream_executor/dso_loader.cc.
_tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", {
"%{tensorrt_version}": trt_version,
})
repository_ctx.template(
"tensorrt/include/tensorrt_config.h",
tpl_paths["tensorrt/include/tensorrt_config.h"],
{"%{tensorrt_version}": trt_version},
)
tensorrt_configure = repository_rule(
implementation = _tensorrt_configure_impl,