PR #44471: [ROCm] Update to use ROCm 3.9 (when building TF with --config=rocm)

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/44471

PR https://github.com/tensorflow/tensorflow/pull/43636 is a pre-requisite for this PR.

For the time being, this PR includes commits from it's pre-req as well.  Once the pre-req PR is merged, I will rebase this PR to remove those commits.

--------------------------------------

/cc @cheshire @chsigg @nvining-work

Copybara import of the project:

--
3f0d378c14f55ac850ace17ac154e2333169329b by Deven Desai <deven.desai.amd@gmail.com>:

Adding #defines for ROCm / MIOpen / HIP Runtime version numbers

This PR/commit introduces the following #defines in the `rocm/rocm_config.h` file

```
#define TF_ROCM_VERSION <Version Number of ROCm install>
#define TF_MIOPEN_VERSION <Verion Number of MIOpen in ROCm install>
#define TF_HIPRUNTIME_VERSION <Version Number of HIP Runtinme in ROCm install>
```

These #defines should be used within TF code to add ROCm/MIOpen/HIp Runtime version specific code.

Details on how we go about determining these version numbers can found on the following wiki-page

https://github.com/ROCmSoftwarePlatform/tensorflow-internal/wiki/How-to-add-ROCm-version-specific-code-changes-in-the-TensorFlow-code%3F

A new script `find_rocm_config.py` is being added by this commit. This script does all the work of determining the version number information and it is pretty to extend it to query more information about the ROCM install.

The information collected by the script is available to `rocm_configure.bzl` and hence can be used to add version specific code in `rocm_configure.bzl` as well.

--
922e0e556c4f31f7ff8da1053f014964d01c0859 by Deven Desai <deven.desai.amd@gmail.com>:

Updating Dockerfile.rocm to use ROCm 3.9

--
cc0b4ae28218a83b3cc262ac83d0b2cf476939c8 by Deven Desai <deven.desai.amd@gmail.com>:

Changing CI scripts to use ROCm 3.9

--
fbfdb64c3375f79674a4f56433f944e1e4fd6b6e by Deven Desai <deven.desai.amd@gmail.com>:

Updating rocm_config.py to account for the new location of the rocblas version header file (in ROCm 3.8)

--
3f191faf8b8f2a0111bc386f41316079cad4aaaa by Deven Desai <deven.desai.amd@gmail.com>:

Removing references to TENSORFLOW_COMPILER_IS_HIP_CLANG

Now that we are way past the switch to use ROCm 3.5 and above (i.e. hip-clang), the codes within `#ifdef TENSORFLOW_COMPILER_IS_HIP_CLANG` are always enabled, and the codes within the corresponding `#else` blocks are deadcodes.

This commit removes the references to `#ifdef TENSORFLOW_COMPILER_IS_HIP_CLANG` and their corresponding `#else` blocks

--
9a4841c9bb8117e8228946be1f3752bdaea4a359 by Deven Desai <deven.desai.amd@gmail.com>:

Removing -DTENSORFLOW_COMPILER_IS_HIP_CLANG from the list of compile flags

--
745e2ad6db4282f5efcfef3155d9a46d9235dbf6 by Deven Desai <deven.desai.amd@gmail.com>:

Removing deadcode for the ROCm platform within the third_party/gpus dir

--
c96dc03986636badce7dbd87fb85cf26dff7a43b by Deven Desai <deven.desai.amd@gmail.com>:

Updating XLA code to account for the device lib files location change in ROCm 3.9

The location of the ROCm device lib files is changing in ROCm 3.9

Current (ROCm 3.8 and before) location is $ROCM_PATH/lib

