Merge pull request #46368 from ROCmSoftwarePlatform:google_upstream_rocm_csb_fix_210112
PiperOrigin-RevId: 351987218 Change-Id: I6bb9d208a9edc005de51a647df8003e0823c5d45
This commit is contained in:
commit
0699826027
@ -69,6 +69,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_compiler",
|
||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -23,13 +23,15 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
MlirGpuTestBase::MlirGpuTestBase() {
|
||||
se::Platform* platform =
|
||||
se::MultiPlatformManager::PlatformWithName("cuda").ConsumeValueOrDie();
|
||||
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName())
|
||||
.ConsumeValueOrDie();
|
||||
BackendOptions options;
|
||||
options.set_platform(platform);
|
||||
backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie();
|
||||
@ -40,8 +42,13 @@ StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
|
||||
absl::Span<const se::DeviceMemoryBase> arguments) {
|
||||
llvm::LLVMContext llvm_context;
|
||||
auto llvm_module = absl::make_unique<llvm::Module>("", llvm_context);
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
llvm_module->setTargetTriple(amdgpu::kTargetTriple);
|
||||
llvm_module->setDataLayout(amdgpu::kDataLayout);
|
||||
#else
|
||||
llvm_module->setTargetTriple(nvptx::kTargetTriple);
|
||||
llvm_module->setDataLayout(nvptx::kDataLayout);
|
||||
#endif
|
||||
|
||||
se::StreamExecutor* stream_exec = stream->parent();
|
||||
GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
|
||||
|
Loading…
x
Reference in New Issue
Block a user