Patch to enable r2.0 ROCm non-xla support

The goal for this PR is to patch Tensorflow r2.0 release, so it would fully enable ROCm
non-xla path support.

Most of the PRs been cherry-picked in this patch have already been upstreamed in the
upstream master branch.

The following were all the related commits been cherry-picked:

Commits on Aug 20, 2019
deven-amd and sunway513
adding/updating ROCm support in the ci_build scripts
d5a0eee
deven-amd and sunway513
updating Dockerfile.rocm to pick a specific version of the rocm libra… …
e335575
deven-amd and sunway513
adding a script for testing the ROCm Community Supported Build
ae83a20

Commits on Aug 22, 2019
deven-amd and sunway513
Resolve merge conflicts for PR #31393
73ff708
deven-amd and sunway513
The following PR/commit breaks the --config=rocm build …
614bdb5
deven-amd and sunway513
updating testcases to work correctly with ROCm
1685240
jeffdaily and sunway513
improve concurrency between compute and nccl streams …
3fbb049
whchung and sunway513
[ROCm] enable roll op on ROCm.
1d5f440
whchung and sunway513
[ROCm] enable InTopK op on ROCm.
941f713
deven-amd and sunway513
updating README.md with information on ROCm Community Supported Builds
73ce64e

Commits on Aug 25, 2019
houtoms and sunway513
fixed potential rocm breaks from use_padded_io
0832b33
deven-amd and sunway513
adding no_rocm tag on unit-tests that check features that are current… …
7aed626
deven-amd and sunway513
Adding ROCm support for reduction ops
82bd216
sunway513
Fix ROCm path build error in rocm_dnn.h
5dba305

Commits on Aug 27, 2019
deven-amd
fixing test failures by skipping parts that functionality not yet sup… …
be6378c
sunway513
Merge pull request #616 from ROCmSoftwarePlatform/r2.0-rocm-upstream-… …
d98a943
sunway513
Add no_rocm tag to //tensorflow/python:stateful_random_ops_test_gpu
d05a47f

Commits on Sep 04, 2019
sunway513
Merge branch 'r2.0-rocm-upstream' of https://github.com/ROCmSoftwareP… …
b1148e4

Commits on Sep 06, 2019
deven-amd and sunway513
adding ROCm support in the build_pip_package script
b908324
This commit is contained in:
Peng Sun 2019-08-08 14:59:46 +00:00 committed by sunway513
parent 4096702326
commit 4da1ccb9ff
48 changed files with 237 additions and 63 deletions

View File

@ -116,6 +116,8 @@ The TensorFlow project strives to abide by generally accepted best practices in
Build Type | Status | Artifacts
--------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | [Release](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)

View File

@ -83,7 +83,10 @@ void ExecuteWithProfiling(bool async) {
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
// device name with "stream:all" is collected by Device Tracer.
#ifndef TENSORFLOW_USE_ROCM
// ROCm platform does not yet support stream level tracing
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
#endif
}
// "/host:CPU" is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));

View File

