diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 51583117706..452a7537c7a 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -67,6 +67,10 @@ limitations under the License. #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/util/env_var.h" +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif + namespace xla { namespace gpu { namespace { @@ -560,11 +564,18 @@ namespace { static std::vector GetROCDLPaths(int amdgpu_version, const string& rocdl_dir_path) { // AMDGPU version-neutral bitcodes. +#if TF_ROCM_VERSION >= 30900 + static std::vector* rocdl_filenames = new std::vector( + {"hc.bc", "opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc", + "oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc", + "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc"}); +#else static std::vector* rocdl_filenames = new std::vector( {"hc.amdgcn.bc", "opencl.amdgcn.bc", "ocml.amdgcn.bc", "ockl.amdgcn.bc", "oclc_finite_only_off.amdgcn.bc", "oclc_daz_opt_off.amdgcn.bc", "oclc_correctly_rounded_sqrt_on.amdgcn.bc", "oclc_unsafe_math_off.amdgcn.bc", "oclc_wavefrontsize64_on.amdgcn.bc"}); +#endif // Construct full path to ROCDL bitcode libraries. std::vector result; @@ -575,7 +586,11 @@ static std::vector GetROCDLPaths(int amdgpu_version, // Add AMDGPU version-specific bitcodes. result.push_back(tensorflow::io::JoinPath( rocdl_dir_path, +#if TF_ROCM_VERSION >= 30900 + absl::StrCat("oclc_isa_version_", amdgpu_version, ".bc"))); +#else absl::StrCat("oclc_isa_version_", amdgpu_version, ".amdgcn.bc"))); +#endif return result; } diff --git a/tensorflow/core/kernels/conv_2d_gpu.h b/tensorflow/core/kernels/conv_2d_gpu.h index 1ed88ca753c..67126f31e27 100644 --- a/tensorflow/core/kernels/conv_2d_gpu.h +++ b/tensorflow/core/kernels/conv_2d_gpu.h @@ -287,7 +287,7 @@ __global__ void SwapDimension1And2InTensor3UsingTiles( // One extra line in the inner dimension to avoid share memory bank conflict. // This is to mimic the following, but no constructor of T can be invoked. // __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; -#if GOOGLE_CUDA // || TENSORFLOW_COMPILER_IS_HIP_CLANG +#if GOOGLE_CUDA __shared__ __align__( alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)]; typedef T(*SharedMemoryTile)[TileSizeJ + 1]; diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index fc439a08df1..8a208610b07 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -387,7 +387,7 @@ __global__ __launch_bounds__(1024) void ColumnReduceKernel( // - = // = const int numRowsThisBlock = - min(blockDim.y, num_rows - blockIdx.y * blockDim.y); + min(int(blockDim.y), num_rows - blockIdx.y * blockDim.y); for (int row = 1; row < numRowsThisBlock; ++row) { value_type t = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + row]; diff --git a/tensorflow/core/kernels/scan_ops_gpu.h b/tensorflow/core/kernels/scan_ops_gpu.h index f99f8af3190..7914b7a1103 100644 --- a/tensorflow/core/kernels/scan_ops_gpu.h +++ b/tensorflow/core/kernels/scan_ops_gpu.h @@ -248,10 +248,8 @@ void LaunchScan(const GPUDevice& d, typename TTypes::ConstTensor in, int num_blocks = dimx * dimz; int ideal_block_size = dimy / items_per_thread; -#if TENSORFLOW_COMPILER_IS_HIP_CLANG const int rocm_threads_per_warp = 64; ideal_block_size = std::max(ideal_block_size, rocm_threads_per_warp); -#endif // There seems to be a bug when the type is not float and block_size 1024. // Launch on the smallest power of 2 block size that we can. diff --git a/tensorflow/core/platform/default/rocm_rocdl_path.cc b/tensorflow/core/platform/default/rocm_rocdl_path.cc index 9e9261d26c8..7e43286897c 100644 --- a/tensorflow/core/platform/default/rocm_rocdl_path.cc +++ b/tensorflow/core/platform/default/rocm_rocdl_path.cc @@ -36,10 +36,10 @@ string RocmRoot() { } string RocdlRoot() { -#if TENSORFLOW_COMPILER_IS_HIP_CLANG - return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "lib"); +#if TF_ROCM_VERSION >= 30900 + return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "amdgcn/bitcode"); #else - return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "hcc/lib"); + return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "lib"); #endif } diff --git a/tensorflow/core/platform/rocm_rocdl_path_test.cc b/tensorflow/core/platform/rocm_rocdl_path_test.cc index 166e99bb509..ae42ab4e6c6 100644 --- a/tensorflow/core/platform/rocm_rocdl_path_test.cc +++ b/tensorflow/core/platform/rocm_rocdl_path_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/platform/rocm_rocdl_path.h" +#include "rocm/rocm_config.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" @@ -27,7 +28,11 @@ TEST(RocmRocdlPathTest, ROCDLPath) { VLOG(2) << "ROCm-Device-Libs root = " << RocdlRoot(); std::vector rocdl_files; TF_EXPECT_OK(Env::Default()->GetMatchingPaths( +#if TF_ROCM_VERSION >= 30900 + io::JoinPath(RocdlRoot(), "*.bc"), &rocdl_files)); +#else io::JoinPath(RocdlRoot(), "*.amdgcn.bc"), &rocdl_files)); +#endif EXPECT_LT(0, rocdl_files.size()); } #endif diff --git a/tensorflow/core/util/gpu_launch_config.h b/tensorflow/core/util/gpu_launch_config.h index 4c2df39e1a2..0b943e917da 100644 --- a/tensorflow/core/util/gpu_launch_config.h +++ b/tensorflow/core/util/gpu_launch_config.h @@ -168,25 +168,10 @@ GpuLaunchConfig GetGpuLaunchConfig(int work_element_count, block_size_limit); CHECK_EQ(err, cudaSuccess); #elif TENSORFLOW_USE_ROCM -#if TENSORFLOW_COMPILER_IS_HIP_CLANG hipError_t err = hipOccupancyMaxPotentialBlockSize( &block_count, &thread_per_block, func, dynamic_shared_memory_size, block_size_limit); CHECK_EQ(err, hipSuccess); -#else - // Earlier versions of this HIP routine incorrectly returned void. - // TODO re-enable hipError_t error checking when HIP is fixed. - // ROCm interface uses unsigned int, convert after checking - uint32_t block_count_uint = 0; - uint32_t thread_per_block_uint = 0; - CHECK_GE(block_size_limit, 0); - uint32_t block_size_limit_uint = static_cast(block_size_limit); - hipOccupancyMaxPotentialBlockSize(&block_count_uint, &thread_per_block_uint, - func, dynamic_shared_memory_size, - block_size_limit_uint); - block_count = static_cast(block_count_uint); - thread_per_block = static_cast(thread_per_block_uint); -#endif #endif block_count = @@ -216,22 +201,9 @@ GpuLaunchConfig GetGpuLaunchConfigFixedBlockSize( &block_count, func, fixed_block_size, dynamic_shared_memory_size); CHECK_EQ(err, cudaSuccess); #elif TENSORFLOW_USE_ROCM -#if TENSORFLOW_COMPILER_IS_HIP_CLANG hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor( &block_count, func, fixed_block_size, dynamic_shared_memory_size); CHECK_EQ(err, hipSuccess); -#else - // Apply the heuristic in GetGpuLaunchConfig(int, const Eigen::GpuDevice&) - // that the kernel is quite simple and will largely be memory-limited. - const int physical_thread_count = std::min( - d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(), - work_element_count); - // Assume the kernel be simple enough that it is okay to use 1024 threads - // per workgroup. - int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock()); - block_count = std::min(DivUp(physical_thread_count, thread_per_block), - d.getNumGpuMultiProcessors()); -#endif #endif block_count = std::min(block_count * d.getNumGpuMultiProcessors(), DivUp(work_element_count, fixed_block_size)); diff --git a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc index 2a85cb820ed..dbab0304d82 100644 --- a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc +++ b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc @@ -856,6 +856,11 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { float clock_rate_ghz = static_cast(prop.clockRate) / 1e6; builder.set_clock_rate_ghz(clock_rate_ghz); + + // mem_bandwidth = 2 * mem_bus_width_in_bytes * mem_clock_rate_in_hz + int64 memory_bandwidth = 2 * (int64(prop.memoryBusWidth) / 8) * + (int64(prop.memoryClockRate) * 1000); + builder.set_memory_bandwidth(memory_bandwidth); } { diff --git a/tensorflow/tools/ci_build/Dockerfile.rocm b/tensorflow/tools/ci_build/Dockerfile.rocm index a72915504be..89293c54e4a 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rocm +++ b/tensorflow/tools/ci_build/Dockerfile.rocm @@ -3,10 +3,10 @@ FROM ubuntu:bionic MAINTAINER Jeff Poznanovic -ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/3.7/ +ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/3.9/ ARG ROCM_BUILD_NAME=xenial ARG ROCM_BUILD_NUM=main -ARG ROCM_PATH=/opt/rocm-3.7.0 +ARG ROCM_PATH=/opt/rocm-3.9.0 ENV DEBIAN_FRONTEND noninteractive ENV TF_NEED_ROCM 1 diff --git a/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh b/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh index 92d21cb133b..44f60b53070 100755 --- a/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh +++ b/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh @@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} echo "" # First positional argument (if any) specifies the ROCM_INSTALL_DIR -ROCM_INSTALL_DIR=/opt/rocm-3.7.0 +ROCM_INSTALL_DIR=/opt/rocm-3.9.0 if [[ -n $1 ]]; then ROCM_INSTALL_DIR=$1 fi diff --git a/tensorflow/tools/ci_build/linux/rocm/run_csb_tests.sh b/tensorflow/tools/ci_build/linux/rocm/run_csb_tests.sh index 80c0686e647..f6ed1bef84f 100755 --- a/tensorflow/tools/ci_build/linux/rocm/run_csb_tests.sh +++ b/tensorflow/tools/ci_build/linux/rocm/run_csb_tests.sh @@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} echo "" # First positional argument (if any) specifies the ROCM_INSTALL_DIR -ROCM_INSTALL_DIR=/opt/rocm-3.7.0 +ROCM_INSTALL_DIR=/opt/rocm-3.9.0 if [[ -n $1 ]]; then ROCM_INSTALL_DIR=$1 fi diff --git a/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh b/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh index 3a09081dd6a..586ec1520ad 100755 --- a/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh +++ b/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh @@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} echo "" # First positional argument (if any) specifies the ROCM_INSTALL_DIR -ROCM_INSTALL_DIR=/opt/rocm-3.7.0 +ROCM_INSTALL_DIR=/opt/rocm-3.9.0 if [[ -n $1 ]]; then ROCM_INSTALL_DIR=$1 fi diff --git a/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh b/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh index d623b77d533..dc9a8b50ee1 100755 --- a/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh +++ b/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh @@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} echo "" # First positional argument (if any) specifies the ROCM_INSTALL_DIR -ROCM_INSTALL_DIR=/opt/rocm-3.7.0 +ROCM_INSTALL_DIR=/opt/rocm-3.9.0 if [[ -n $1 ]]; then ROCM_INSTALL_DIR=$1 fi diff --git a/third_party/gpus/compress_find_rocm_config.py b/third_party/gpus/compress_find_rocm_config.py new file mode 100644 index 00000000000..90615d4b1ea --- /dev/null +++ b/third_party/gpus/compress_find_rocm_config.py @@ -0,0 +1,36 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Compresses the contents of 'find_rocm_config.py'. + +The compressed file is what is actually being used. It works around remote +config not being able to upload files yet. +""" +import base64 +import zlib + + +def main(): + with open('find_rocm_config.py', 'rb') as f: + data = f.read() + + compressed = zlib.compress(data) + b64encoded = base64.b64encode(compressed) + + with open('find_rocm_config.py.gz.base64', 'wb') as f: + f.write(b64encoded) + + +if __name__ == '__main__': + main() diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index d5bfe78c644..161bc7c8df4 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -26,12 +26,9 @@ import pipes # Template values set by rocm_configure.bzl. CPU_COMPILER = ('%{cpu_compiler}') -GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}') HIPCC_PATH = '%{hipcc_path}' -PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH) HIPCC_ENV = '%{hipcc_env}' -HIPCC_IS_HIPCLANG = '%{hipcc_is_hipclang}'=="True" HIP_RUNTIME_PATH = '%{hip_runtime_path}' HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}' ROCR_RUNTIME_PATH = '%{rocr_runtime_path}' @@ -98,27 +95,6 @@ def GetHostCompilerOptions(argv): return opts -def GetHipccOptions(argv): - """Collect the -hipcc_options values from argv. - - Args: - argv: A list of strings, possibly the argv passed to main(). - - Returns: - The string that can be passed directly to hipcc. - """ - - parser = ArgumentParser() - parser.add_argument('-hipcc_options', nargs='*', action='append') - - args, _ = parser.parse_known_args(argv) - - if args.hipcc_options: - options = _update_options(sum(args.hipcc_options, [])) - return ' '.join(['--'+a for a in options]) - return '' - - def system(cmd): """Invokes cmd with os.system(). @@ -148,7 +124,6 @@ def InvokeHipcc(argv, log=False): """ host_compiler_options = GetHostCompilerOptions(argv) - hipcc_compiler_options = GetHipccOptions(argv) opt_option = GetOptionValue(argv, 'O') m_options = GetOptionValue(argv, 'm') m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']]) @@ -193,14 +168,13 @@ def InvokeHipcc(argv, log=False): # Otherwise, we get build error. # Also we need to retain warning about uninitialised shared variable as # warning only, even when -Werror option is specified. - if HIPCC_IS_HIPCLANG: - hipccopts += ' --include=hip/hip_runtime.h ' - hipccopts += ' ' + hipcc_compiler_options + hipccopts += ' --include=hip/hip_runtime.h ' # Use -fno-gpu-rdc by default for early GPU kernel finalization # This flag would trigger GPU kernels be generated at compile time, instead # of link time. This allows the default host compiler (gcc) be used as the # linker for TensorFlow on ROCm platform. hipccopts += ' -fno-gpu-rdc ' + hipccopts += ' -fcuda-flush-denormals-to-zero ' hipccopts += undefines hipccopts += defines hipccopts += std_options @@ -211,22 +185,19 @@ def InvokeHipcc(argv, log=False): depfile = depfiles[0] cmd = (HIPCC_PATH + ' ' + hipccopts + host_compiler_options + - ' ' + GCC_HOST_COMPILER_PATH + ' -I .' + includes + ' ' + srcs + ' -M -o ' + depfile) + cmd = HIPCC_ENV.replace(';', ' ') + ' ' + cmd if log: Log(cmd) + if VERBOSE: print(cmd) exit_status = os.system(cmd) if exit_status != 0: return exit_status cmd = (HIPCC_PATH + ' ' + hipccopts + host_compiler_options + ' -fPIC' + - ' ' + GCC_HOST_COMPILER_PATH + ' -I .' + opt + includes + ' -c ' + srcs + out) - # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'. - # Need to investigate and fix. - cmd = 'PATH=' + PREFIX_DIR + ':$PATH '\ - + HIPCC_ENV.replace(';', ' ') + ' '\ + cmd = HIPCC_ENV.replace(';', ' ') + ' '\ + cmd if log: Log(cmd) if VERBOSE: print(cmd) @@ -268,8 +239,7 @@ def main(): gpu_linker_flags.append('-L' + HIP_RUNTIME_PATH) gpu_linker_flags.append('-Wl,-rpath=' + HIP_RUNTIME_PATH) gpu_linker_flags.append('-l' + HIP_RUNTIME_LIBRARY) - if HIPCC_IS_HIPCLANG: - gpu_linker_flags.append("-lrt") + gpu_linker_flags.append("-lrt") if VERBOSE: print(' '.join([CPU_COMPILER] + gpu_linker_flags)) return subprocess.call([CPU_COMPILER] + gpu_linker_flags) diff --git a/third_party/gpus/find_rocm_config.py b/third_party/gpus/find_rocm_config.py new file mode 100644 index 00000000000..c1eb119612b --- /dev/null +++ b/third_party/gpus/find_rocm_config.py @@ -0,0 +1,286 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Prints ROCm library and header directories and versions found on the system. + +The script searches for ROCm library and header files on the system, inspects +them to determine their version and prints the configuration to stdout. +The path to inspect is specified through an environment variable (ROCM_PATH). +If no valid configuration is found, the script prints to stderr and +returns an error code. + +The script takes the directory specified by the ROCM_PATH environment variable. +The script looks for headers and library files in a hard-coded set of +subdirectories from base path of the specified directory. If ROCM_PATH is not +specified, then "/opt/rocm" is used as it default value + +""" + +import io +import os +import re +import sys + + +class ConfigError(Exception): + pass + + +def _get_default_rocm_path(): + return "/opt/rocm" + + +def _get_rocm_install_path(): + """Determines and returns the ROCm installation path.""" + rocm_install_path = _get_default_rocm_path() + if "ROCM_PATH" in os.environ: + rocm_install_path = os.environ["ROCM_PATH"] + # rocm_install_path = os.path.realpath(rocm_install_path) + return rocm_install_path + + +def _get_composite_version_number(major, minor, patch): + return 10000 * major + 100 * minor + patch + + +def _get_header_version(path, name): + """Returns preprocessor defines in C header file.""" + for line in io.open(path, "r", encoding="utf-8"): + match = re.match(r"#define %s +(\d+)" % name, line) + if match: + value = match.group(1) + return int(value) + + raise ConfigError('#define "{}" is either\n'.format(name) + + " not present in file {} OR\n".format(path) + + " its value is not an integer literal") + + +def _find_rocm_config(rocm_install_path): + + def rocm_version_numbers(path): + version_file = os.path.join(path, ".info/version-dev") + if not os.path.exists(version_file): + raise ConfigError('ROCm version file "{}" not found'.format(version_file)) + version_numbers = [] + with open(version_file) as f: + version_string = f.read().strip() + version_numbers = version_string.split(".") + major = int(version_numbers[0]) + minor = int(version_numbers[1]) + patch = int(version_numbers[2].split("-")[0]) + return major, minor, patch + + major, minor, patch = rocm_version_numbers(rocm_install_path) + + rocm_config = { + "rocm_version_number": _get_composite_version_number(major, minor, patch) + } + + return rocm_config + + +def _find_hipruntime_config(rocm_install_path): + + def hipruntime_version_number(path): + version_file = os.path.join(path, "hip/include/hip/hip_version.h") + if not os.path.exists(version_file): + raise ConfigError( + 'HIP Runtime version file "{}" not found'.format(version_file)) + # This header file has an explicit #define for HIP_VERSION, whose value + # is (HIP_VERSION_MAJOR * 100 + HIP_VERSION_MINOR) + # Retreive the major + minor and re-calculate here, since we do not + # want get into the business of parsing arith exprs + major = _get_header_version(version_file, "HIP_VERSION_MAJOR") + minor = _get_header_version(version_file, "HIP_VERSION_MINOR") + return 100 * major + minor + + hipruntime_config = { + "hipruntime_version_number": hipruntime_version_number(rocm_install_path) + } + + return hipruntime_config + + +def _find_miopen_config(rocm_install_path): + + def miopen_version_numbers(path): + version_file = os.path.join(path, "miopen/include/miopen/version.h") + if not os.path.exists(version_file): + raise ConfigError( + 'MIOpen version file "{}" not found'.format(version_file)) + major = _get_header_version(version_file, "MIOPEN_VERSION_MAJOR") + minor = _get_header_version(version_file, "MIOPEN_VERSION_MINOR") + patch = _get_header_version(version_file, "MIOPEN_VERSION_PATCH") + return major, minor, patch + + major, minor, patch = miopen_version_numbers(rocm_install_path) + + miopen_config = { + "miopen_version_number": + _get_composite_version_number(major, minor, patch) + } + + return miopen_config + + +def _find_rocblas_config(rocm_install_path): + + def rocblas_version_numbers(path): + possible_version_files = [ + "rocblas/include/rocblas-version.h", # ROCm 3.7 and prior + "rocblas/include/internal/rocblas-version.h", # ROCm 3.8 + ] + version_file = None + for f in possible_version_files: + version_file_path = os.path.join(path, f) + if os.path.exists(version_file_path): + version_file = version_file_path + break + if not version_file: + raise ConfigError( + "rocblas version file not found in {}".format( + possible_version_files)) + major = _get_header_version(version_file, "ROCBLAS_VERSION_MAJOR") + minor = _get_header_version(version_file, "ROCBLAS_VERSION_MINOR") + patch = _get_header_version(version_file, "ROCBLAS_VERSION_PATCH") + return major, minor, patch + + major, minor, patch = rocblas_version_numbers(rocm_install_path) + + rocblas_config = { + "rocblas_version_number": + _get_composite_version_number(major, minor, patch) + } + + return rocblas_config + + +def _find_rocrand_config(rocm_install_path): + + def rocrand_version_number(path): + version_file = os.path.join(path, "rocrand/include/rocrand_version.h") + if not os.path.exists(version_file): + raise ConfigError( + 'rocblas version file "{}" not found'.format(version_file)) + version_number = _get_header_version(version_file, "ROCRAND_VERSION") + return version_number + + rocrand_config = { + "rocrand_version_number": rocrand_version_number(rocm_install_path) + } + + return rocrand_config + + +def _find_rocfft_config(rocm_install_path): + + def rocfft_version_numbers(path): + version_file = os.path.join(path, "rocfft/include/rocfft-version.h") + if not os.path.exists(version_file): + raise ConfigError( + 'rocfft version file "{}" not found'.format(version_file)) + major = _get_header_version(version_file, "rocfft_version_major") + minor = _get_header_version(version_file, "rocfft_version_minor") + patch = _get_header_version(version_file, "rocfft_version_patch") + return major, minor, patch + + major, minor, patch = rocfft_version_numbers(rocm_install_path) + + rocfft_config = { + "rocfft_version_number": + _get_composite_version_number(major, minor, patch) + } + + return rocfft_config + + +def _find_roctracer_config(rocm_install_path): + + def roctracer_version_numbers(path): + version_file = os.path.join(path, "roctracer/include/roctracer.h") + if not os.path.exists(version_file): + raise ConfigError( + 'roctracer version file "{}" not found'.format(version_file)) + major = _get_header_version(version_file, "ROCTRACER_VERSION_MAJOR") + minor = _get_header_version(version_file, "ROCTRACER_VERSION_MINOR") + # roctracer header does not have a patch version number + patch = 0 + return major, minor, patch + + major, minor, patch = roctracer_version_numbers(rocm_install_path) + + roctracer_config = { + "roctracer_version_number": + _get_composite_version_number(major, minor, patch) + } + + return roctracer_config + + +def _find_hipsparse_config(rocm_install_path): + + def hipsparse_version_numbers(path): + version_file = os.path.join(path, "hipsparse/include/hipsparse-version.h") + if not os.path.exists(version_file): + raise ConfigError( + 'hipsparse version file "{}" not found'.format(version_file)) + major = _get_header_version(version_file, "hipsparseVersionMajor") + minor = _get_header_version(version_file, "hipsparseVersionMinor") + patch = _get_header_version(version_file, "hipsparseVersionPatch") + return major, minor, patch + + major, minor, patch = hipsparse_version_numbers(rocm_install_path) + + hipsparse_config = { + "hipsparse_version_number": + _get_composite_version_number(major, minor, patch) + } + + return hipsparse_config + + +def find_rocm_config(): + """Returns a dictionary of ROCm components config info.""" + rocm_install_path = _get_rocm_install_path() + if not os.path.exists(rocm_install_path): + raise ConfigError( + 'Specified ROCM_PATH "{}" does not exist'.format(rocm_install_path)) + + result = {} + + result["rocm_toolkit_path"] = rocm_install_path + result.update(_find_rocm_config(rocm_install_path)) + result.update(_find_hipruntime_config(rocm_install_path)) + result.update(_find_miopen_config(rocm_install_path)) + result.update(_find_rocblas_config(rocm_install_path)) + result.update(_find_rocrand_config(rocm_install_path)) + result.update(_find_rocfft_config(rocm_install_path)) + result.update(_find_roctracer_config(rocm_install_path)) + result.update(_find_hipsparse_config(rocm_install_path)) + + return result + + +def main(): + try: + for key, value in sorted(find_rocm_config().items()): + print("%s: %s" % (key, value)) + except ConfigError as e: + sys.stderr.write("\nERROR: {}\n\n".format(str(e))) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/third_party/gpus/find_rocm_config.py.gz.base64 b/third_party/gpus/find_rocm_config.py.gz.base64 new file mode 100644 index 00000000000..f38b64fda87 --- /dev/null +++ b/third_party/gpus/find_rocm_config.py.gz.base64 @@ -0,0 +1 @@ +eJy9WtFu27gSfddXEAqKyhtHSXsfdpFFHrxpFvXe1gns7C4WTWDQMm1zK4u6JJXUKPrvd4akZEqWEydOa6CoZQ0PhzNnDkdiDsi5yFeSzxeavD15e0KuF4xcs0wJ+Xsq7kmv0AshVUx6aUqGaKbIkCkm79g0Dg6CA/KBJ2DOpqTIpkwSDeN7OU3gP3enS/5iUnGRkbfxCYnQIHS3ws6vgLASBVnSFcmEJoViAMEVmfGUEfYlYbkmPCOJWOYpp1nCyD3XCzONAwE3yD8OQkw0BWsK9jlczXw7QrVxGD8LrfPT4+P7+/uYGmdjIefHqTVUxx/65xeD0cUROGyG/JmlTCki2f8KLmGpkxWhOfiT0Al4mdJ7IiShc8ngnhbo773kmmfzLlFipu+pZIAy5UpLPil0LVild7Bm3wDCRTMS9kakPwrJb71Rf9QFjL/71+8v/7wmf/eGw97gun8xIpdDcn45eNe/7l8O4Op30hv8Q/7bH7zrEgahgmnYl1yi/+AkxzCa1JERYzUHZsI6pHKW8BlPYF3ZvKBzRubijskMlkNyJpdcYTIVuDcFlJQvuaba/LKxKJzm7EU/QRiGV5JnSMPL8yVMP5FUrtAZsmAU559CihItJGfGR3Jn2QeUEuAgBtascqU0W8ZBgIRXieTAM8WoBC4oE4pt8EhMVUfpQsYxaloF8OMSKTBlGkOVmRBzWTphgHLrP45PRDbj80KaAOI4paei0LHxKqdIdFGCI0NcbpBmCymK+QJJwrI7LkW2ZJkmd1RyQ8oI/P84vupdv+/EQX8GxQX3Uj5tTMldWLp2OTYOpYPGHSalSbVkupAm7QR+ggAlYsrq8dP0M7PrKnOw8jyGosFblV+tfsc+XirEZ5sMG3ubzzInNhGm2hdUTo/QnynkUEPdB6qY+DyYSbEkE6pcUJ0wrH2r/I0JxGrtIoQHVCmoDE2YoCyPRa6PpUiWIZoUKH8UfNGQ9xktUlxPWrAA2RoEUHNCQvpE+U2o8hvogvsGTAqCIEkp1Om5SdEFRjm6MBIIqeqcBgS8V2gGs5DxnOmxm26MroxxaZExs7ny3fQHGWMglaZp6g0CX9+VrLWRLlPu0rYkbpClDo6McYEwXxORnG31D8z5jIRVjENMoVCxYwN60g64tvnkjb4F+4Nt9sZFyWhqpt4w6qxDtXHPDxhuPkJxzcaujMdZsZwwGS3pv0J2CUQM/4NhycKP/5sT+JCfiDEjh3iNV2gNV8bcn8aSvJwjQi+6JKNLVqZn6PIBap6DwyDpAATDTcIgiue+RrnUYPWkqENwn4tY5KxEDmUIO0QGZQPKfhYWenb0S9ix8V+ibxBDyWLzNZLhgZ2IvFLkMLqZHnZC8sp41zX4HTMOMmvsLQqxZQA45sd4DpKVR2867qaLEmhNZOw6AYaOcqhSvwJelzOHX7+ZcrO72k32OobFAXJkYkQOHWz9ExLTV+AGiDoDUTCdxddvsG/eZGEJYeiwHYKDHNq1WEVAFQS/2ZxhdKFoaBp2ylyCs1PLeKu1LcQ7xaWisblVZ5WKnIkJoLtlnF5z+l/BqzTGPJuJY2d4NGV3YZUL9LQcwr5Ad6EiH7BTpqkl6Kbey13LzG7Cj4hmw6iCXwPs1Jx26wG/P92aG6Z1MxSsjULtnFWUcXewFYKO44zMsIKnUSfGX/Ko07BbT1IfGSvodHQUxi4ctgjPLN3qYz+d3DobU5rtNm+cTe5Ko83m7W0561HYqVAdz1vEAlnQ8jMWXhsvWgSsFF/LNBj41YUnbEEIT5+hZ4D3LWgIpZ2txvcFz2WRab5kO7DeM2548DTqA84xz5K0mLJj/A7/SsB48RJl4OnB6/f9KzK0Xj+3MA7g8QoExFNpaF1sS/UFHyeggSjFDnUbZhz/dTEcQWPfJfcLAc7ZvsJiAVLkmYw/9v6AZ4GfzDZzSGp3+oPLYekC7CKS8Tvb/Jdbk2W+3fSPEpomBezy4B2ToO+Km+cuaOuEaYcszj0FPQU6YSUIAzYpFG5FCvurnMLqoX6hr4OaxycQVSvDth3Pjxhkd2NtYb1KnwqBQQhrJfmmtj0bXKToBpf9utrKXaiu7bxu7T38utqYs1ZdS46quUNlOcP9dhQLUlWWu/xuhfWxfwn4z62pJxAKJrq6GOzLqSaKR6tSvZ+OAr3s+fvw+fvFlry37xg1NvncbkUJT71k7b1/1OZudkwTePzZrWkylg+wHPxTHJ4nx37YTSdSLSZ0MBXP3fXRmuhdo5fYCf0n/rl8dgeN2AqBPaHMaPoI1i8G4batHgcis/qOG8AM+9X2pTTbJfyx+eTj1fSsbJugah+o2LEXwhbnNkwrwwn0aJ99WfBNd5KBMpZ1HagkACMBmlAKQaNRb4/R0xUC8vPbh95oX4nYgHmeRjRh9hWJbXWzta/06rHRWrbgvKxO1GdvCoWEWtxNKIzlXk2mA/Flwgd9+f2wtRCe+/S1M9OGvcG7kml1jtUBHTW8DDSo0RJxaI22pOLRvqg+VZMHs5nejQZouF9TZEF8FsDl0fckAeD/gKaoERwz8umK10TBkU8XvAaKGbmf3rXlfavcrdnUoPQmyouL3XruJse1pAmEbCeaO9u9mW5xfLLbX74Lzy30D6A6qNz1sHd+MXyB/b0J5O3wB+tMVGdSgtn3hgsKD97UMbRcsRNWv1hO9iH9FhZs5X2NYQ3qt2K9OPtrHjTfKyl8l7DjayVnu18BVDj+yyX7y/cT/GqKH1AI1VzuZP7j80R/A+Z5qt+EudpT9rfToL0CmhRrvO1pxXrZCmh64Cpg4yCheRBEyZQneBiH56FiZh8vjSMZwxNctx48IHjspK7lUDDYxu62EnyE4K9H1Vnr+nTVcLuSRgNeEXxzDqtWTOHxKmTo2/ryk33ZrYVIP3NtrMPb8iV67VCvHBEX+ZRqFu1yVtPZMmqXN97bxj72Pm/buEffkTww8OFnpgcGPthkPzDuscblgcA+Jvkdf/swEK5klhRk3NBRy9Vp9SLlM1t1ywO8jCghNZtGm+UVQ/kuVdSpJNv8JUQUvlKn5JXCA89ojWT8d3+h5HEeD7PcGw+1UrH9C4oY/x6IReFNdjEcXg5Pgb43mXf8qLSMALBTDYNi0HhSGgRQguMxnnCOx+TsjITjMa5xPDYKZJcb/B9qO76I \ No newline at end of file diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl index d2533a08de1..ecbb4b5cebc 100644 --- a/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/gpus/rocm/BUILD.tpl @@ -147,7 +147,6 @@ filegroup( name = "rocm_root", srcs = [ "rocm/bin/clang-offload-bundler", - "rocm/bin/bin2c.py", ], ) diff --git a/third_party/gpus/rocm/rocm_config.h.tpl b/third_party/gpus/rocm/rocm_config.h.tpl index 957413b9acd..ec26b00a5b5 100644 --- a/third_party/gpus/rocm/rocm_config.h.tpl +++ b/third_party/gpus/rocm/rocm_config.h.tpl @@ -18,4 +18,8 @@ limitations under the License. #define TF_ROCM_TOOLKIT_PATH "%{rocm_toolkit_path}" +#define TF_ROCM_VERSION %{rocm_version_number} +#define TF_MIOPEN_VERSION %{miopen_version_number} +#define TF_HIPRUNTIME_VERSION %{hipruntime_version_number} + #endif // ROCM_ROCM_CONFIG_H_ diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 05082795188..10f03bfec24 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -4,11 +4,7 @@ * `TF_NEED_ROCM`: Whether to enable building with ROCm. * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path - * `ROCM_TOOLKIT_PATH`: The path to the ROCm toolkit. Default is - `/opt/rocm`. - * `TF_ROCM_VERSION`: The version of the ROCm toolkit. If this is blank, then - use the system default. - * `TF_MIOPEN_VERSION`: The version of the MIOpen library. + * `ROCM_PATH`: The path to the ROCm toolkit. Default is `/opt/rocm`. * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ @@ -27,6 +23,7 @@ load( "get_bash_bin", "get_cpu_value", "get_host_environ", + "get_python_bin", "raw_exec", "realpath", "which", @@ -35,13 +32,9 @@ load( _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" _ROCM_TOOLKIT_PATH = "ROCM_PATH" -_TF_ROCM_VERSION = "TF_ROCM_VERSION" -_TF_MIOPEN_VERSION = "TF_MIOPEN_VERSION" _TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS" _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" -_DEFAULT_ROCM_VERSION = "" -_DEFAULT_MIOPEN_VERSION = "" _DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" def verify_build_defines(params): @@ -193,6 +186,7 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/11.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/12.0.0/include") # Support hcc based off clang 10.0.0 (for ROCm 3.3) inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") @@ -212,20 +206,6 @@ def _enable_rocm(repository_ctx): return True return False -def _rocm_toolkit_path(repository_ctx, bash_bin): - """Finds the rocm toolkit directory. - - Args: - repository_ctx: The repository context. - - Returns: - A speculative real path of the rocm toolkit install directory. - """ - rocm_toolkit_path = get_host_environ(repository_ctx, _ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) - if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]: - auto_configure_fail("Cannot find rocm toolkit path.") - return rocm_toolkit_path - def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin): """Returns a list of strings representing AMDGPU targets.""" amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS) @@ -236,7 +216,7 @@ def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin): amdgpu_targets_str = ",".join(targets) amdgpu_targets = amdgpu_targets_str.split(",") for amdgpu_target in amdgpu_targets: - if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit(): + if amdgpu_target[:3] != "gfx": auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target) return amdgpu_targets @@ -265,51 +245,6 @@ def _hipcc_env(repository_ctx): hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";") return hipcc_env.strip() -def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin): - """Returns if hipcc is based on hip-clang toolchain. - - Args: - repository_ctx: The repository context. - rocm_config: The path to the hip compiler. - bash_bin: the path to the bash interpreter - - Returns: - A string "True" if hipcc is based on hip-clang toolchain. - The functions returns "False" if not (ie: based on HIP/HCC toolchain). - """ - - # check user-defined hip-clang environment variables - for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]: - if get_host_environ(repository_ctx, name): - return "True" - - # grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo - cmd = "grep HIP_COMPILER=clang %s/hip/lib/.hipInfo || true" % rocm_config.rocm_toolkit_path - grep_result = execute(repository_ctx, [bash_bin, "-c", cmd], empty_stdout_fine = True) - result = grep_result.stdout.strip() - if result == "HIP_COMPILER=clang": - return "True" - return "False" - -def _if_hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin, if_true, if_false = []): - """ - Returns either the if_true or if_false arg based on whether hipcc - is based on the hip-clang toolchain - - Args : - repository_ctx: The repository context. - rocm_config: The path to the hip compiler. - if_true : value to return if hipcc is hip-clang based - if_false : value to return if hipcc is not hip-clang based - (optional, defaults to empty list) - - Returns : - either the if_true arg or the of_False arg - """ - if _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin) == "True": - return if_true - return if_false - def _crosstool_verbose(repository_ctx): """Returns the environment variable value CROSSTOOL_VERBOSE. @@ -402,7 +337,40 @@ def _find_libs(repository_ctx, rocm_config, bash_bin): return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin) -def _get_rocm_config(repository_ctx, bash_bin): +def _exec_find_rocm_config(repository_ctx, script_path): + python_bin = get_python_bin(repository_ctx) + + # If used with remote execution then repository_ctx.execute() can't + # access files from the source tree. A trick is to read the contents + # of the file in Starlark and embed them as part of the command. In + # this case the trick is not sufficient as the find_cuda_config.py + # script has more than 8192 characters. 8192 is the command length + # limit of cmd.exe on Windows. Thus we additionally need to compress + # the contents locally and decompress them as part of the execute(). + compressed_contents = repository_ctx.read(script_path) + decompress_and_execute_cmd = ( + "from zlib import decompress;" + + "from base64 import b64decode;" + + "from os import system;" + + "script = decompress(b64decode('%s'));" % compressed_contents + + "f = open('script.py', 'wb');" + + "f.write(script);" + + "f.close();" + + "system('\"%s\" script.py');" % (python_bin) + ) + + return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd]) + +def find_rocm_config(repository_ctx, script_path): + """Returns ROCm config dictionary from running find_rocm_config.py""" + exec_result = _exec_find_rocm_config(repository_ctx, script_path) + if exec_result.return_code: + auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result)) + + # Parse the dict from stdout. + return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()]) + +def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script): """Detects and returns information about the ROCm installation on the system. Args: @@ -413,11 +381,21 @@ def _get_rocm_config(repository_ctx, bash_bin): A struct containing the following fields: rocm_toolkit_path: The ROCm toolkit installation directory. amdgpu_targets: A list of the system's AMDGPU targets. + rocm_version_number: The version of ROCm on the system. + miopen_version_number: The version of MIOpen on the system. + hipruntime_version_number: The version of HIP Runtime on the system. """ - rocm_toolkit_path = _rocm_toolkit_path(repository_ctx, bash_bin) + config = find_rocm_config(repository_ctx, find_rocm_config_script) + rocm_toolkit_path = config["rocm_toolkit_path"] + rocm_version_number = config["rocm_version_number"] + miopen_version_number = config["miopen_version_number"] + hipruntime_version_number = config["hipruntime_version_number"] return struct( - rocm_toolkit_path = rocm_toolkit_path, amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin), + rocm_toolkit_path = rocm_toolkit_path, + rocm_version_number = rocm_version_number, + miopen_version_number = miopen_version_number, + hipruntime_version_number = hipruntime_version_number, ) def _tpl_path(repository_ctx, labelname): @@ -550,8 +528,10 @@ def _create_local_rocm_repository(repository_ctx): "rocm:rocm_config.h", ]} + find_rocm_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_rocm_config.py.gz.base64")) + bash_bin = get_bash_bin(repository_ctx) - rocm_config = _get_rocm_config(repository_ctx, bash_bin) + rocm_config = _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script) # Copy header and library files to execroot. # rocm_toolkit_path @@ -609,13 +589,7 @@ def _create_local_rocm_repository(repository_ctx): outs = rocm_lib_outs, )) - clang_offload_bundler_path = rocm_toolkit_path + _if_hipcc_is_hipclang( - repository_ctx, - rocm_config, - bash_bin, - "/llvm/bin/", - "/hcc/bin/", - ) + "clang-offload-bundler" + clang_offload_bundler_path = rocm_toolkit_path + "/llvm/bin/clang-offload-bundler" # copy files mentioned in third_party/gpus/rocm/BUILD copy_rules.append(make_copy_files_rule( @@ -688,17 +662,7 @@ def _create_local_rocm_repository(repository_ctx): "-DTENSORFLOW_USE_ROCM=1", "-D__HIP_PLATFORM_HCC__", "-DEIGEN_USE_HIP", - ] + _if_hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin, [ - # - # define "TENSORFLOW_COMPILER_IS_HIP_CLANG" when we are using clang - # based hipcc to compile/build tensorflow - # - # Note that this #define should not be used to check whether or not - # tensorflow is being built with ROCm support - # (only TENSORFLOW_USE_ROCM should be used for that purpose) - # - "-DTENSORFLOW_COMPILER_IS_HIP_CLANG=1", - ])) + ]) rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" @@ -729,7 +693,6 @@ def _create_local_rocm_repository(repository_ctx): "%{cpu_compiler}": str(cc), "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/hip/bin/hipcc", "%{hipcc_env}": _hipcc_env(repository_ctx), - "%{hipcc_is_hipclang}": _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin), "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", "%{rocr_runtime_library}": "hsa-runtime64", "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib", @@ -749,6 +712,9 @@ def _create_local_rocm_repository(repository_ctx): ["\"%s\"" % c for c in rocm_config.amdgpu_targets], ), "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path, + "%{rocm_version_number}": rocm_config.rocm_version_number, + "%{miopen_version_number}": rocm_config.miopen_version_number, + "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, }, ) @@ -813,8 +779,6 @@ _ENVIRONS = [ _GCC_HOST_COMPILER_PREFIX, "TF_NEED_ROCM", _ROCM_TOOLKIT_PATH, - _TF_ROCM_VERSION, - _TF_MIOPEN_VERSION, _TF_ROCM_AMDGPU_TARGETS, ]