diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl index c3f67879225..4855ea0e50e 100644 --- a/third_party/tensorrt/tensorrt_configure.bzl +++ b/third_party/tensorrt/tensorrt_configure.bzl @@ -45,10 +45,13 @@ def _get_tensorrt_headers(tensorrt_version): return _TF_TENSORRT_HEADERS_V6 return _TF_TENSORRT_HEADERS +def _tpl_path(repository_ctx, filename): + return repository_ctx.path(Label("//third_party/tensorrt:%s.tpl" % filename)) + def _tpl(repository_ctx, tpl, substitutions): repository_ctx.template( tpl, - Label("//third_party/tensorrt:%s.tpl" % tpl), + _tpl_path(repository_ctx, tpl), substitutions, ) @@ -71,12 +74,6 @@ def enable_tensorrt(repository_ctx): def _tensorrt_configure_impl(repository_ctx): """Implementation of the tensorrt_configure repository rule.""" - # Resolve all labels before doing any real work. Resolving causes the - # function to be restarted with all previous state being lost. This - # can easily lead to a O(n^2) runtime in the number of labels. - # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 - find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py")) - if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ: # Forward to the pre-configured remote repository. remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO] @@ -109,6 +106,17 @@ def _tensorrt_configure_impl(repository_ctx): _create_dummy_repository(repository_ctx) return + # Resolve all labels before doing any real work. Resolving causes the + # function to be restarted with all previous state being lost. This + # can easily lead to a O(n^2) runtime in the number of labels. + # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 + find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py")) + tpl_paths = { + "build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"), + "BUILD": _tpl_path(repository_ctx, "BUILD"), + "tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"), + } + config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"]) trt_version = config["tensorrt_version"] cpu_value = get_cpu_value(repository_ctx) @@ -134,18 +142,26 @@ def _tensorrt_configure_impl(repository_ctx): ] # Set up config file. - _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"}) + repository_ctx.template( + "build_defs.bzl", + tpl_paths["build_defs.bzl"], + {"%{if_tensorrt}": "if_true"}, + ) # Set up BUILD file. - _tpl(repository_ctx, "BUILD", { - "%{copy_rules}": "\n".join(copy_rules), - }) + repository_ctx.template( + "BUILD", + tpl_paths["BUILD"], + {"%{copy_rules}": "\n".join(copy_rules)}, + ) # Set up tensorrt_config.h, which is used by # tensorflow/stream_executor/dso_loader.cc. - _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", { - "%{tensorrt_version}": trt_version, - }) + repository_ctx.template( + "tensorrt/include/tensorrt_config.h", + tpl_paths["tensorrt/include/tensorrt_config.h"], + {"%{tensorrt_version}": trt_version}, + ) tensorrt_configure = repository_rule( implementation = _tensorrt_configure_impl,