tensorrt_configure: fix quadratic runtime due to label resolution
PiperOrigin-RevId: 293788006 Change-Id: Icdbcb244fdbf51736ef9a0343124196a52f3b1f3
This commit is contained in:
parent
543a87593e
commit
a3fddc818e
44
third_party/tensorrt/tensorrt_configure.bzl
vendored
44
third_party/tensorrt/tensorrt_configure.bzl
vendored
@ -45,10 +45,13 @@ def _get_tensorrt_headers(tensorrt_version):
|
||||
return _TF_TENSORRT_HEADERS_V6
|
||||
return _TF_TENSORRT_HEADERS
|
||||
|
||||
def _tpl_path(repository_ctx, filename):
|
||||
return repository_ctx.path(Label("//third_party/tensorrt:%s.tpl" % filename))
|
||||
|
||||
def _tpl(repository_ctx, tpl, substitutions):
|
||||
repository_ctx.template(
|
||||
tpl,
|
||||
Label("//third_party/tensorrt:%s.tpl" % tpl),
|
||||
_tpl_path(repository_ctx, tpl),
|
||||
substitutions,
|
||||
)
|
||||
|
||||
@ -71,12 +74,6 @@ def enable_tensorrt(repository_ctx):
|
||||
def _tensorrt_configure_impl(repository_ctx):
|
||||
"""Implementation of the tensorrt_configure repository rule."""
|
||||
|
||||
# Resolve all labels before doing any real work. Resolving causes the
|
||||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
|
||||
|
||||
if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ:
|
||||
# Forward to the pre-configured remote repository.
|
||||
remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
|
||||
@ -109,6 +106,17 @@ def _tensorrt_configure_impl(repository_ctx):
|
||||
_create_dummy_repository(repository_ctx)
|
||||
return
|
||||
|
||||
# Resolve all labels before doing any real work. Resolving causes the
|
||||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
|
||||
tpl_paths = {
|
||||
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
|
||||
"BUILD": _tpl_path(repository_ctx, "BUILD"),
|
||||
"tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"),
|
||||
}
|
||||
|
||||
config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"])
|
||||
trt_version = config["tensorrt_version"]
|
||||
cpu_value = get_cpu_value(repository_ctx)
|
||||
@ -134,18 +142,26 @@ def _tensorrt_configure_impl(repository_ctx):
|
||||
]
|
||||
|
||||
# Set up config file.
|
||||
_tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"})
|
||||
repository_ctx.template(
|
||||
"build_defs.bzl",
|
||||
tpl_paths["build_defs.bzl"],
|
||||
{"%{if_tensorrt}": "if_true"},
|
||||
)
|
||||
|
||||
# Set up BUILD file.
|
||||
_tpl(repository_ctx, "BUILD", {
|
||||
"%{copy_rules}": "\n".join(copy_rules),
|
||||
})
|
||||
repository_ctx.template(
|
||||
"BUILD",
|
||||
tpl_paths["BUILD"],
|
||||
{"%{copy_rules}": "\n".join(copy_rules)},
|
||||
)
|
||||
|
||||
# Set up tensorrt_config.h, which is used by
|
||||
# tensorflow/stream_executor/dso_loader.cc.
|
||||
_tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", {
|
||||
"%{tensorrt_version}": trt_version,
|
||||
})
|
||||
repository_ctx.template(
|
||||
"tensorrt/include/tensorrt_config.h",
|
||||
tpl_paths["tensorrt/include/tensorrt_config.h"],
|
||||
{"%{tensorrt_version}": trt_version},
|
||||
)
|
||||
|
||||
tensorrt_configure = repository_rule(
|
||||
implementation = _tensorrt_configure_impl,
|
||||
|
Loading…
Reference in New Issue
Block a user