```
root@ixt-rack-04:/opt/rocm-3.8.0# find . -name *.bc
./lib/oclc_isa_version_701.amdgcn.bc
./lib/ocml.amdgcn.bc
./lib/oclc_daz_opt_on.amdgcn.bc
./lib/oclc_isa_version_700.amdgcn.bc
./lib/oclc_isa_version_810.amdgcn.bc
./lib/oclc_unsafe_math_off.amdgcn.bc
./lib/oclc_wavefrontsize64_off.amdgcn.bc
./lib/oclc_isa_version_803.amdgcn.bc
./lib/oclc_isa_version_1011.amdgcn.bc
./lib/oclc_isa_version_1012.amdgcn.bc
./lib/opencl.amdgcn.bc
./lib/oclc_unsafe_math_on.amdgcn.bc
./lib/oclc_isa_version_1010.amdgcn.bc
./lib/oclc_finite_only_off.amdgcn.bc
./lib/oclc_correctly_rounded_sqrt_on.amdgcn.bc
./lib/oclc_daz_opt_off.amdgcn.bc
./lib/oclc_isa_version_802.amdgcn.bc
./lib/ockl.amdgcn.bc
./lib/oclc_isa_version_906.amdgcn.bc
./lib/oclc_isa_version_1030.amdgcn.bc
./lib/oclc_correctly_rounded_sqrt_off.amdgcn.bc
./lib/hip.amdgcn.bc
./lib/oclc_isa_version_908.amdgcn.bc
./lib/oclc_isa_version_900.amdgcn.bc
./lib/oclc_isa_version_702.amdgcn.bc
./lib/oclc_wavefrontsize64_on.amdgcn.bc
./lib/hc.amdgcn.bc
./lib/oclc_isa_version_902.amdgcn.bc
./lib/oclc_isa_version_801.amdgcn.bc
./lib/oclc_finite_only_on.amdgcn.bc
./lib/oclc_isa_version_904.amdgcn.bc
```

New (ROCm 3.9 and above) location is $ROCM_PATH/amdgcn/bitcode
```
root@ixt-hq-99:/opt/rocm-3.9.0-3703# find -name *.bc
./amdgcn/bitcode/oclc_isa_version_700.bc
./amdgcn/bitcode/ocml.bc
./amdgcn/bitcode/oclc_isa_version_1030.bc
./amdgcn/bitcode/oclc_isa_version_1010.bc
./amdgcn/bitcode/oclc_isa_version_904.bc
./amdgcn/bitcode/hip.bc
./amdgcn/bitcode/hc.bc
./amdgcn/bitcode/oclc_daz_opt_off.bc
./amdgcn/bitcode/oclc_wavefrontsize64_off.bc
./amdgcn/bitcode/oclc_wavefrontsize64_on.bc
./amdgcn/bitcode/oclc_isa_version_900.bc
./amdgcn/bitcode/oclc_isa_version_1012.bc
./amdgcn/bitcode/oclc_isa_version_702.bc
./amdgcn/bitcode/oclc_daz_opt_on.bc
./amdgcn/bitcode/oclc_unsafe_math_off.bc
./amdgcn/bitcode/ockl.bc
./amdgcn/bitcode/oclc_isa_version_803.bc
./amdgcn/bitcode/oclc_isa_version_908.bc
./amdgcn/bitcode/oclc_isa_version_802.bc
./amdgcn/bitcode/oclc_correctly_rounded_sqrt_off.bc
./amdgcn/bitcode/oclc_finite_only_on.bc
./amdgcn/bitcode/oclc_isa_version_701.bc
./amdgcn/bitcode/oclc_unsafe_math_on.bc
./amdgcn/bitcode/oclc_isa_version_902.bc
./amdgcn/bitcode/oclc_finite_only_off.bc
./amdgcn/bitcode/opencl.bc
./amdgcn/bitcode/oclc_isa_version_906.bc
./amdgcn/bitcode/oclc_isa_version_810.bc
./amdgcn/bitcode/oclc_isa_version_801.bc
./amdgcn/bitcode/oclc_correctly_rounded_sqrt_on.bc
./amdgcn/bitcode/oclc_isa_version_1011.bc
```

Also not the change in the filename(s)

This commit updates the XLA code, that has the device lib path + filename(s) hardcoded, to account for the change in location / filename

--
6f981a91c8d8a349c88b450c2191df9c62b2b38b by Deven Desai <deven.desai.amd@gmail.com>:

