[XLA][MLIR] Emit ReduceWindow HLO instruction as xla_lhlo.ReduceWindowOp.

PiperOrigin-RevId: 302002238
Change-Id: I93d202b3729b63d85ae56e1515b1c1971e794aa0
This commit is contained in:
Alexander Belyaev 2020-03-20 04:28:09 -07:00 committed by TensorFlower Gardener
parent c0306ef626
commit 0e0fc5a791
8 changed files with 124 additions and 35 deletions

View File

@ -108,12 +108,12 @@ StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
}
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
const llvm::ArrayRef<int64> vector, mlir::Builder builder) {
const llvm::ArrayRef<int64> vector, mlir::Builder builder,
llvm::ArrayRef<int64_t> shape) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(vector.size(),
builder.getIntegerType(64)),
vector)
.cast<mlir::DenseIntElementsAttr>();
mlir::RankedTensorType::get(shape.empty() ? vector.size() : shape,
builder.getIntegerType(64)),
vector);
}
StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,

View File

@ -30,8 +30,11 @@ namespace xla {
StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
const LiteralBase& literal, mlir::Builder builder);
// Creates an DenseIntElementsAttr using the elements of the vector and the
// optional shape.
mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
const llvm::ArrayRef<int64> vector, mlir::Builder builder);
const llvm::ArrayRef<int64> vector, mlir::Builder builder,
llvm::ArrayRef<int64_t> shape = {});
StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
mlir::Builder builder);

View File

