Merge pull request #46368 from ROCmSoftwarePlatform:google_upstream_rocm_csb_fix_210112

PiperOrigin-RevId: 351987218
Change-Id: I6bb9d208a9edc005de51a647df8003e0823c5d45
This commit is contained in:
TensorFlower Gardener 2021-01-15 05:07:56 -08:00
commit 0699826027
2 changed files with 9 additions and 1 deletions

View File

@ -69,6 +69,7 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:gpu_compiler",
"//tensorflow/compiler/xla/service/gpu:target_constants",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core/common_runtime/gpu:gpu_init",
"@llvm-project//llvm:Core",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",

View File

@ -23,13 +23,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
namespace xla {
namespace gpu {
MlirGpuTestBase::MlirGpuTestBase() {
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ConsumeValueOrDie();
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName())
.ConsumeValueOrDie();
BackendOptions options;
options.set_platform(platform);
backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie();
@ -40,8 +42,13 @@ StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
absl::Span<const se::DeviceMemoryBase> arguments) {
llvm::LLVMContext llvm_context;
auto llvm_module = absl::make_unique<llvm::Module>("", llvm_context);
#if TENSORFLOW_USE_ROCM
llvm_module->setTargetTriple(amdgpu::kTargetTriple);
llvm_module->setDataLayout(amdgpu::kDataLayout);
#else
llvm_module->setTargetTriple(nvptx::kTargetTriple);
llvm_module->setDataLayout(nvptx::kDataLayout);
#endif
se::StreamExecutor* stream_exec = stream->parent();
GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);