cuda_configure: make find_cuda_config() compatible with remote execution

repository_ctx.execute() does not support uploading of files from the source tree. I initially tried constructing a command that simply embeds the file's contents. However that did not work on Windows because the file is larger than 8192 characters. So my best idea was to compress it locally and embed the compressed contents in the command and to uncompress it remotely. This works but comes with the drawback that we need to compress it first. This can't be done as part of the repository_rule either because within one repository_rule every execute() runs either locally or remotely. I thus decided to check in the compressed version in the source tree. It's very much a temporary measure as I'll add the ability to upload files to a future version of Bazel.

PiperOrigin-RevId: 295787408
Change-Id: I1545dd86cdec7e4b20cba43d6a134ad6d1a08109
This commit is contained in:
Jakob Buchgraber 2020-02-18 11:50:58 -08:00 committed by TensorFlower Gardener
parent f60fc7a072
commit b7796f3c85
6 changed files with 67 additions and 7 deletions

View File

@ -95,6 +95,7 @@ tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/find_cuda_config.py.gz.base64
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl

View File

@ -0,0 +1,37 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Compresses the contents of find_cuda_config.py.oss.
The compressed file is what is actually being used. It works around remote
config not being able to upload files yet.
"""
import base64
import zlib
def main():
with open('find_cuda_config.py.oss', 'rb') as f:
data = f.read()
compressed = zlib.compress(data)
b64encoded = base64.b64encode(compressed)
with open('find_cuda_config.py.gz.base64.oss', 'wb') as f:
f.write(b64encoded)
if __name__ == '__main__':
main()

View File

@ -579,14 +579,35 @@ def _cudart_static_linkopt(cpu_value):
"""Returns additional platform-specific linkopts for cudart."""
return "" if cpu_value == "Darwin" else "\"-lrt\","
def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries):
python_bin = get_python_bin(repository_ctx)
# If used with remote execution then repository_ctx.execute() can't
# access files from the source tree. A trick is to read the contents
# of the file in Starlark and embed them as part of the command. In
# this case the trick is not sufficient as the find_cuda_config.py
# script has more than 8192 characters. 8192 is the command length
# limit of cmd.exe on Windows. Thus we additionally need to compress
# the contents locally and decompress them as part of the execute().
compressed_contents = repository_ctx.read(script_path)
decompress_and_execute_cmd = (
"from zlib import decompress;" +
"from base64 import b64decode;" +
"from os import system;" +
"script = decompress(b64decode('%s'));" % compressed_contents +
"f = open('script.py', 'wb');" +
"f.write(script);" +
"f.close();" +
"system('%s script.py %s');" % (python_bin, " ".join(cuda_libraries))
)
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
# and nccl_configure.bzl.
def find_cuda_config(repository_ctx, script_path, cuda_libraries):
"""Returns CUDA config dictionary from running find_cuda_config.py"""
exec_result = raw_exec(repository_ctx, [
get_python_bin(repository_ctx),
script_path,
] + cuda_libraries)
exec_result = _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries)
if exec_result.return_code:
auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result))
@ -858,7 +879,7 @@ def _create_local_cuda_repository(repository_ctx):
"cuda:cuda_config.h",
]}
tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD")
find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script)

File diff suppressed because one or more lines are too long

View File

@ -75,7 +75,7 @@ def _nccl_configure_impl(repository_ctx):
# function to be restarted with all previous state being lost. This
# can easily lead to a O(n^2) runtime in the number of labels.
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
if nccl_version:

View File

@ -114,7 +114,7 @@ def _tensorrt_configure_impl(repository_ctx):
# function to be restarted with all previous state being lost. This
# can easily lead to a O(n^2) runtime in the number of labels.
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
tpl_paths = {
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
"BUILD": _tpl_path(repository_ctx, "BUILD"),