@ -9,6 +9,7 @@ tf_cuda_cc_test(
name = "profiler_test",
srcs = ["profiler_test.cc"],
tags = [
"no_rocm", # stream level tracing not supported on ROCm
"nogpu", # b/77649654
],
deps = [

View File

@ -4605,7 +4605,7 @@ tf_cc_test(
size = "small",
srcs = ["common_runtime/constant_folding_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
tags = tf_cuda_tests_tags() + ["no_rocm"],
deps = [
":core",
":core_cpu",
@ -4671,6 +4671,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["common_runtime/process_function_library_runtime_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = ["no_rocm"],
deps = [
":core_cpu",
":core_cpu_internal",

View File

@ -51,9 +51,11 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_runtime.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
@ -2089,6 +2091,12 @@ bool IsCUDATensor(const Tensor& t) {
if (err == cudaErrorInvalidValue) return false;
CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err);
return (attributes.memoryType == cudaMemoryTypeDevice);
#elif TENSORFLOW_USE_ROCM
hipPointerAttribute_t attributes;
hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data());
if (err == hipErrorInvalidValue) return false;
CHECK_EQ(hipSuccess, err) << hipGetErrorString(err);
return (attributes.memoryType == hipMemoryTypeDevice);
#else
return false;
#endif

View File

@ -33,9 +33,11 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_runtime.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
@ -122,7 +124,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
}
Tensor GPUToCPU(const Tensor& device_tensor) {
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
CHECK(gpu_device_);
CHECK(gpu_device_->tensorflow_gpu_device_info() != nullptr);
DeviceContext* device_context =
@ -146,7 +148,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
}
Tensor CPUToGPU(const Tensor& cpu_tensor) {
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
CHECK(gpu_device_);
CHECK(gpu_device_->tensorflow_gpu_device_info() != nullptr);
DeviceContext* device_context =
@ -461,6 +463,12 @@ bool IsCUDATensor(const Tensor& t) {
if (err == cudaErrorInvalidValue) return false;
CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err);
return (attributes.memoryType == cudaMemoryTypeDevice);
#elif TENSORFLOW_USE_ROCM
hipPointerAttribute_t attributes;
hipError_t err = hipPointerGetAttributes(&attributes, t.tensor_data().data());
if (err == hipErrorInvalidValue) return false;
CHECK_EQ(hipSuccess, err) << hipGetErrorString(err);
return (attributes.memoryType == hipMemoryTypeDevice);
#else
CHECK(false)
<< "IsCUDATensor should not be called when CUDA is not available";

View File

@ -40,6 +40,18 @@ TEST(UtilsTest, GetLocalGPUInfo) {
properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
EXPECT_EQ("NVIDIA", properties.vendor());
#elif TENSORFLOW_USE_ROCM
LOG(INFO) << "ROCm is enabled.";
DeviceProperties properties;
// Invalid platform GPU ID.
properties = GetLocalGPUInfo(PlatformGpuId(100));
EXPECT_EQ("UNKNOWN", properties.type());
// Succeed when a valid platform GPU id was inserted.
properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
EXPECT_EQ("Advanced Micro Devices, Inc", properties.vendor());
#else
LOG(INFO) << "CUDA is not enabled.";
DeviceProperties properties;
@ -73,6 +85,8 @@ TEST(UtilsTest, GetDeviceInfo) {
EXPECT_EQ("GPU", properties.type());
#if GOOGLE_CUDA
EXPECT_EQ("NVIDIA", properties.vendor());
#elif TENSORFLOW_USE_ROCM
EXPECT_EQ("Advanced Micro Devices, Inc", properties.vendor());
#endif
// TF to platform GPU id mapping entry doesn't exist.
@ -81,7 +95,7 @@ TEST(UtilsTest, GetDeviceInfo) {
properties = GetDeviceInfo(device);
EXPECT_EQ("UNKNOWN", properties.type());
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Invalid platform GPU id.
TF_ASSERT_OK(
GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(0), PlatformGpuId(100)));
@ -94,7 +108,11 @@ TEST(UtilsTest, GetDeviceInfo) {
device.id = 1;
properties = GetDeviceInfo(device);
EXPECT_EQ("GPU", properties.type());
#if GOOGLE_CUDA
EXPECT_EQ("NVIDIA", properties.vendor());
#elif TENSORFLOW_USE_ROCM
EXPECT_EQ("Advanced Micro Devices, Inc", properties.vendor());
#endif
#endif
}

View File

@ -203,7 +203,7 @@ TEST_F(PinToHostOptimizerTest, Identity) {
// If CUDA, then there is a GPU kernel registration that is pinned to Host
// memory. Consequently, `b` will be mapped to Host correct if there is
// a GPU kernel registered.
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
EXPECT_EQ(node.device(), "/device:CPU:0");
#else
EXPECT_TRUE(node.device().empty());

View File

@ -970,6 +970,15 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
se::TfAllocatorAdapter tf_allocator_adapter(
stream->parent()->platform(), ctx->device()->GetAllocator({}));
se::cuda::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
se::cuda::PtxCompilationOptions());
se::DeviceMemory<T> filter_backprop_ptr_rz(
WrapRedzoneBestEffort(&rz_allocator, filter_backprop_ptr));
std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),

View File

@ -1096,6 +1096,16 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
se::TfAllocatorAdapter tf_allocator_adapter(
stream->parent()->platform(), ctx->device()->GetAllocator({}));
se::cuda::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
se::cuda::PtxCompilationOptions());
se::DeviceMemory<T> in_backprop_ptr_rz(
WrapRedzoneBestEffort(&rz_allocator, in_backprop_ptr));
std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),

View File

@ -1001,6 +1001,10 @@ class FusedConv2DWithBatchNormOpTest : public FusedConv2DOpTest<T> {};
TYPED_TEST_SUITE_P(FusedConv2DWithBiasOpTest);
TYPED_TEST_SUITE_P(FusedConv2DWithBatchNormOpTest);
// ROCm does not yet support the _FusedConv2D op,
// Therefore disable tests that check _FusedConv2D, when building with ROCm
#ifndef TENSORFLOW_USE_ROCM
// -------------------------------------------------------------------------- //
// Conv2D + BiasAdd + {Activation} //
// -------------------------------------------------------------------------- //
@ -1165,4 +1169,5 @@ using FusedBatchNormDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedConv2DWithBatchNormOpTest,
FusedBatchNormDataTypes);
#endif // TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -116,7 +116,7 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2")
.TypeConstraint<int64>("T"),
InTopK<CPUDevice, float, int64>);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
@ -142,6 +142,6 @@ REGISTER_KERNEL_BUILDER(
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int64>("T"),
InTopK<GPUDevice, float, int64>);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -41,7 +41,7 @@ __global__ void ComputePredictionMaskKernel(
const TargetT* targets, // dims: [ num_targets ]
int64* mask, // dims: [ num_targets x num_classes ]
int num_targets, int num_classes) {
CUDA_1D_KERNEL_LOOP(i, num_targets * num_classes) {
GPU_1D_KERNEL_LOOP(i, num_targets * num_classes) {
const int batch_index = i / num_classes;
TargetT target_idx = ldg(targets + batch_index);
@ -118,7 +118,7 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
const auto& d = context->eigen_device<GPUDevice>();
// Compute a mask for all predictions.
CudaLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
GpuLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
OP_REQUIRES_OK(
context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
config.block_count, config.thread_per_block, 0,
@ -173,4 +173,4 @@ DEFINE_GPU_KERNELS(float, int64);
} // end namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -117,6 +117,10 @@ struct Identity {
FIX_MEAN_IDENTITY(Eigen::half)
FIX_MEAN_IDENTITY(float)
FIX_MEAN_IDENTITY(double)
#if GOOGLE_CUDA
FIX_MEAN_IDENTITY(complex64)
FIX_MEAN_IDENTITY(complex128)
#endif
#undef FIX_MEAN_IDENTITY
template <typename Device, typename OUT_T, typename Reducer>

View File

@ -30,7 +30,7 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, bool, int64, Eigen::internal::AndReducer>);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(
Name("All")
.TypeConstraint<int32>("Tidx")

View File

@ -30,7 +30,7 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, bool, int64, Eigen::internal::OrReducer>);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(
Name("Any")
.TypeConstraint<int32>("Tidx")

View File

@ -15,8 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_
#if !GOOGLE_CUDA
#error This file must only be included when building with Cuda support
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
#error This file must only be included when building with GPU support
#endif
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

View File

@ -33,7 +33,7 @@ namespace tensorflow {
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("EuclideanNorm") \
@ -51,8 +51,10 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
ReductionOp<GPUDevice, type, int64, \
functor::EuclideanNormReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_GPU_KERNELS);
TF_CALL_complex128(REGISTER_GPU_KERNELS);
#endif
#undef REGISTER_GPU_KERNELS
#endif

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -59,4 +59,4 @@ DEFINE_FOR_TYPE_AND_R(bool, Eigen::internal::OrReducer);
} // end namespace functor
} // end namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -67,4 +67,4 @@ DEFINE_FOR_ALL_REDUCERS(double);
} // end namespace functor
} // end namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -67,4 +67,4 @@ DEFINE_FOR_ALL_REDUCERS(float);
} // end namespace functor
} // end namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -68,4 +68,4 @@ DEFINE_FOR_ALL_REDUCERS(int64);
} // end namespace functor
} // end namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -64,4 +64,4 @@ DEFINE_FOR_ALL_REDUCERS(Eigen::half);
} // end namespace functor
} // end namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -64,4 +64,4 @@ DEFINE_FOR_ALL_REDUCERS(Eigen::half);
} // end namespace functor
} // end namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -33,7 +33,7 @@ namespace tensorflow {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \

