Refactor mhlo->tf legalization for mhlo::ReduceOp
* Improve reduction function matching to be more strict against invalid reduction functions. * Share more code between the 3 reduce op rewrite patterns. * Move the reduction function matching into a new function for future reuse by the mhlo::ReduceWindowOp legaliser. PiperOrigin-RevId: 341637842 Change-Id: I9546edd8c6be6a5d54e676bd040d84dc024c2125
This commit is contained in:
parent
c0ec1e2525
commit
7d2981e88c
@ -282,7 +282,7 @@ void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
|
||||
// Appends all elements in `range` to `values`.
|
||||
template <typename ValueT, typename Range, typename... RangeTs>
|
||||
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
|
||||
RangeTs &&... ranges) {
|
||||
RangeTs &&...ranges) {
|
||||
values.insert(values.end(), range.begin(), range.end());
|
||||
Append(values, ranges...);
|
||||
}
|
||||
@ -295,13 +295,13 @@ size_t Size(Range &&range) {
|
||||
|
||||
// Returns the total number of elements in a variadic number of `ranges`.
|
||||
template <typename Range, typename... RangeTs>
|
||||
size_t Size(Range &&range, RangeTs &&... ranges) {
|
||||
size_t Size(Range &&range, RangeTs &&...ranges) {
|
||||
return range.size() + Size(std::forward<RangeTs>(ranges)...);
|
||||
}
|
||||
|
||||
// Concats all elements in `ranges` and returns a small vector as a result.
|
||||
template <typename ValueT, typename... RangeTs>
|
||||
llvm::SmallVector<ValueT, 4> Concat(RangeTs &&... ranges) {
|
||||
llvm::SmallVector<ValueT, 4> Concat(RangeTs &&...ranges) {
|
||||
llvm::SmallVector<int64_t, 4> results;
|
||||
results.reserve(Size(std::forward<RangeTs>(ranges)...));
|
||||
Append(results, std::forward<RangeTs>(ranges)...);
|
||||
@ -472,29 +472,34 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
return reshaped.getResult();
|
||||
}
|
||||
|
||||
// This function tries to match that the "mhlo::ReduceOp" only has one
|
||||
// input, one init_value and one result. Also "mhlo::ReduceOp" has two ops
|
||||
// in the region, and the last one is return op.
|
||||
LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) {
|
||||
if (reduce_op.operands().size() != 1 || reduce_op.init_values().size() != 1 ||
|
||||
reduce_op.getResults().size() != 1)
|
||||
// Checks if the specified region is a binary reduction function what takes 2
|
||||
// inputs, passes it to an instance of the specifiied reduction op and then
|
||||
// returns the result.
|
||||
template <typename ReductionOp>
|
||||
LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
|
||||
Block &body = function.front();
|
||||
if (body.getNumArguments() != 2) return failure();
|
||||
if (body.getOperations().size() != 2) return failure();
|
||||
|
||||
ReductionOp reduce_op = dyn_cast<ReductionOp>(body.front());
|
||||
if (!reduce_op) return failure();
|
||||
if (reduce_op.lhs() != body.getArgument(0) ||
|
||||
reduce_op.rhs() != body.getArgument(1))
|
||||
return failure();
|
||||
|
||||
if (!reduce_op.operands()[0].getType().isa<RankedTensorType>())
|
||||
return failure();
|
||||
if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
|
||||
|
||||
auto block = &reduce_op.body().front();
|
||||
if (block->getOperations().size() != 2 || isa<ReturnOp>(block->back()))
|
||||
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
|
||||
if (!return_op) return failure();
|
||||
if (return_op.getNumOperands() != 1 ||
|
||||
return_op.results().front() != reduce_op)
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(jingpu): This "mhlo::ReduceOp" can corresponds to many TF ops
|
||||
// with different ops in reduce_op.body. Now we only match to "tf.Max", "tf.Min"
|
||||
// and "tf.Sum".
|
||||
class ConvertReduceOpToTfSum : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
// Converts an mhlo.reduce op with the specified BinaryOp as the reduction
|
||||
// operation into the specified TfOp.
|
||||
template <typename BinaryOp, typename TfOp>
|
||||
class ConvertReduceOpToTfOp : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@ -503,116 +508,96 @@ class ConvertReduceOpToTfSum : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (failed(MatchReduceOpInput(reduce_op))) return failure();
|
||||
|
||||
Operation *first_op = &reduce_op.body().front().front();
|
||||
if (!llvm::isa<mhlo::AddOp>(first_op)) return failure();
|
||||
if (failed(MatchBinaryReduceFunction<BinaryOp>(reduce_op.body())))
|
||||
return failure();
|
||||
|
||||
// In `MatchReduceOpInput` function, we already match that the
|
||||
// "mhlo::ReduceOp" only has one input, one init_value and one result.
|
||||
if (failed(MatchInitValue(reduce_op.init_values()[0]))) return failure();
|
||||
|
||||
auto input = reduce_op.operands()[0];
|
||||
|
||||
// Get reduction dimension.
|
||||
DenseIntElementsAttr dimension = reduce_op.dimensions();
|
||||
SmallVector<int64_t, 4> reduce_dims;
|
||||
for (const int64_t &dim : dimension.getValues<int64_t>()) {
|
||||
reduce_dims.emplace_back(dim);
|
||||
}
|
||||
|
||||
// Check initial value is zero.
|
||||
DenseFPElementsAttr init_value;
|
||||
if (!matchPattern(reduce_op.init_values()[0], m_Constant(&init_value)) ||
|
||||
!init_value.isSplat() || !init_value.getSplatValue<APFloat>().isZero())
|
||||
return failure();
|
||||
|
||||
auto dim_type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
|
||||
auto reduction_indices = rewriter.create<ConstOp>(
|
||||
reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
|
||||
rewriter.replaceOpWithNewOp<SumOp>(
|
||||
reduce_op, reduce_op.getType(0), input, reduction_indices,
|
||||
/*keep_dim=*/rewriter.getBoolAttr(false));
|
||||
|
||||
rewriter.replaceOpWithNewOp<TfOp>(reduce_op, reduce_op.getType(0), input,
|
||||
reduction_indices,
|
||||
/*keep_dim=*/rewriter.getBoolAttr(false));
|
||||
return success();
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
// Checks that the init value matches with the init value expected for the
|
||||
// target TfOp.
|
||||
virtual LogicalResult MatchInitValue(Value init_value) const = 0;
|
||||
|
||||
// This function tries to match that the "mhlo::ReduceOp" only has one
|
||||
// input, one init_value and one result.
|
||||
LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) const {
|
||||
if (reduce_op.operands().size() != 1 ||
|
||||
reduce_op.init_values().size() != 1 ||
|
||||
reduce_op.getResults().size() != 1)
|
||||
return failure();
|
||||
|
||||
if (!reduce_op.operands()[0].getType().isa<RankedTensorType>())
|
||||
return failure();
|
||||
if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertReduceOpToTfMax : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
class ConvertReduceOpToTfSum
|
||||
: public ConvertReduceOpToTfOp<mhlo::AddOp, TF::SumOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (failed(MatchReduceOpInput(reduce_op))) return failure();
|
||||
|
||||
Operation *first_op = &reduce_op.body().front().front();
|
||||
if (!llvm::isa<mhlo::MaxOp>(first_op)) return failure();
|
||||
|
||||
// In `MatchReduceOpInput` function, we already match that the
|
||||
// "mhlo::ReduceOp" only has one input, one init_value and one result.
|
||||
auto input = reduce_op.operands()[0];
|
||||
// Get reduction dimension.
|
||||
DenseIntElementsAttr dimension = reduce_op.dimensions();
|
||||
SmallVector<int64_t, 4> reduce_dims;
|
||||
for (const int64_t &dim : dimension.getValues<int64_t>()) {
|
||||
reduce_dims.emplace_back(dim);
|
||||
}
|
||||
|
||||
// Check initial value is float.minimum.
|
||||
DenseFPElementsAttr init_value;
|
||||
if (!matchPattern(reduce_op.init_values()[0], m_Constant(&init_value)) ||
|
||||
!init_value.isSplat() ||
|
||||
!init_value.getSplatValue<APFloat>().isInfinity() ||
|
||||
!init_value.getSplatValue<APFloat>().isNegative())
|
||||
LogicalResult MatchInitValue(Value init_value) const override {
|
||||
DenseFPElementsAttr init_attr;
|
||||
if (!matchPattern(init_value, m_Constant(&init_attr)) ||
|
||||
!init_attr.isSplat() || !init_attr.getSplatValue<APFloat>().isZero())
|
||||
return failure();
|
||||
|
||||
auto dim_type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
|
||||
auto reduction_indices = rewriter.create<ConstOp>(
|
||||
reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
|
||||
rewriter.replaceOpWithNewOp<MaxOp>(
|
||||
reduce_op, reduce_op.getType(0), input, reduction_indices,
|
||||
/*keep_dim=*/rewriter.getBoolAttr(false));
|
||||
return success();
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertReduceOpToTfMin : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
class ConvertReduceOpToTfMax
|
||||
: public ConvertReduceOpToTfOp<mhlo::MaxOp, TF::MaxOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (failed(MatchReduceOpInput(reduce_op))) return failure();
|
||||
|
||||
Operation *first_op = &reduce_op.body().front().front();
|
||||
if (!llvm::isa<mhlo::MinOp>(first_op)) return failure();
|
||||
|
||||
// In `MatchReduceOpInput` function, we already match that the
|
||||
// "mhlo::ReduceOp" only has one input, one init_value and one result.
|
||||
Value input = reduce_op.operands()[0];
|
||||
// Get reduction dimension.
|
||||
DenseIntElementsAttr dimension = reduce_op.dimensions();
|
||||
SmallVector<int64_t, 4> reduce_dims;
|
||||
for (const int64_t &dim : dimension.getValues<int64_t>()) {
|
||||
reduce_dims.emplace_back(dim);
|
||||
}
|
||||
|
||||
// Check initial value is +INF.
|
||||
DenseFPElementsAttr init_value;
|
||||
if (!matchPattern(reduce_op.init_values()[0], m_Constant(&init_value)) ||
|
||||
!init_value.isSplat() ||
|
||||
!init_value.getSplatValue<APFloat>().isInfinity() ||
|
||||
init_value.getSplatValue<APFloat>().isNegative())
|
||||
LogicalResult MatchInitValue(Value init_value) const override {
|
||||
DenseFPElementsAttr init_attr;
|
||||
if (!matchPattern(init_value, m_Constant(&init_attr)) ||
|
||||
!init_attr.isSplat() ||
|
||||
!init_attr.getSplatValue<APFloat>().isInfinity() ||
|
||||
!init_attr.getSplatValue<APFloat>().isNegative())
|
||||
return failure();
|
||||
|
||||
auto dim_type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
|
||||
auto reduction_indices = rewriter.create<ConstOp>(
|
||||
reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
|
||||
rewriter.replaceOpWithNewOp<MinOp>(
|
||||
reduce_op, reduce_op.getType(0), input, reduction_indices,
|
||||
/*keep_dim=*/rewriter.getBoolAttr(false));
|
||||
return success();
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertReduceOpToTfMin
|
||||
: public ConvertReduceOpToTfOp<mhlo::MinOp, TF::MinOp> {
|
||||
public:
|
||||
using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
|
||||
|
||||
LogicalResult MatchInitValue(Value init_value) const override {
|
||||
DenseFPElementsAttr init_attr;
|
||||
if (!matchPattern(init_value, m_Constant(&init_attr)) ||
|
||||
!init_attr.isSplat() ||
|
||||
!init_attr.getSplatValue<APFloat>().isInfinity() ||
|
||||
init_attr.getSplatValue<APFloat>().isNegative())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
|
||||
|
Loading…
x
Reference in New Issue
Block a user