Adding "-fcuda-flush-denormals-to-zero" as a default hipcc option

Prior to ROCm 3.8, hipcc (hipclang) flushed denormal values to zero by default. Starting with ROCm 3.8 that is no longer true, denormal values are kept as is.

TF expects denormals to be flushed to zero. This is enforced on the CUDA side by explicitly passing the "-fcuda-flush-denormals-to-zero" (see tensorflow.bzl). This commit does the same for the ROCm side.

Also removing the no_rocm tag from the corresponding unit test - //tensorflow/python/kernel_tests:denormal_test_gpu

--
74810439720e0692f81ffb0cc3b97dc6ed50876d by Deven Desai <deven.desai.amd@gmail.com>:

Fix for TF build failure with ROCm 3.9 (error: call to 'min' is ambiguous)

When building TF with ROCm 3.9, we are running into the following compile error

```
In file included from tensorflow/core/kernels/reduction_ops_half_mean_sum.cu.cc:20:
./tensorflow/core/kernels/reduction_gpu_kernels.cu.h:430:9: error: call to 'min' is ambiguous
        min(blockDim.y, num_rows - blockIdx.y * blockDim.y);
        ^~~
/opt/rocm-3.9.0-3805/llvm/lib/clang/12.0.0/include/__clang_hip_math.h:1183:23: note: candidate function
__DEVICE__ inline int min(int __arg1, int __arg2) {
                      ^
/opt/rocm-3.9.0-3805/llvm/lib/clang/12.0.0/include/__clang_hip_math.h:1197:14: note: candidate function
inline float min(float __x, float __y) { return fminf(__x, __y); }
             ^
/opt/rocm-3.9.0-3805/llvm/lib/clang/12.0.0/include/__clang_hip_math.h:1200:15: note: candidate function
inline double min(double __x, double __y) { return fmin(__x, __y); }
              ^
1 error generated when compiling for gfx803.
```

The build error seems to be because ROCm 3.9 uses llvm header files from `llvm/lib/clang/12.0.0/include` (ROCm 3.8 uses the `11.0.0` version). `12.0.0` has a new `__clang_hip_math.h` file, which is not present in `11.0.0`. This file has the `min` function overloaded for the `float` and `double` types.

The first argument in the call to `min` (which leads to the error) is `blockDim.y` which has a `uint` type, and hence the compiler gets confused as to which overloaded type to resole to. Previously (i.e. ROCm 3.8 and before) there was only one option (`int`), with ROCm 3.9 there are three (`int`, `float`, and `double`) and hence the error.

The "fix" is to explicitly cast the first argument to `int` to remove the ambiguity (the second argument is already an `int` type).

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/44471 from ROCmSoftwarePlatform:google_upstream_rocm_switch_to_rocm39 74810439720e0692f81ffb0cc3b97dc6ed50876d
PiperOrigin-RevId: 341569721
Change-Id: Ia614893881bf8db1ef8901034c35cc585a82dba8
This commit is contained in:
Deven Desai 2020-11-10 00:50:23 -08:00 committed by TensorFlower Gardener
parent c417a3048a
commit 312e6bacca
17 changed files with 65 additions and 156 deletions

View File

