Merge pull request #45408 from ROCmSoftwarePlatform/google_upstream_r24_rocm_updates
[ROCm][r2.4]Porting changes from master to switch to ROCm 3.9
This commit is contained in:
commit
5bf048973c
@ -67,6 +67,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
#include "tensorflow/core/util/env_var.h"
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
#include "rocm/rocm_config.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace {
|
namespace {
|
||||||
@ -560,11 +564,18 @@ namespace {
|
|||||||
static std::vector<string> GetROCDLPaths(int amdgpu_version,
|
static std::vector<string> GetROCDLPaths(int amdgpu_version,
|
||||||
const string& rocdl_dir_path) {
|
const string& rocdl_dir_path) {
|
||||||
// AMDGPU version-neutral bitcodes.
|
// AMDGPU version-neutral bitcodes.
|
||||||
|
#if TF_ROCM_VERSION >= 30900
|
||||||
|
static std::vector<string>* rocdl_filenames = new std::vector<string>(
|
||||||
|
{"hc.bc", "opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc",
|
||||||
|
"oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc",
|
||||||
|
"oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc"});
|
||||||
|
#else
|
||||||
static std::vector<string>* rocdl_filenames = new std::vector<string>(
|
static std::vector<string>* rocdl_filenames = new std::vector<string>(
|
||||||
{"hc.amdgcn.bc", "opencl.amdgcn.bc", "ocml.amdgcn.bc", "ockl.amdgcn.bc",
|
{"hc.amdgcn.bc", "opencl.amdgcn.bc", "ocml.amdgcn.bc", "ockl.amdgcn.bc",
|
||||||
"oclc_finite_only_off.amdgcn.bc", "oclc_daz_opt_off.amdgcn.bc",
|
"oclc_finite_only_off.amdgcn.bc", "oclc_daz_opt_off.amdgcn.bc",
|
||||||
"oclc_correctly_rounded_sqrt_on.amdgcn.bc",
|
"oclc_correctly_rounded_sqrt_on.amdgcn.bc",
|
||||||
"oclc_unsafe_math_off.amdgcn.bc", "oclc_wavefrontsize64_on.amdgcn.bc"});
|
"oclc_unsafe_math_off.amdgcn.bc", "oclc_wavefrontsize64_on.amdgcn.bc"});
|
||||||
|
#endif
|
||||||
|
|
||||||
// Construct full path to ROCDL bitcode libraries.
|
// Construct full path to ROCDL bitcode libraries.
|
||||||
std::vector<string> result;
|
std::vector<string> result;
|
||||||
@ -575,7 +586,11 @@ static std::vector<string> GetROCDLPaths(int amdgpu_version,
|
|||||||
// Add AMDGPU version-specific bitcodes.
|
// Add AMDGPU version-specific bitcodes.
|
||||||
result.push_back(tensorflow::io::JoinPath(
|
result.push_back(tensorflow::io::JoinPath(
|
||||||
rocdl_dir_path,
|
rocdl_dir_path,
|
||||||
|
#if TF_ROCM_VERSION >= 30900
|
||||||
|
absl::StrCat("oclc_isa_version_", amdgpu_version, ".bc")));
|
||||||
|
#else
|
||||||
absl::StrCat("oclc_isa_version_", amdgpu_version, ".amdgcn.bc")));
|
absl::StrCat("oclc_isa_version_", amdgpu_version, ".amdgcn.bc")));
|
||||||
|
#endif
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -287,7 +287,7 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
|
|||||||
// One extra line in the inner dimension to avoid share memory bank conflict.
|
// One extra line in the inner dimension to avoid share memory bank conflict.
|
||||||
// This is to mimic the following, but no constructor of T can be invoked.
|
// This is to mimic the following, but no constructor of T can be invoked.
|
||||||
// __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
|
// __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
|
||||||
#if GOOGLE_CUDA // || TENSORFLOW_COMPILER_IS_HIP_CLANG
|
#if GOOGLE_CUDA
|
||||||
__shared__ __align__(
|
__shared__ __align__(
|
||||||
alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)];
|
alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)];
|
||||||
typedef T(*SharedMemoryTile)[TileSizeJ + 1];
|
typedef T(*SharedMemoryTile)[TileSizeJ + 1];
|
||||||
|
@ -387,7 +387,7 @@ __global__ __launch_bounds__(1024) void ColumnReduceKernel(
|
|||||||
// - =
|
// - =
|
||||||
// =
|
// =
|
||||||
const int numRowsThisBlock =
|
const int numRowsThisBlock =
|
||||||
min(blockDim.y, num_rows - blockIdx.y * blockDim.y);
|
min(int(blockDim.y), num_rows - blockIdx.y * blockDim.y);
|
||||||
|
|
||||||
for (int row = 1; row < numRowsThisBlock; ++row) {
|
for (int row = 1; row < numRowsThisBlock; ++row) {
|
||||||
value_type t = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + row];
|
value_type t = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + row];
|
||||||
|
@ -248,10 +248,8 @@ void LaunchScan(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
|
|||||||
int num_blocks = dimx * dimz;
|
int num_blocks = dimx * dimz;
|
||||||
|
|
||||||
int ideal_block_size = dimy / items_per_thread;
|
int ideal_block_size = dimy / items_per_thread;
|
||||||
#if TENSORFLOW_COMPILER_IS_HIP_CLANG
|
|
||||||
const int rocm_threads_per_warp = 64;
|
const int rocm_threads_per_warp = 64;
|
||||||
ideal_block_size = std::max(ideal_block_size, rocm_threads_per_warp);
|
ideal_block_size = std::max(ideal_block_size, rocm_threads_per_warp);
|
||||||
#endif
|
|
||||||
|
|
||||||
// There seems to be a bug when the type is not float and block_size 1024.
|
// There seems to be a bug when the type is not float and block_size 1024.
|
||||||
// Launch on the smallest power of 2 block size that we can.
|
// Launch on the smallest power of 2 block size that we can.
|
||||||
|
@ -36,10 +36,10 @@ string RocmRoot() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string RocdlRoot() {
|
string RocdlRoot() {
|
||||||
#if TENSORFLOW_COMPILER_IS_HIP_CLANG
|
#if TF_ROCM_VERSION >= 30900
|
||||||
return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "lib");
|
return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "amdgcn/bitcode");
|
||||||
#else
|
#else
|
||||||
return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "hcc/lib");
|
return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "lib");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/platform/rocm_rocdl_path.h"
|
#include "tensorflow/core/platform/rocm_rocdl_path.h"
|
||||||
|
|
||||||
|
#include "rocm/rocm_config.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/path.h"
|
#include "tensorflow/core/platform/path.h"
|
||||||
@ -27,7 +28,11 @@ TEST(RocmRocdlPathTest, ROCDLPath) {
|
|||||||
VLOG(2) << "ROCm-Device-Libs root = " << RocdlRoot();
|
VLOG(2) << "ROCm-Device-Libs root = " << RocdlRoot();
|
||||||
std::vector<string> rocdl_files;
|
std::vector<string> rocdl_files;
|
||||||
TF_EXPECT_OK(Env::Default()->GetMatchingPaths(
|
TF_EXPECT_OK(Env::Default()->GetMatchingPaths(
|
||||||
|
#if TF_ROCM_VERSION >= 30900
|
||||||
|
io::JoinPath(RocdlRoot(), "*.bc"), &rocdl_files));
|
||||||
|
#else
|
||||||
io::JoinPath(RocdlRoot(), "*.amdgcn.bc"), &rocdl_files));
|
io::JoinPath(RocdlRoot(), "*.amdgcn.bc"), &rocdl_files));
|
||||||
|
#endif
|
||||||
EXPECT_LT(0, rocdl_files.size());
|
EXPECT_LT(0, rocdl_files.size());
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -168,25 +168,10 @@ GpuLaunchConfig GetGpuLaunchConfig(int work_element_count,
|
|||||||
block_size_limit);
|
block_size_limit);
|
||||||
CHECK_EQ(err, cudaSuccess);
|
CHECK_EQ(err, cudaSuccess);
|
||||||
#elif TENSORFLOW_USE_ROCM
|
#elif TENSORFLOW_USE_ROCM
|
||||||
#if TENSORFLOW_COMPILER_IS_HIP_CLANG
|
|
||||||
hipError_t err = hipOccupancyMaxPotentialBlockSize(
|
hipError_t err = hipOccupancyMaxPotentialBlockSize(
|
||||||
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
|
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
|
||||||
block_size_limit);
|
block_size_limit);
|
||||||
CHECK_EQ(err, hipSuccess);
|
CHECK_EQ(err, hipSuccess);
|
||||||
#else
|
|
||||||
// Earlier versions of this HIP routine incorrectly returned void.
|
|
||||||
// TODO re-enable hipError_t error checking when HIP is fixed.
|
|
||||||
// ROCm interface uses unsigned int, convert after checking
|
|
||||||
uint32_t block_count_uint = 0;
|
|
||||||
uint32_t thread_per_block_uint = 0;
|
|
||||||
CHECK_GE(block_size_limit, 0);
|
|
||||||
uint32_t block_size_limit_uint = static_cast<uint32_t>(block_size_limit);
|
|
||||||
hipOccupancyMaxPotentialBlockSize(&block_count_uint, &thread_per_block_uint,
|
|
||||||
func, dynamic_shared_memory_size,
|
|
||||||
block_size_limit_uint);
|
|
||||||
block_count = static_cast<int>(block_count_uint);
|
|
||||||
thread_per_block = static_cast<int>(thread_per_block_uint);
|
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
block_count =
|
block_count =
|
||||||
@ -216,22 +201,9 @@ GpuLaunchConfig GetGpuLaunchConfigFixedBlockSize(
|
|||||||
&block_count, func, fixed_block_size, dynamic_shared_memory_size);
|
&block_count, func, fixed_block_size, dynamic_shared_memory_size);
|
||||||
CHECK_EQ(err, cudaSuccess);
|
CHECK_EQ(err, cudaSuccess);
|
||||||
#elif TENSORFLOW_USE_ROCM
|
#elif TENSORFLOW_USE_ROCM
|
||||||
#if TENSORFLOW_COMPILER_IS_HIP_CLANG
|
|
||||||
hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
&block_count, func, fixed_block_size, dynamic_shared_memory_size);
|
&block_count, func, fixed_block_size, dynamic_shared_memory_size);
|
||||||
CHECK_EQ(err, hipSuccess);
|
CHECK_EQ(err, hipSuccess);
|
||||||
#else
|
|
||||||
// Apply the heuristic in GetGpuLaunchConfig(int, const Eigen::GpuDevice&)
|
|
||||||
// that the kernel is quite simple and will largely be memory-limited.
|
|
||||||
const int physical_thread_count = std::min(
|
|
||||||
d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(),
|
|
||||||
work_element_count);
|
|
||||||
// Assume the kernel be simple enough that it is okay to use 1024 threads
|
|
||||||
// per workgroup.
|
|
||||||
int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
|
|
||||||
block_count = std::min(DivUp(physical_thread_count, thread_per_block),
|
|
||||||
d.getNumGpuMultiProcessors());
|
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
block_count = std::min(block_count * d.getNumGpuMultiProcessors(),
|
block_count = std::min(block_count * d.getNumGpuMultiProcessors(),
|
||||||
DivUp(work_element_count, fixed_block_size));
|
DivUp(work_element_count, fixed_block_size));
|
||||||
|
@ -856,6 +856,11 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) {
|
|||||||
|
|
||||||
float clock_rate_ghz = static_cast<float>(prop.clockRate) / 1e6;
|
float clock_rate_ghz = static_cast<float>(prop.clockRate) / 1e6;
|
||||||
builder.set_clock_rate_ghz(clock_rate_ghz);
|
builder.set_clock_rate_ghz(clock_rate_ghz);
|
||||||
|
|
||||||
|
// mem_bandwidth = 2 * mem_bus_width_in_bytes * mem_clock_rate_in_hz
|
||||||
|
int64 memory_bandwidth = 2 * (int64(prop.memoryBusWidth) / 8) *
|
||||||
|
(int64(prop.memoryClockRate) * 1000);
|
||||||
|
builder.set_memory_bandwidth(memory_bandwidth);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -3,10 +3,10 @@
|
|||||||
FROM ubuntu:bionic
|
FROM ubuntu:bionic
|
||||||
MAINTAINER Jeff Poznanovic <jeffrey.poznanovic@amd.com>
|
MAINTAINER Jeff Poznanovic <jeffrey.poznanovic@amd.com>
|
||||||
|
|
||||||
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/3.7/
|
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/3.9/
|
||||||
ARG ROCM_BUILD_NAME=xenial
|
ARG ROCM_BUILD_NAME=xenial
|
||||||
ARG ROCM_BUILD_NUM=main
|
ARG ROCM_BUILD_NUM=main
|
||||||
ARG ROCM_PATH=/opt/rocm-3.7.0
|
ARG ROCM_PATH=/opt/rocm-3.9.0
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND noninteractive
|
ENV DEBIAN_FRONTEND noninteractive
|
||||||
ENV TF_NEED_ROCM 1
|
ENV TF_NEED_ROCM 1
|
||||||
|
@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
|
|||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
||||||
ROCM_INSTALL_DIR=/opt/rocm-3.7.0
|
ROCM_INSTALL_DIR=/opt/rocm-3.9.0
|
||||||
if [[ -n $1 ]]; then
|
if [[ -n $1 ]]; then
|
||||||
ROCM_INSTALL_DIR=$1
|
ROCM_INSTALL_DIR=$1
|
||||||
fi
|
fi
|
||||||
|
@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
|
|||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
||||||
ROCM_INSTALL_DIR=/opt/rocm-3.7.0
|
ROCM_INSTALL_DIR=/opt/rocm-3.9.0
|
||||||
if [[ -n $1 ]]; then
|
if [[ -n $1 ]]; then
|
||||||
ROCM_INSTALL_DIR=$1
|
ROCM_INSTALL_DIR=$1
|
||||||
fi
|
fi
|
||||||
|
@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
|
|||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
||||||
ROCM_INSTALL_DIR=/opt/rocm-3.7.0
|
ROCM_INSTALL_DIR=/opt/rocm-3.9.0
|
||||||
if [[ -n $1 ]]; then
|
if [[ -n $1 ]]; then
|
||||||
ROCM_INSTALL_DIR=$1
|
ROCM_INSTALL_DIR=$1
|
||||||
fi
|
fi
|
||||||
|
@ -28,7 +28,7 @@ echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS}
|
|||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
# First positional argument (if any) specifies the ROCM_INSTALL_DIR
|
||||||
ROCM_INSTALL_DIR=/opt/rocm-3.7.0
|
ROCM_INSTALL_DIR=/opt/rocm-3.9.0
|
||||||
if [[ -n $1 ]]; then
|
if [[ -n $1 ]]; then
|
||||||
ROCM_INSTALL_DIR=$1
|
ROCM_INSTALL_DIR=$1
|
||||||
fi
|
fi
|
||||||
|
36
third_party/gpus/compress_find_rocm_config.py
vendored
Normal file
36
third_party/gpus/compress_find_rocm_config.py
vendored
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Compresses the contents of 'find_rocm_config.py'.
|
||||||
|
|
||||||
|
The compressed file is what is actually being used. It works around remote
|
||||||
|
config not being able to upload files yet.
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import zlib
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
with open('find_rocm_config.py', 'rb') as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
compressed = zlib.compress(data)
|
||||||
|
b64encoded = base64.b64encode(compressed)
|
||||||
|
|
||||||
|
with open('find_rocm_config.py.gz.base64', 'wb') as f:
|
||||||
|
f.write(b64encoded)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -26,12 +26,9 @@ import pipes
|
|||||||
|
|
||||||
# Template values set by rocm_configure.bzl.
|
# Template values set by rocm_configure.bzl.
|
||||||
CPU_COMPILER = ('%{cpu_compiler}')
|
CPU_COMPILER = ('%{cpu_compiler}')
|
||||||
GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
|
|
||||||
|
|
||||||
HIPCC_PATH = '%{hipcc_path}'
|
HIPCC_PATH = '%{hipcc_path}'
|
||||||
PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
|
|
||||||
HIPCC_ENV = '%{hipcc_env}'
|
HIPCC_ENV = '%{hipcc_env}'
|
||||||
HIPCC_IS_HIPCLANG = '%{hipcc_is_hipclang}'=="True"
|
|
||||||
HIP_RUNTIME_PATH = '%{hip_runtime_path}'
|
HIP_RUNTIME_PATH = '%{hip_runtime_path}'
|
||||||
HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}'
|
HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}'
|
||||||
ROCR_RUNTIME_PATH = '%{rocr_runtime_path}'
|
ROCR_RUNTIME_PATH = '%{rocr_runtime_path}'
|
||||||
@ -98,27 +95,6 @@ def GetHostCompilerOptions(argv):
|
|||||||
|
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
def GetHipccOptions(argv):
|
|
||||||
"""Collect the -hipcc_options values from argv.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
argv: A list of strings, possibly the argv passed to main().
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The string that can be passed directly to hipcc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument('-hipcc_options', nargs='*', action='append')
|
|
||||||
|
|
||||||
args, _ = parser.parse_known_args(argv)
|
|
||||||
|
|
||||||
if args.hipcc_options:
|
|
||||||
options = _update_options(sum(args.hipcc_options, []))
|
|
||||||
return ' '.join(['--'+a for a in options])
|
|
||||||
return ''
|
|
||||||
|
|
||||||
|
|
||||||
def system(cmd):
|
def system(cmd):
|
||||||
"""Invokes cmd with os.system().
|
"""Invokes cmd with os.system().
|
||||||
|
|
||||||
@ -148,7 +124,6 @@ def InvokeHipcc(argv, log=False):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
host_compiler_options = GetHostCompilerOptions(argv)
|
host_compiler_options = GetHostCompilerOptions(argv)
|
||||||
hipcc_compiler_options = GetHipccOptions(argv)
|
|
||||||
opt_option = GetOptionValue(argv, 'O')
|
opt_option = GetOptionValue(argv, 'O')
|
||||||
m_options = GetOptionValue(argv, 'm')
|
m_options = GetOptionValue(argv, 'm')
|
||||||
m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
|
m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
|
||||||
@ -193,14 +168,13 @@ def InvokeHipcc(argv, log=False):
|
|||||||
# Otherwise, we get build error.
|
# Otherwise, we get build error.
|
||||||
# Also we need to retain warning about uninitialised shared variable as
|
# Also we need to retain warning about uninitialised shared variable as
|
||||||
# warning only, even when -Werror option is specified.
|
# warning only, even when -Werror option is specified.
|
||||||
if HIPCC_IS_HIPCLANG:
|
hipccopts += ' --include=hip/hip_runtime.h '
|
||||||
hipccopts += ' --include=hip/hip_runtime.h '
|
|
||||||
hipccopts += ' ' + hipcc_compiler_options
|
|
||||||
# Use -fno-gpu-rdc by default for early GPU kernel finalization
|
# Use -fno-gpu-rdc by default for early GPU kernel finalization
|
||||||
# This flag would trigger GPU kernels be generated at compile time, instead
|
# This flag would trigger GPU kernels be generated at compile time, instead
|
||||||
# of link time. This allows the default host compiler (gcc) be used as the
|
# of link time. This allows the default host compiler (gcc) be used as the
|
||||||
# linker for TensorFlow on ROCm platform.
|
# linker for TensorFlow on ROCm platform.
|
||||||
hipccopts += ' -fno-gpu-rdc '
|
hipccopts += ' -fno-gpu-rdc '
|
||||||
|
hipccopts += ' -fcuda-flush-denormals-to-zero '
|
||||||
hipccopts += undefines
|
hipccopts += undefines
|
||||||
hipccopts += defines
|
hipccopts += defines
|
||||||
hipccopts += std_options
|
hipccopts += std_options
|
||||||
@ -211,22 +185,19 @@ def InvokeHipcc(argv, log=False):
|
|||||||
depfile = depfiles[0]
|
depfile = depfiles[0]
|
||||||
cmd = (HIPCC_PATH + ' ' + hipccopts +
|
cmd = (HIPCC_PATH + ' ' + hipccopts +
|
||||||
host_compiler_options +
|
host_compiler_options +
|
||||||
' ' + GCC_HOST_COMPILER_PATH +
|
|
||||||
' -I .' + includes + ' ' + srcs + ' -M -o ' + depfile)
|
' -I .' + includes + ' ' + srcs + ' -M -o ' + depfile)
|
||||||
|
cmd = HIPCC_ENV.replace(';', ' ') + ' ' + cmd
|
||||||
if log: Log(cmd)
|
if log: Log(cmd)
|
||||||
|
if VERBOSE: print(cmd)
|
||||||
exit_status = os.system(cmd)
|
exit_status = os.system(cmd)
|
||||||
if exit_status != 0:
|
if exit_status != 0:
|
||||||
return exit_status
|
return exit_status
|
||||||
|
|
||||||
cmd = (HIPCC_PATH + ' ' + hipccopts +
|
cmd = (HIPCC_PATH + ' ' + hipccopts +
|
||||||
host_compiler_options + ' -fPIC' +
|
host_compiler_options + ' -fPIC' +
|
||||||
' ' + GCC_HOST_COMPILER_PATH +
|
|
||||||
' -I .' + opt + includes + ' -c ' + srcs + out)
|
' -I .' + opt + includes + ' -c ' + srcs + out)
|
||||||
|
|
||||||
# TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
|
cmd = HIPCC_ENV.replace(';', ' ') + ' '\
|
||||||
# Need to investigate and fix.
|
|
||||||
cmd = 'PATH=' + PREFIX_DIR + ':$PATH '\
|
|
||||||
+ HIPCC_ENV.replace(';', ' ') + ' '\
|
|
||||||
+ cmd
|
+ cmd
|
||||||
if log: Log(cmd)
|
if log: Log(cmd)
|
||||||
if VERBOSE: print(cmd)
|
if VERBOSE: print(cmd)
|
||||||
@ -268,8 +239,7 @@ def main():
|
|||||||
gpu_linker_flags.append('-L' + HIP_RUNTIME_PATH)
|
gpu_linker_flags.append('-L' + HIP_RUNTIME_PATH)
|
||||||
gpu_linker_flags.append('-Wl,-rpath=' + HIP_RUNTIME_PATH)
|
gpu_linker_flags.append('-Wl,-rpath=' + HIP_RUNTIME_PATH)
|
||||||
gpu_linker_flags.append('-l' + HIP_RUNTIME_LIBRARY)
|
gpu_linker_flags.append('-l' + HIP_RUNTIME_LIBRARY)
|
||||||
if HIPCC_IS_HIPCLANG:
|
gpu_linker_flags.append("-lrt")
|
||||||
gpu_linker_flags.append("-lrt")
|
|
||||||
|
|
||||||
if VERBOSE: print(' '.join([CPU_COMPILER] + gpu_linker_flags))
|
if VERBOSE: print(' '.join([CPU_COMPILER] + gpu_linker_flags))
|
||||||
return subprocess.call([CPU_COMPILER] + gpu_linker_flags)
|
return subprocess.call([CPU_COMPILER] + gpu_linker_flags)
|
||||||
|
286
third_party/gpus/find_rocm_config.py
vendored
Normal file
286
third_party/gpus/find_rocm_config.py
vendored
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Prints ROCm library and header directories and versions found on the system.
|
||||||
|
|
||||||
|
The script searches for ROCm library and header files on the system, inspects
|
||||||
|
them to determine their version and prints the configuration to stdout.
|
||||||
|
The path to inspect is specified through an environment variable (ROCM_PATH).
|
||||||
|
If no valid configuration is found, the script prints to stderr and
|
||||||
|
returns an error code.
|
||||||
|
|
||||||
|
The script takes the directory specified by the ROCM_PATH environment variable.
|
||||||
|
The script looks for headers and library files in a hard-coded set of
|
||||||
|
subdirectories from base path of the specified directory. If ROCM_PATH is not
|
||||||
|
specified, then "/opt/rocm" is used as it default value
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_rocm_path():
|
||||||
|
return "/opt/rocm"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rocm_install_path():
|
||||||
|
"""Determines and returns the ROCm installation path."""
|
||||||
|
rocm_install_path = _get_default_rocm_path()
|
||||||
|
if "ROCM_PATH" in os.environ:
|
||||||
|
rocm_install_path = os.environ["ROCM_PATH"]
|
||||||
|
# rocm_install_path = os.path.realpath(rocm_install_path)
|
||||||
|
return rocm_install_path
|
||||||
|
|
||||||
|
|
||||||
|
def _get_composite_version_number(major, minor, patch):
|
||||||
|
return 10000 * major + 100 * minor + patch
|
||||||
|
|
||||||
|
|
||||||
|
def _get_header_version(path, name):
|
||||||
|
"""Returns preprocessor defines in C header file."""
|
||||||
|
for line in io.open(path, "r", encoding="utf-8"):
|
||||||
|
match = re.match(r"#define %s +(\d+)" % name, line)
|
||||||
|
if match:
|
||||||
|
value = match.group(1)
|
||||||
|
return int(value)
|
||||||
|
|
||||||
|
raise ConfigError('#define "{}" is either\n'.format(name) +
|
||||||
|
" not present in file {} OR\n".format(path) +
|
||||||
|
" its value is not an integer literal")
|
||||||
|
|
||||||
|
|
||||||
|
def _find_rocm_config(rocm_install_path):
|
||||||
|
|
||||||
|
def rocm_version_numbers(path):
|
||||||
|
version_file = os.path.join(path, ".info/version-dev")
|
||||||
|
if not os.path.exists(version_file):
|
||||||
|
raise ConfigError('ROCm version file "{}" not found'.format(version_file))
|
||||||
|
version_numbers = []
|
||||||
|
with open(version_file) as f:
|
||||||
|
version_string = f.read().strip()
|
||||||
|
version_numbers = version_string.split(".")
|
||||||
|
major = int(version_numbers[0])
|
||||||
|
minor = int(version_numbers[1])
|
||||||
|
patch = int(version_numbers[2].split("-")[0])
|
||||||
|
return major, minor, patch
|
||||||
|
|
||||||
|
major, minor, patch = rocm_version_numbers(rocm_install_path)
|
||||||
|
|
||||||
|
rocm_config = {
|
||||||
|
"rocm_version_number": _get_composite_version_number(major, minor, patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rocm_config
|
||||||
|
|
||||||
|
|
||||||
|
def _find_hipruntime_config(rocm_install_path):
|
||||||
|
|
||||||
|
def hipruntime_version_number(path):
|
||||||
|
version_file = os.path.join(path, "hip/include/hip/hip_version.h")
|
||||||
|
if not os.path.exists(version_file):
|
||||||
|
raise ConfigError(
|
||||||
|
'HIP Runtime version file "{}" not found'.format(version_file))
|
||||||
|
# This header file has an explicit #define for HIP_VERSION, whose value
|
||||||
|
# is (HIP_VERSION_MAJOR * 100 + HIP_VERSION_MINOR)
|
||||||
|
# Retreive the major + minor and re-calculate here, since we do not
|
||||||
|
# want get into the business of parsing arith exprs
|
||||||
|
major = _get_header_version(version_file, "HIP_VERSION_MAJOR")
|
||||||
|
minor = _get_header_version(version_file, "HIP_VERSION_MINOR")
|
||||||
|
return 100 * major + minor
|
||||||
|
|
||||||
|
hipruntime_config = {
|
||||||
|
"hipruntime_version_number": hipruntime_version_number(rocm_install_path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hipruntime_config
|
||||||
|
|
||||||
|
|
||||||
|
def _find_miopen_config(rocm_install_path):
|
||||||
|
|
||||||
|
def miopen_version_numbers(path):
|
||||||
|
version_file = os.path.join(path, "miopen/include/miopen/version.h")
|
||||||
|
if not os.path.exists(version_file):
|
||||||
|
raise ConfigError(
|
||||||
|
'MIOpen version file "{}" not found'.format(version_file))
|
||||||
|
major = _get_header_version(version_file, "MIOPEN_VERSION_MAJOR")
|
||||||
|
minor = _get_header_version(version_file, "MIOPEN_VERSION_MINOR")
|
||||||
|
patch = _get_header_version(version_file, "MIOPEN_VERSION_PATCH")
|
||||||
|
return major, minor, patch
|
||||||
|
|
||||||
|
major, minor, patch = miopen_version_numbers(rocm_install_path)
|
||||||
|
|
||||||
|
miopen_config = {
|
||||||
|
"miopen_version_number":
|
||||||
|
_get_composite_version_number(major, minor, patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
return miopen_config
|
||||||
|
|
||||||
|
|
||||||
|
def _find_rocblas_config(rocm_install_path):
|
||||||
|
|
||||||
|
def rocblas_version_numbers(path):
|
||||||
|
possible_version_files = [
|
||||||
|
"rocblas/include/rocblas-version.h", # ROCm 3.7 and prior
|
||||||
|
"rocblas/include/internal/rocblas-version.h", # ROCm 3.8
|
||||||
|
]
|
||||||
|
version_file = None
|
||||||
|
for f in possible_version_files:
|
||||||
|
version_file_path = os.path.join(path, f)
|
||||||
|
if os.path.exists(version_file_path):
|
||||||
|
version_file = version_file_path
|
||||||
|
break
|
||||||
|
if not version_file:
|
||||||
|
raise ConfigError(
|
||||||
|
"rocblas version file not found in {}".format(
|
||||||
|
possible_version_files))
|
||||||
|
major = _get_header_version(version_file, "ROCBLAS_VERSION_MAJOR")
|
||||||
|
minor = _get_header_version(version_file, "ROCBLAS_VERSION_MINOR")
|
||||||
|
patch = _get_header_version(version_file, "ROCBLAS_VERSION_PATCH")
|
||||||
|
return major, minor, patch
|
||||||
|
|
||||||
|
major, minor, patch = rocblas_version_numbers(rocm_install_path)
|
||||||
|
|
||||||
|
rocblas_config = {
|
||||||
|
"rocblas_version_number":
|
||||||
|
_get_composite_version_number(major, minor, patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rocblas_config
|
||||||
|
|
||||||
|
|
||||||
|
def _find_rocrand_config(rocm_install_path):
|
||||||
|
|
||||||
|
def rocrand_version_number(path):
|
||||||
|
version_file = os.path.join(path, "rocrand/include/rocrand_version.h")
|
||||||
|
if not os.path.exists(version_file):
|
||||||
|
raise ConfigError(
|
||||||
|
'rocblas version file "{}" not found'.format(version_file))
|
||||||
|
version_number = _get_header_version(version_file, "ROCRAND_VERSION")
|
||||||
|
return version_number
|
||||||
|
|
||||||
|
rocrand_config = {
|
||||||
|
"rocrand_version_number": rocrand_version_number(rocm_install_path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rocrand_config
|
||||||
|
|
||||||
|
|
||||||
|
def _find_rocfft_config(rocm_install_path):
|
||||||
|
|
||||||
|
def rocfft_version_numbers(path):
|
||||||
|
version_file = os.path.join(path, "rocfft/include/rocfft-version.h")
|
||||||
|
if not os.path.exists(version_file):
|
||||||
|
raise ConfigError(
|
||||||
|
'rocfft version file "{}" not found'.format(version_file))
|
||||||
|
major = _get_header_version(version_file, "rocfft_version_major")
|
||||||
|
minor = _get_header_version(version_file, "rocfft_version_minor")
|
||||||
|
patch = _get_header_version(version_file, "rocfft_version_patch")
|
||||||
|
return major, minor, patch
|
||||||
|
|
||||||
|
major, minor, patch = rocfft_version_numbers(rocm_install_path)
|
||||||
|
|
||||||
|
rocfft_config = {
|
||||||
|
"rocfft_version_number":
|
||||||
|
_get_composite_version_number(major, minor, patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rocfft_config
|
||||||
|
|
||||||
|
|
||||||
|
def _find_roctracer_config(rocm_install_path):
|
||||||
|
|
||||||
|
def roctracer_version_numbers(path):
|
||||||
|
version_file = os.path.join(path, "roctracer/include/roctracer.h")
|
||||||
|
if not os.path.exists(version_file):
|
||||||
|
raise ConfigError(
|
||||||
|
'roctracer version file "{}" not found'.format(version_file))
|
||||||
|
major = _get_header_version(version_file, "ROCTRACER_VERSION_MAJOR")
|
||||||
|
minor = _get_header_version(version_file, "ROCTRACER_VERSION_MINOR")
|
||||||
|
# roctracer header does not have a patch version number
|
||||||
|
patch = 0
|
||||||
|
return major, minor, patch
|
||||||
|
|
||||||
|
major, minor, patch = roctracer_version_numbers(rocm_install_path)
|
||||||
|
|
||||||
|
roctracer_config = {
|
||||||
|
"roctracer_version_number":
|
||||||
|
_get_composite_version_number(major, minor, patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
return roctracer_config
|
||||||
|
|
||||||
|
|
||||||
|
def _find_hipsparse_config(rocm_install_path):
|
||||||
|
|
||||||
|
def hipsparse_version_numbers(path):
|
||||||
|
version_file = os.path.join(path, "hipsparse/include/hipsparse-version.h")
|
||||||
|
if not os.path.exists(version_file):
|
||||||
|
raise ConfigError(
|
||||||
|
'hipsparse version file "{}" not found'.format(version_file))
|
||||||
|
major = _get_header_version(version_file, "hipsparseVersionMajor")
|
||||||
|
minor = _get_header_version(version_file, "hipsparseVersionMinor")
|
||||||
|
patch = _get_header_version(version_file, "hipsparseVersionPatch")
|
||||||
|
return major, minor, patch
|
||||||
|
|
||||||
|
major, minor, patch = hipsparse_version_numbers(rocm_install_path)
|
||||||
|
|
||||||
|
hipsparse_config = {
|
||||||
|
"hipsparse_version_number":
|
||||||
|
_get_composite_version_number(major, minor, patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hipsparse_config
|
||||||
|
|
||||||
|
|
||||||
|
def find_rocm_config():
|
||||||
|
"""Returns a dictionary of ROCm components config info."""
|
||||||
|
rocm_install_path = _get_rocm_install_path()
|
||||||
|
if not os.path.exists(rocm_install_path):
|
||||||
|
raise ConfigError(
|
||||||
|
'Specified ROCM_PATH "{}" does not exist'.format(rocm_install_path))
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
result["rocm_toolkit_path"] = rocm_install_path
|
||||||
|
result.update(_find_rocm_config(rocm_install_path))
|
||||||
|
result.update(_find_hipruntime_config(rocm_install_path))
|
||||||
|
result.update(_find_miopen_config(rocm_install_path))
|
||||||
|
result.update(_find_rocblas_config(rocm_install_path))
|
||||||
|
result.update(_find_rocrand_config(rocm_install_path))
|
||||||
|
result.update(_find_rocfft_config(rocm_install_path))
|
||||||
|
result.update(_find_roctracer_config(rocm_install_path))
|
||||||
|
result.update(_find_hipsparse_config(rocm_install_path))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
for key, value in sorted(find_rocm_config().items()):
|
||||||
|
print("%s: %s" % (key, value))
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\nERROR: {}\n\n".format(str(e)))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
third_party/gpus/find_rocm_config.py.gz.base64
vendored
Normal file
1
third_party/gpus/find_rocm_config.py.gz.base64
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
eJy9WtFu27gSfddXEAqKyhtHSXsfdpFFHrxpFvXe1gns7C4WTWDQMm1zK4u6JJXUKPrvd4akZEqWEydOa6CoZQ0PhzNnDkdiDsi5yFeSzxeavD15e0KuF4xcs0wJ+Xsq7kmv0AshVUx6aUqGaKbIkCkm79g0Dg6CA/KBJ2DOpqTIpkwSDeN7OU3gP3enS/5iUnGRkbfxCYnQIHS3ws6vgLASBVnSFcmEJoViAMEVmfGUEfYlYbkmPCOJWOYpp1nCyD3XCzONAwE3yD8OQkw0BWsK9jlczXw7QrVxGD8LrfPT4+P7+/uYGmdjIefHqTVUxx/65xeD0cUROGyG/JmlTCki2f8KLmGpkxWhOfiT0Al4mdJ7IiShc8ngnhbo773kmmfzLlFipu+pZIAy5UpLPil0LVild7Bm3wDCRTMS9kakPwrJb71Rf9QFjL/71+8v/7wmf/eGw97gun8xIpdDcn45eNe/7l8O4Op30hv8Q/7bH7zrEgahgmnYl1yi/+AkxzCa1JERYzUHZsI6pHKW8BlPYF3ZvKBzRubijskMlkNyJpdcYTIVuDcFlJQvuaba/LKxKJzm7EU/QRiGV5JnSMPL8yVMP5FUrtAZsmAU559CihItJGfGR3Jn2QeUEuAgBtascqU0W8ZBgIRXieTAM8WoBC4oE4pt8EhMVUfpQsYxaloF8OMSKTBlGkOVmRBzWTphgHLrP45PRDbj80KaAOI4paei0LHxKqdIdFGCI0NcbpBmCymK+QJJwrI7LkW2ZJkmd1RyQ8oI/P84vupdv+/EQX8GxQX3Uj5tTMldWLp2OTYOpYPGHSalSbVkupAm7QR+ggAlYsrq8dP0M7PrKnOw8jyGosFblV+tfsc+XirEZ5sMG3ubzzInNhGm2hdUTo/QnynkUEPdB6qY+DyYSbEkE6pcUJ0wrH2r/I0JxGrtIoQHVCmoDE2YoCyPRa6PpUiWIZoUKH8UfNGQ9xktUlxPWrAA2RoEUHNCQvpE+U2o8hvogvsGTAqCIEkp1Om5SdEFRjm6MBIIqeqcBgS8V2gGs5DxnOmxm26MroxxaZExs7ny3fQHGWMglaZp6g0CX9+VrLWRLlPu0rYkbpClDo6McYEwXxORnG31D8z5jIRVjENMoVCxYwN60g64tvnkjb4F+4Nt9sZFyWhqpt4w6qxDtXHPDxhuPkJxzcaujMdZsZwwGS3pv0J2CUQM/4NhycKP/5sT+JCfiDEjh3iNV2gNV8bcn8aSvJwjQi+6JKNLVqZn6PIBap6DwyDpAATDTcIgiue+RrnUYPWkqENwn4tY5KxEDmUIO0QGZQPKfhYWenb0S9ix8V+ibxBDyWLzNZLhgZ2IvFLkMLqZHnZC8sp41zX4HTMOMmvsLQqxZQA45sd4DpKVR2867qaLEmhNZOw6AYaOcqhSvwJelzOHX7+ZcrO72k32OobFAXJkYkQOHWz9ExLTV+AGiDoDUTCdxddvsG/eZGEJYeiwHYKDHNq1WEVAFQS/2ZxhdKFoaBp2ylyCs1PLeKu1LcQ7xaWisblVZ5WKnIkJoLtlnF5z+l/BqzTGPJuJY2d4NGV3YZUL9LQcwr5Ad6EiH7BTpqkl6Kbey13LzG7Cj4hmw6iCXwPs1Jx26wG/P92aG6Z1MxSsjULtnFWUcXewFYKO44zMsIKnUSfGX/Ko07BbT1IfGSvodHQUxi4ctgjPLN3qYz+d3DobU5rtNm+cTe5Ko83m7W0561HYqVAdz1vEAlnQ8jMWXhsvWgSsFF/LNBj41YUnbEEIT5+hZ4D3LWgIpZ2txvcFz2WRab5kO7DeM2548DTqA84xz5K0mLJj/A7/SsB48RJl4OnB6/f9KzK0Xj+3MA7g8QoExFNpaF1sS/UFHyeggSjFDnUbZhz/dTEcQWPfJfcLAc7ZvsJiAVLkmYw/9v6AZ4GfzDZzSGp3+oPLYekC7CKS8Tvb/Jdbk2W+3fSPEpomBezy4B2ToO+Km+cuaOuEaYcszj0FPQU6YSUIAzYpFG5FCvurnMLqoX6hr4OaxycQVSvDth3Pjxhkd2NtYb1KnwqBQQhrJfmmtj0bXKToBpf9utrKXaiu7bxu7T38utqYs1ZdS46quUNlOcP9dhQLUlWWu/xuhfWxfwn4z62pJxAKJrq6GOzLqSaKR6tSvZ+OAr3s+fvw+fvFlry37xg1NvncbkUJT71k7b1/1OZudkwTePzZrWkylg+wHPxTHJ4nx37YTSdSLSZ0MBXP3fXRmuhdo5fYCf0n/rl8dgeN2AqBPaHMaPoI1i8G4batHgcis/qOG8AM+9X2pTTbJfyx+eTj1fSsbJugah+o2LEXwhbnNkwrwwn0aJ99WfBNd5KBMpZ1HagkACMBmlAKQaNRb4/R0xUC8vPbh95oX4nYgHmeRjRh9hWJbXWzta/06rHRWrbgvKxO1GdvCoWEWtxNKIzlXk2mA/Flwgd9+f2wtRCe+/S1M9OGvcG7kml1jtUBHTW8DDSo0RJxaI22pOLRvqg+VZMHs5nejQZouF9TZEF8FsDl0fckAeD/gKaoERwz8umK10TBkU8XvAaKGbmf3rXlfavcrdnUoPQmyouL3XruJse1pAmEbCeaO9u9mW5xfLLbX74Lzy30D6A6qNz1sHd+MXyB/b0J5O3wB+tMVGdSgtn3hgsKD97UMbRcsRNWv1hO9iH9FhZs5X2NYQ3qt2K9OPtrHjTfKyl8l7DjayVnu18BVDj+yyX7y/cT/GqKH1AI1VzuZP7j80R/A+Z5qt+EudpT9rfToL0CmhRrvO1pxXrZCmh64Cpg4yCheRBEyZQneBiH56FiZh8vjSMZwxNctx48IHjspK7lUDDYxu62EnyE4K9H1Vnr+nTVcLuSRgNeEXxzDqtWTOHxKmTo2/ryk33ZrYVIP3NtrMPb8iV67VCvHBEX+ZRqFu1yVtPZMmqXN97bxj72Pm/buEffkTww8OFnpgcGPthkPzDuscblgcA+Jvkdf/swEK5klhRk3NBRy9Vp9SLlM1t1ywO8jCghNZtGm+UVQ/kuVdSpJNv8JUQUvlKn5JXCA89ojWT8d3+h5HEeD7PcGw+1UrH9C4oY/x6IReFNdjEcXg5Pgb43mXf8qLSMALBTDYNi0HhSGgRQguMxnnCOx+TsjITjMa5xPDYKZJcb/B9qO76I
|
1
third_party/gpus/rocm/BUILD.tpl
vendored
1
third_party/gpus/rocm/BUILD.tpl
vendored
@ -147,7 +147,6 @@ filegroup(
|
|||||||
name = "rocm_root",
|
name = "rocm_root",
|
||||||
srcs = [
|
srcs = [
|
||||||
"rocm/bin/clang-offload-bundler",
|
"rocm/bin/clang-offload-bundler",
|
||||||
"rocm/bin/bin2c.py",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
4
third_party/gpus/rocm/rocm_config.h.tpl
vendored
4
third_party/gpus/rocm/rocm_config.h.tpl
vendored
@ -18,4 +18,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#define TF_ROCM_TOOLKIT_PATH "%{rocm_toolkit_path}"
|
#define TF_ROCM_TOOLKIT_PATH "%{rocm_toolkit_path}"
|
||||||
|
|
||||||
|
#define TF_ROCM_VERSION %{rocm_version_number}
|
||||||
|
#define TF_MIOPEN_VERSION %{miopen_version_number}
|
||||||
|
#define TF_HIPRUNTIME_VERSION %{hipruntime_version_number}
|
||||||
|
|
||||||
#endif // ROCM_ROCM_CONFIG_H_
|
#endif // ROCM_ROCM_CONFIG_H_
|
||||||
|
152
third_party/gpus/rocm_configure.bzl
vendored
152
third_party/gpus/rocm_configure.bzl
vendored
@ -4,11 +4,7 @@
|
|||||||
|
|
||||||
* `TF_NEED_ROCM`: Whether to enable building with ROCm.
|
* `TF_NEED_ROCM`: Whether to enable building with ROCm.
|
||||||
* `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
|
* `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
|
||||||
* `ROCM_TOOLKIT_PATH`: The path to the ROCm toolkit. Default is
|
* `ROCM_PATH`: The path to the ROCm toolkit. Default is `/opt/rocm`.
|
||||||
`/opt/rocm`.
|
|
||||||
* `TF_ROCM_VERSION`: The version of the ROCm toolkit. If this is blank, then
|
|
||||||
use the system default.
|
|
||||||
* `TF_MIOPEN_VERSION`: The version of the MIOpen library.
|
|
||||||
* `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets.
|
* `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -27,6 +23,7 @@ load(
|
|||||||
"get_bash_bin",
|
"get_bash_bin",
|
||||||
"get_cpu_value",
|
"get_cpu_value",
|
||||||
"get_host_environ",
|
"get_host_environ",
|
||||||
|
"get_python_bin",
|
||||||
"raw_exec",
|
"raw_exec",
|
||||||
"realpath",
|
"realpath",
|
||||||
"which",
|
"which",
|
||||||
@ -35,13 +32,9 @@ load(
|
|||||||
_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
|
_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
|
||||||
_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
|
_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
|
||||||
_ROCM_TOOLKIT_PATH = "ROCM_PATH"
|
_ROCM_TOOLKIT_PATH = "ROCM_PATH"
|
||||||
_TF_ROCM_VERSION = "TF_ROCM_VERSION"
|
|
||||||
_TF_MIOPEN_VERSION = "TF_MIOPEN_VERSION"
|
|
||||||
_TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS"
|
_TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS"
|
||||||
_TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO"
|
_TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO"
|
||||||
|
|
||||||
_DEFAULT_ROCM_VERSION = ""
|
|
||||||
_DEFAULT_MIOPEN_VERSION = ""
|
|
||||||
_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm"
|
_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm"
|
||||||
|
|
||||||
def verify_build_defines(params):
|
def verify_build_defines(params):
|
||||||
@ -193,6 +186,7 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin):
|
|||||||
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include")
|
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include")
|
||||||
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include")
|
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include")
|
||||||
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/11.0.0/include")
|
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/11.0.0/include")
|
||||||
|
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/12.0.0/include")
|
||||||
|
|
||||||
# Support hcc based off clang 10.0.0 (for ROCm 3.3)
|
# Support hcc based off clang 10.0.0 (for ROCm 3.3)
|
||||||
inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
|
inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
|
||||||
@ -212,20 +206,6 @@ def _enable_rocm(repository_ctx):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _rocm_toolkit_path(repository_ctx, bash_bin):
|
|
||||||
"""Finds the rocm toolkit directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repository_ctx: The repository context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A speculative real path of the rocm toolkit install directory.
|
|
||||||
"""
|
|
||||||
rocm_toolkit_path = get_host_environ(repository_ctx, _ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH)
|
|
||||||
if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]:
|
|
||||||
auto_configure_fail("Cannot find rocm toolkit path.")
|
|
||||||
return rocm_toolkit_path
|
|
||||||
|
|
||||||
def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin):
|
def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin):
|
||||||
"""Returns a list of strings representing AMDGPU targets."""
|
"""Returns a list of strings representing AMDGPU targets."""
|
||||||
amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
|
amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
|
||||||
@ -236,7 +216,7 @@ def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin):
|
|||||||
amdgpu_targets_str = ",".join(targets)
|
amdgpu_targets_str = ",".join(targets)
|
||||||
amdgpu_targets = amdgpu_targets_str.split(",")
|
amdgpu_targets = amdgpu_targets_str.split(",")
|
||||||
for amdgpu_target in amdgpu_targets:
|
for amdgpu_target in amdgpu_targets:
|
||||||
if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
|
if amdgpu_target[:3] != "gfx":
|
||||||
auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target)
|
auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target)
|
||||||
return amdgpu_targets
|
return amdgpu_targets
|
||||||
|
|
||||||
@ -265,51 +245,6 @@ def _hipcc_env(repository_ctx):
|
|||||||
hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";")
|
hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";")
|
||||||
return hipcc_env.strip()
|
return hipcc_env.strip()
|
||||||
|
|
||||||
def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
|
|
||||||
"""Returns if hipcc is based on hip-clang toolchain.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repository_ctx: The repository context.
|
|
||||||
rocm_config: The path to the hip compiler.
|
|
||||||
bash_bin: the path to the bash interpreter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string "True" if hipcc is based on hip-clang toolchain.
|
|
||||||
The functions returns "False" if not (ie: based on HIP/HCC toolchain).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# check user-defined hip-clang environment variables
|
|
||||||
for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]:
|
|
||||||
if get_host_environ(repository_ctx, name):
|
|
||||||
return "True"
|
|
||||||
|
|
||||||
# grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo
|
|
||||||
cmd = "grep HIP_COMPILER=clang %s/hip/lib/.hipInfo || true" % rocm_config.rocm_toolkit_path
|
|
||||||
grep_result = execute(repository_ctx, [bash_bin, "-c", cmd], empty_stdout_fine = True)
|
|
||||||
result = grep_result.stdout.strip()
|
|
||||||
if result == "HIP_COMPILER=clang":
|
|
||||||
return "True"
|
|
||||||
return "False"
|
|
||||||
|
|
||||||
def _if_hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin, if_true, if_false = []):
|
|
||||||
"""
|
|
||||||
Returns either the if_true or if_false arg based on whether hipcc
|
|
||||||
is based on the hip-clang toolchain
|
|
||||||
|
|
||||||
Args :
|
|
||||||
repository_ctx: The repository context.
|
|
||||||
rocm_config: The path to the hip compiler.
|
|
||||||
if_true : value to return if hipcc is hip-clang based
|
|
||||||
if_false : value to return if hipcc is not hip-clang based
|
|
||||||
(optional, defaults to empty list)
|
|
||||||
|
|
||||||
Returns :
|
|
||||||
either the if_true arg or the of_False arg
|
|
||||||
"""
|
|
||||||
if _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin) == "True":
|
|
||||||
return if_true
|
|
||||||
return if_false
|
|
||||||
|
|
||||||
def _crosstool_verbose(repository_ctx):
|
def _crosstool_verbose(repository_ctx):
|
||||||
"""Returns the environment variable value CROSSTOOL_VERBOSE.
|
"""Returns the environment variable value CROSSTOOL_VERBOSE.
|
||||||
|
|
||||||
@ -402,7 +337,40 @@ def _find_libs(repository_ctx, rocm_config, bash_bin):
|
|||||||
|
|
||||||
return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin)
|
return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin)
|
||||||
|
|
||||||
def _get_rocm_config(repository_ctx, bash_bin):
|
def _exec_find_rocm_config(repository_ctx, script_path):
|
||||||
|
python_bin = get_python_bin(repository_ctx)
|
||||||
|
|
||||||
|
# If used with remote execution then repository_ctx.execute() can't
|
||||||
|
# access files from the source tree. A trick is to read the contents
|
||||||
|
# of the file in Starlark and embed them as part of the command. In
|
||||||
|
# this case the trick is not sufficient as the find_cuda_config.py
|
||||||
|
# script has more than 8192 characters. 8192 is the command length
|
||||||
|
# limit of cmd.exe on Windows. Thus we additionally need to compress
|
||||||
|
# the contents locally and decompress them as part of the execute().
|
||||||
|
compressed_contents = repository_ctx.read(script_path)
|
||||||
|
decompress_and_execute_cmd = (
|
||||||
|
"from zlib import decompress;" +
|
||||||
|
"from base64 import b64decode;" +
|
||||||
|
"from os import system;" +
|
||||||
|
"script = decompress(b64decode('%s'));" % compressed_contents +
|
||||||
|
"f = open('script.py', 'wb');" +
|
||||||
|
"f.write(script);" +
|
||||||
|
"f.close();" +
|
||||||
|
"system('\"%s\" script.py');" % (python_bin)
|
||||||
|
)
|
||||||
|
|
||||||
|
return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
|
||||||
|
|
||||||
|
def find_rocm_config(repository_ctx, script_path):
|
||||||
|
"""Returns ROCm config dictionary from running find_rocm_config.py"""
|
||||||
|
exec_result = _exec_find_rocm_config(repository_ctx, script_path)
|
||||||
|
if exec_result.return_code:
|
||||||
|
auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result))
|
||||||
|
|
||||||
|
# Parse the dict from stdout.
|
||||||
|
return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
|
||||||
|
|
||||||
|
def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script):
|
||||||
"""Detects and returns information about the ROCm installation on the system.
|
"""Detects and returns information about the ROCm installation on the system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -413,11 +381,21 @@ def _get_rocm_config(repository_ctx, bash_bin):
|
|||||||
A struct containing the following fields:
|
A struct containing the following fields:
|
||||||
rocm_toolkit_path: The ROCm toolkit installation directory.
|
rocm_toolkit_path: The ROCm toolkit installation directory.
|
||||||
amdgpu_targets: A list of the system's AMDGPU targets.
|
amdgpu_targets: A list of the system's AMDGPU targets.
|
||||||
|
rocm_version_number: The version of ROCm on the system.
|
||||||
|
miopen_version_number: The version of MIOpen on the system.
|
||||||
|
hipruntime_version_number: The version of HIP Runtime on the system.
|
||||||
"""
|
"""
|
||||||
rocm_toolkit_path = _rocm_toolkit_path(repository_ctx, bash_bin)
|
config = find_rocm_config(repository_ctx, find_rocm_config_script)
|
||||||
|
rocm_toolkit_path = config["rocm_toolkit_path"]
|
||||||
|
rocm_version_number = config["rocm_version_number"]
|
||||||
|
miopen_version_number = config["miopen_version_number"]
|
||||||
|
hipruntime_version_number = config["hipruntime_version_number"]
|
||||||
return struct(
|
return struct(
|
||||||
rocm_toolkit_path = rocm_toolkit_path,
|
|
||||||
amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin),
|
amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin),
|
||||||
|
rocm_toolkit_path = rocm_toolkit_path,
|
||||||
|
rocm_version_number = rocm_version_number,
|
||||||
|
miopen_version_number = miopen_version_number,
|
||||||
|
hipruntime_version_number = hipruntime_version_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _tpl_path(repository_ctx, labelname):
|
def _tpl_path(repository_ctx, labelname):
|
||||||
@ -550,8 +528,10 @@ def _create_local_rocm_repository(repository_ctx):
|
|||||||
"rocm:rocm_config.h",
|
"rocm:rocm_config.h",
|
||||||
]}
|
]}
|
||||||
|
|
||||||
|
find_rocm_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_rocm_config.py.gz.base64"))
|
||||||
|
|
||||||
bash_bin = get_bash_bin(repository_ctx)
|
bash_bin = get_bash_bin(repository_ctx)
|
||||||
rocm_config = _get_rocm_config(repository_ctx, bash_bin)
|
rocm_config = _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script)
|
||||||
|
|
||||||
# Copy header and library files to execroot.
|
# Copy header and library files to execroot.
|
||||||
# rocm_toolkit_path
|
# rocm_toolkit_path
|
||||||
@ -609,13 +589,7 @@ def _create_local_rocm_repository(repository_ctx):
|
|||||||
outs = rocm_lib_outs,
|
outs = rocm_lib_outs,
|
||||||
))
|
))
|
||||||
|
|
||||||
clang_offload_bundler_path = rocm_toolkit_path + _if_hipcc_is_hipclang(
|
clang_offload_bundler_path = rocm_toolkit_path + "/llvm/bin/clang-offload-bundler"
|
||||||
repository_ctx,
|
|
||||||
rocm_config,
|
|
||||||
bash_bin,
|
|
||||||
"/llvm/bin/",
|
|
||||||
"/hcc/bin/",
|
|
||||||
) + "clang-offload-bundler"
|
|
||||||
|
|
||||||
# copy files mentioned in third_party/gpus/rocm/BUILD
|
# copy files mentioned in third_party/gpus/rocm/BUILD
|
||||||
copy_rules.append(make_copy_files_rule(
|
copy_rules.append(make_copy_files_rule(
|
||||||
@ -688,17 +662,7 @@ def _create_local_rocm_repository(repository_ctx):
|
|||||||
"-DTENSORFLOW_USE_ROCM=1",
|
"-DTENSORFLOW_USE_ROCM=1",
|
||||||
"-D__HIP_PLATFORM_HCC__",
|
"-D__HIP_PLATFORM_HCC__",
|
||||||
"-DEIGEN_USE_HIP",
|
"-DEIGEN_USE_HIP",
|
||||||
] + _if_hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin, [
|
])
|
||||||
#
|
|
||||||
# define "TENSORFLOW_COMPILER_IS_HIP_CLANG" when we are using clang
|
|
||||||
# based hipcc to compile/build tensorflow
|
|
||||||
#
|
|
||||||
# Note that this #define should not be used to check whether or not
|
|
||||||
# tensorflow is being built with ROCm support
|
|
||||||
# (only TENSORFLOW_USE_ROCM should be used for that purpose)
|
|
||||||
#
|
|
||||||
"-DTENSORFLOW_COMPILER_IS_HIP_CLANG=1",
|
|
||||||
]))
|
|
||||||
|
|
||||||
rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
|
rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
|
||||||
|
|
||||||
@ -729,7 +693,6 @@ def _create_local_rocm_repository(repository_ctx):
|
|||||||
"%{cpu_compiler}": str(cc),
|
"%{cpu_compiler}": str(cc),
|
||||||
"%{hipcc_path}": rocm_config.rocm_toolkit_path + "/hip/bin/hipcc",
|
"%{hipcc_path}": rocm_config.rocm_toolkit_path + "/hip/bin/hipcc",
|
||||||
"%{hipcc_env}": _hipcc_env(repository_ctx),
|
"%{hipcc_env}": _hipcc_env(repository_ctx),
|
||||||
"%{hipcc_is_hipclang}": _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin),
|
|
||||||
"%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib",
|
"%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib",
|
||||||
"%{rocr_runtime_library}": "hsa-runtime64",
|
"%{rocr_runtime_library}": "hsa-runtime64",
|
||||||
"%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib",
|
"%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib",
|
||||||
@ -749,6 +712,9 @@ def _create_local_rocm_repository(repository_ctx):
|
|||||||
["\"%s\"" % c for c in rocm_config.amdgpu_targets],
|
["\"%s\"" % c for c in rocm_config.amdgpu_targets],
|
||||||
),
|
),
|
||||||
"%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path,
|
"%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path,
|
||||||
|
"%{rocm_version_number}": rocm_config.rocm_version_number,
|
||||||
|
"%{miopen_version_number}": rocm_config.miopen_version_number,
|
||||||
|
"%{hipruntime_version_number}": rocm_config.hipruntime_version_number,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -813,8 +779,6 @@ _ENVIRONS = [
|
|||||||
_GCC_HOST_COMPILER_PREFIX,
|
_GCC_HOST_COMPILER_PREFIX,
|
||||||
"TF_NEED_ROCM",
|
"TF_NEED_ROCM",
|
||||||
_ROCM_TOOLKIT_PATH,
|
_ROCM_TOOLKIT_PATH,
|
||||||
_TF_ROCM_VERSION,
|
|
||||||
_TF_MIOPEN_VERSION,
|
|
||||||
_TF_ROCM_AMDGPU_TARGETS,
|
_TF_ROCM_AMDGPU_TARGETS,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user