diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 296722936a8..a0ce4305b16 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -101,6 +101,7 @@ tensorflow/third_party/gpus/cuda/cuda_config.h.tpl tensorflow/third_party/gpus/cuda/cuda_config.py.tpl tensorflow/third_party/gpus/cuda_configure.bzl tensorflow/third_party/gpus/find_cuda_config.py +tensorflow/third_party/gpus/find_cuda_config.py.gz.base64 tensorflow/third_party/gpus/rocm/BUILD tensorflow/third_party/gpus/rocm/BUILD.tpl tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl diff --git a/third_party/gpus/compress_find_cuda_config.py b/third_party/gpus/compress_find_cuda_config.py new file mode 100644 index 00000000000..606bbf2cdd5 --- /dev/null +++ b/third_party/gpus/compress_find_cuda_config.py @@ -0,0 +1,37 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Compresses the contents of 'find_cuda.py'. + +The compressed file is what is actually being used. It works around remote +config not being able to upload files yet. +""" +import base64 +import zlib + + +def main(): + with open('find_cuda.py', 'rb') as f: + data = f.read() + + compressed = zlib.compress(data) + b64encoded = base64.b64encode(compressed) + + with open('find_cuda.py.gz.base64', 'wb') as f: + f.write(b64encoded) + + +if __name__ == '__main__': + main() + diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index c09a22a73c0..70bb91159de 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -605,19 +605,42 @@ def _cudart_static_linkopt(cpu_value): """Returns additional platform-specific linkopts for cudart.""" return "" if cpu_value == "Darwin" else "\"-lrt\"," +def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries): + python_bin = get_python_bin(repository_ctx) + + # If used with remote execution then repository_ctx.execute() can't + # access files from the source tree. A trick is to read the contents + # of the file in Starlark and embed them as part of the command. In + # this case the trick is not sufficient as the find_cuda_config.py + # script has more than 8192 characters. 8192 is the command length + # limit of cmd.exe on Windows. Thus we additionally need to compress + # the contents locally and decompress them as part of the execute(). + compressed_contents = repository_ctx.read(script_path) + decompress_and_execute_cmd = ( + "from zlib import decompress;" + + "from base64 import b64decode;" + + "from os import system;" + + "script = decompress(b64decode('%s'));" % compressed_contents + + "f = open('script.py', 'wb');" + + "f.write(script);" + + "f.close();" + + "system('\"%s\" script.py %s');" % (python_bin, " ".join(cuda_libraries)) + ) + + return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd]) + # TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl, # and nccl_configure.bzl. -def find_cuda_config(repository_ctx, cuda_libraries): +def find_cuda_config(repository_ctx, script_path, cuda_libraries): """Returns CUDA config dictionary from running find_cuda_config.py""" - python_bin = get_python_bin(repository_ctx) - exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_cuda_config] + cuda_libraries) + exec_result = _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries) if exec_result.return_code: auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result)) # Parse the dict from stdout. return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()]) -def _get_cuda_config(repository_ctx): +def _get_cuda_config(repository_ctx, find_cuda_config_script): """Detects and returns information about the CUDA installation on the system. Args: @@ -632,7 +655,7 @@ def _get_cuda_config(repository_ctx): compute_capabilities: A list of the system's CUDA compute capabilities. cpu_value: The name of the host operating system. """ - config = find_cuda_config(repository_ctx, ["cuda", "cudnn"]) + config = find_cuda_config(repository_ctx, find_cuda_config_script, ["cuda", "cudnn"]) cpu_value = get_cpu_value(repository_ctx) toolkit_path = config["cuda_toolkit_path"] @@ -928,8 +951,9 @@ def _create_local_cuda_repository(repository_ctx): "cuda:cuda_config.py", ]} tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD") + find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64")) - cuda_config = _get_cuda_config(repository_ctx) + cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script) cuda_include_path = cuda_config.config["cuda_include_dir"] cublas_include_path = cuda_config.config["cublas_include_dir"] @@ -1370,20 +1394,12 @@ remote_cuda_configure = repository_rule( remotable = True, attrs = { "environ": attr.string_dict(), - "_find_cuda_config": attr.label( - default = Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"), - ), }, ) cuda_configure = repository_rule( implementation = _cuda_autoconf_impl, environ = _ENVIRONS + [_TF_CUDA_CONFIG_REPO], - attrs = { - "_find_cuda_config": attr.label( - default = Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"), - ), - }, ) """Detects and configures the local CUDA toolchain. diff --git a/third_party/gpus/find_cuda_config.py.gz.base64 b/third_party/gpus/find_cuda_config.py.gz.base64 new file mode 100644 index 00000000000..981219bb10a --- /dev/null +++ b/third_party/gpus/find_cuda_config.py.gz.base64 @@ -0,0 +1 @@ +eJzdPGtT40iS3/0r6tRHINNGwOzGxJ5vmQsGum/Y5aAD3D23AV5vIZeNpmXJJ8kG78b+98vMqpKqSpINpnu2Y4iYHkuqzMrMyme93rDTdL7KoulDwb47PPoPNngQbCCSPM3ex+kjO1kUD2mWB+wkjtk1NsvZtchFthTjoPOm84ZdRCE0F2O2SMYiYwXAn8x5CP9TX3rsk8jyKE3Yd8Eh87GBpz553f8EDKt0wWZ8xZK0YItcAIooZ5MoFkw8hWJesChhYTqbxxFPQsEeo+KBulFIgAz2F4UivS84tObQfg5PE7Md4wURjH8PRTHvHxw8Pj4GnIgN0mx6EMuG+cHF+em7y5t3+0AwgXxMYpHnLBP/t4gyYPV+xfgc6An5PVAZ80eWZoxPMwHfihTpfcyiIkqmPZank+KRZwKwjKO8yKL7RWEJS1MHPJsNQFw8Yd7JDTu/8diPJzfnNz3A8fP54KerjwP288n19cnl4PzdDbu6ZqdXl2fng/OrS3h6z04u/8L+fH551mMCRAXdiKd5hvQDkRGKkYaO3QhhETBJJUH5XITRJAqBr2S64FPBpulSZAmww+Yim0U5DmYO5I0BSxzNooIX9KbGFHZz/EX/Op7nfciiBNTw9OPZCXR/n/FshcSwB8Gx/zEMUVikWSSIRraU2gcqlQKBKFjicpUXYhZ0OqjweZhFoGe54BnoQk6iaEOPipnbWHow4ii1Iu/AyxmqwFgUKKqERBxlmghCNJf0I3yYJpNoushIgAiXF+N0UQRE1ZwXD7nUJ8JOwAhV6mHJGiiYHjdUwYcsXUwfmEiWUZYmM5EUnSXPItRWMOXzCZgaW/I4GjsEREpIPcmclIoml4gTWUYDn4likZESMHgF4grTsVDSjEGN0fak8HAYAHYSAfGAv6KSI9nTBVIHRN0s5vM0Q82vwNBsaBj8KAnjxRhehYsfL05uuj34cXZ52WOXp6cXPRKMdFrXA3tIC/4ZEZU03XPQdFNDKnrAqJHrwfsR9jn6cDL46aZjiJBpESLl4I9mfD8Xcw6iA+BpnN5TJwEzeo/T9LPUJqk8eQcp1UolNYnc1QPPxvsowjEoIRGaL+5NMidZOkPygHpigXQj6MBQWvSihNGNllyBbNjVTWXTYzHhi7jAduBqx/1Oh4GxJounPnjFg0WeHcRpyOODcDHmPXohpSt1kajYjcdSa9j+fDcA+J9haNPHvM9KOliT2Hrgfzqs+jvt3919yNJpxmfsPYri7u7y0/nZ+Qn77w8fISzN5gt0oWyQpvHnqLi7Q/R3d3udznsQ6D0PP4NfHdNQAHXRfRRHxQod7kyYOhTnKUUVHoNFJtBwKUXYqYk3lc6ygXQYJQolK9vQAhAdqVVHDel+KedGLEracvRZGRahbaVb7GQ8jtAYeVxTVoDef+UfoCCLoj+tOp/eXd9A8NDjgq8GV1cXfz4f0GgCjDS7EgYfLKhmGLBQo5/LS7sj+er88mZwcnGhgdCeS+LwwSGOXpkw0gWMfjq71ii0IyAUA4jjV9fXAwNN+crquvNJ+9JmBQjB090LnVKASYMhPO1iSN19Cla76OLQLQvQB94p1UA5aIjEs3mxwtaLxPDTKeMhZTg8Wem2yn9BFAD9x/64ClzabYDpGlT0O8Vk9Ef17YeRQtJnQFTwd/ub9EEjUKg+C4LA/qh+VF8h0HY6kC6AU2ZRqn+luf6FHk//nse8QGL0M/ht9Qvc2DxLQ8g+yjervANpwXwVQ1jpY86DAj6e7svv++C+9nmxX6TzTpGt+jCgZJz5A3iDmCkkjw9R+NBR6eE5vXuHYahsjqkUAuRBPuePiYbDKDQSTyJcFNqTS1QVQSJppafTCWMOedQpuT/q0H9HNIDIu9j3HD5DM/CybBTlIFTwrD59kQGzlFQgMwe/C/kRpsPQzjMAH6VLfQaocr4m8IyH6XNAz3gG/ZSQM15g8qNVyOdhseDxqFRhnXXoN9QBaMnpgwg/oxwF+U/yvzrdmQmhMh0FLaN9R8UBLCpELF9qpXa7YYnKqsH6MI1Fs4wo+1JIbDpL3PhXw+U0xhY5hEMD5CWeVP8dMffvKLDfDbKFMAGC71wA5/k9xK0NEMHv1kLU/9w+iCg9Etk075fQrtz6lNRoodVSJgivWVAC2yK2QTEtDjFf0vmMjEaQ4BY8jikJ1YiuZYbZZz8rxcLWEnelXlJl7aQ4TUQgNRP+hbjtDDm4z0toIplV9iFlB2+gKoIKbA4s8tDIuspEGxE7+I5dBcQabu53Ow3qd1x7ZbUmWmrYeFbkWPP6NQPUlsuLUSx4XrzMdJUsj9kt+D1/2aU8dUn5qEMClIyF7wVed2gw1QRY584BNThkP1TS0IxMRaGjlGYFE88eS/hMaH+jFAOdgQot0D2AQ7FF2fSpWakFUg2QwhirMfgepUE6Fxqzl3kQnRPIvSHXPPYWxWT/D143yAAFAig/yqSq0QAG9NP33sg+2U7O3vp347ddj+0QoT3qqktgoH/UXNuWEgG9C6ZQps39I2PwvdIbhzDsIo94MgImx4uw8CdRlheQ3wrIvseuMMB6KEfHbPg+SlRBDi6VoCiFl4BKIKq/2zQPECz4JY0SfwLY5WhOUFASFB9JrhJ+aI5VPB7JUmBE9YHfRFUtzbbrh9JMsWxxYmZFJmqPBqO+YCQocPuefu11Mb3yDnJg/6B8CXAqkzo2cpEgxJg1kl/8WwszqMT+3Bt2KZoXWDPIUccyAxQKlHmPHf/A/GCv68mRwwiC6EVBdmzqmuxBmoGlTiq1MTVL9abUC1sHY4E1oe/xPIwirytVSmU9H5MIP55RkzL5MRFWEaGuhUhzwMdjXysADBPqrm9rZtfQzZzKcx+LXF8i6HZNZVCl5QhLR6UP9NP0OWqczfeSJvMNkO7t4ciJWLUHB0KTgJvArMe3AAZoZK+1pKrSrTJwgShU7h8AP74VS72ytvV6zoftytjlDjQEj2Ehs+RFXwy3ees59fn+To4uxwQC5XUaeeqVSbbTJkm8IUir2aL1CCu33Gzn9emLTMSy1C4nsXAuxJgKkbMfjjdSJHolsZ6a+Km9ULy5b/f2yX3sT5NF9VE8FRnPD04/fhicH6xFKNvQt9LP6cLoC3BuTfxsZB1af/9769F6aOYVPzyZYBbzezZrYFojqi1HNIfn43yD5LNXsqGfkWbloUgGVVs0V08GkFvvLtFKtruT76J6krdG7ulHVPqRCkF3KL2M2eFLkdrQQ8NtmcWaFsppuojH5FloWhKr7x05Ewe/jJm3VR+sNEkn9L/KWH1DGHVJGYyV1kN1J0K9SMYmgw0JQTsq5eEokuM6CsBjuR7gP74V8mXEMzvuOqkKfkNp8igX26uMJQhlBxawNjOZQLXVmtrsZGoip7BJQmBmuggQeTWLvKryi4YIYBAI+mZSQMEDLU5HIaeudmG9nXwvGK8QArQTE16tu/TqrYUcnO3aOs3+W5NT3/aPhhShBQT6RrLW0gGRIU+9DdRs6h4ltWc4shZVd93oGu1QHYmXGUzea+fC5aHHMNS9QLW4NKPigRe1mlOqm5ZOqW7Fai5MPwGJNfobsMRbzKB6rFjMYzGsjRp6vlvzefgFHYHWjCixutQW/zp/0VAZG2ImWXTLptG2E07VmJbeqed03O6tFPwaldIRAmxD1UTWEIKySwXAjE7zqYNSa31Oai3THmusVE7V7PK2UFKVWT1bTS2TXWN2TvJnE77Bla1ny5IQJbIy+3RUuqYH0CnCtcwXKEWp9BDnKdonFzxzEURXWarysKqNUmI4gWTWEN7OONgZk+vXnR4csKPDw8NeScUOPcv33U5HJfuGZEuxVnS3ao6HwMGDt0n6zxqRLXDUZdnt1Koxhx32BhclcAJDMH6fLkVQSsEMDiXbTVkC8Z0VXs8plQx9SJZh2KQNRlQ8BVgm63m98J7G0hcIXHK7G7+9C+CfHvvk69/0T9cr3amu8FsnFfRcwr6ixRt2G+YBGiaWysSyfQqgXtG3zyzZSgsvSECU5IAo8CEQT8KrpUiUVMgGnoZSk3GGhMvRWuM+Gmqb+yiRpUivIqfnVLLuYHYl8cuZrSbt3WHbA9CisVhGoVHx5VC0uQUk1k1Q+48jvk+ldSGrdQccyPXKN3vB0WFwH3rKmOdFZFpzO301b+oRMFhzt0T0EoOA9l6DkyyNC6cD5cIeYHKne8ph7WrhGguBje2V/Ltyuvx/L0501zLWKMGVUYkKrHnxxOVmnFJ6AYEPrs6u/DCPptNunxbP0KTuUwA1pw1xeQBkFmiOVB9lpajjuiY2STPcEvBgpy6OMDDCQxLb7T0L2BUNQR+UCLoy6atTd3s4ZP923PThSOV+MlUxq1TvPIEomEOWiGtdtEBiSrWPteoyZzT70+K76/3ReClP8A+t9abBeX3b/qw2apIEOYd2rlK4kaxrAxtiawM29d2BrkZMU2iMoWnqTj+1EauwooVt4sgx564Dvokn14gdpsyh0Wy5eoMQ/3QSpPuY589JkZz42NezoLWVImfyEDzakSebPyfDMnMsc750w0KO0RJDKcUisHvfq3aWjP7n5E9X1x4lZ9W780t815avmE0/nAxOf6rCpc7TdD1sOknGtk/CaDz4PNomFXt1HtachKFbVdt1FrlyyjP+S1rtQUyTeBWo6XNioC1lM2v9w2HHmmh4g+u6mZAbzkrUUTKBQmTJIbHC7RX3AsZXrfGiavXYL4tcTbrRBhZcoTOHoHkq3x2kl8VWY4wc+RhLwrTfEfWQq42l+psPOWu3RVruCKKEXhK5ER2lsibeFkdtNgGPYcur57Rb79xa3BpBrvdrNY/m+qc8jYGkr+qhjoLDf6mHurm6QAfj+Cjz7UYvZTT+lfxUOS6zGZj1t+SsStq2cUJf0Wu4Emul9rVOQCIkN2DjbnEEdqN1rkC13MoZKNhXuoOMJ+PftDO4Prk8c11B9W6jIyib/kpuAMfjG7N+UpFvzPaVmFpIfK3BIzoydxNvi7GbTdaZOrXbytAJ8pVmPpkUv2krf/9+4Bp5+UrauPnqV7JlEPo3ZsqoBt+YJUshNRP4WjsGbGTGBtYWKzZarDNibLaVDSPgK014nCRfZ+ml2l71XPtzrQ9PiRiWR4+V1eEjWdzFu0/vLuQ+03ajQ3+jfuPMHE1x63nx7UzxZebjeyRqNFwmf5YGAmr6QmSvNOXWNZyKqGeY8YuWbpJErtxUPbRNS1Yt5PRY9dyzW21lMAj42kp3zjNg7rcc9m4+nFzfvKtVusbbzZVu1fjXqnRpXL6xyKiU5RsLjqWoWsl8dW1LCGVta+Fuq22tRmtrW9lyu9pWwr7KASRhGH+dgFnZ7nYBk45DlvYqn3S4pCfLENeZ4bZGiLJ5uQF+lWhGw/Q1ghkiBoma+Bt12myAi1LGY89qs4UmE9yrtLigo7LZs6q3f4Emf6qO71YKbb5cF3+slpbOv2F/U0T8Daf+OTCUiIwX0Hl6/4sI8ZBNyh7pdgZoWLCokPur6BIBtaaxyKNkqtDBd3+2iIpojpsEo5nIoVWcPqp1g0oidHbCig1qmxGPY7+SszaIctsGdvJJfpbn+9UtDTP+WdT20TF9NwOd/wg2mnl5FGVbe79cnicTkX1DMVcdkTEW1ctFK304/Hs2S5dibJ7uo9WrbFbexaH4+qQT9OBLSKnE9o0ICxVAe4FnOsoX+cllhFyD5bq9NLpLtxG4NPdVr9Z2C9dZwm7rPtGSR3icbSSSpQ//qS1M6hDU8e3Q3bLp3hhCzoQOxDVdOFBtJNa46ThZeULJOsJUvb7VrYd6zHqecRZCUWcd4RNTHkoOn8UGlzt7rDPA4BkJSzMnOM71ux56TVc5NF7NQAdrihSw7MpzS3jE5Q/fj77/fXUAZpc90pGOeSaWUbrI45VcZKatRPVLfDpMXaDCzuhAMniBFK8IgQZKUtTVbk/eCfIYyeuR8L4RdAzO6Het48brxqt+iPSv/sHtXw/YcK/7lhi7e3xbcXXwX/8OhtM0umsPl97aewCNI2Sb1LZbnQwCPxhHf5eGXIV3QxHKJnIHck/ey6WllMc8f5D3JKkbCpRGK5ehLS0TPK666LSd1FBQCmQe81D43t0dHW4zFRwbKB5qW4rrmjyOQrppJaOrwtZd+KTURbJQXS5zzG55Nl0GEOvB3cujs/iCNoau8gB/3x71h8NObXOsc9jQc+5iQc6IMeuclTN+nnX3z8aA8qxTmrU+aSOrdQLB+AyjpUdSPAFtuRzHofTt6mDsP/4px5W2PtFBzlKCxinONh5rV8t41hEr5QfpPOtiPgbX6td3k1cdOJMxarrBGRuJ7tbeHjdULWmbhCa2IkQbJB3r8N2D8SYmI5p22R+Zf3TYY0dl1veG/VjbNKO2qzzyXN18IbeqlKfUcNOKsfdS36Dg0FoJweRkg0aaN/2UOukKvMrzG7apmVS4e13aRkOtPn9RKYMcD0sp17qoCae2zt8onmonx8sEZO+Tsamp7wRoExOt3X1NIVkd1ETkrIs2CkivcL9MPOa+AZMKd9W0TTC4IPI15WLir4nFXmVqlIpaKlwrlIbFVaNjZ+Gp1ZDkVNdXNSS7i7ohuZOKzYZUThS/0JCsaXibmvq0Y5OYVGBKkpbIlCRVaHITZq+e1zZEJ3dJpTn0V9ejbVQJY7HOINBdWSm5o7mqBuZo/qidt1p+3sCaM7/WwJl5h9sGxsw51Yo2Z5Kt5EqXcE2cleVdO3eNpUYDhw3FcQOX7jVzGzh1591sehtqZeIazfNzT9+0Q1ijQszKJBmk8hkoG8vbgjyqaelSFOstduEZk0yU6nweopCcrH9plujqoiyZW894lMhey8kjIk6sengocSGMc/71PFxTXRJBN4z63k7el4camF9h6jbP52DKo1ZIMNGWN5MGeOuu8PMi84Wa68OPkJkWeBaq08Higiqe0YhuQBuNkJPRyENMkqnO/wPiKjLf \ No newline at end of file diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl index 03642bcf04a..d59e861d70b 100644 --- a/third_party/nccl/nccl_configure.bzl +++ b/third_party/nccl/nccl_configure.bzl @@ -64,11 +64,17 @@ def _label(file): return Label("//third_party/nccl:{}".format(file)) def _create_local_nccl_repository(repository_ctx): + # Resolve all labels before doing any real work. Resolving causes the + # function to be restarted with all previous state being lost. This + # can easily lead to a O(n^2) runtime in the number of labels. + # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 + find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64")) + nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "") if nccl_version: nccl_version = nccl_version.split(".")[0] - cuda_config = find_cuda_config(repository_ctx, ["cuda"]) + cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"]) cuda_version = cuda_config["cuda_version"].split(".") cuda_major = cuda_version[0] cuda_minor = cuda_version[1] @@ -90,7 +96,7 @@ def _create_local_nccl_repository(repository_ctx): ) else: # Create target for locally installed NCCL. - config = find_cuda_config(repository_ctx, ["nccl"]) + config = find_cuda_config(repository_ctx, find_cuda_config_path, ["nccl"]) config_wrap = { "%{nccl_version}": config["nccl_version"], "%{nccl_header_dir}": config["nccl_include_dir"], @@ -139,20 +145,12 @@ remote_nccl_configure = repository_rule( remotable = True, attrs = { "environ": attr.string_dict(), - "_find_cuda_config": attr.label( - default = Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"), - ), }, ) nccl_configure = repository_rule( implementation = _nccl_autoconf_impl, environ = _ENVIRONS, - attrs = { - "_find_cuda_config": attr.label( - default = Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"), - ), - }, ) """Detects and configures the NCCL configuration. diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl index d26fa2a34d4..9c980a92cf8 100644 --- a/third_party/tensorrt/tensorrt_configure.bzl +++ b/third_party/tensorrt/tensorrt_configure.bzl @@ -88,13 +88,14 @@ def _create_local_tensorrt_repository(repository_ctx): # function to be restarted with all previous state being lost. This # can easily lead to a O(n^2) runtime in the number of labels. # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 + find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64")) tpl_paths = { "build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"), "BUILD": _tpl_path(repository_ctx, "BUILD"), "tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"), } - config = find_cuda_config(repository_ctx, ["tensorrt"]) + config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"]) trt_version = config["tensorrt_version"] cpu_value = get_cpu_value(repository_ctx) @@ -190,16 +191,12 @@ remote_tensorrt_configure = repository_rule( remotable = True, attrs = { "environ": attr.string_dict(), - "_find_cuda_config": attr.label(default = "@org_tensorflow//third_party/gpus:find_cuda_config.py"), }, ) tensorrt_configure = repository_rule( implementation = _tensorrt_configure_impl, environ = _ENVIRONS + [_TF_TENSORRT_CONFIG_REPO], - attrs = { - "_find_cuda_config": attr.label(default = "@org_tensorflow//third_party/gpus:find_cuda_config.py"), - }, ) """Detects and configures the local CUDA toolchain.