Merge pull request #43636 from ROCmSoftwarePlatform:google_upstream_rocm_version_defines
PiperOrigin-RevId: 341361119 Change-Id: I7c52235336cd81aeab3e20e60b1e9356c520b36b
This commit is contained in:
commit
cd24d4b345
36
third_party/gpus/compress_find_rocm_config.py
vendored
Normal file
36
third_party/gpus/compress_find_rocm_config.py
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compresses the contents of 'find_rocm_config.py'.
|
||||
|
||||
The compressed file is what is actually being used. It works around remote
|
||||
config not being able to upload files yet.
|
||||
"""
|
||||
import base64
|
||||
import zlib
|
||||
|
||||
|
||||
def main():
|
||||
with open('find_rocm_config.py', 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
compressed = zlib.compress(data)
|
||||
b64encoded = base64.b64encode(compressed)
|
||||
|
||||
with open('find_rocm_config.py.gz.base64', 'wb') as f:
|
||||
f.write(b64encoded)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
276
third_party/gpus/find_rocm_config.py
vendored
Normal file
276
third_party/gpus/find_rocm_config.py
vendored
Normal file
@ -0,0 +1,276 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Prints ROCm library and header directories and versions found on the system.
|
||||
|
||||
The script searches for ROCm library and header files on the system, inspects
|
||||
them to determine their version and prints the configuration to stdout.
|
||||
The path to inspect is specified through an environment variable (ROCM_PATH).
|
||||
If no valid configuration is found, the script prints to stderr and
|
||||
returns an error code.
|
||||
|
||||
The script takes the directory specified by the ROCM_PATH environment variable.
|
||||
The script looks for headers and library files in a hard-coded set of
|
||||
subdirectories from base path of the specified directory. If ROCM_PATH is not
|
||||
specified, then "/opt/rocm" is used as it default value
|
||||
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_default_rocm_path():
|
||||
return "/opt/rocm"
|
||||
|
||||
|
||||
def _get_rocm_install_path():
|
||||
"""Determines and returns the ROCm installation path."""
|
||||
rocm_install_path = _get_default_rocm_path()
|
||||
if "ROCM_PATH" in os.environ:
|
||||
rocm_install_path = os.environ["ROCM_PATH"]
|
||||
# rocm_install_path = os.path.realpath(rocm_install_path)
|
||||
return rocm_install_path
|
||||
|
||||
|
||||
def _get_composite_version_number(major, minor, patch):
|
||||
return 10000 * major + 100 * minor + patch
|
||||
|
||||
|
||||
def _get_header_version(path, name):
|
||||
"""Returns preprocessor defines in C header file."""
|
||||
for line in io.open(path, "r", encoding="utf-8"):
|
||||
match = re.match(r"#define %s +(\d+)" % name, line)
|
||||
if match:
|
||||
value = match.group(1)
|
||||
return int(value)
|
||||
|
||||
raise ConfigError('#define "{}" is either\n'.format(name) +
|
||||
" not present in file {} OR\n".format(path) +
|
||||
" its value is not an integer literal")
|
||||
|
||||
|
||||
def _find_rocm_config(rocm_install_path):
|
||||
|
||||
def rocm_version_numbers(path):
|
||||
version_file = os.path.join(path, ".info/version-dev")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError('ROCm version file "{}" not found'.format(version_file))
|
||||
version_numbers = []
|
||||
with open(version_file) as f:
|
||||
version_string = f.read().strip()
|
||||
version_numbers = version_string.split(".")
|
||||
major = int(version_numbers[0])
|
||||
minor = int(version_numbers[1])
|
||||
patch = int(version_numbers[2].split("-")[0])
|
||||
return major, minor, patch
|
||||
|
||||
major, minor, patch = rocm_version_numbers(rocm_install_path)
|
||||
|
||||
rocm_config = {
|
||||
"rocm_version_number": _get_composite_version_number(major, minor, patch)
|
||||
}
|
||||
|
||||
return rocm_config
|
||||
|
||||
|
||||
def _find_hipruntime_config(rocm_install_path):
|
||||
|
||||
def hipruntime_version_number(path):
|
||||
version_file = os.path.join(path, "hip/include/hip/hip_version.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'HIP Runtime version file "{}" not found'.format(version_file))
|
||||
# This header file has an explicit #define for HIP_VERSION, whose value
|
||||
# is (HIP_VERSION_MAJOR * 100 + HIP_VERSION_MINOR)
|
||||
# Retreive the major + minor and re-calculate here, since we do not
|
||||
# want get into the business of parsing arith exprs
|
||||
major = _get_header_version(version_file, "HIP_VERSION_MAJOR")
|
||||
minor = _get_header_version(version_file, "HIP_VERSION_MINOR")
|
||||
return 100 * major + minor
|
||||
|
||||
hipruntime_config = {
|
||||
"hipruntime_version_number": hipruntime_version_number(rocm_install_path)
|
||||
}
|
||||
|
||||
return hipruntime_config
|
||||
|
||||
|
||||
def _find_miopen_config(rocm_install_path):
|
||||
|
||||
def miopen_version_numbers(path):
|
||||
version_file = os.path.join(path, "miopen/include/miopen/version.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'MIOpen version file "{}" not found'.format(version_file))
|
||||
major = _get_header_version(version_file, "MIOPEN_VERSION_MAJOR")
|
||||
minor = _get_header_version(version_file, "MIOPEN_VERSION_MINOR")
|
||||
patch = _get_header_version(version_file, "MIOPEN_VERSION_PATCH")
|
||||
return major, minor, patch
|
||||
|
||||
major, minor, patch = miopen_version_numbers(rocm_install_path)
|
||||
|
||||
miopen_config = {
|
||||
"miopen_version_number":
|
||||
_get_composite_version_number(major, minor, patch)
|
||||
}
|
||||
|
||||
return miopen_config
|
||||
|
||||
|
||||
def _find_rocblas_config(rocm_install_path):
|
||||
|
||||
def rocblas_version_numbers(path):
|
||||
version_file = os.path.join(path, "rocblas/include/rocblas-version.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'rocblas version file "{}" not found'.format(version_file))
|
||||
major = _get_header_version(version_file, "ROCBLAS_VERSION_MAJOR")
|
||||
minor = _get_header_version(version_file, "ROCBLAS_VERSION_MINOR")
|
||||
patch = _get_header_version(version_file, "ROCBLAS_VERSION_PATCH")
|
||||
return major, minor, patch
|
||||
|
||||
major, minor, patch = rocblas_version_numbers(rocm_install_path)
|
||||
|
||||
rocblas_config = {
|
||||
"rocblas_version_number":
|
||||
_get_composite_version_number(major, minor, patch)
|
||||
}
|
||||
|
||||
return rocblas_config
|
||||
|
||||
|
||||
def _find_rocrand_config(rocm_install_path):
|
||||
|
||||
def rocrand_version_number(path):
|
||||
version_file = os.path.join(path, "rocrand/include/rocrand_version.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'rocblas version file "{}" not found'.format(version_file))
|
||||
version_number = _get_header_version(version_file, "ROCRAND_VERSION")
|
||||
return version_number
|
||||
|
||||
rocrand_config = {
|
||||
"rocrand_version_number": rocrand_version_number(rocm_install_path)
|
||||
}
|
||||
|
||||
return rocrand_config
|
||||
|
||||
|
||||
def _find_rocfft_config(rocm_install_path):
|
||||
|
||||
def rocfft_version_numbers(path):
|
||||
version_file = os.path.join(path, "rocfft/include/rocfft-version.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'rocfft version file "{}" not found'.format(version_file))
|
||||
major = _get_header_version(version_file, "rocfft_version_major")
|
||||
minor = _get_header_version(version_file, "rocfft_version_minor")
|
||||
patch = _get_header_version(version_file, "rocfft_version_patch")
|
||||
return major, minor, patch
|
||||
|
||||
major, minor, patch = rocfft_version_numbers(rocm_install_path)
|
||||
|
||||
rocfft_config = {
|
||||
"rocfft_version_number":
|
||||
_get_composite_version_number(major, minor, patch)
|
||||
}
|
||||
|
||||
return rocfft_config
|
||||
|
||||
|
||||
def _find_roctracer_config(rocm_install_path):
|
||||
|
||||
def roctracer_version_numbers(path):
|
||||
version_file = os.path.join(path, "roctracer/include/roctracer.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'roctracer version file "{}" not found'.format(version_file))
|
||||
major = _get_header_version(version_file, "ROCTRACER_VERSION_MAJOR")
|
||||
minor = _get_header_version(version_file, "ROCTRACER_VERSION_MINOR")
|
||||
# roctracer header does not have a patch version number
|
||||
patch = 0
|
||||
return major, minor, patch
|
||||
|
||||
major, minor, patch = roctracer_version_numbers(rocm_install_path)
|
||||
|
||||
roctracer_config = {
|
||||
"roctracer_version_number":
|
||||
_get_composite_version_number(major, minor, patch)
|
||||
}
|
||||
|
||||
return roctracer_config
|
||||
|
||||
|
||||
def _find_hipsparse_config(rocm_install_path):
|
||||
|
||||
def hipsparse_version_numbers(path):
|
||||
version_file = os.path.join(path, "hipsparse/include/hipsparse-version.h")
|
||||
if not os.path.exists(version_file):
|
||||
raise ConfigError(
|
||||
'hipsparse version file "{}" not found'.format(version_file))
|
||||
major = _get_header_version(version_file, "hipsparseVersionMajor")
|
||||
minor = _get_header_version(version_file, "hipsparseVersionMinor")
|
||||
patch = _get_header_version(version_file, "hipsparseVersionPatch")
|
||||
return major, minor, patch
|
||||
|
||||
major, minor, patch = hipsparse_version_numbers(rocm_install_path)
|
||||
|
||||
hipsparse_config = {
|
||||
"hipsparse_version_number":
|
||||
_get_composite_version_number(major, minor, patch)
|
||||
}
|
||||
|
||||
return hipsparse_config
|
||||
|
||||
|
||||
def find_rocm_config():
|
||||
"""Returns a dictionary of ROCm components config info."""
|
||||
rocm_install_path = _get_rocm_install_path()
|
||||
if not os.path.exists(rocm_install_path):
|
||||
raise ConfigError(
|
||||
'Specified ROCM_PATH "{}" does not exist'.format(rocm_install_path))
|
||||
|
||||
result = {}
|
||||
|
||||
result["rocm_toolkit_path"] = rocm_install_path
|
||||
result.update(_find_rocm_config(rocm_install_path))
|
||||
result.update(_find_hipruntime_config(rocm_install_path))
|
||||
result.update(_find_miopen_config(rocm_install_path))
|
||||
# result.update(_find_rocblas_config(rocm_install_path))
|
||||
# result.update(_find_rocrand_config(rocm_install_path))
|
||||
# result.update(_find_rocfft_config(rocm_install_path))
|
||||
# result.update(_find_roctracer_config(rocm_install_path))
|
||||
# result.update(_find_hipsparse_config(rocm_install_path))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
for key, value in sorted(find_rocm_config().items()):
|
||||
print("%s: %s" % (key, value))
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\nERROR: {}\n\n".format(str(e)))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
third_party/gpus/find_rocm_config.py.gz.base64
vendored
Normal file
1
third_party/gpus/find_rocm_config.py.gz.base64
vendored
Normal file
@ -0,0 +1 @@
|
||||
eJzNWlFv2zgSftevIBQUlTaOkvbpkEMevGkW9V2bBHZ2F4smMGiJtrmVRS1JxTWK/vebISlZkuXEiZPiDASxxJlvhjPfDCnRB+Rc5CvJZ3NN3p+8PyE3c0ZuWKaE/C0VS9Iv9FxIFZF+mpIhiikyZIrJe5ZE3oF3QD7xGMRZQoosYZJo0O/nNIZ/bqRH/mBScZGR99EJCVDAd0N++G9AWImCLOiKZEKTQjGA4IpMecoI+xazXBOekVgs8pTTLGZkyfXcmHEg4Ab5y0GIiaYgTUE+h6tpXY5QbRzGz1zr/PT4eLlcRtQ4Gwk5O06toDr+NDi/uBxdHIHDRuX3LGVKEcn+KbiEqU5WhObgT0wn4GVKl0RIQmeSwZgW6O9Scs2zWY8oMdVLKhmgJFxpySeFbgSr9A7mXBeAcNGM+P0RGYx88mt/NBj1AOPPwc3Hq99vyJ/94bB/eTO4GJGrITm/uvwwuBlcXcLVb6R/+Rf57+DyQ48wCBWYYd9yif6DkxzDaFJHRow1HJgK65DKWcynPIZ5ZbOCzhiZiXsmM5gOyZlccIXJVOBeAigpX3BNtbmzMSk0c/aiH8/3/WvJM6Th1fkCzE8klSt0hswZRfsJpCjWQnJmfCT3ln1AKQEOYmDNLFdKs0XkeUh4FUsOPFOMSuCCMqHYBo/EVE2UHmQco6aVBzcXSIGEaQxVZkLMZemEAcqt/6gfi2zKZ4U0AUQ9pRNR6Mh4lVMkuijBkSEuN0izuRTFbI4kYdk9lyJbsEyTeyq5IWUA/n8eX/dvPoaRN5hCccFYypOWSe7C0rPTsXEoHTTuMClNqiXThTRpJ3ALAhSLhDXjp+lXZudV5mBV8xiKBocqvzr9jup4qRBfbTJs7G0+y5zYRJhqn1OZHKE/CeRQQ917qpjUeTCVYkEmVLmgusaw9q3yNyIQq7WLEB7oSl4laMIEZXkscn0sRbzwUaTA9kfBFw15n9IixfmkBfOQrZ4HNSckpE+U34Qqv0FfcN+ASZ7nxSmFOj03KbrAKAcXpgVCqsJTj4D3CsXAChnPmB47c2N0ZYxTC4yYzVXdzbqSEQZSaZqmNSXw9UPJWhvpMuUubQvilCx1UDPCCYK9NiI52+ofiPMp8asY+5hCoSLHBvSkG3At86WmfQfyB9vkjYuS0dSY3hAK16HaGKsHDBcfobhmY1fG46xYTJgMFvRvIXsEIob/QC2e1+P/7gQ+5BdixMghXuMVSsOVEa+bsSQvbQToRY9kdMHK9AxdPqCb5+AwtHQAAnWTMIjieb1HudRg9aTYh2Cci0jkrET2pQ8rRAZlA539zC/09Ohffmjjv0DfIIaSReZrIP0Da4i8UeQwuE0OQ5+8Md71DH5o9CCzRt6iEFsGgGNuRjNoWXnwLnSDLkrQawIjF3oYOsqhSusV8La07H//YcrNrmq32dsIJgfIgYkROXSwzY9PzL4CF0DsMxAFs7P4/gPWzdvMLyEMHbZDcGiHdi62I2AXBL/ZjGF0oWho6odlLsHZxDLe9toO4p3iVFHYDDVZpQInYgLohozTa07/LXiVxohnU3HsBI8Sdu9XuUBPSxX2DXYXKqgDhmWaOoJu6r1ctYx1E35ENAtGFfwGYNhw2s0H/P5yZwbM1s1QsKGFvXNaUcaN4FYIdhxnZIoVnARhhHfyIGzJrY00NSMFOx0d+JELhy3CM0u3pu6XkzsnY0qzW+adk8ldaXTJvL8rrR75YYXqeN7RLJAFHbex8Lp40dHAyuZrmQaK3114/A4E//QZ/QzwfnitRmmtNfg+57ksMs0XbAfW14RbHjyN+oBzzLM4LRJ2jN/hrwSM5i9RBrV+8Pbj4JoMrdfPLYwDeLyCBlLr0rB1sVuqb/g4ARuIstlh3waL4z8uhiPY2PfIci7AObuvsFiAFNRExp/7/4FngV/MMnNIGiODy6th6QKsIpLxe7v5L5cmy3y76B/FNI0LWOXBOyahvytunrtgWyfMdsjiLCn0U6ATVoIwYJNC4VKkcH+VU5g91C/s66Dm8QlENcqwa8WrRwyyuzE3v1mlT4XAIPiNknzXWJ4NLlJ0g8v1utrKXaiu7bzu3HvU62rDZqO6Fhy75g6V5QT3W1EsSFVZ7vLVCuvz4Arwn1tTTyAUGLq+uNyXU22UGq3K7v10FNjLnn/0n79ebMl794rRYFOd250o/mktWXuvHw3b7R3TBB5/dts0Gcn9WO5QKpq766NX47kz8BOIDpu3Xz/1R/syfQPmeVRvw+zL9W3p37o9qtGqtUPqwHlZujett/kuYcHdje9Gcq+9kgOp070O+n9H9+Zkd2basH/5oWRak2NNQEeNWgZa1OiIOKzwW1Lx6PLeNNXmwXSqd6MBCu7d9QCkzgK4fNWeB/g/oeW1gmM0n97x2iio+fSG10Ixmvv1u668b213aza1KL2J8uLNbm27zXEtaQwh24nmTnZvplucOtntnVfhuYX+Oav7zbB/fjF8gfW9DVRb4Q/WmaiOVgSzr7/mFJ4fqWNoOWPXWOvFcrIP6bewYCvvGwxrUb8T68XZ3/Cg/XpE4SPxjm9HnOx+BVDh1N+R2Duv1/ArEz+hECpb7oD58/Oa/gbM87p+G+Z6z7a/nQbdFdCmWOulRSfWy1ZA2wNXARvvw9vnGZQkPMYzJTzWE1N71mQcyRgeRLr54Hvuxw6cOs62vG3s7irBRwj+dlQdGa4PCQ23q9ZowCuCb9qw3YopPCWEDP1YX36x72y1EOlXro20f1e+C26cTZUaUZEnVLNglyOHcIvWLi9ut+k+9loK9Q62+frww/6Dqg8/Nz2o+uBW+0HNxzYw25V3aP5hfSExEK54FhQauiGmlivLT3w1/JWteuWJVEaUkJolwWahRVDICxWEVfM2R/uB/0adkjcKT/CCNZKZgfvJTY39eDrDrL5aqcj+JCDCH7iwwL/NLobDq+EpEPk2q52nKS0DAAwrNSgLjUd/ngfFOB7jkd14TM7OiD8e4xzHY9OL7HS9/wFaLmEH
|
4
third_party/gpus/rocm/rocm_config.h.tpl
vendored
4
third_party/gpus/rocm/rocm_config.h.tpl
vendored
@ -18,4 +18,8 @@ limitations under the License.
|
||||
|
||||
#define TF_ROCM_TOOLKIT_PATH "%{rocm_toolkit_path}"
|
||||
|
||||
#define TF_ROCM_VERSION %{rocm_version_number}
|
||||
#define TF_MIOPEN_VERSION %{miopen_version_number}
|
||||
#define TF_HIPRUNTIME_VERSION %{hipruntime_version_number}
|
||||
|
||||
#endif // ROCM_ROCM_CONFIG_H_
|
||||
|
83
third_party/gpus/rocm_configure.bzl
vendored
83
third_party/gpus/rocm_configure.bzl
vendored
@ -4,11 +4,7 @@
|
||||
|
||||
* `TF_NEED_ROCM`: Whether to enable building with ROCm.
|
||||
* `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
|
||||
* `ROCM_TOOLKIT_PATH`: The path to the ROCm toolkit. Default is
|
||||
`/opt/rocm`.
|
||||
* `TF_ROCM_VERSION`: The version of the ROCm toolkit. If this is blank, then
|
||||
use the system default.
|
||||
* `TF_MIOPEN_VERSION`: The version of the MIOpen library.
|
||||
* `ROCM_PATH`: The path to the ROCm toolkit. Default is `/opt/rocm`.
|
||||
* `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets.
|
||||
"""
|
||||
|
||||
@ -27,6 +23,7 @@ load(
|
||||
"get_bash_bin",
|
||||
"get_cpu_value",
|
||||
"get_host_environ",
|
||||
"get_python_bin",
|
||||
"raw_exec",
|
||||
"realpath",
|
||||
"which",
|
||||
@ -35,13 +32,9 @@ load(
|
||||
_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
|
||||
_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
|
||||
_ROCM_TOOLKIT_PATH = "ROCM_PATH"
|
||||
_TF_ROCM_VERSION = "TF_ROCM_VERSION"
|
||||
_TF_MIOPEN_VERSION = "TF_MIOPEN_VERSION"
|
||||
_TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS"
|
||||
_TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO"
|
||||
|
||||
_DEFAULT_ROCM_VERSION = ""
|
||||
_DEFAULT_MIOPEN_VERSION = ""
|
||||
_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm"
|
||||
|
||||
def verify_build_defines(params):
|
||||
@ -212,20 +205,6 @@ def _enable_rocm(repository_ctx):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _rocm_toolkit_path(repository_ctx, bash_bin):
|
||||
"""Finds the rocm toolkit directory.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
|
||||
Returns:
|
||||
A speculative real path of the rocm toolkit install directory.
|
||||
"""
|
||||
rocm_toolkit_path = get_host_environ(repository_ctx, _ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH)
|
||||
if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]:
|
||||
auto_configure_fail("Cannot find rocm toolkit path.")
|
||||
return rocm_toolkit_path
|
||||
|
||||
def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin):
|
||||
"""Returns a list of strings representing AMDGPU targets."""
|
||||
amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
|
||||
@ -402,7 +381,40 @@ def _find_libs(repository_ctx, rocm_config, bash_bin):
|
||||
|
||||
return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin)
|
||||
|
||||
def _get_rocm_config(repository_ctx, bash_bin):
|
||||
def _exec_find_rocm_config(repository_ctx, script_path):
|
||||
python_bin = get_python_bin(repository_ctx)
|
||||
|
||||
# If used with remote execution then repository_ctx.execute() can't
|
||||
# access files from the source tree. A trick is to read the contents
|
||||
# of the file in Starlark and embed them as part of the command. In
|
||||
# this case the trick is not sufficient as the find_cuda_config.py
|
||||
# script has more than 8192 characters. 8192 is the command length
|
||||
# limit of cmd.exe on Windows. Thus we additionally need to compress
|
||||
# the contents locally and decompress them as part of the execute().
|
||||
compressed_contents = repository_ctx.read(script_path)
|
||||
decompress_and_execute_cmd = (
|
||||
"from zlib import decompress;" +
|
||||
"from base64 import b64decode;" +
|
||||
"from os import system;" +
|
||||
"script = decompress(b64decode('%s'));" % compressed_contents +
|
||||
"f = open('script.py', 'wb');" +
|
||||
"f.write(script);" +
|
||||
"f.close();" +
|
||||
"system('\"%s\" script.py');" % (python_bin)
|
||||
)
|
||||
|
||||
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
|
||||
|
||||
def find_rocm_config(repository_ctx, script_path):
|
||||
"""Returns ROCm config dictionary from running find_rocm_config.py"""
|
||||
exec_result = _exec_find_rocm_config(repository_ctx, script_path)
|
||||
if exec_result.return_code:
|
||||
auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result))
|
||||
|
||||
# Parse the dict from stdout.
|
||||
return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
|
||||
|
||||
def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script):
|
||||
"""Detects and returns information about the ROCm installation on the system.
|
||||
|
||||
Args:
|
||||
@ -413,11 +425,21 @@ def _get_rocm_config(repository_ctx, bash_bin):
|
||||
A struct containing the following fields:
|
||||
rocm_toolkit_path: The ROCm toolkit installation directory.
|
||||
amdgpu_targets: A list of the system's AMDGPU targets.
|
||||
rocm_version_number: The version of ROCm on the system.
|
||||
miopen_version_number: The version of MIOpen on the system.
|
||||
hipruntime_version_number: The version of HIP Runtime on the system.
|
||||
"""
|
||||
rocm_toolkit_path = _rocm_toolkit_path(repository_ctx, bash_bin)
|
||||
config = find_rocm_config(repository_ctx, find_rocm_config_script)
|
||||
rocm_toolkit_path = config["rocm_toolkit_path"]
|
||||
rocm_version_number = config["rocm_version_number"]
|
||||
miopen_version_number = config["miopen_version_number"]
|
||||
hipruntime_version_number = config["hipruntime_version_number"]
|
||||
return struct(
|
||||
rocm_toolkit_path = rocm_toolkit_path,
|
||||
amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin),
|
||||
rocm_toolkit_path = rocm_toolkit_path,
|
||||
rocm_version_number = rocm_version_number,
|
||||
miopen_version_number = miopen_version_number,
|
||||
hipruntime_version_number = hipruntime_version_number,
|
||||
)
|
||||
|
||||
def _tpl_path(repository_ctx, labelname):
|
||||
@ -550,8 +572,10 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
"rocm:rocm_config.h",
|
||||
]}
|
||||
|
||||
find_rocm_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_rocm_config.py.gz.base64"))
|
||||
|
||||
bash_bin = get_bash_bin(repository_ctx)
|
||||
rocm_config = _get_rocm_config(repository_ctx, bash_bin)
|
||||
rocm_config = _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script)
|
||||
|
||||
# Copy header and library files to execroot.
|
||||
# rocm_toolkit_path
|
||||
@ -749,6 +773,9 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
["\"%s\"" % c for c in rocm_config.amdgpu_targets],
|
||||
),
|
||||
"%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path,
|
||||
"%{rocm_version_number}": rocm_config.rocm_version_number,
|
||||
"%{miopen_version_number}": rocm_config.miopen_version_number,
|
||||
"%{hipruntime_version_number}": rocm_config.hipruntime_version_number,
|
||||
},
|
||||
)
|
||||
|
||||
@ -813,8 +840,6 @@ _ENVIRONS = [
|
||||
_GCC_HOST_COMPILER_PREFIX,
|
||||
"TF_NEED_ROCM",
|
||||
_ROCM_TOOLKIT_PATH,
|
||||
_TF_ROCM_VERSION,
|
||||
_TF_MIOPEN_VERSION,
|
||||
_TF_ROCM_AMDGPU_TARGETS,
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user