Merge pull request #46368 from ROCmSoftwarePlatform:google_upstream_rocm_csb_fix_210112
PiperOrigin-RevId: 351987218 Change-Id: I6bb9d208a9edc005de51a647df8003e0823c5d45
This commit is contained in:
commit
0699826027
@ -69,6 +69,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service/gpu:gpu_compiler",
|
"//tensorflow/compiler/xla/service/gpu:gpu_compiler",
|
||||||
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
"//tensorflow/compiler/xla/service/gpu:target_constants",
|
||||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
||||||
"@llvm-project//llvm:Core",
|
"@llvm-project//llvm:Core",
|
||||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -23,13 +23,15 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
MlirGpuTestBase::MlirGpuTestBase() {
|
MlirGpuTestBase::MlirGpuTestBase() {
|
||||||
se::Platform* platform =
|
se::Platform* platform =
|
||||||
se::MultiPlatformManager::PlatformWithName("cuda").ConsumeValueOrDie();
|
se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName())
|
||||||
|
.ConsumeValueOrDie();
|
||||||
BackendOptions options;
|
BackendOptions options;
|
||||||
options.set_platform(platform);
|
options.set_platform(platform);
|
||||||
backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie();
|
backend_ = xla::Backend::CreateBackend(options).ConsumeValueOrDie();
|
||||||
@ -40,8 +42,13 @@ StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
|
|||||||
absl::Span<const se::DeviceMemoryBase> arguments) {
|
absl::Span<const se::DeviceMemoryBase> arguments) {
|
||||||
llvm::LLVMContext llvm_context;
|
llvm::LLVMContext llvm_context;
|
||||||
auto llvm_module = absl::make_unique<llvm::Module>("", llvm_context);
|
auto llvm_module = absl::make_unique<llvm::Module>("", llvm_context);
|
||||||
|
#if TENSORFLOW_USE_ROCM
|
||||||
|
llvm_module->setTargetTriple(amdgpu::kTargetTriple);
|
||||||
|
llvm_module->setDataLayout(amdgpu::kDataLayout);
|
||||||
|
#else
|
||||||
llvm_module->setTargetTriple(nvptx::kTargetTriple);
|
llvm_module->setTargetTriple(nvptx::kTargetTriple);
|
||||||
llvm_module->setDataLayout(nvptx::kDataLayout);
|
llvm_module->setDataLayout(nvptx::kDataLayout);
|
||||||
|
#endif
|
||||||
|
|
||||||
se::StreamExecutor* stream_exec = stream->parent();
|
se::StreamExecutor* stream_exec = stream->parent();
|
||||||
GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
|
GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user