cuda_configure: make find_cuda_config() compatible with remote execution
repository_ctx.execute() does not support uploading of files from the source tree. I initially tried constructing a command that simply embeds the file's contents. However that did not work on Windows because the file is larger than 8192 characters. So my best idea was to compress it locally and embed the compressed contents in the command and to uncompress it remotely. This works but comes with the drawback that we need to compress it first. This can't be done as part of the repository_rule either because within one repository_rule every execute() runs either locally or remotely. I thus decided to check in the compressed version in the source tree. It's very much a temporary measure as I'll add the ability to upload files to a future version of Bazel. PiperOrigin-RevId: 295787408 Change-Id: I1545dd86cdec7e4b20cba43d6a134ad6d1a08109
This commit is contained in:
parent
f60fc7a072
commit
b7796f3c85
|
@ -95,6 +95,7 @@ tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
|||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/find_cuda_config.py.gz.base64
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compresses the contents of find_cuda_config.py.oss.
|
||||
|
||||
The compressed file is what is actually being used. It works around remote
|
||||
config not being able to upload files yet.
|
||||
"""
|
||||
import base64
|
||||
import zlib
|
||||
|
||||
|
||||
def main():
|
||||
with open('find_cuda_config.py.oss', 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
compressed = zlib.compress(data)
|
||||
b64encoded = base64.b64encode(compressed)
|
||||
|
||||
with open('find_cuda_config.py.gz.base64.oss', 'wb') as f:
|
||||
f.write(b64encoded)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -579,14 +579,35 @@ def _cudart_static_linkopt(cpu_value):
|
|||
"""Returns additional platform-specific linkopts for cudart."""
|
||||
return "" if cpu_value == "Darwin" else "\"-lrt\","
|
||||
|
||||
def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries):
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
|
||||
# If used with remote execution then repository_ctx.execute() can't
|
||||
# access files from the source tree. A trick is to read the contents
|
||||
# of the file in Starlark and embed them as part of the command. In
|
||||
# this case the trick is not sufficient as the find_cuda_config.py
|
||||
# script has more than 8192 characters. 8192 is the command length
|
||||
# limit of cmd.exe on Windows. Thus we additionally need to compress
|
||||
# the contents locally and decompress them as part of the execute().
|
||||
compressed_contents = repository_ctx.read(script_path)
|
||||
decompress_and_execute_cmd = (
|
||||
"from zlib import decompress;" +
|
||||
"from base64 import b64decode;" +
|
||||
"from os import system;" +
|
||||
"script = decompress(b64decode('%s'));" % compressed_contents +
|
||||
"f = open('script.py', 'wb');" +
|
||||
"f.write(script);" +
|
||||
"f.close();" +
|
||||
"system('%s script.py %s');" % (python_bin, " ".join(cuda_libraries))
|
||||
)
|
||||
|
||||
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
|
||||
|
||||
# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
|
||||
# and nccl_configure.bzl.
|
||||
def find_cuda_config(repository_ctx, script_path, cuda_libraries):
|
||||
"""Returns CUDA config dictionary from running find_cuda_config.py"""
|
||||
exec_result = raw_exec(repository_ctx, [
|
||||
get_python_bin(repository_ctx),
|
||||
script_path,
|
||||
] + cuda_libraries)
|
||||
exec_result = _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries)
|
||||
if exec_result.return_code:
|
||||
auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result))
|
||||
|
||||
|
@ -858,7 +879,7 @@ def _create_local_cuda_repository(repository_ctx):
|
|||
"cuda:cuda_config.h",
|
||||
]}
|
||||
tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD")
|
||||
find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
|
||||
find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
|
||||
cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script)
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -75,7 +75,7 @@ def _nccl_configure_impl(repository_ctx):
|
|||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
|
||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
|
||||
nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
|
||||
if nccl_version:
|
||||
|
|
|
@ -114,7 +114,7 @@ def _tensorrt_configure_impl(repository_ctx):
|
|||
# function to be restarted with all previous state being lost. This
|
||||
# can easily lead to a O(n^2) runtime in the number of labels.
|
||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
|
||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
|
||||
tpl_paths = {
|
||||
"build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
|
||||
"BUILD": _tpl_path(repository_ctx, "BUILD"),
|
||||
|
|
Loading…
Reference in New Issue