@ -1,5 +1,9 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm_is_configured",
)
package(
default_visibility = [":friends"],
@ -26,21 +30,11 @@ cc_library(
"utils.h",
],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service/gpu:gpu_types",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/lib:traceme",
"@llvm-project//llvm:AMDGPUCodeGen",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@llvm-project//llvm:AMDGPUCodeGen",
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:BitReader",
"@llvm-project//llvm:BitWriter",
@ -54,7 +48,19 @@ cc_library(
"@llvm-project//llvm:Scalar",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
],
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service/gpu:gpu_types",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/profiler/lib:traceme",
] + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
]),
)
tf_cc_test(

View File

@ -67,6 +67,10 @@ limitations under the License.
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/env_var.h"
#if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM
#include "rocm/rocm_config.h"
#endif
namespace xla {
namespace gpu {
namespace {
@ -560,11 +564,18 @@ namespace {
static std::vector<string> GetROCDLPaths(int amdgpu_version,
const string& rocdl_dir_path) {
// AMDGPU version-neutral bitcodes.
#if TF_ROCM_VERSION >= 30900
static std::vector<string>* rocdl_filenames = new std::vector<string>(
{"hc.bc", "opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc",
"oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc",
"oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc"});
#else
static std::vector<string>* rocdl_filenames = new std::vector<string>(
{"hc.amdgcn.bc", "opencl.amdgcn.bc", "ocml.amdgcn.bc", "ockl.amdgcn.bc",
"oclc_finite_only_off.amdgcn.bc", "oclc_daz_opt_off.amdgcn.bc",
"oclc_correctly_rounded_sqrt_on.amdgcn.bc",
"oclc_unsafe_math_off.amdgcn.bc", "oclc_wavefrontsize64_on.amdgcn.bc"});
#endif
// Construct full path to ROCDL bitcode libraries.
std::vector<string> result;
@ -575,7 +586,11 @@ static std::vector<string> GetROCDLPaths(int amdgpu_version,
// Add AMDGPU version-specific bitcodes.
result.push_back(tensorflow::io::JoinPath(
rocdl_dir_path,
#if TF_ROCM_VERSION >= 30900
absl::StrCat("oclc_isa_version_", amdgpu_version, ".bc")));
#else
absl::StrCat("oclc_isa_version_", amdgpu_version, ".amdgcn.bc")));
#endif
return result;
}

View File

@ -287,7 +287,7 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
// One extra line in the inner dimension to avoid share memory bank conflict.
// This is to mimic the following, but no constructor of T can be invoked.
// __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
#if GOOGLE_CUDA // || TENSORFLOW_COMPILER_IS_HIP_CLANG
#if GOOGLE_CUDA
__shared__ __align__(
alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)];
typedef T(*SharedMemoryTile)[TileSizeJ + 1];

View File

@ -387,7 +387,7 @@ __global__ __launch_bounds__(1024) void ColumnReduceKernel(
// - =
// =
const int numRowsThisBlock =
min(blockDim.y, num_rows - blockIdx.y * blockDim.y);
min(static_cast<int>(blockDim.y), num_rows - blockIdx.y * blockDim.y);
for (int row = 1; row < numRowsThisBlock; ++row) {
value_type t = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + row];

View File

@ -248,10 +248,8 @@ void LaunchScan(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
int num_blocks = dimx * dimz;
int ideal_block_size = dimy / items_per_thread;
#if TENSORFLOW_COMPILER_IS_HIP_CLANG
const int rocm_threads_per_warp = 64;
ideal_block_size = std::max(ideal_block_size, rocm_threads_per_warp);
#endif
// There seems to be a bug when the type is not float and block_size 1024.
// Launch on the smallest power of 2 block size that we can.

View File

@ -1297,7 +1297,9 @@ tf_cuda_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
] + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
]),
)
# --------------------------------------------------------------------------

View File

@ -36,10 +36,10 @@ string RocmRoot() {
}
string RocdlRoot() {
#if TENSORFLOW_COMPILER_IS_HIP_CLANG
return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "lib");
#if TF_ROCM_VERSION >= 30900
return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "amdgcn/bitcode");
#else
return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "hcc/lib");
return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "lib");
#endif
}

View File

@ -20,6 +20,10 @@ limitations under the License.
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/test.h"
#if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM
#include "rocm/rocm_config.h"
#endif
namespace tensorflow {
#if TENSORFLOW_USE_ROCM
@ -27,7 +31,11 @@ TEST(RocmRocdlPathTest, ROCDLPath) {
VLOG(2) << "ROCm-Device-Libs root = " << RocdlRoot();
std::vector<string> rocdl_files;
TF_EXPECT_OK(Env::Default()->GetMatchingPaths(
#if TF_ROCM_VERSION >= 30900
io::JoinPath(RocdlRoot(), "*.bc"), &rocdl_files));
#else
io::JoinPath(RocdlRoot(), "*.amdgcn.bc"), &rocdl_files));
#endif
EXPECT_LT(0, rocdl_files.size());
}
#endif

