Scatter HLO -> LHLO conversion test

- Change xla-opt to be not dependent on xla_gpu_jit and xla_cpu_jit but just the
  cpu and gpu plugins.
- Implement GpuCompiler::RunHloPassesAndBufferAssignement() so that GPU backend
  can be exercised using xla-opt.
- Added a gpu-only unit test for MLIR HLO Scatter -> LHLO Scatter conversion using XLA.

PiperOrigin-RevId: 338340350
Change-Id: I10035214adb38ff06ec5f0c5cba90a247f880a75
This commit is contained in:
Rahul Joshi 2020-10-21 14:28:38 -07:00 committed by TensorFlower Gardener
parent adce39ad55
commit 7912c72005
7 changed files with 138 additions and 24 deletions

View File

@ -428,8 +428,8 @@ tf_cc_binary(
name = "xla-opt",
deps = [
":all_xla_passes_for_testing",
"//tensorflow/compiler/jit:xla_cpu_jit",
"//tensorflow/compiler/jit:xla_gpu_jit",
"//tensorflow/compiler/mlir:tf_mlir_opt_main",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service:gpu_plugin",
],
)

View File

@ -1,12 +1,19 @@
load("//tensorflow:tensorflow.bzl", "filegroup")
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
)
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
tags_override = {
"hlo_to_lhlo_with_xla/gpu_ops.mlir": tf_cuda_tests_tags(),
},
test_file_exts = [
"mlir",
"hlotxt",

View File

@ -1,16 +0,0 @@
// RUN: tf-mlir-translate -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s
HloModule TestModule
// CHECK: func @TestComputation
FusedComputation {
// CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>}
x = f32[3, 2]{0,1} parameter(0)
ROOT y = f32[3, 2]{0,1} add(x, x)
}
ENTRY TestComputation {
x = f32[3, 2]{0,1} parameter(0)
ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation
}

View File

@ -0,0 +1,33 @@
// RUN: xla-opt -split-input-file "-xla-hlo-to-lhlo-with-xla=platform=CUDA" %s
//// | FILECHECK_OPTS="" FileCheck --enable-var-scope %s
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<3x3xi32>
// CHECK-SAME: %[[ARG1:.*]]: memref<2xi32>
// CHECK-SAME: %[[ARG2:.*]]: memref<2x3xi32>
// CHECK-SAME: %[[ARG3:.*]]: memref<36xi8> {lmhlo.alloc = 0
// CHECK: %[[VIEW0:.*]] = std.view %[[ARG3]]{{.*}} : memref<36xi8> to memref3x3xi32>
// CHECK: "lmhlo.copy"(%[[ARG0]], %[[VIEW0]])
// CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32>
// CHECK: "lmhlo.scatter"(%[[VIEW0]], %[[ARG1]], %[[ARG2]], %[[VIEW1]])
// CHECK: mhlo.add
// CHECK: indices_are_sorted = false
// CHECK: index_vector_dim = 1 : i64
// CHECK: inserted_window_dims = dense<0> : tensor<1xi64>
// CHECK: scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>
// CHECK: update_window_dims = dense<1> : tensor<1xi64>
// CHECK: unique_indices = false
func @main(%operand:tensor<3x3xi32>, %indices: tensor<2xi32>, %updates: tensor<2x3xi32>) -> tensor<3x3xi32> {
%result = "mhlo.scatter"(%operand, %indices, %updates) ( {
^bb0(%x: tensor<i32>, %y : tensor<i32>):
%result = "mhlo.add"(%x, %y): (tensor<i32>, tensor<i32>) -> tensor<i32>
"mhlo.return"(%result) : (tensor<i32>) -> ()
}) { scatter_dimension_numbers = {index_vector_dim = 1 : i64,
inserted_window_dims = dense<0> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
update_window_dims = dense<1> : tensor<1xi64>},
indices_are_sorted = false,
unique_indices = false} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
return %result : tensor<3x3xi32>
}

View File

@ -0,0 +1,48 @@
// RUN: tf-mlir-translate -split-input-file -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s
HloModule TestModule
// CHECK: func @TestComputation
FusedComputation {
// CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>}
x = f32[3, 2]{0,1} parameter(0)
ROOT y = f32[3, 2]{0,1} add(x, x)
}
ENTRY TestComputation {
x = f32[3, 2]{0,1} parameter(0)
ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation
}
// -----
HloModule ScatterModule
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
// CHECK: func @main
// CHECK: "lmhlo.scatter"
// CHECK: ^bb0(%[[ARG5:.*]]: tensor<i32>, %[[ARG6:.*]]: tensor<i32>):
// CHECK: "mhlo.return"(%[[ARG6]])
// CHECK: indices_are_sorted = false
// CHECK: index_vector_dim = 1 : i64
// CHECK: inserted_window_dims = dense<0> : tensor<1xi64>
// CHECK: scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>
// CHECK: update_window_dims = dense<1> : tensor<1xi64>
// CHECK: unique_indices = false
// CHECK: (memref<3x3xi32>, memref<2xi32>, memref<2x3xi32>, memref<3x3xi32>) -> ()
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter_op = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}

View File

@ -479,6 +479,47 @@ StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
return std::move(module);
}
static absl::optional<bool> DummyCanShareBufferFunction(const HloInstruction*,
const HloInstruction*,
const ShapeIndex&) {
return absl::nullopt;
}
StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
GpuCompiler::RunHloPassesAndBufferAssignement(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator, bool optimize) {
if (optimize) {
TF_ASSIGN_OR_RETURN(hlo_module, RunHloPasses(std::move(hlo_module),
executor, device_allocator));
}
std::unique_ptr<StreamAssignment> stream_assignment =
AssignStreams(*hlo_module);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<GpuHloSchedule> hlo_schedule,
GpuHloSchedule::Build(*hlo_module, *stream_assignment, pointer_size_));
auto buffer_size_bytes_function =
[this](const BufferValue& buffer_value) -> int64 {
return GpuCompiler::GetSizeOfShape(buffer_value.shape(), pointer_size_);
};
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(
hlo_module.get(), hlo_schedule->ConsumeHloOrdering(),
buffer_size_bytes_function,
/*color_alignment=*/
[](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; },
/*allocate_buffers_for_constants=*/true,
/*colorer=*/BufferAssigner::DefaultColorer(),
/*must_not_live_out=*/{}, DummyCanShareBufferFunction));
return std::make_tuple(std::move(hlo_module), std::move(assignment));
}
// The order of `thunk_sequence` corresponds to
// `hlo_schedule->ThunkLaunchOrder()`.
static Status CompileModuleToLlvmIrImpl(
@ -724,12 +765,6 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
}
static absl::optional<bool> DummyCanShareBufferFunction(const HloInstruction*,
const HloInstruction*,
const ShapeIndex&) {
return absl::nullopt;
}
StatusOr<std::unique_ptr<llvm::Module>> CompileModuleToLlvmIr(
HloModule* hlo_module, llvm::LLVMContext* llvm_context,
const std::string& target_triple, const std::string& data_layout,

View File

@ -55,6 +55,13 @@ class GpuCompiler : public LLVMCompiler {
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
StatusOr<
std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> hlo_module,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator,
bool optimize) override;
Status OptimizeHloModule(HloModule* hlo_module,
se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator);