[XLA][MLIR] Emit SelectAndScatter HLO instruction as lhlo.SelectAndScatterOp.
PiperOrigin-RevId: 302009183 Change-Id: If4ef8d3d23118c5815e33affcfb58ac7c7612352
This commit is contained in:
parent
5c91eab5d4
commit
abf182a882
@ -150,6 +150,7 @@ Status SpliceHloComputation(OpBuilder builder, mlir::Location loc,
|
||||
const HloComputation& hlo_computation,
|
||||
xla::mlir_gpu::EmissionContext* emission_context) {
|
||||
auto block = builder.getInsertionBlock();
|
||||
builder.setInsertionPoint(block->getTerminator());
|
||||
llvm::SmallVector<Value, 4> arg_values;
|
||||
// First map parameters to memrefs on the operation.
|
||||
for (auto param : hlo_computation.parameter_instructions()) {
|
||||
@ -242,7 +243,7 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleBroadcast(HloInstruction* broadcast) {
|
||||
mlir::DenseIntElementsAttr broadcast_dim =
|
||||
DenseIntElementsAttr broadcast_dim =
|
||||
CreateDenseIntElementsAttrFromVector(broadcast->dimensions(), builder_);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*broadcast));
|
||||
@ -343,6 +344,51 @@ Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
|
||||
*reduce_window->to_apply(), emission_context_);
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleSelectAndScatter(HloInstruction* hlo) {
|
||||
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*hlo));
|
||||
llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
|
||||
function.args_end()};
|
||||
OpBuilder builder(function.getBody());
|
||||
auto loc = getLocation(hlo);
|
||||
|
||||
// Collect attribute values.
|
||||
llvm::SmallVector<int64, 2> window_dimensions, window_strides, padding;
|
||||
int64 rank = hlo->window().dimensions_size();
|
||||
window_dimensions.reserve(rank);
|
||||
window_strides.reserve(rank);
|
||||
padding.reserve(2 * rank);
|
||||
for (const auto& window : hlo->window().dimensions()) {
|
||||
window_dimensions.push_back(window.size());
|
||||
window_strides.push_back(window.stride());
|
||||
padding.push_back(window.padding_low());
|
||||
padding.push_back(window.padding_high());
|
||||
}
|
||||
|
||||
auto select_scatter_op = builder.create<lhlo::SelectAndScatterOp>(
|
||||
loc, /*operand=*/arg_values[0], /*source=*/arg_values[1],
|
||||
/*init_value=*/arg_values[2],
|
||||
/*out=*/arg_values[3],
|
||||
CreateDenseIntElementsAttrFromVector(window_dimensions, builder),
|
||||
CreateDenseIntElementsAttrFromVector(window_strides, builder),
|
||||
CreateDenseIntElementsAttrFromVector(padding, builder, {rank, 2}));
|
||||
|
||||
// Convert `select` computation.
|
||||
builder.createBlock(&select_scatter_op.select());
|
||||
OpBuilder select_builder{&select_scatter_op.select()};
|
||||
select_builder.create<lhlo::TerminatorOp>(loc);
|
||||
TF_RETURN_IF_ERROR(SpliceHloComputation(select_builder, loc, *hlo->select(),
|
||||
emission_context_));
|
||||
|
||||
// Convert `scatter` computation.
|
||||
builder.createBlock(&select_scatter_op.scatter());
|
||||
OpBuilder scatter_builder{&select_scatter_op.scatter()};
|
||||
scatter_builder.create<lhlo::TerminatorOp>(loc);
|
||||
TF_RETURN_IF_ERROR(SpliceHloComputation(scatter_builder, loc, *hlo->scatter(),
|
||||
emission_context_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||
}
|
||||
|
||||
@ -62,6 +62,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
|
||||
Status HandleParameter(HloInstruction* parameter) override;
|
||||
Status HandleReduce(HloInstruction* reduce) override;
|
||||
Status HandleReduceWindow(HloInstruction* reduce_window) override;
|
||||
Status HandleSelectAndScatter(HloInstruction* hlo) override;
|
||||
Status HandleTuple(HloInstruction* tuple) override;
|
||||
|
||||
Status FinishVisit(HloInstruction* root) override;
|
||||
|
||||
@ -51,6 +51,7 @@ tf_cc_test(
|
||||
"rem.hlo",
|
||||
"rsqrt.hlo",
|
||||
"select.hlo",
|
||||
"select_and_scatter.hlo",
|
||||
"sign.hlo",
|
||||
"sqrt.hlo",
|
||||
"tanh.hlo",
|
||||
|
||||
@ -193,6 +193,12 @@ TEST_F(LhloGenTest, Rsqrt) {
|
||||
"rsqrt.hlo"));
|
||||
}
|
||||
|
||||
TEST_F(LhloGenTest, SelectAndScatter) {
|
||||
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
|
||||
"service", "mlir_gpu", "tests",
|
||||
"select_and_scatter.hlo"));
|
||||
}
|
||||
|
||||
TEST_F(LhloGenTest, Sign) {
|
||||
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
|
||||
"service", "mlir_gpu", "tests",
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
HloModule SelectAndScatter
|
||||
|
||||
%ge (x: f32[], y: f32[]) -> pred[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %compare = pred[] compare(f32[] %x, f32[] %y), direction=GE
|
||||
}
|
||||
|
||||
%add (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
ENTRY %SelectAndScatter (x: f32[128,64,112,112],
|
||||
y: f32[128,64,56,56],
|
||||
z: f32[]) -> f32[128,64,112,112] {
|
||||
%x = f32[128,64,112,112] parameter(0)
|
||||
%y = f32[128,64,56,56] parameter(1)
|
||||
%z = f32[] parameter(2)
|
||||
ROOT %result = f32[128,64,112,112] select-and-scatter(
|
||||
f32[128,64,112,112] %x,
|
||||
f32[128,64,56,56] %y,
|
||||
f32[] %z),
|
||||
window={size=1x1x3x3 stride=1x1x2x2 pad=0_0x0_0x0_1x0_1},
|
||||
select=%ge,
|
||||
scatter=%add
|
||||
}
|
||||
|
||||
// CHECK: func @"select-and-scatter"(
|
||||
// CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[SRC:%.*]]: [[SRCT:.*]], [[CST:%.*]]: memref<f32>, [[RES:%.*]]: [[REST:.*]]) {
|
||||
// CHECK: "xla_lhlo.select_and_scatter"([[ARG]], [[SRC]], [[CST]], [[RES]]) ( {
|
||||
// CHECK: ^bb0([[LHS:%.*]]: memref<f32>, [[RHS:%.*]]: memref<f32>,
|
||||
// CHECK-SAME: [[OUT:%.*]]: memref<i1>):
|
||||
// CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]]
|
||||
// CHECK: [[RHS_TENSOR:%.*]] = tensor_load [[RHS]]
|
||||
// CHECK: [[OUT_TENSOR:%.*]] = "xla_hlo.compare"
|
||||
// CHECK-SAME: ([[LHS_TENSOR]], [[RHS_TENSOR]]) {comparison_direction = "GE"}
|
||||
// CHECK: tensor_store [[OUT_TENSOR]], [[OUT]]
|
||||
// CHECK: xla_lhlo.terminator
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0([[LHS_:%.*]]: memref<f32>, [[RHS_:%.*]]: memref<f32>,
|
||||
// CHECK-SAME: [[OUT_:%.*]]: memref<f32>):
|
||||
// CHECK: [[LHS_TENSOR_:%.*]] = tensor_load [[LHS_]]
|
||||
// CHECK: [[RHS_TENSOR_:%.*]] = tensor_load [[RHS_]]
|
||||
// CHECK: [[OUT_TENSOR_:%.*]] = xla_hlo.add [[LHS_TENSOR_]], [[RHS_TENSOR_]]
|
||||
// CHECK: tensor_store [[OUT_TENSOR_]], [[OUT_]]
|
||||
// CHECK: xla_lhlo.terminator
|
||||
// CHECK: }) {
|
||||
// CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]>
|
||||
// CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]>
|
||||
// CHECK-SAME: window_strides = dense<[1, 1, 2, 2]>
|
||||
// CHECK-SAME: } : ([[ARGT]], [[SRCT]], memref<f32>, [[REST]]) -> ()
|
||||
Loading…
x
Reference in New Issue
Block a user