From 846ae3505389ef6c56dbab19ab8a31d2c811d705 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Tue, 12 Jan 2021 16:49:16 +0000 Subject: [PATCH] [ROCm] Fix for breakage in ROCm support - 210112 The following commit introduces a new unit-test `//tensorflow/compiler/xla/service/gpu/tests:mlir_sorting_test`, which fails on the ROCm platform. https://github.com/tensorflow/tensorflow/commit/70502be4a5941b6ffa2b8c33bb549657b33976da The test fails on the ROCm platform because the underlying code for it is CUDA-centric. This PR/commit updates the test to make it work on the ROCm platform as well. --- .../compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc b/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc index 1fa8de26ada..d0ba544289c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.cc @@ -23,13 +23,15 @@ limitations under the License. #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" namespace xla { namespace gpu { MlirGpuTestBase::MlirGpuTestBase() { se::Platform* platform = - se::MultiPlatformManager::PlatformWithName("cuda").ConsumeValueOrDie(); + se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName()) + .ConsumeValueOrDie(); BackendOptions options; options.set_platform(platform); backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie(); @@ -40,8 +42,13 @@ StatusOr MlirGpuTestBase::RunMlirModule( absl::Span arguments) { llvm::LLVMContext llvm_context; auto llvm_module = absl::make_unique("", llvm_context); +#if TENSORFLOW_USE_ROCM + llvm_module->setTargetTriple(amdgpu::kTargetTriple); + llvm_module->setDataLayout(amdgpu::kDataLayout); +#else llvm_module->setTargetTriple(nvptx::kTargetTriple); llvm_module->setDataLayout(nvptx::kDataLayout); +#endif se::StreamExecutor* stream_exec = stream->parent(); GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);