[XLA][MLIR] Emit SelectAndScatter HLO instruction as lhlo.SelectAndScatterOp.

PiperOrigin-RevId: 302009183
Change-Id: If4ef8d3d23118c5815e33affcfb58ac7c7612352
This commit is contained in:
Alexander Belyaev 2020-03-20 05:32:03 -07:00 committed by TensorFlower Gardener
parent 5c91eab5d4
commit abf182a882
5 changed files with 108 additions and 1 deletions

View File

@ -150,6 +150,7 @@ Status SpliceHloComputation(OpBuilder builder, mlir::Location loc,
const HloComputation& hlo_computation,
xla::mlir_gpu::EmissionContext* emission_context) {
auto block = builder.getInsertionBlock();
builder.setInsertionPoint(block->getTerminator());
llvm::SmallVector<Value, 4> arg_values;
// First map parameters to memrefs on the operation.
for (auto param : hlo_computation.parameter_instructions()) {
@ -242,7 +243,7 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
}
Status LhloDialectEmitter::HandleBroadcast(HloInstruction* broadcast) {
mlir::DenseIntElementsAttr broadcast_dim =
DenseIntElementsAttr broadcast_dim =
CreateDenseIntElementsAttrFromVector(broadcast->dimensions(), builder_);
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*broadcast));
@ -343,6 +344,51 @@ Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
*reduce_window->to_apply(), emission_context_);
}
Status LhloDialectEmitter::HandleSelectAndScatter(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*hlo));
llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
function.args_end()};
OpBuilder builder(function.getBody());
auto loc = getLocation(hlo);
// Collect attribute values.
llvm::SmallVector<int64, 2> window_dimensions, window_strides, padding;
int64 rank = hlo->window().dimensions_size();
window_dimensions.reserve(rank);
window_strides.reserve(rank);
padding.reserve(2 * rank);
for (const auto& window : hlo->window().dimensions()) {
window_dimensions.push_back(window.size());
window_strides.push_back(window.stride());
padding.push_back(window.padding_low());
padding.push_back(window.padding_high());
}
auto select_scatter_op = builder.create<lhlo::SelectAndScatterOp>(
loc, /*operand=*/arg_values[0], /*source=*/arg_values[1],
/*init_value=*/arg_values[2],
/*out=*/arg_values[3],
CreateDenseIntElementsAttrFromVector(window_dimensions, builder),
CreateDenseIntElementsAttrFromVector(window_strides, builder),
CreateDenseIntElementsAttrFromVector(padding, builder, {rank, 2}));
// Convert `select` computation.
builder.createBlock(&select_scatter_op.select());
OpBuilder select_builder{&select_scatter_op.select()};
select_builder.create<lhlo::TerminatorOp>(loc);
TF_RETURN_IF_ERROR(SpliceHloComputation(select_builder, loc, *hlo->select(),
emission_context_));
// Convert `scatter` computation.
builder.createBlock(&select_scatter_op.scatter());
OpBuilder scatter_builder{&select_scatter_op.scatter()};
scatter_builder.create<lhlo::TerminatorOp>(loc);
TF_RETURN_IF_ERROR(SpliceHloComputation(scatter_builder, loc, *hlo->scatter(),
emission_context_));
return Status::OK();
}
Status LhloDialectEmitter::HandleCustomCall(HloInstruction* custom_call) {
return ThunkEmitter(this).HandleCustomCall(custom_call);
}

View File

@ -62,6 +62,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleReduceWindow(HloInstruction* reduce_window) override;
Status HandleSelectAndScatter(HloInstruction* hlo) override;
Status HandleTuple(HloInstruction* tuple) override;
Status FinishVisit(HloInstruction* root) override;

View File

@ -51,6 +51,7 @@ tf_cc_test(
"rem.hlo",
"rsqrt.hlo",
"select.hlo",
"select_and_scatter.hlo",
"sign.hlo",
"sqrt.hlo",
"tanh.hlo",

View File

@ -193,6 +193,12 @@ TEST_F(LhloGenTest, Rsqrt) {
"rsqrt.hlo"));
}
TEST_F(LhloGenTest, SelectAndScatter) {
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"select_and_scatter.hlo"));
}
TEST_F(LhloGenTest, Sign) {
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",

View File

@ -0,0 +1,53 @@
HloModule SelectAndScatter
%ge (x: f32[], y: f32[]) -> pred[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %compare = pred[] compare(f32[] %x, f32[] %y), direction=GE
}
%add (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %SelectAndScatter (x: f32[128,64,112,112],
y: f32[128,64,56,56],
z: f32[]) -> f32[128,64,112,112] {
%x = f32[128,64,112,112] parameter(0)
%y = f32[128,64,56,56] parameter(1)
%z = f32[] parameter(2)
ROOT %result = f32[128,64,112,112] select-and-scatter(
f32[128,64,112,112] %x,
f32[128,64,56,56] %y,
f32[] %z),
window={size=1x1x3x3 stride=1x1x2x2 pad=0_0x0_0x0_1x0_1},
select=%ge,
scatter=%add
}
// CHECK: func @"select-and-scatter"(
// CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[SRC:%.*]]: [[SRCT:.*]], [[CST:%.*]]: memref<f32>, [[RES:%.*]]: [[REST:.*]]) {
// CHECK: "xla_lhlo.select_and_scatter"([[ARG]], [[SRC]], [[CST]], [[RES]]) ( {
// CHECK: ^bb0([[LHS:%.*]]: memref<f32>, [[RHS:%.*]]: memref<f32>,
// CHECK-SAME: [[OUT:%.*]]: memref<i1>):
// CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]]
// CHECK: [[RHS_TENSOR:%.*]] = tensor_load [[RHS]]
// CHECK: [[OUT_TENSOR:%.*]] = "xla_hlo.compare"
// CHECK-SAME: ([[LHS_TENSOR]], [[RHS_TENSOR]]) {comparison_direction = "GE"}
// CHECK: tensor_store [[OUT_TENSOR]], [[OUT]]
// CHECK: xla_lhlo.terminator
// CHECK: }, {
// CHECK: ^bb0([[LHS_:%.*]]: memref<f32>, [[RHS_:%.*]]: memref<f32>,
// CHECK-SAME: [[OUT_:%.*]]: memref<f32>):
// CHECK: [[LHS_TENSOR_:%.*]] = tensor_load [[LHS_]]
// CHECK: [[RHS_TENSOR_:%.*]] = tensor_load [[RHS_]]
// CHECK: [[OUT_TENSOR_:%.*]] = xla_hlo.add [[LHS_TENSOR_]], [[RHS_TENSOR_]]
// CHECK: tensor_store [[OUT_TENSOR_]], [[OUT_]]
// CHECK: xla_lhlo.terminator
// CHECK: }) {
// CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]>
// CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]>
// CHECK-SAME: window_strides = dense<[1, 1, 2, 2]>
// CHECK-SAME: } : ([[ARGT]], [[SRCT]], memref<f32>, [[REST]]) -> ()