Materialize broadcasts for more xla_hlo ops, respect arg element types.

Handling xla_hlo.compare revealed that the conversion should keep argument element types for the generated broadcast_in_dims ops, rather than rely on the result element type matching argument element types. Note the test case where the function signature is `(tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`.

PiperOrigin-RevId: 292040732
Change-Id: I998104bc5d238f006f17dcaedd94496880900403
This commit is contained in:
Scott Todd 2020-01-28 16:41:26 -08:00 committed by TensorFlower Gardener
parent f779661b60
commit b862920996
2 changed files with 230 additions and 62 deletions

View File

@ -55,18 +55,6 @@ func @addBroadcastScalar(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf
// -----
// TODO(scotttodd): Check if this use of dynamic shapes should pass verification
// CHECK-LABEL: @addBroadcastLhsDynamicShape
func @addBroadcastLhsDynamicShape(%arg0: tensor<?xf32>, %arg1: tensor<1x3xf32>) -> tensor<?x3xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>) -> tensor<?x3xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x3xf32>) -> tensor<?x3xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<?x3xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<?xf32>, tensor<1x3xf32>) -> tensor<?x3xf32>
return %0 : tensor<?x3xf32>
}
// -----
// CHECK-LABEL: @addWithoutBroadcast
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<4xf32>
@ -85,6 +73,17 @@ func @addUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// -----
// CHECK-LABEL: @atan2BroadcastRhs
func @atan2BroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.atan2 %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.atan2"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @divBroadcastRhs
func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
@ -140,6 +139,50 @@ func @powBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x
// -----
// CHECK-LABEL: @remainderBroadcastRhs
func @remainderBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.remainder %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @shiftLeftBroadcastRhs
func @shiftLeftBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_left %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.shift_left"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @shiftRightArithmeticBroadcastRhs
func @shiftRightArithmeticBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_arithmetic %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @shiftRightLogicalBroadcastRhs
func @shiftRightLogicalBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_logical %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32>
%0 = "xla_hlo.shift_right_logical"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @subBroadcastRhs
func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
@ -148,3 +191,47 @@ func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x
%0 = "xla_hlo.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// CHECK-LABEL: @andBroadcastRhs
func @andBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xi32>) -> tensor<1x4xi32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.and %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xi32>
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
// CHECK-LABEL: @orBroadcastRhs
func @orBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xi32>) -> tensor<1x4xi32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.or %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xi32>
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
// CHECK-LABEL: @xorBroadcastRhs
func @xorBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xi32>) -> tensor<1x4xi32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32>
// CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.xor %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xi32>
%0 = "xla_hlo.xor"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
// CHECK-LABEL: @compareBroadcastRhs
func @compareBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xi1> {
// CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = "xla_hlo.compare"(%[[BROADCAST0]], %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>
return %0 : tensor<1x4xi1>
}

View File

