Materialize broadcasts for more xla_hlo ops, respect arg element types.
Handling xla_hlo.compare revealed that the conversion should keep argument element types for the generated broadcast_in_dims ops, rather than rely on the result element type matching argument element types. Note the test case where the function signature is `(tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`. PiperOrigin-RevId: 292040732 Change-Id: I998104bc5d238f006f17dcaedd94496880900403
This commit is contained in:
parent
f779661b60
commit
b862920996
@ -55,18 +55,6 @@ func @addBroadcastScalar(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf
|
||||
|
||||
// -----
|
||||
|
||||
// TODO(scotttodd): Check if this use of dynamic shapes should pass verification
|
||||
// CHECK-LABEL: @addBroadcastLhsDynamicShape
|
||||
func @addBroadcastLhsDynamicShape(%arg0: tensor<?xf32>, %arg1: tensor<1x3xf32>) -> tensor<?x3xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>) -> tensor<?x3xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x3xf32>) -> tensor<?x3xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<?x3xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<?xf32>, tensor<1x3xf32>) -> tensor<?x3xf32>
|
||||
return %0 : tensor<?x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addWithoutBroadcast
|
||||
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<4xf32>
|
||||
@ -85,6 +73,17 @@ func @addUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @atan2BroadcastRhs
|
||||
func @atan2BroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.atan2 %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.atan2"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @divBroadcastRhs
|
||||
func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
@ -140,6 +139,50 @@ func @powBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @remainderBroadcastRhs
|
||||
func @remainderBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.remainder %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @shiftLeftBroadcastRhs
|
||||
func @shiftLeftBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_left %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.shift_left"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @shiftRightArithmeticBroadcastRhs
|
||||
func @shiftRightArithmeticBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_arithmetic %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @shiftRightLogicalBroadcastRhs
|
||||
func @shiftRightLogicalBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_logical %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
|
||||
%0 = "xla_hlo.shift_right_logical"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @subBroadcastRhs
|
||||
func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
@ -148,3 +191,47 @@ func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x
|
||||
%0 = "xla_hlo.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @andBroadcastRhs
|
||||
func @andBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xi32>) -> tensor<1x4xi32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.and %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xi32>
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @orBroadcastRhs
|
||||
func @orBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xi32>) -> tensor<1x4xi32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.or %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xi32>
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @xorBroadcastRhs
|
||||
func @xorBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xi32>) -> tensor<1x4xi32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.xor %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xi32>
|
||||
%0 = "xla_hlo.xor"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @compareBroadcastRhs
|
||||
func @compareBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xi1> {
|
||||
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = "xla_hlo.compare"(%[[BROADCAST0]], %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>
|
||||
return %0 : tensor<1x4xi1>
|
||||
}
|
||||
|
@ -41,6 +41,83 @@ static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
|
||||
return DenseIntElementsAttr::get(ty, vals);
|
||||
}
|
||||
|
||||
// Helper function for OpRewritePattern classes to materialize broadcasts on
|
||||
// LHS and RHS arguments to a binary op.
|
||||
//
|
||||
// Returns true and sets out_lhs and out_rhs to BroadcastInDimOps if successful,
|
||||
// returns false otherwise.
|
||||
template <typename SrcOp>
|
||||
bool CreateBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter,
|
||||
Value *out_lhs, Value *out_rhs) {
|
||||
if (!op.broadcast_dimensions().hasValue()) {
|
||||
// Note: the op may still have an implicit broadcast on it, such as
|
||||
// for (tensor<1xf32>, tensor<4xf32>).
|
||||
return false;
|
||||
}
|
||||
|
||||
// Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
|
||||
// replacing the original LHS and RHS args in the source op with the results
|
||||
// of the broadcasts.
|
||||
//
|
||||
// If the higher dimensional argument does not actually need the broadcast,
|
||||
// a canonicalization pass should be able to remove that op later.
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
|
||||
auto op_ranked_type = op.getType().template dyn_cast<RankedTensorType>();
|
||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
if (!op_ranked_type || !lhs_ranked_type || !rhs_ranked_type) {
|
||||
// Unranked, can't determine at this point how to perform the broadcast.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op_ranked_type.hasStaticShape()) {
|
||||
// Dynamic result shape, can't use BroadcastInDimOp.
|
||||
return false;
|
||||
}
|
||||
|
||||
auto lhs_rank = lhs_ranked_type.getRank();
|
||||
auto rhs_rank = rhs_ranked_type.getRank();
|
||||
|
||||
// Set broadcast_dimensions to [0, ..., rank] for the higher rank arg.
|
||||
// Use the original op.broadcast_dimensions for the lower rank arg.
|
||||
auto higher_rank_broadcast_dims =
|
||||
GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter);
|
||||
DenseIntElementsAttr lhs_broadcast_dims;
|
||||
DenseIntElementsAttr rhs_broadcast_dims;
|
||||
if (lhs_rank > rhs_rank) {
|
||||
lhs_broadcast_dims = higher_rank_broadcast_dims;
|
||||
rhs_broadcast_dims = op.broadcast_dimensions().getValue();
|
||||
} else if (lhs_rank < rhs_rank) {
|
||||
lhs_broadcast_dims = op.broadcast_dimensions().getValue();
|
||||
rhs_broadcast_dims = higher_rank_broadcast_dims;
|
||||
} else {
|
||||
// This shouldn't happen for legal ops. If the broadcast_dimensions
|
||||
// attribute is set, the ranks should be different.
|
||||
// TODO(scotttodd): Add a custom verification for ops and assert here.
|
||||
return false;
|
||||
}
|
||||
|
||||
// BroadcastInDimOp must have the same element type for operands and results,
|
||||
// so preserve the original output shape and the original input element type.
|
||||
// For example, `SrcOp (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`:
|
||||
// broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
// broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
|
||||
ArrayRef<int64_t> op_shape = op_ranked_type.getShape();
|
||||
auto lhs_type =
|
||||
RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
|
||||
auto rhs_type =
|
||||
RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
|
||||
|
||||
*out_lhs = rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), lhs_type,
|
||||
lhs, lhs_broadcast_dims);
|
||||
*out_rhs = rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), rhs_type,
|
||||
rhs, rhs_broadcast_dims);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename SrcOp>
|
||||
struct BinaryOpWithBroadcastConvert : public OpRewritePattern<SrcOp> {
|
||||
explicit BinaryOpWithBroadcastConvert(MLIRContext *context)
|
||||
@ -48,62 +125,36 @@ struct BinaryOpWithBroadcastConvert : public OpRewritePattern<SrcOp> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.broadcast_dimensions().hasValue()) {
|
||||
// Note: the op may still have an implicit broadcast on it, such as
|
||||
// for (tensor<1xf32>, tensor<4xf32>).
|
||||
Value new_lhs;
|
||||
Value new_rhs;
|
||||
if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) {
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto result_type = op.getType();
|
||||
|
||||
// Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
|
||||
// replacing the original LHS and RHS args in the source op with the results
|
||||
// of the broadcasts.
|
||||
//
|
||||
// If the higher dimensional argument does not actually need the broadcast,
|
||||
// a canonicalization pass should be able to remove that op later.
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
|
||||
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
if (!lhs_ranked_type || !rhs_ranked_type) {
|
||||
// Unranked, can't determine at this point how to perform the broadcast.
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto lhs_rank = lhs_ranked_type.getRank();
|
||||
auto rhs_rank = rhs_ranked_type.getRank();
|
||||
|
||||
// Set broadcast_dimensions to [0, ..., rank] for the higher rank arg.
|
||||
// Use the original op.broadcast_dimensions for the lower rank arg.
|
||||
auto higher_rank_broadcast_dims =
|
||||
GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), &rewriter);
|
||||
DenseIntElementsAttr lhs_broadcast_dims;
|
||||
DenseIntElementsAttr rhs_broadcast_dims;
|
||||
if (lhs_rank > rhs_rank) {
|
||||
lhs_broadcast_dims = higher_rank_broadcast_dims;
|
||||
rhs_broadcast_dims = op.broadcast_dimensions().getValue();
|
||||
} else if (lhs_rank < rhs_rank) {
|
||||
lhs_broadcast_dims = op.broadcast_dimensions().getValue();
|
||||
rhs_broadcast_dims = higher_rank_broadcast_dims;
|
||||
} else {
|
||||
// This shouldn't happen for legal ops. If the broadcast_dimensions
|
||||
// attribute is set, the ranks should be different.
|
||||
// TODO(scotttodd): Add a custom verification for ops and assert here.
|
||||
return this->matchFailure();
|
||||
}
|
||||
lhs = rewriter.createOrFold<BroadcastInDimOp>(op.getLoc(), result_type, lhs,
|
||||
lhs_broadcast_dims);
|
||||
rhs = rewriter.createOrFold<BroadcastInDimOp>(op.getLoc(), result_type, rhs,
|
||||
rhs_broadcast_dims);
|
||||
|
||||
// Replace the original op with a new one that uses the new args.
|
||||
// As the new args are broadcasts, no broadcast dimensions are needed on
|
||||
// the replacement op.
|
||||
rewriter.replaceOpWithNewOp<SrcOp>(op, result_type, lhs, rhs,
|
||||
// New args are broadcasts, so no dims are needed on the replacement op.
|
||||
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), new_lhs, new_rhs,
|
||||
/*broadcast_dims=*/nullptr);
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Specialized class for CompareOp, as it has an additional builder argument.
|
||||
struct CompareWithBroadcastConvert : public OpRewritePattern<CompareOp> {
|
||||
explicit CompareWithBroadcastConvert(MLIRContext *context)
|
||||
: OpRewritePattern<CompareOp>(context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(CompareOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value new_lhs;
|
||||
Value new_rhs;
|
||||
if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) {
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<CompareOp>(op, op.getType(), new_lhs, new_rhs,
|
||||
/*broadcast_dims=*/nullptr,
|
||||
op.comparison_direction());
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
@ -115,25 +166,55 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context,
|
||||
#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \
|
||||
conversionTarget->addDynamicallyLegalOp<OpType>( \
|
||||
[](OpType op) { return !op.broadcast_dimensions().hasValue(); });
|
||||
// Binary elementwise ops.
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(DivOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MaxOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MinOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MulOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(PowOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(RemOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftLeftOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightArithmeticOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightLogicalOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(SubOp);
|
||||
|
||||
// Binary logical elementwise ops.
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AndOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OrOp);
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(XorOp);
|
||||
|
||||
// CompareOp.
|
||||
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(CompareOp);
|
||||
|
||||
#undef ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST
|
||||
}
|
||||
|
||||
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
// Binary elementwise ops.
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<AddOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<Atan2Op>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<DivOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<MaxOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<MinOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<MulOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<PowOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<RemOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<ShiftLeftOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightArithmeticOp>>(
|
||||
context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightLogicalOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<SubOp>>(context);
|
||||
|
||||
// Binary logical elementwise ops.
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<AndOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<OrOp>>(context);
|
||||
patterns->insert<BinaryOpWithBroadcastConvert<XorOp>>(context);
|
||||
|
||||
// CompareOp. Note the specialized class instead of using the template.
|
||||
patterns->insert<CompareWithBroadcastConvert>(context);
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
|
Loading…
Reference in New Issue
Block a user