[XLA][MLIR] Emit ReduceWindow HLO instruction as xla_lhlo.ReduceWindowOp.
PiperOrigin-RevId: 302002238 Change-Id: I93d202b3729b63d85ae56e1515b1c1971e794aa0
This commit is contained in:
parent
c0306ef626
commit
0e0fc5a791
@ -108,12 +108,12 @@ StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
|
||||
}
|
||||
|
||||
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
||||
const llvm::ArrayRef<int64> vector, mlir::Builder builder) {
|
||||
const llvm::ArrayRef<int64> vector, mlir::Builder builder,
|
||||
llvm::ArrayRef<int64_t> shape) {
|
||||
return mlir::DenseIntElementsAttr::get(
|
||||
mlir::RankedTensorType::get(vector.size(),
|
||||
builder.getIntegerType(64)),
|
||||
vector)
|
||||
.cast<mlir::DenseIntElementsAttr>();
|
||||
mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape,
|
||||
builder.getIntegerType(64)),
|
||||
vector);
|
||||
}
|
||||
|
||||
StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
|
||||
|
@ -30,8 +30,11 @@ namespace xla {
|
||||
StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
|
||||
const LiteralBase& literal, mlir::Builder builder);
|
||||
|
||||
// Creates an DenseIntElementsAttr using the elements of the vector and the
|
||||
// optional shape.
|
||||
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
||||
const llvm::ArrayRef<int64> vector, mlir::Builder builder);
|
||||
const llvm::ArrayRef<int64> vector, mlir::Builder builder,
|
||||
llvm::ArrayRef<int64_t> shape = {});
|
||||
|
||||
StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
|
||||
mlir::Builder builder);
|
||||
|
@ -70,12 +70,10 @@ struct FusionToLhloConverter
|
||||
target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>();
|
||||
::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns);
|
||||
|
||||
getFunction().walk([&](FusionOp op) {
|
||||
if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
|
||||
signalPassFailure();
|
||||
getFunction().walk([&](mlir::Operation* op) {
|
||||
if (op->getNumRegions() == 0) {
|
||||
return;
|
||||
}
|
||||
});
|
||||
getFunction().walk([&](mlir::xla_lhlo::ReduceOp op) {
|
||||
if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ namespace {
|
||||
using ::mlir::ArrayRef;
|
||||
using ::mlir::Attribute;
|
||||
using ::mlir::Builder;
|
||||
using ::mlir::DenseIntElementsAttr;
|
||||
using ::mlir::FuncOp;
|
||||
using ::mlir::Identifier;
|
||||
using ::mlir::Location;
|
||||
@ -143,6 +144,37 @@ StatusOr<llvm::SmallVector<Type, 4>> GetInstructionArgTypes(
|
||||
return arg_types;
|
||||
}
|
||||
|
||||
// Converts HloComputation into a block with HLO dialect ops. The block gets
|
||||
// memref arguments corresponding to HloComputation arguments and results.
|
||||
Status SpliceHloComputation(OpBuilder builder, mlir::Location loc,
|
||||
const HloComputation& hlo_computation,
|
||||
xla::mlir_gpu::EmissionContext* emission_context) {
|
||||
auto block = builder.getInsertionBlock();
|
||||
llvm::SmallVector<Value, 4> arg_values;
|
||||
// First map parameters to memrefs on the operation.
|
||||
for (auto param : hlo_computation.parameter_instructions()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto arg_type, ConvertShapeToType<MemRefType>(param->shape(), builder));
|
||||
auto block_arg = block->addArgument(arg_type);
|
||||
arg_values.push_back(builder.create<::mlir::TensorLoadOp>(loc, block_arg));
|
||||
}
|
||||
HloDialectEmitter hlo_emitter(emission_context, builder, arg_values);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
hlo_emitter.EmitComputation(hlo_computation));
|
||||
|
||||
// Now add a block arg and store for the result.
|
||||
builder.setInsertionPoint(block->getTerminator());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto result_type,
|
||||
ConvertShapeToType<MemRefType>(
|
||||
hlo_computation.root_instruction()->shape(), builder));
|
||||
auto block_arg = block->addArgument(result_type);
|
||||
builder.create<::mlir::TensorStoreOp>(loc, result, block_arg);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
mlir::Location LhloDialectEmitter::getLocation(
|
||||
@ -268,33 +300,47 @@ Status LhloDialectEmitter::HandleReduce(HloInstruction* reduce) {
|
||||
auto reduce_op = builder.create<lhlo::ReduceOp>(loc, inputs, init_values,
|
||||
results, dimensions_attr);
|
||||
reduce_op.ensureTerminator(reduce_op.body(), builder, getLocation(reduce));
|
||||
return SpliceHloComputation(OpBuilder{&reduce_op.body()}, loc,
|
||||
*reduce->to_apply(), emission_context_);
|
||||
}
|
||||
|
||||
OpBuilder body_builder(reduce_op.body());
|
||||
auto block = body_builder.getInsertionBlock();
|
||||
auto to_apply = reduce->to_apply();
|
||||
llvm::SmallVector<Value, 4> reduce_arg_values;
|
||||
// First map parameters to memrefs on the operation.
|
||||
for (auto param : to_apply->parameter_instructions()) {
|
||||
TF_ASSIGN_OR_RETURN(auto arg_type, ConvertShapeToType<MemRefType>(
|
||||
param->shape(), builder_));
|
||||
auto block_arg = block->addArgument(arg_type);
|
||||
reduce_arg_values.push_back(
|
||||
body_builder.create<::mlir::TensorLoadOp>(loc, block_arg));
|
||||
Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
|
||||
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*reduce_window));
|
||||
llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
|
||||
function.args_end()};
|
||||
OpBuilder builder(function.getBody());
|
||||
auto loc = getLocation(reduce_window);
|
||||
|
||||
// Collect attribute values.
|
||||
llvm::SmallVector<int64, 2> window_dimensions, window_strides, base_dilations,
|
||||
window_dilations;
|
||||
llvm::SmallVector<int64, 4> padding;
|
||||
int64 rank = reduce_window->window().dimensions_size();
|
||||
window_dimensions.reserve(rank);
|
||||
window_strides.reserve(rank);
|
||||
base_dilations.reserve(rank);
|
||||
window_dilations.reserve(rank);
|
||||
padding.reserve(2 * rank);
|
||||
for (const auto& window : reduce_window->window().dimensions()) {
|
||||
window_dimensions.push_back(window.size());
|
||||
window_strides.push_back(window.stride());
|
||||
base_dilations.push_back(window.base_dilation());
|
||||
window_dilations.push_back(window.window_dilation());
|
||||
padding.push_back(window.padding_low());
|
||||
padding.push_back(window.padding_high());
|
||||
}
|
||||
HloDialectEmitter hlo_emitter(emission_context_, body_builder,
|
||||
reduce_arg_values);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto result, hlo_emitter.EmitComputation(*to_apply));
|
||||
|
||||
// Now add a block arg and store for the result.
|
||||
body_builder.setInsertionPoint(block->getTerminator());
|
||||
TF_ASSIGN_OR_RETURN(auto result_type,
|
||||
ConvertShapeToType<MemRefType>(
|
||||
to_apply->root_instruction()->shape(), builder));
|
||||
auto block_arg = block->addArgument(result_type);
|
||||
body_builder.create<::mlir::TensorStoreOp>(loc, result, block_arg);
|
||||
|
||||
return Status::OK();
|
||||
auto reduce_window_op = builder.create<lhlo::ReduceWindowOp>(
|
||||
loc, /*operand=*/arg_values[0], /*init_value=*/arg_values[1],
|
||||
/*out=*/arg_values[2],
|
||||
CreateDenseIntElementsAttrFromVector(window_dimensions, builder),
|
||||
CreateDenseIntElementsAttrFromVector(window_strides, builder),
|
||||
CreateDenseIntElementsAttrFromVector(base_dilations, builder),
|
||||
CreateDenseIntElementsAttrFromVector(window_dilations, builder),
|
||||
CreateDenseIntElementsAttrFromVector(padding, builder, {rank, 2}));
|
||||
reduce_window_op.ensureTerminator(reduce_window_op.body(), builder, loc);
|
||||
return SpliceHloComputation(OpBuilder{&reduce_window_op.body()}, loc,
|
||||
*reduce_window->to_apply(), emission_context_);
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
|
@ -61,6 +61,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
|
||||
Status HandleIota(HloInstruction* iota) override;
|
||||
Status HandleParameter(HloInstruction* parameter) override;
|
||||
Status HandleReduce(HloInstruction* reduce) override;
|
||||
Status HandleReduceWindow(HloInstruction* reduce_window) override;
|
||||
Status HandleTuple(HloInstruction* tuple) override;
|
||||
|
||||
Status FinishVisit(HloInstruction* root) override;
|
||||
|
@ -47,6 +47,7 @@ tf_cc_test(
|
||||
"iota_add_multiply.hlo",
|
||||
"log.hlo",
|
||||
"neg.hlo",
|
||||
"reduce_window.hlo",
|
||||
"rem.hlo",
|
||||
"rsqrt.hlo",
|
||||
"select.hlo",
|
||||
|
@ -175,6 +175,12 @@ TEST_F(LhloGenTest, Neg) {
|
||||
"neg.hlo"));
|
||||
}
|
||||
|
||||
TEST_F(LhloGenTest, ReduceWindow) {
|
||||
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
|
||||
"service", "mlir_gpu", "tests",
|
||||
"reduce_window.hlo"));
|
||||
}
|
||||
|
||||
TEST_F(LhloGenTest, Rem) {
|
||||
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
|
||||
"service", "mlir_gpu", "tests",
|
||||
|
@ -0,0 +1,34 @@
|
||||
HloModule ReduceWindow
|
||||
|
||||
%max (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %max = f32[] maximum(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
ENTRY %ReduceWindow (x: f32[128,64,112,112], y: f32[]) -> f32[128,64,56,56] {
|
||||
%x = f32[128,64,112,112] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %reduce-window = f32[128,64,56,56] reduce-window(
|
||||
f32[128,64,112,112] %x,
|
||||
f32[] %y
|
||||
),
|
||||
window={size=1x1x3x3 stride=1x1x2x2 pad=0_0x0_0x0_1x0_1}, to_apply=%max
|
||||
}
|
||||
|
||||
// CHECK: func @"reduce-window"(
|
||||
// CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[CST:%.*]]: memref<f32>, [[RES:%.*]]: [[REST:.*]]) {
|
||||
// CHECK: "xla_lhlo.reduce_window"([[LHS:%.*]], [[RHS:%.*]], [[OUT:%.*]]) ( {
|
||||
// CHECK: ^bb0([[LHS:%.*]]: memref<f32>, [[RHS:%.*]]: memref<f32>, [[OUT:%.*]]: memref<f32>):
|
||||
// CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]]
|
||||
// CHECK: [[RHS_TENSOR:%.*]] = tensor_load [[RHS]]
|
||||
// CHECK: [[OUT_TENSOR:%.*]] = xla_hlo.maximum [[LHS_TENSOR]], [[RHS_TENSOR]]
|
||||
// CHECK: tensor_store [[OUT_TENSOR]], [[OUT]]
|
||||
// CHECK: "xla_lhlo.terminator"() : () -> ()
|
||||
// CHECK: }) {
|
||||
// CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>
|
||||
// CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]>
|
||||
// CHECK-SAME: window_dilations = dense<1> : tensor<4xi64>
|
||||
// CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]>
|
||||
// CHECK-SAME: window_strides = dense<[1, 1, 2, 2]>
|
||||
// CHECK: } : ([[ARGT]], memref<f32>, [[REST]]) -> ()
|
Loading…
x
Reference in New Issue
Block a user