diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index bedf77f726a..7240f85ed3c 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -1202,6 +1202,15 @@ func @DontConvertSqueezeToReshape(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: return %[[RESULT]] } +func @DontConvertSqueezeToReshapeOnMultiDynamicDims(%arg0: tensor) -> tensor { + %0 = "tfl.squeeze"(%arg0) {squeeze_dims = [0]}: (tensor) -> tensor + return %0: tensor + +// CHECK-LABEL: DontConvertSqueezeToReshapeOnMultiDynamicDims +// CHECK: %[[RESULT:.*]] = "tfl.squeeze"(%arg0) +// CHECK: return %[[RESULT]] +} + func @ConvertPow1ToIdentity(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<1.000000e+00> : tensor %0 = "tfl.pow"(%arg0, %cst) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 8b12cefd07a..fa266c5e44e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -473,14 +473,15 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; -// Returns True if the operand type is RankedTensorType. -def HasRankedTensor : Constraint< - CPred<"$0.getType().isa()">>; +// Returns True if the operand type is RankedTensorType and valid. +def HasValidRankedTensor : Constraint() && " + "$0.getType().cast().getNumDynamicDims() <= 1">>; def ConvertSqueezeToReshape : Pat< (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))), - [(HasRankedTensor $squeeze_op)]>; + [(HasValidRankedTensor $squeeze_op)]>; // Convert expand_dims to reshape if possible. def ConvertExpandDimsToReshape : Pat<