View File

@ -168,25 +168,10 @@ GpuLaunchConfig GetGpuLaunchConfig(int work_element_count,
block_size_limit);
CHECK_EQ(err, cudaSuccess);
#elif TENSORFLOW_USE_ROCM
#if TENSORFLOW_COMPILER_IS_HIP_CLANG
hipError_t err = hipOccupancyMaxPotentialBlockSize(
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
block_size_limit);
CHECK_EQ(err, hipSuccess);
#else
// Earlier versions of this HIP routine incorrectly returned void.
// TODO re-enable hipError_t error checking when HIP is fixed.
// ROCm interface uses unsigned int, convert after checking
uint32_t block_count_uint = 0;
uint32_t thread_per_block_uint = 0;
CHECK_GE(block_size_limit, 0);
uint32_t block_size_limit_uint = static_cast<uint32_t>(block_size_limit);
hipOccupancyMaxPotentialBlockSize(&block_count_uint, &thread_per_block_uint,
func, dynamic_shared_memory_size,
block_size_limit_uint);
block_count = static_cast<int>(block_count_uint);
thread_per_block = static_cast<int>(thread_per_block_uint);
#endif
#endif
block_count =
@ -216,22 +201,9 @@ GpuLaunchConfig GetGpuLaunchConfigFixedBlockSize(
&block_count, func, fixed_block_size, dynamic_shared_memory_size);
CHECK_EQ(err, cudaSuccess);
#elif TENSORFLOW_USE_ROCM
#if TENSORFLOW_COMPILER_IS_HIP_CLANG
hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&block_count, func, fixed_block_size, dynamic_shared_memory_size);
CHECK_EQ(err, hipSuccess);
#else
// Apply the heuristic in GetGpuLaunchConfig(int, const Eigen::GpuDevice&)
// that the kernel is quite simple and will largely be memory-limited.
const int physical_thread_count = std::min(
d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(),
work_element_count);
// Assume the kernel be simple enough that it is okay to use 1024 threads
// per workgroup.
int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
block_count = std::min(DivUp(physical_thread_count, thread_per_block),
d.getNumGpuMultiProcessors());
#endif
#endif
block_count = std::min(block_count * d.getNumGpuMultiProcessors(),
DivUp(work_element_count, fixed_block_size));

View File

@ -3,10 +3,10 @@
FROM ubuntu:bionic
MAINTAINER Jeff Poznanovic <jeffrey.poznanovic@amd.com>
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/3.7/
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/3.9/
ARG ROCM_BUILD_NAME=xenial
ARG ROCM_BUILD_NUM=main
ARG ROCM_PATH=/opt/rocm-3.7.0
ARG ROCM_PATH=/opt/rocm-3.9.0
ENV DEBIAN_FRONTEND noninteractive
ENV TF_NEED_ROCM 1

View File

@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
echo ""
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
ROCM_INSTALL_DIR=/opt/rocm-3.7.0
ROCM_INSTALL_DIR=/opt/rocm-3.9.0
if [[ -n $1 ]]; then
ROCM_INSTALL_DIR=$1
fi

View File

@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
echo ""
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
ROCM_INSTALL_DIR=/opt/rocm-3.7.0
ROCM_INSTALL_DIR=/opt/rocm-3.9.0
if [[ -n $1 ]]; then
ROCM_INSTALL_DIR=$1
fi

View File

@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
echo ""
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
ROCM_INSTALL_DIR=/opt/rocm-3.7.0
ROCM_INSTALL_DIR=/opt/rocm-3.9.0
if [[ -n $1 ]]; then
ROCM_INSTALL_DIR=$1
fi

View File

@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
echo ""
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
ROCM_INSTALL_DIR=/opt/rocm-3.7.0
ROCM_INSTALL_DIR=/opt/rocm-3.9.0
if [[ -n $1 ]]; then
ROCM_INSTALL_DIR=$1
fi

View File

