diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 20807467d45..29f3eb9a8f5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -31,9 +31,11 @@ using mlir::OwningRewritePatternList; using mlir::PassRegistration; namespace mlir { -namespace xla_hlo { namespace { #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_to_standard.inc" +} // end anonymous namespace +namespace xla_hlo { +namespace { struct CompareIConvert : public RewritePattern { explicit CompareIConvert(MLIRContext *context) @@ -131,7 +133,7 @@ mlir::xla_hlo::createLegalizeToStdPass() { void mlir::xla_hlo::PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, mlir::MLIRContext *ctx) { - mlir::xla_hlo::populateWithGenerated(ctx, patterns); + mlir::populateWithGenerated(ctx, patterns); patterns ->insert( ctx); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td index fa6f3a5f378..2772b796298 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td @@ -28,6 +28,11 @@ def IsSameSizePred : CPred< "== $1->getType().cast().getShape()">; def IsSameSizeConstraint : Constraint; + +def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r, + IsNullAttr:$broadcast_dimensions), + (AndOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r, IsNullAttr:$broadcast_dimensions), (AddFOp $l, $r),