View File

@ -33,7 +33,7 @@ namespace tensorflow {
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
@ -51,8 +51,10 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, functor::MeanReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_GPU_KERNELS);
TF_CALL_complex128(REGISTER_GPU_KERNELS);
#endif
#undef REGISTER_GPU_KERNELS
#endif

View File

@ -33,7 +33,7 @@ namespace tensorflow {
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \

View File

@ -33,7 +33,7 @@ namespace tensorflow {
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Prod") \
@ -52,8 +52,10 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
Eigen::internal::ProdReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int32(REGISTER_GPU_KERNELS);
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_GPU_KERNELS);
TF_CALL_complex128(REGISTER_GPU_KERNELS);
#endif
#undef REGISTER_GPU_KERNELS
#endif

View File

@ -33,7 +33,7 @@ namespace tensorflow {
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
@ -52,8 +52,10 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
ReductionOp<GPUDevice, type, int64, Eigen::internal::SumReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_GPU_KERNELS);
TF_CALL_complex128(REGISTER_GPU_KERNELS);
#endif
#undef REGISTER_GPU_KERNELS
// A special GPU kernel for int32.

View File

@ -360,7 +360,7 @@ struct Roll<CPUDevice, T> {
TF_CALL_ALL_TYPES(REGISTER_CPU);
#undef REGISTER_CPU
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Roll") \
.Device(DEVICE_GPU) \
@ -402,5 +402,5 @@ TF_CALL_complex64(REGISTER_KERNEL);
TF_CALL_complex128(REGISTER_KERNEL);
#undef REGISTER_KERNEL
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -33,7 +33,7 @@ template <typename T>
__global__ void RollKernel(const int32 nthreads, const int32 num_dims,
const T* input, T* output, const int32* dim_size,
const int32* threshold, const int64* dim_range) {
CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
int64 offset = 0;
for (int i = 0; i < num_dims; i++) {
const int64 stride = dim_range[i] / dim_size[i];
@ -71,7 +71,7 @@ struct Roll<GPUDevice, T> {
d.memcpyHostToDevice(thres_buf, threshold.data(), thres_bytes);
d.memcpyHostToDevice(range_buf, dim_range.data(), range_bytes);
CudaLaunchConfig cfg = GetGpuLaunchConfig(num_elements, d);
GpuLaunchConfig cfg = GetGpuLaunchConfig(num_elements, d);
TF_CHECK_OK(GpuLaunchKernel(RollKernel<T>, cfg.block_count,
cfg.thread_per_block, 0, d.stream(),
@ -98,4 +98,4 @@ TF_CALL_complex128(DEFINE_GPU_SPECS);
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -29,6 +29,7 @@ cc_library(
copts = tf_copts(),
deps = if_cuda([
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@local_config_nccl//:nccl",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",

View File

@ -533,7 +533,7 @@ void NcclManager::RunCollective(Collective* collective) {
// Wait to ensure that the kernel that produces the data in the input
// tensor has finished running before the nccl kernel runs on the
// communication stream.
nccl_stream->stream->ThenWaitFor(p->tensor_stream);
nccl_stream->stream->ThenWaitFor(p->input_event.get());
}
if (p->root) {
CHECK_EQ(collective->root_rank, -1);

View File

@ -27,6 +27,7 @@ limitations under the License.
#endif
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "third_party/nccl/nccl.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/tensor.h"
@ -63,6 +64,7 @@ class NcclManager {
event_mgr(event_mgr),
gpu_device_id(gpu_device_id),
input(input),
input_event(nullptr),
output(output),
global_rank(global_rank),
done_callback(std::move(done_callback)),
@ -70,6 +72,11 @@ class NcclManager {
DCHECK(executor != nullptr);
DCHECK(event_mgr != nullptr);
DCHECK(tensor_stream != nullptr);
if (input != nullptr) {
input_event = absl::make_unique<se::Event>(executor);
input_event->Init();
tensor_stream->ThenRecordEvent(input_event.get());
}
}
// StreamExecutor for the device. Expected to be live for process lifetime.
@ -94,6 +101,10 @@ class NcclManager {
// called. Is NULL for participants that only receive data.
const Tensor* input;
// Wait on this event rather than synchronizing on the entire stream.
// This allows greater concurrency between compute and nccl streams.
std::unique_ptr<se::Event> input_event;
// Owned by the caller, who must keep it live until `done_callback` is
// called. Is NULL for participants that only send data.
Tensor* output;

View File

@ -2501,6 +2501,7 @@ cuda_py_test(
],
tags = [
"no_cuda_on_cpu_tap",
"no_rocm",
"no_windows",
],
)
@ -3431,6 +3432,7 @@ cuda_py_test(
"//tensorflow/python/kernel_tests/random:util",
"//tensorflow/python/distribute:mirrored_strategy",
],
tags = ["no_rocm"],
xla_enable_strict_auto_jit = False,
)

View File

@ -61,7 +61,7 @@ class Conv3DTest(test.TestCase):
# as we will be using its gradients as reference for fp16 gradients.
return optional_float64 + [dtypes.float32, dtypes.float16]
else:
return [dtypes.float64, dtypes.float32, dtypes.float16]
return [dtypes.float32, dtypes.float16] + ([dtypes.float64] if not test.is_built_with_rocm else [])
def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride,
padding, data_format, dtype, use_gpu):

View File

@ -753,6 +753,13 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
else:
shape = [4, 16, 16, 16, 64]
convolution = convolutional.conv3d
if test.is_built_with_rocm():
# This subtest triggers a known bug in ROCm runtime code
# The bug has been fixed and will be available in ROCm 2.7
# Re-enable this test once ROCm 2.7 is released
continue
inputs = random_ops.random_normal(shape, dtype=dtype)
inputs_2norm = linalg_ops.norm(inputs)
outputs = convolution(

View File

@ -766,7 +766,7 @@ class PoolingTest(test.TestCase):
# The following are tests that verify that the CPU and GPU implementations
# produce the same results.
def _CompareMaxPoolingFwd(self, input_shape, ksize, strides, padding):
for dtype in np.float64, np.float32, np.float16:
for dtype in [np.float32, np.float16] + [np.float64] if not test.is_built_with_rocm() else []:
tensor_input = np.random.rand(*input_shape).astype(dtype)
with self.cached_session(use_gpu=True):
t = constant_op.constant(tensor_input, shape=input_shape)
@ -780,7 +780,7 @@ class PoolingTest(test.TestCase):
def _CompareMaxPoolingBk(self, input_shape, output_shape, ksize, strides,
padding):
for dtype in np.float64, np.float32, np.float16:
for dtype in [np.float32, np.float16] + [np.float64] if not test.is_built_with_rocm() else []:
# Generate numbers in a narrow range, so that there are many duplicates
# in the input.
tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype)
@ -810,7 +810,7 @@ class PoolingTest(test.TestCase):
def _CompareMaxPoolingGradBk(self, input_shape, output_shape, ksize, strides,
padding):
for dtype in np.float64, np.float32, np.float16:
for dtype in [np.float32, np.float16] + [np.float64] if not test.is_built_with_rocm() else []:
# Generate numbers in a narrow range, so that there are many duplicates
# in the input.
tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype)

View File

@ -2280,7 +2280,8 @@ MIOpenSupport::createRnnDescriptor(
int batch_size, dnn::RnnInputMode input_mode,
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
float dropout, uint64 seed, ScratchAllocator* state_allocator) {
float dropout, uint64 seed, ScratchAllocator* state_allocator,
bool use_padded_io) {
// ROCM TODO: cell_size is ignored for now
// ROCM TODO: batch_size is ignored for now

View File

@ -50,7 +50,8 @@ class MIOpenSupport : public dnn::DnnSupport {
int batch_size, dnn::RnnInputMode input_mode,
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
float dropout, uint64 seed, ScratchAllocator* state_allocator) override;
float dropout, uint64 seed, ScratchAllocator* state_allocator,
bool use_padded_io);
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int seq_length, int batch_size,

View File

@ -1,15 +1,15 @@
# This Dockerfile provides a starting point for a ROCm installation of
# MIOpen and tensorflow.
# This Dockerfile provides a starting point for a ROCm installation of
# MIOpen and tensorflow.
FROM ubuntu:xenial
MAINTAINER Jeff Poznanovic <jeffrey.poznanovic@amd.com>
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/debian/
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/2.6/
ARG ROCM_PATH=/opt/rocm
ENV DEBIAN_FRONTEND noninteractive
ENV TF_NEED_ROCM 1
ENV HOME /root/
RUN apt update && apt install -y wget software-properties-common
RUN apt update && apt install -y wget software-properties-common
# Add rocm repository
RUN apt-get clean all

View File

@ -131,8 +131,8 @@ echo "Using Bazel flags: ${BAZEL_FLAGS}"
PIP_BUILD_TARGET="//tensorflow/tools/pip_package:build_pip_package"
GPU_FLAG=""
ROCM_FLAG=""
if [[ ${CONTAINER_TYPE} == "cpu" ]] || \
[[ ${CONTAINER_TYPE} == "rocm" ]] || \
[[ ${CONTAINER_TYPE} == "debian.jessie.cpu" ]]; then
bazel build ${BAZEL_FLAGS} ${PIP_BUILD_TARGET} || \
die "Build failed."
@ -140,6 +140,10 @@ elif [[ ${CONTAINER_TYPE} == "gpu" ]]; then
bazel build ${BAZEL_FLAGS} ${PIP_BUILD_TARGET} || \
die "Build failed."
GPU_FLAG="--gpu"
elif [[ ${CONTAINER_TYPE} == "rocm" ]]; then
bazel build ${BAZEL_FLAGS} ${PIP_BUILD_TARGET} || \
die "Build failed."
ROCM_FLAG="--rocm"
else
die "Unrecognized container type: \"${CONTAINER_TYPE}\""
fi
@ -193,7 +197,7 @@ fi
PIP_WHL_DIR="${PIP_TEST_ROOT}/whl"
PIP_WHL_DIR=$(realpath ${PIP_WHL_DIR}) # Get absolute path
rm -rf ${PIP_WHL_DIR} && mkdir -p ${PIP_WHL_DIR}
bazel-bin/tensorflow/tools/pip_package/build_pip_package ${PIP_WHL_DIR} ${GPU_FLAG} ${NIGHTLY_FLAG} || \
bazel-bin/tensorflow/tools/pip_package/build_pip_package ${PIP_WHL_DIR} ${GPU_FLAG} ${ROCM_FLAG} ${NIGHTLY_FLAG} || \
die "build_pip_package FAILED"
WHL_PATH=$(ls ${PIP_WHL_DIR}/${PROJECT_NAME}*.whl)
@ -406,7 +410,7 @@ do_virtualenv_pip_test() {
return ${SKIP_RETURN_CODE}
else
# Call run_pip_tests.sh to perform test-on-install
"${SCRIPT_DIR}/run_pip_tests.sh" --virtualenv ${GPU_FLAG} ${MAC_FLAG}
"${SCRIPT_DIR}/run_pip_tests.sh" --virtualenv ${GPU_FLAG} ${ROCM_FLAG} ${MAC_FLAG}
if [[ $? != 0 ]]; then
echo "PIP tests-on-install FAILED"
return 1
@ -426,7 +430,7 @@ do_virtualenv_oss_serial_pip_test() {
else
# Call run_pip_tests.sh to perform test-on-install
"${SCRIPT_DIR}/run_pip_tests.sh" \
--virtualenv ${GPU_FLAG} ${MAC_FLAG} --oss_serial
--virtualenv ${GPU_FLAG} ${ROCM_FLAG} ${MAC_FLAG} --oss_serial
if [[ $? != 0 ]]; then
echo "PIP tests-on-install (oss_serial) FAILED"
return 1
@ -439,7 +443,7 @@ do_virtualenv_oss_serial_pip_test() {
################################################################################
do_test_user_ops() {
if [[ "${DO_TEST_USER_OPS}" == "1" ]]; then
"${SCRIPT_DIR}/test_user_ops.sh" --virtualenv ${GPU_FLAG}
"${SCRIPT_DIR}/test_user_ops.sh" --virtualenv ${GPU_FLAG} ${ROCM_FLAG}
if [[ $? != 0 ]]; then
echo "PIP user-op tests-on-install FAILED"
return 1

View File

@ -28,6 +28,7 @@ echo ""
export PYTHON_BIN_PATH=`which python3`
export CC_OPT_FLAGS='-mavx'
export TF_NEED_ROCM=0
export TF_NEED_CUDA=1
export TF_CUDA_COMPUTE_CAPABILITIES=3.7

View File

@ -28,6 +28,7 @@ echo ""
export PYTHON_BIN_PATH=`which python3`
export CC_OPT_FLAGS='-mavx'
export TF_NEED_ROCM=0
export TF_NEED_CUDA=1
export TF_CUDA_COMPUTE_CAPABILITIES=3.7

View File

@ -0,0 +1,56 @@
#!/usr/bin/env bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
set -e
set -x
N_JOBS=$(grep -c ^processor /proc/cpuinfo)
N_GPUS=$(lspci|grep 'VGA'|grep 'AMD/ATI'|wc -l)
echo ""
echo "Bazel will use ${N_JOBS} concurrent build job(s) and ${N_GPUS} concurrent test job(s)."
echo ""
# Run configure.
export PYTHON_BIN_PATH=`which python3`
export CC_OPT_FLAGS='-mavx'
export TF_NEED_ROCM=1
export TF_GPU_COUNT=${N_GPUS}
yes "" | $PYTHON_BIN_PATH configure.py
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test \
--config=rocm \
-k \
--test_tag_filters=gpu,-no_gpu,-no_rocm,-benchmark-test,-no_oss,-oss_serial, \
--test_timeout 600,900,2400,7200 \
--test_output=errors \
--jobs=${N_JOBS} \
--local_test_jobs=${TF_GPU_COUNT} \
--test_sharding_strategy=disabled \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
-- \
//tensorflow/... \
-//tensorflow/compiler/... \
-//tensorflow/contrib/... \
-//tensorflow/lite/... \
-//tensorflow/python/compiler/tensorrt/... \

View File

@ -19,23 +19,27 @@ set -e
set -x
N_JOBS=$(grep -c ^processor /proc/cpuinfo)
N_GPUS=$(lspci|grep 'VGA'|grep 'AMD/ATI'|wc -l)
echo ""
echo "Bazel will use ${N_JOBS} concurrent job(s)."
echo "Bazel will use ${N_JOBS} concurrent build job(s) and ${N_GPUS} concurrent test job(s)."
echo ""
# Run configure.
export PYTHON_BIN_PATH=`which python3`
export TF_NEED_ROCM=1
export TF_GPU_COUNT=${N_GPUS}
yes "" | $PYTHON_BIN_PATH configure.py
echo "build --distinct_host_configuration=false" >> .tf_configure.bazelrc
bazel clean
# Run bazel test command. Double test timeouts to avoid flakes.
bazel test --config=rocm --test_tag_filters=-no_gpu,-benchmark-test,-no_oss -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
--build_tests_only --test_output=errors --local_test_jobs=1 \
bazel test --config=rocm --test_tag_filters=-no_gpu,-benchmark-test,-no_oss,-no_rocm -k \
--jobs=${N_JOBS} --test_timeout 600,900,2400,7200 \
--build_tests_only --test_output=errors --local_test_jobs=${TF_GPU_COUNT} \
--test_sharding_strategy=disabled \
--run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \
--config=xla -- \
//tensorflow/compiler/...

View File

@ -229,6 +229,7 @@ function usage() {
echo " --project_name <name> set project name to name"
echo " --gpu build tensorflow_gpu"
echo " --gpudirect build tensorflow_gpudirect"
echo " --rocm build tensorflow_rocm"
echo " --nightly_flag build tensorflow nightly"
echo ""
exit 1
@ -238,6 +239,7 @@ function main() {
PKG_NAME_FLAG=""
PROJECT_NAME=""
GPU_BUILD=0
ROCM_BUILD=0
NIGHTLY_BUILD=0
SRCDIR=""
DSTDIR=""
@ -252,6 +254,8 @@ function main() {
GPU_BUILD=1
elif [[ "$1" == "--gpudirect" ]]; then
PKG_NAME_FLAG="--project_name tensorflow_gpudirect"
elif [[ "$1" == "--rocm" ]]; then
ROCM_BUILD=1
elif [[ "$1" == "--project_name" ]]; then
shift
if [[ -z "$1" ]]; then
@ -297,10 +301,14 @@ function main() {
PKG_NAME_FLAG="--project_name ${PROJECT_NAME}"
elif [[ ${NIGHTLY_BUILD} == "1" && ${GPU_BUILD} == "1" ]]; then
PKG_NAME_FLAG="--project_name tf_nightly_gpu"
elif [[ ${NIGHTLY_BUILD} == "1" && ${ROCM_BUILD} == "1" ]]; then
PKG_NAME_FLAG="--project_name tf_nightly_rocm"
elif [[ ${NIGHTLY_BUILD} == "1" ]]; then
PKG_NAME_FLAG="--project_name tf_nightly"
elif [[ ${GPU_BUILD} == "1" ]]; then
PKG_NAME_FLAG="--project_name tensorflow_gpu"
elif [[ ${ROCM_BUILD} == "1" ]]; then
PKG_NAME_FLAG="--project_name tensorflow_rocm"
fi
build_wheel "$SRCDIR" "$DSTDIR" "$PKG_NAME_FLAG"