@ -41,6 +41,83 @@ static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
return DenseIntElementsAttr::get(ty, vals);
}
// Helper function for OpRewritePattern classes to materialize broadcasts on
// LHS and RHS arguments to a binary op.
//
// Returns true and sets out_lhs and out_rhs to BroadcastInDimOps if successful,
// returns false otherwise.
template <typename SrcOp>
bool CreateBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter,
Value *out_lhs, Value *out_rhs) {
if (!op.broadcast_dimensions().hasValue()) {
// Note: the op may still have an implicit broadcast on it, such as
// for (tensor<1xf32>, tensor<4xf32>).
return false;
}
// Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
// replacing the original LHS and RHS args in the source op with the results
// of the broadcasts.
//
// If the higher dimensional argument does not actually need the broadcast,
// a canonicalization pass should be able to remove that op later.
Value lhs = op.lhs();
Value rhs = op.rhs();
auto op_ranked_type = op.getType().template dyn_cast<RankedTensorType>();
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
if (!op_ranked_type || !lhs_ranked_type || !rhs_ranked_type) {
// Unranked, can't determine at this point how to perform the broadcast.
return false;
}
if (!op_ranked_type.hasStaticShape()) {
// Dynamic result shape, can't use BroadcastInDimOp.
return false;
}
auto lhs_rank = lhs_ranked_type.getRank();
auto rhs_rank = rhs_ranked_type.getRank();
// Set broadcast_dimensions to [0, ..., rank] for the higher rank arg.
// Use the original op.broadcast_dimensions for the lower rank arg.
auto higher_rank_broadcast_dims =
GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter);
DenseIntElementsAttr lhs_broadcast_dims;
DenseIntElementsAttr rhs_broadcast_dims;
if (lhs_rank > rhs_rank) {
lhs_broadcast_dims = higher_rank_broadcast_dims;
rhs_broadcast_dims = op.broadcast_dimensions().getValue();
} else if (lhs_rank < rhs_rank) {
lhs_broadcast_dims = op.broadcast_dimensions().getValue();
rhs_broadcast_dims = higher_rank_broadcast_dims;
} else {
// This shouldn't happen for legal ops. If the broadcast_dimensions
// attribute is set, the ranks should be different.
// TODO(scotttodd): Add a custom verification for ops and assert here.
return false;
}
// BroadcastInDimOp must have the same element type for operands and results,
// so preserve the original output shape and the original input element type.
// For example, `SrcOp (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`:
// broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32>
// broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32>
// SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1>
ArrayRef<int64_t> op_shape = op_ranked_type.getShape();
auto lhs_type =
RankedTensorType::get(op_shape, lhs_ranked_type.getElementType());
auto rhs_type =
RankedTensorType::get(op_shape, rhs_ranked_type.getElementType());
*out_lhs = rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), lhs_type,
lhs, lhs_broadcast_dims);
*out_rhs = rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), rhs_type,
rhs, rhs_broadcast_dims);
return true;
}
template <typename SrcOp>
struct BinaryOpWithBroadcastConvert : public OpRewritePattern<SrcOp> {
explicit BinaryOpWithBroadcastConvert(MLIRContext *context)
@ -48,62 +125,36 @@ struct BinaryOpWithBroadcastConvert : public OpRewritePattern<SrcOp> {
PatternMatchResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
if (!op.broadcast_dimensions().hasValue()) {
// Note: the op may still have an implicit broadcast on it, such as
// for (tensor<1xf32>, tensor<4xf32>).
Value new_lhs;
Value new_rhs;
if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) {
return this->matchFailure();
}
auto result_type = op.getType();
// Insert BroadcastInDimOps for the left-hand-side and right-hand-side args,
// replacing the original LHS and RHS args in the source op with the results
// of the broadcasts.
//
// If the higher dimensional argument does not actually need the broadcast,
// a canonicalization pass should be able to remove that op later.
Value lhs = op.lhs();
Value rhs = op.rhs();
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_ranked_type || !rhs_ranked_type) {
// Unranked, can't determine at this point how to perform the broadcast.
return this->matchFailure();
}
auto lhs_rank = lhs_ranked_type.getRank();
auto rhs_rank = rhs_ranked_type.getRank();
// Set broadcast_dimensions to [0, ..., rank] for the higher rank arg.
// Use the original op.broadcast_dimensions for the lower rank arg.
auto higher_rank_broadcast_dims =
GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), &rewriter);
DenseIntElementsAttr lhs_broadcast_dims;
DenseIntElementsAttr rhs_broadcast_dims;
if (lhs_rank > rhs_rank) {
lhs_broadcast_dims = higher_rank_broadcast_dims;
rhs_broadcast_dims = op.broadcast_dimensions().getValue();
} else if (lhs_rank < rhs_rank) {
lhs_broadcast_dims = op.broadcast_dimensions().getValue();
rhs_broadcast_dims = higher_rank_broadcast_dims;
} else {
// This shouldn't happen for legal ops. If the broadcast_dimensions
// attribute is set, the ranks should be different.
// TODO(scotttodd): Add a custom verification for ops and assert here.
return this->matchFailure();
}
lhs = rewriter.createOrFold<BroadcastInDimOp>(op.getLoc(), result_type, lhs,
lhs_broadcast_dims);
rhs = rewriter.createOrFold<BroadcastInDimOp>(op.getLoc(), result_type, rhs,
rhs_broadcast_dims);
// Replace the original op with a new one that uses the new args.
// As the new args are broadcasts, no broadcast dimensions are needed on
// the replacement op.
rewriter.replaceOpWithNewOp<SrcOp>(op, result_type, lhs, rhs,
// New args are broadcasts, so no dims are needed on the replacement op.
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), new_lhs, new_rhs,
/*broadcast_dims=*/nullptr);
return this->matchSuccess();
}
};
// Specialized class for CompareOp, as it has an additional builder argument.
struct CompareWithBroadcastConvert : public OpRewritePattern<CompareOp> {
explicit CompareWithBroadcastConvert(MLIRContext *context)
: OpRewritePattern<CompareOp>(context) {}
PatternMatchResult matchAndRewrite(CompareOp op,
PatternRewriter &rewriter) const override {
Value new_lhs;
Value new_rhs;
if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) {
return this->matchFailure();
}
rewriter.replaceOpWithNewOp<CompareOp>(op, op.getType(), new_lhs, new_rhs,
/*broadcast_dims=*/nullptr,
op.comparison_direction());
return this->matchSuccess();
}
};
@ -115,25 +166,55 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context,
#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \
conversionTarget->addDynamicallyLegalOp<OpType>( \
[](OpType op) { return !op.broadcast_dimensions().hasValue(); });
// Binary elementwise ops.
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(DivOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MaxOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MinOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MulOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(PowOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(RemOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftLeftOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightArithmeticOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightLogicalOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(SubOp);
// Binary logical elementwise ops.
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AndOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OrOp);
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(XorOp);
// CompareOp.
ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(CompareOp);
#undef ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST
}
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// Binary elementwise ops.
patterns->insert<BinaryOpWithBroadcastConvert<AddOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<Atan2Op>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<DivOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<MaxOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<MinOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<MulOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<PowOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<RemOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<ShiftLeftOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightArithmeticOp>>(
context);
patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightLogicalOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<SubOp>>(context);
// Binary logical elementwise ops.
patterns->insert<BinaryOpWithBroadcastConvert<AndOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<OrOp>>(context);
patterns->insert<BinaryOpWithBroadcastConvert<XorOp>>(context);
// CompareOp. Note the specialized class instead of using the template.
patterns->insert<CompareWithBroadcastConvert>(context);
}
} // namespace xla_hlo