@ -26,12 +26,9 @@ import pipes
# Template values set by rocm_configure.bzl.
CPU_COMPILER = ('%{cpu_compiler}')
GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
HIPCC_PATH = '%{hipcc_path}'
PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
HIPCC_ENV = '%{hipcc_env}'
HIPCC_IS_HIPCLANG = '%{hipcc_is_hipclang}'=="True"
HIP_RUNTIME_PATH = '%{hip_runtime_path}'
HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}'
ROCR_RUNTIME_PATH = '%{rocr_runtime_path}'
@ -98,27 +95,6 @@ def GetHostCompilerOptions(argv):
return opts
def GetHipccOptions(argv):
"""Collect the -hipcc_options values from argv.
Args:
argv: A list of strings, possibly the argv passed to main().
Returns:
The string that can be passed directly to hipcc.
"""
parser = ArgumentParser()
parser.add_argument('-hipcc_options', nargs='*', action='append')
args, _ = parser.parse_known_args(argv)
if args.hipcc_options:
options = _update_options(sum(args.hipcc_options, []))
return ' '.join(['--'+a for a in options])
return ''
def system(cmd):
"""Invokes cmd with os.system().
@ -148,7 +124,6 @@ def InvokeHipcc(argv, log=False):
"""
host_compiler_options = GetHostCompilerOptions(argv)
hipcc_compiler_options = GetHipccOptions(argv)
opt_option = GetOptionValue(argv, 'O')
m_options = GetOptionValue(argv, 'm')
m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
@ -193,14 +168,13 @@ def InvokeHipcc(argv, log=False):
# Otherwise, we get build error.
# Also we need to retain warning about uninitialised shared variable as
# warning only, even when -Werror option is specified.
if HIPCC_IS_HIPCLANG:
hipccopts += ' --include=hip/hip_runtime.h '
hipccopts += ' ' + hipcc_compiler_options
hipccopts += ' --include=hip/hip_runtime.h '
# Use -fno-gpu-rdc by default for early GPU kernel finalization
# This flag would trigger GPU kernels be generated at compile time, instead
# of link time. This allows the default host compiler (gcc) be used as the
# linker for TensorFlow on ROCm platform.
hipccopts += ' -fno-gpu-rdc '
hipccopts += ' -fcuda-flush-denormals-to-zero '
hipccopts += undefines
hipccopts += defines
hipccopts += std_options
@ -211,22 +185,19 @@ def InvokeHipcc(argv, log=False):
depfile = depfiles[0]
cmd = (HIPCC_PATH + ' ' + hipccopts +
host_compiler_options +
' ' + GCC_HOST_COMPILER_PATH +
' -I .' + includes + ' ' + srcs + ' -M -o ' + depfile)
cmd = HIPCC_ENV.replace(';', ' ') + ' ' + cmd
if log: Log(cmd)
if VERBOSE: print(cmd)
exit_status = os.system(cmd)
if exit_status != 0:
return exit_status
cmd = (HIPCC_PATH + ' ' + hipccopts +
host_compiler_options + ' -fPIC' +
' ' + GCC_HOST_COMPILER_PATH +
' -I .' + opt + includes + ' -c ' + srcs + out)
# TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
# Need to investigate and fix.
cmd = 'PATH=' + PREFIX_DIR + ':$PATH '\
+ HIPCC_ENV.replace(';', ' ') + ' '\
cmd = HIPCC_ENV.replace(';', ' ') + ' '\
+ cmd
if log: Log(cmd)
if VERBOSE: print(cmd)
@ -268,8 +239,7 @@ def main():
gpu_linker_flags.append('-L' + HIP_RUNTIME_PATH)
gpu_linker_flags.append('-Wl,-rpath=' + HIP_RUNTIME_PATH)
gpu_linker_flags.append('-l' + HIP_RUNTIME_LIBRARY)
if HIPCC_IS_HIPCLANG:
gpu_linker_flags.append("-lrt")
gpu_linker_flags.append("-lrt")
if VERBOSE: print(' '.join([CPU_COMPILER] + gpu_linker_flags))
return subprocess.call([CPU_COMPILER] + gpu_linker_flags)

View File

@ -147,7 +147,6 @@ filegroup(
name = "rocm_root",
srcs = [
"rocm/bin/clang-offload-bundler",
"rocm/bin/bin2c.py",
],
)

View File

@ -186,6 +186,7 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin):
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/11.0.0/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/12.0.0/include")
# Support hcc based off clang 10.0.0 (for ROCm 3.3)
inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
@ -215,7 +216,7 @@ def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin):
amdgpu_targets_str = ",".join(targets)
amdgpu_targets = amdgpu_targets_str.split(",")
for amdgpu_target in amdgpu_targets:
if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
if amdgpu_target[:3] != "gfx":
auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target)
return amdgpu_targets
@ -244,51 +245,6 @@ def _hipcc_env(repository_ctx):
hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";")
return hipcc_env.strip()
def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
"""Returns if hipcc is based on hip-clang toolchain.
Args:
repository_ctx: The repository context.
rocm_config: The path to the hip compiler.
bash_bin: the path to the bash interpreter
Returns:
A string "True" if hipcc is based on hip-clang toolchain.
The functions returns "False" if not (ie: based on HIP/HCC toolchain).
"""
# check user-defined hip-clang environment variables
for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]:
if get_host_environ(repository_ctx, name):
return "True"
# grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo
cmd = "grep HIP_COMPILER=clang %s/hip/lib/.hipInfo || true" % rocm_config.rocm_toolkit_path
grep_result = execute(repository_ctx, [bash_bin, "-c", cmd], empty_stdout_fine = True)
result = grep_result.stdout.strip()
if result == "HIP_COMPILER=clang":
return "True"
return "False"
def _if_hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin, if_true, if_false = []):
"""
Returns either the if_true or if_false arg based on whether hipcc
is based on the hip-clang toolchain
Args :
repository_ctx: The repository context.
rocm_config: The path to the hip compiler.
if_true : value to return if hipcc is hip-clang based
if_false : value to return if hipcc is not hip-clang based
(optional, defaults to empty list)
Returns :
either the if_true arg or the of_False arg
"""
if _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin) == "True":
return if_true
return if_false
def _crosstool_verbose(repository_ctx):
"""Returns the environment variable value CROSSTOOL_VERBOSE.
@ -633,13 +589,7 @@ def _create_local_rocm_repository(repository_ctx):
outs = rocm_lib_outs,
))
clang_offload_bundler_path = rocm_toolkit_path + _if_hipcc_is_hipclang(
repository_ctx,
rocm_config,
bash_bin,
"/llvm/bin/",
"/hcc/bin/",
) + "clang-offload-bundler"
clang_offload_bundler_path = rocm_toolkit_path + "/llvm/bin/clang-offload-bundler"
# copy files mentioned in third_party/gpus/rocm/BUILD
copy_rules.append(make_copy_files_rule(
@ -712,17 +662,7 @@ def _create_local_rocm_repository(repository_ctx):
"-DTENSORFLOW_USE_ROCM=1",
"-D__HIP_PLATFORM_HCC__",
"-DEIGEN_USE_HIP",
] + _if_hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin, [
#
# define "TENSORFLOW_COMPILER_IS_HIP_CLANG" when we are using clang
# based hipcc to compile/build tensorflow
#
# Note that this #define should not be used to check whether or not
# tensorflow is being built with ROCm support
# (only TENSORFLOW_USE_ROCM should be used for that purpose)
#
"-DTENSORFLOW_COMPILER_IS_HIP_CLANG=1",
]))
])
rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
@ -753,7 +693,6 @@ def _create_local_rocm_repository(repository_ctx):
"%{cpu_compiler}": str(cc),
"%{hipcc_path}": rocm_config.rocm_toolkit_path + "/hip/bin/hipcc",
"%{hipcc_env}": _hipcc_env(repository_ctx),
"%{hipcc_is_hipclang}": _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin),
"%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib",
"%{rocr_runtime_library}": "hsa-runtime64",
"%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib",