[mlir][xla] LHLO-to-Affine: simplify DotOp conversion with nested builders
MLIR recently introduced a new idiom for constructing loop nests. Use it to make the legalization of LHLO to affine loops more concise and readable. PiperOrigin-RevId: 317649515 Change-Id: Idfab27b4655d6df90d940fb7b064ea9941d8a700
This commit is contained in:
parent
12f5cd7dde
commit
bb73be3e36
@ -31,6 +31,17 @@ namespace mlir {
|
|||||||
namespace xla_lhlo {
|
namespace xla_lhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
|
||||||
|
// steps, and populates the body of the innermost loop using "body_builder".
|
||||||
|
static void BuildBoundedAffineLoopNest(
|
||||||
|
OpBuilder& builder, Location location, ArrayRef<int64_t> upper_bounds,
|
||||||
|
function_ref<void(OpBuilder&, Location, ValueRange)> body_builder) {
|
||||||
|
SmallVector<int64_t, 3> lower_bounds(upper_bounds.size(), /*Value=*/0);
|
||||||
|
SmallVector<int64_t, 3> steps(upper_bounds.size(), /*Value=*/1);
|
||||||
|
buildAffineLoopNest(builder, location, lower_bounds, upper_bounds, steps,
|
||||||
|
body_builder);
|
||||||
|
}
|
||||||
|
|
||||||
struct DotOpConverter : public OpRewritePattern<DotOp> {
|
struct DotOpConverter : public OpRewritePattern<DotOp> {
|
||||||
using OpRewritePattern<DotOp>::OpRewritePattern;
|
using OpRewritePattern<DotOp>::OpRewritePattern;
|
||||||
|
|
||||||
@ -48,37 +59,29 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
|
|||||||
if ((lhs_type.getRank() != 2) || (rhs_type.getRank() != 2)) {
|
if ((lhs_type.getRank() != 2) || (rhs_type.getRank() != 2)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
SmallVector<Value, 4> lhs_indices, rhs_indices, result_indices;
|
|
||||||
const auto& loc = op.getLoc();
|
|
||||||
|
|
||||||
// Create the canonical ijk form of matmul.
|
LogicalResult map_status = success();
|
||||||
auto forOp = rewriter.create<AffineForOp>(loc, 0, shape_lhs[0]);
|
auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
|
||||||
lhs_indices.push_back(forOp.getInductionVar());
|
SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},
|
||||||
result_indices.push_back(forOp.getInductionVar());
|
rhs_indices{ivs[2], ivs[1]}, result_indices{ivs[0], ivs[1]};
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(forOp.getBody());
|
auto l = builder.create<AffineLoadOp>(loc, lhs, lhs_indices);
|
||||||
forOp = rewriter.create<AffineForOp>(loc, 0, shape_rhs.back());
|
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
|
||||||
result_indices.push_back(forOp.getInductionVar());
|
auto result =
|
||||||
rhs_indices.resize(2);
|
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
|
||||||
rhs_indices[1] = forOp.getInductionVar();
|
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>(
|
||||||
|
op, element_type, {l, r, result}, &builder);
|
||||||
|
map_status = success(op_result != nullptr);
|
||||||
|
if (failed(map_status)) return;
|
||||||
|
builder.create<AffineStoreOp>(loc, op_result, op.output(),
|
||||||
|
result_indices);
|
||||||
|
};
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(forOp.getBody());
|
BuildBoundedAffineLoopNest(rewriter, op.getLoc(),
|
||||||
forOp = rewriter.create<AffineForOp>(loc, 0, shape_rhs.front());
|
{shape_lhs[0], shape_rhs[1], shape_rhs[0]},
|
||||||
lhs_indices.push_back(forOp.getInductionVar());
|
body_builder);
|
||||||
rhs_indices[0] = forOp.getInductionVar();
|
if (failed(map_status)) return failure();
|
||||||
|
|
||||||
// Construct the innermost loop body.
|
|
||||||
rewriter.setInsertionPointToStart(forOp.getBody());
|
|
||||||
auto l = rewriter.create<AffineLoadOp>(loc, lhs, lhs_indices);
|
|
||||||
auto r = rewriter.create<AffineLoadOp>(loc, rhs, rhs_indices);
|
|
||||||
auto result =
|
|
||||||
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
|
|
||||||
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>(
|
|
||||||
op, element_type, {l, r, result}, &rewriter);
|
|
||||||
if (op_result == nullptr) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
rewriter.create<AffineStoreOp>(loc, op_result, op.output(), result_indices);
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -99,22 +102,22 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
|||||||
if (lhs_type.getShape() != rhs_type.getShape()) {
|
if (lhs_type.getShape() != rhs_type.getShape()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
const auto& shape = lhs_type.getShape();
|
|
||||||
SmallVector<Value, 4> induction_vars;
|
LogicalResult map_status = success();
|
||||||
const auto loc = op.getLoc();
|
auto body_builder = [&](OpBuilder& builder, Location loc,
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
ValueRange induction_vars) {
|
||||||
auto forOp = rewriter.create<AffineForOp>(loc, 0, shape[i]);
|
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
|
||||||
induction_vars.push_back(forOp.getInductionVar());
|
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
|
||||||
rewriter.setInsertionPointToStart(forOp.getBody());
|
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
||||||
}
|
op, element_type, {l, r}, &builder);
|
||||||
auto l = rewriter.create<AffineLoadOp>(loc, lhs, induction_vars);
|
map_status = success(op_result != nullptr);
|
||||||
auto r = rewriter.create<AffineLoadOp>(loc, rhs, induction_vars);
|
if (failed(map_status)) return;
|
||||||
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
rewriter.create<AffineStoreOp>(loc, op_result, op.out(), induction_vars);
|
||||||
op, element_type, {l, r}, &rewriter);
|
};
|
||||||
if (opResult == nullptr) {
|
|
||||||
return failure();
|
BuildBoundedAffineLoopNest(rewriter, op.getLoc(), lhs_type.getShape(),
|
||||||
}
|
body_builder);
|
||||||
rewriter.create<AffineStoreOp>(loc, opResult, op.out(), induction_vars);
|
if (failed(map_status)) return failure();
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user