Refactor mhlo->tf legalization for mhlo::ReduceOp

* Improve reduction function matching to be more strict against invalid
  reduction functions.
* Share more code between the 3 reduce op rewrite patterns.
* Move the reduction function matching into a new function for future
 reuse by the mhlo::ReduceWindowOp legaliser.

PiperOrigin-RevId: 341637842
Change-Id: I9546edd8c6be6a5d54e676bd040d84dc024c2125
This commit is contained in:
A. Unique TensorFlower 2020-11-10 09:42:44 -08:00 committed by TensorFlower Gardener
parent c0ec1e2525
commit 7d2981e88c

View File

@ -282,7 +282,7 @@ void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
// Appends all elements in `range` to `values`.
template <typename ValueT, typename Range, typename... RangeTs>
void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
RangeTs &&... ranges) {
RangeTs &&...ranges) {
values.insert(values.end(), range.begin(), range.end());
Append(values, ranges...);
}
@ -295,13 +295,13 @@ size_t Size(Range &&range) {
// Returns the total number of elements in a variadic number of `ranges`.
template <typename Range, typename... RangeTs>
size_t Size(Range &&range, RangeTs &&... ranges) {
size_t Size(Range &&range, RangeTs &&...ranges) {
return range.size() + Size(std::forward<RangeTs>(ranges)...);
}
// Concats all elements in `ranges` and returns a small vector as a result.
template <typename ValueT, typename... RangeTs>
llvm::SmallVector<ValueT, 4> Concat(RangeTs &&... ranges) {
llvm::SmallVector<ValueT, 4> Concat(RangeTs &&...ranges) {
llvm::SmallVector<int64_t, 4> results;
results.reserve(Size(std::forward<RangeTs>(ranges)...));
Append(results, std::forward<RangeTs>(ranges)...);
@ -472,29 +472,34 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
return reshaped.getResult();
}
// This function tries to match that the "mhlo::ReduceOp" only has one
// input, one init_value and one result. Also "mhlo::ReduceOp" has two ops
// in the region, and the last one is return op.
LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) {
if (reduce_op.operands().size() != 1 || reduce_op.init_values().size() != 1 ||
reduce_op.getResults().size() != 1)
// Checks if the specified region is a binary reduction function what takes 2
// inputs, passes it to an instance of the specifiied reduction op and then
// returns the result.
template <typename ReductionOp>
LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
Block &body = function.front();
if (body.getNumArguments() != 2) return failure();
if (body.getOperations().size() != 2) return failure();
ReductionOp reduce_op = dyn_cast<ReductionOp>(body.front());
if (!reduce_op) return failure();
if (reduce_op.lhs() != body.getArgument(0) ||
reduce_op.rhs() != body.getArgument(1))
return failure();
if (!reduce_op.operands()[0].getType().isa<RankedTensorType>())
return failure();
if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
auto block = &reduce_op.body().front();
if (block->getOperations().size() != 2 || isa<ReturnOp>(block->back()))
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
if (!return_op) return failure();
if (return_op.getNumOperands() != 1 ||
return_op.results().front() != reduce_op)
return failure();
return success();
}
// TODO(jingpu): This "mhlo::ReduceOp" can corresponds to many TF ops
// with different ops in reduce_op.body. Now we only match to "tf.Max", "tf.Min"
// and "tf.Sum".
class ConvertReduceOpToTfSum : public OpConversionPattern<mhlo::ReduceOp> {
// Converts an mhlo.reduce op with the specified BinaryOp as the reduction
// operation into the specified TfOp.
template <typename BinaryOp, typename TfOp>
class ConvertReduceOpToTfOp : public OpConversionPattern<mhlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -503,116 +508,96 @@ class ConvertReduceOpToTfSum : public OpConversionPattern<mhlo::ReduceOp> {
ConversionPatternRewriter &rewriter) const final {
if (failed(MatchReduceOpInput(reduce_op))) return failure();
Operation *first_op = &reduce_op.body().front().front();
if (!llvm::isa<mhlo::AddOp>(first_op)) return failure();
if (failed(MatchBinaryReduceFunction<BinaryOp>(reduce_op.body())))
return failure();
// In `MatchReduceOpInput` function, we already match that the
// "mhlo::ReduceOp" only has one input, one init_value and one result.
if (failed(MatchInitValue(reduce_op.init_values()[0]))) return failure();
auto input = reduce_op.operands()[0];
// Get reduction dimension.
DenseIntElementsAttr dimension = reduce_op.dimensions();
SmallVector<int64_t, 4> reduce_dims;
for (const int64_t &dim : dimension.getValues<int64_t>()) {
reduce_dims.emplace_back(dim);
}
// Check initial value is zero.
DenseFPElementsAttr init_value;
if (!matchPattern(reduce_op.init_values()[0], m_Constant(&init_value)) ||
!init_value.isSplat() || !init_value.getSplatValue<APFloat>().isZero())
return failure();
auto dim_type = RankedTensorType::get(
{static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
auto reduction_indices = rewriter.create<ConstOp>(
reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
rewriter.replaceOpWithNewOp<SumOp>(
reduce_op, reduce_op.getType(0), input, reduction_indices,
/*keep_dim=*/rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<TfOp>(reduce_op, reduce_op.getType(0), input,
reduction_indices,
/*keep_dim=*/rewriter.getBoolAttr(false));
return success();
};
}
private:
// Checks that the init value matches with the init value expected for the
// target TfOp.
virtual LogicalResult MatchInitValue(Value init_value) const = 0;
// This function tries to match that the "mhlo::ReduceOp" only has one
// input, one init_value and one result.
LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) const {
if (reduce_op.operands().size() != 1 ||
reduce_op.init_values().size() != 1 ||
reduce_op.getResults().size() != 1)
return failure();
if (!reduce_op.operands()[0].getType().isa<RankedTensorType>())
return failure();
if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
return success();
}
};
class ConvertReduceOpToTfMax : public OpConversionPattern<mhlo::ReduceOp> {
class ConvertReduceOpToTfSum
: public ConvertReduceOpToTfOp<mhlo::AddOp, TF::SumOp> {
public:
using OpConversionPattern::OpConversionPattern;
using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
LogicalResult matchAndRewrite(
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
if (failed(MatchReduceOpInput(reduce_op))) return failure();
Operation *first_op = &reduce_op.body().front().front();
if (!llvm::isa<mhlo::MaxOp>(first_op)) return failure();
// In `MatchReduceOpInput` function, we already match that the
// "mhlo::ReduceOp" only has one input, one init_value and one result.
auto input = reduce_op.operands()[0];
// Get reduction dimension.
DenseIntElementsAttr dimension = reduce_op.dimensions();
SmallVector<int64_t, 4> reduce_dims;
for (const int64_t &dim : dimension.getValues<int64_t>()) {
reduce_dims.emplace_back(dim);
}
// Check initial value is float.minimum.
DenseFPElementsAttr init_value;
if (!matchPattern(reduce_op.init_values()[0], m_Constant(&init_value)) ||
!init_value.isSplat() ||
!init_value.getSplatValue<APFloat>().isInfinity() ||
!init_value.getSplatValue<APFloat>().isNegative())
LogicalResult MatchInitValue(Value init_value) const override {
DenseFPElementsAttr init_attr;
if (!matchPattern(init_value, m_Constant(&init_attr)) ||
!init_attr.isSplat() || !init_attr.getSplatValue<APFloat>().isZero())
return failure();
auto dim_type = RankedTensorType::get(
{static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
auto reduction_indices = rewriter.create<ConstOp>(
reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
rewriter.replaceOpWithNewOp<MaxOp>(
reduce_op, reduce_op.getType(0), input, reduction_indices,
/*keep_dim=*/rewriter.getBoolAttr(false));
return success();
};
}
};
class ConvertReduceOpToTfMin : public OpConversionPattern<mhlo::ReduceOp> {
class ConvertReduceOpToTfMax
: public ConvertReduceOpToTfOp<mhlo::MaxOp, TF::MaxOp> {
public:
using OpConversionPattern::OpConversionPattern;
using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
LogicalResult matchAndRewrite(
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
if (failed(MatchReduceOpInput(reduce_op))) return failure();
Operation *first_op = &reduce_op.body().front().front();
if (!llvm::isa<mhlo::MinOp>(first_op)) return failure();
// In `MatchReduceOpInput` function, we already match that the
// "mhlo::ReduceOp" only has one input, one init_value and one result.
Value input = reduce_op.operands()[0];
// Get reduction dimension.
DenseIntElementsAttr dimension = reduce_op.dimensions();
SmallVector<int64_t, 4> reduce_dims;
for (const int64_t &dim : dimension.getValues<int64_t>()) {
reduce_dims.emplace_back(dim);
}
// Check initial value is +INF.
DenseFPElementsAttr init_value;
if (!matchPattern(reduce_op.init_values()[0], m_Constant(&init_value)) ||
!init_value.isSplat() ||
!init_value.getSplatValue<APFloat>().isInfinity() ||
init_value.getSplatValue<APFloat>().isNegative())
LogicalResult MatchInitValue(Value init_value) const override {
DenseFPElementsAttr init_attr;
if (!matchPattern(init_value, m_Constant(&init_attr)) ||
!init_attr.isSplat() ||
!init_attr.getSplatValue<APFloat>().isInfinity() ||
!init_attr.getSplatValue<APFloat>().isNegative())
return failure();
auto dim_type = RankedTensorType::get(
{static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
auto reduction_indices = rewriter.create<ConstOp>(
reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
rewriter.replaceOpWithNewOp<MinOp>(
reduce_op, reduce_op.getType(0), input, reduction_indices,
/*keep_dim=*/rewriter.getBoolAttr(false));
return success();
};
}
};
class ConvertReduceOpToTfMin
: public ConvertReduceOpToTfOp<mhlo::MinOp, TF::MinOp> {
public:
using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
LogicalResult MatchInitValue(Value init_value) const override {
DenseFPElementsAttr init_attr;
if (!matchPattern(init_value, m_Constant(&init_attr)) ||
!init_attr.isSplat() ||
!init_attr.getSplatValue<APFloat>().isInfinity() ||
init_attr.getSplatValue<APFloat>().isNegative())
return failure();
return success();
}
};
class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {