Do not convert squeeze op to reshape when there are multiple dynamic dims

This change closes #40504 and fixes the regression described in the #40504.

PiperOrigin-RevId: 339352294
Change-Id: Ice688af970c832c6743ebb0bb6f40c9fd0b51446
This commit is contained in:
Jaesung Chung 2020-10-27 16:14:16 -07:00 committed by TensorFlower Gardener
parent 33c556c11d
commit b801054029
2 changed files with 14 additions and 4 deletions

View File

@ -1202,6 +1202,15 @@ func @DontConvertSqueezeToReshape(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: return %[[RESULT]]
}
func @DontConvertSqueezeToReshapeOnMultiDynamicDims(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = "tfl.squeeze"(%arg0) {squeeze_dims = [0]}: (tensor<?x?xf32>) -> tensor<?x?xf32>
return %0: tensor<?x?xf32>
// CHECK-LABEL: DontConvertSqueezeToReshapeOnMultiDynamicDims
// CHECK: %[[RESULT:.*]] = "tfl.squeeze"(%arg0)
// CHECK: return %[[RESULT]]
}
func @ConvertPow1ToIdentity(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<1.000000e+00> : tensor<f32>
%0 = "tfl.pow"(%arg0, %cst) : (tensor<2x2xf32>, tensor<f32>) -> tensor<2x2xf32>

View File

@ -473,14 +473,15 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
// if called without a ranked tensor it will fail.
def GetShape: NativeCodeCall<"GetShape($0)">;
// Returns True if the operand type is RankedTensorType.
def HasRankedTensor : Constraint<
CPred<"$0.getType().isa<RankedTensorType>()">>;
// Returns True if the operand type is RankedTensorType and valid.
def HasValidRankedTensor : Constraint<CPred<
"$0.getType().isa<RankedTensorType>() && "
"$0.getType().cast<RankedTensorType>().getNumDynamicDims() <= 1">>;
def ConvertSqueezeToReshape : Pat<
(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
(TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))),
[(HasRankedTensor $squeeze_op)]>;
[(HasValidRankedTensor $squeeze_op)]>;
// Convert expand_dims to reshape if possible.
def ConvertExpandDimsToReshape : Pat<