diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index abaad272acd..0ce90f0a445 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -598,6 +598,28 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) { return success(); } +// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary +// BroadcastInDimOp. +class DynamicBroadcastInDimOpNotActuallyDynamic + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, + PatternRewriter& rewriter) const override { + auto type = op.getType().dyn_cast(); + if (!type || !type.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "requires static shape"); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), op.operand(), op.broadcast_dimensions()); + return success(); + } +}; + +void DynamicBroadcastInDimOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ClampOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index bc05a1c100c..1c38d3ae3e1 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -802,6 +802,7 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", let results = (outs HLO_Tensor); + let hasCanonicalizer = 1; // Cannot be exported to legacy formats. let hasCustomHLOConverter = 1; } diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index 18a29968600..1b7d879ca03 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -31,6 +31,14 @@ func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: return %1 : tensor<1x4xi32> } +// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic +func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { + // CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> + %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> + // CHECK: return %[[RESULT]] : tensor<5x4xf32> + return %0 : tensor<5x4xf32> +} + // CHECK-LABEL: @complex_expand_fold func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index e910020de18..ac3d31bc25e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -941,7 +941,7 @@ class ConvertBroadcastToOp : public OpRewritePattern { auto input_type = op.input().getType().dyn_cast(); auto output_type = op.output().getType().dyn_cast(); if (!input_type || !output_type) { - return failure(); + return rewriter.notifyMatchFailure(op, "requires ranked shape"); } auto rank_diff = output_type.getRank() - input_type.getRank(); // The tf.BroadcastTo op performs "right-aligned" numpy-style broadcasting.