@ -70,12 +70,10 @@ struct FusionToLhloConverter
target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>();
::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns);
getFunction().walk([&](FusionOp op) {
if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
signalPassFailure();
getFunction().walk([&](mlir::Operation* op) {
if (op->getNumRegions() == 0) {
return;
}
});
getFunction().walk([&](mlir::xla_lhlo::ReduceOp op) {
if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
signalPassFailure();
}

View File

@ -42,6 +42,7 @@ namespace {
using ::mlir::ArrayRef;
using ::mlir::Attribute;
using ::mlir::Builder;
using ::mlir::DenseIntElementsAttr;
using ::mlir::FuncOp;
using ::mlir::Identifier;
using ::mlir::Location;
@ -143,6 +144,37 @@ StatusOr<llvm::SmallVector<Type, 4>> GetInstructionArgTypes(
return arg_types;
}
// Converts HloComputation into a block with HLO dialect ops. The block gets
// memref arguments corresponding to HloComputation arguments and results.
Status SpliceHloComputation(OpBuilder builder, mlir::Location loc,
const HloComputation& hlo_computation,
xla::mlir_gpu::EmissionContext* emission_context) {
auto block = builder.getInsertionBlock();
llvm::SmallVector<Value, 4> arg_values;
// First map parameters to memrefs on the operation.
for (auto param : hlo_computation.parameter_instructions()) {
TF_ASSIGN_OR_RETURN(
auto arg_type, ConvertShapeToType<MemRefType>(param->shape(), builder));
auto block_arg = block->addArgument(arg_type);
arg_values.push_back(builder.create<::mlir::TensorLoadOp>(loc, block_arg));
}
HloDialectEmitter hlo_emitter(emission_context, builder, arg_values);
TF_ASSIGN_OR_RETURN(auto result,
hlo_emitter.EmitComputation(hlo_computation));
// Now add a block arg and store for the result.
builder.setInsertionPoint(block->getTerminator());
TF_ASSIGN_OR_RETURN(
auto result_type,
ConvertShapeToType<MemRefType>(
hlo_computation.root_instruction()->shape(), builder));
auto block_arg = block->addArgument(result_type);
builder.create<::mlir::TensorStoreOp>(loc, result, block_arg);
return Status::OK();
}
} // namespace
mlir::Location LhloDialectEmitter::getLocation(
@ -268,33 +300,47 @@ Status LhloDialectEmitter::HandleReduce(HloInstruction* reduce) {
auto reduce_op = builder.create<lhlo::ReduceOp>(loc, inputs, init_values,
results, dimensions_attr);
reduce_op.ensureTerminator(reduce_op.body(), builder, getLocation(reduce));
return SpliceHloComputation(OpBuilder{&reduce_op.body()}, loc,
*reduce->to_apply(), emission_context_);
}
OpBuilder body_builder(reduce_op.body());
auto block = body_builder.getInsertionBlock();
auto to_apply = reduce->to_apply();
llvm::SmallVector<Value, 4> reduce_arg_values;
// First map parameters to memrefs on the operation.
for (auto param : to_apply->parameter_instructions()) {
TF_ASSIGN_OR_RETURN(auto arg_type, ConvertShapeToType<MemRefType>(
param->shape(), builder_));
auto block_arg = block->addArgument(arg_type);
reduce_arg_values.push_back(
body_builder.create<::mlir::TensorLoadOp>(loc, block_arg));
Status LhloDialectEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*reduce_window));
llvm::SmallVector<Value, 4> arg_values{function.args_begin(),
function.args_end()};
OpBuilder builder(function.getBody());
auto loc = getLocation(reduce_window);
// Collect attribute values.
llvm::SmallVector<int64, 2> window_dimensions, window_strides, base_dilations,
window_dilations;
llvm::SmallVector<int64, 4> padding;
int64 rank = reduce_window->window().dimensions_size();
window_dimensions.reserve(rank);
window_strides.reserve(rank);
base_dilations.reserve(rank);
window_dilations.reserve(rank);
padding.reserve(2 * rank);
for (const auto& window : reduce_window->window().dimensions()) {
window_dimensions.push_back(window.size());
window_strides.push_back(window.stride());
base_dilations.push_back(window.base_dilation());
window_dilations.push_back(window.window_dilation());
padding.push_back(window.padding_low());
padding.push_back(window.padding_high());
}
HloDialectEmitter hlo_emitter(emission_context_, body_builder,
reduce_arg_values);
TF_ASSIGN_OR_RETURN(auto result, hlo_emitter.EmitComputation(*to_apply));
// Now add a block arg and store for the result.
body_builder.setInsertionPoint(block->getTerminator());
TF_ASSIGN_OR_RETURN(auto result_type,
ConvertShapeToType<MemRefType>(
to_apply->root_instruction()->shape(), builder));
auto block_arg = block->addArgument(result_type);
body_builder.create<::mlir::TensorStoreOp>(loc, result, block_arg);
return Status::OK();
auto reduce_window_op = builder.create<lhlo::ReduceWindowOp>(
loc, /*operand=*/arg_values[0], /*init_value=*/arg_values[1],
/*out=*/arg_values[2],
CreateDenseIntElementsAttrFromVector(window_dimensions, builder),
CreateDenseIntElementsAttrFromVector(window_strides, builder),
CreateDenseIntElementsAttrFromVector(base_dilations, builder),
CreateDenseIntElementsAttrFromVector(window_dilations, builder),
CreateDenseIntElementsAttrFromVector(padding, builder, {rank, 2}));
reduce_window_op.ensureTerminator(reduce_window_op.body(), builder, loc);
return SpliceHloComputation(OpBuilder{&reduce_window_op.body()}, loc,
*reduce_window->to_apply(), emission_context_);
}
Status LhloDialectEmitter::HandleCustomCall(HloInstruction* custom_call) {

View File

@ -61,6 +61,7 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault,
Status HandleIota(HloInstruction* iota) override;
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleReduceWindow(HloInstruction* reduce_window) override;
Status HandleTuple(HloInstruction* tuple) override;
Status FinishVisit(HloInstruction* root) override;

View File

@ -47,6 +47,7 @@ tf_cc_test(
"iota_add_multiply.hlo",
"log.hlo",
"neg.hlo",
"reduce_window.hlo",
"rem.hlo",
"rsqrt.hlo",
"select.hlo",

View File

@ -175,6 +175,12 @@ TEST_F(LhloGenTest, Neg) {
"neg.hlo"));
}
TEST_F(LhloGenTest, ReduceWindow) {
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",
"reduce_window.hlo"));
}
TEST_F(LhloGenTest, Rem) {
CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla",
"service", "mlir_gpu", "tests",

View File

@ -0,0 +1,34 @@
HloModule ReduceWindow
%max (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %max = f32[] maximum(f32[] %x, f32[] %y)
}
ENTRY %ReduceWindow (x: f32[128,64,112,112], y: f32[]) -> f32[128,64,56,56] {
%x = f32[128,64,112,112] parameter(0)
%y = f32[] parameter(1)
ROOT %reduce-window = f32[128,64,56,56] reduce-window(
f32[128,64,112,112] %x,
f32[] %y
),
window={size=1x1x3x3 stride=1x1x2x2 pad=0_0x0_0x0_1x0_1}, to_apply=%max
}
// CHECK: func @"reduce-window"(
// CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[CST:%.*]]: memref<f32>, [[RES:%.*]]: [[REST:.*]]) {
// CHECK: "xla_lhlo.reduce_window"([[LHS:%.*]], [[RHS:%.*]], [[OUT:%.*]]) ( {
// CHECK: ^bb0([[LHS:%.*]]: memref<f32>, [[RHS:%.*]]: memref<f32>, [[OUT:%.*]]: memref<f32>):
// CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]]
// CHECK: [[RHS_TENSOR:%.*]] = tensor_load [[RHS]]
// CHECK: [[OUT_TENSOR:%.*]] = xla_hlo.maximum [[LHS_TENSOR]], [[RHS_TENSOR]]
// CHECK: tensor_store [[OUT_TENSOR]], [[OUT]]
// CHECK: "xla_lhlo.terminator"() : () -> ()
// CHECK: }) {
// CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>
// CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]>
// CHECK-SAME: window_dilations = dense<1> : tensor<4xi64>
// CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]>
// CHECK-SAME: window_strides = dense<[1, 1, 2, 2]>
// CHECK: } : ([[ARGT]], memref<f32>, [[REST]]) -> ()