- Add names for patterns in legalize pass in TFL.

- Move pattern for squareddiff->relu to optimize pass from legalize, and add unit-test.
- Remove redundant fusing pattern for add with activation, Should be handled already in optimize pass.

PiperOrigin-RevId: 321889085
Change-Id: I8886667e2167dc476b651a0c125203f4eca8512e
This commit is contained in:
Karim Nosir 2020-07-17 17:54:28 -07:00 committed by TensorFlower Gardener
parent 19448cf8b9
commit 0721b70578
3 changed files with 204 additions and 136 deletions

View File

@ -992,3 +992,13 @@ func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: return %arg0
}
func @squaredDifferenceReluRemoveRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "tfl.relu"(%0) : (tensor<1xf32>) -> tensor<1xf32>
return %1: tensor<1xf32>
// CHECK-LABEL: squaredDifferenceReluRemoveRelu
// CHECK: %[[RESULT:.*]] = tfl.squared_difference %arg0, %arg1 : tensor<1xf32>
// CHECK: return %[[RESULT]]
}

View File

@ -66,8 +66,10 @@ def LegalizeTFConstToTFLConst: Pat<(TF_ConstOp ElementsAttr:$value),
(TFL_ConstOp $value)>;
// Convert to std constant for statically shaped, non-opaque constants.
def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value),
[(AnyStaticShapeTensor $res)], (addBenefit 10)>;
def ConvertTfConstToStdConst : Pat<
(TF_ConstOp:$res NonOpaqueElementsAttr:$value),
(ConstantOp $value),
[(AnyStaticShapeTensor $res)], (addBenefit 10)>;
//===----------------------------------------------------------------------===//
// Unary ops patterns.
@ -162,186 +164,234 @@ def LegalizeMaximum : Pat<(TF_MaximumOp $arg1, $arg2),
def LegalizeMinimum : Pat<(TF_MinimumOp $arg1, $arg2),
(TFL_MinimumOp $arg1, $arg2)>;
def : Pat<(TF_NegOp $arg), (TFL_NegOp $arg)>;
def : Pat<(TF_OneHotOp $indices, $depth, $on_value, $off_value, $axis),
(TFL_OneHotOp $indices, $depth, $on_value, $off_value,
(convertIntAttrTo32Bit $axis))>;
def : Pat<(TF_PowOp $x, $y), (TFL_PowOp $x, $y)>;
def : Pat<(TF_RangeOp $start, $limit, $delta), (TFL_RangeOp $start, $limit, $delta)>;
def : Pat<(TF_Relu6Op $arg), (TFL_Relu6Op $arg)>;
def : Pat<(TF_ReluOp $arg), (TFL_ReluOp $arg)>;
def : Pat<(TF_ReverseSequenceOp $input, $seq_lengths, $seq_dim, $batch_dim),
(TFL_ReverseSequenceOp $input, $seq_lengths,
(convertIntAttrTo32Bit $seq_dim),
(convertIntAttrTo32Bit $batch_dim))>;
def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
def : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids), (TFL_SegmentSumOp $data, $segment_ids)>;
def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>;
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>;
def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>;
def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>;
def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
def : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>;
def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
def : Pat<(TF_SoftplusOp F32Tensor:$arg0), (TFL_LogOp (TFL_AddOp (TFL_ExpOp $arg0), (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">), TFL_AF_None))>;
def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>;
def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
def : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
def LegalizeNeg : Pat<(TF_NegOp $arg), (TFL_NegOp $arg)>;
def LegalizeOneHot : Pat<
(TF_OneHotOp $indices, $depth, $on_value, $off_value, $axis),
(TFL_OneHotOp $indices, $depth, $on_value, $off_value,
(convertIntAttrTo32Bit $axis))>;
def LegalizePow : Pat<(TF_PowOp $x, $y), (TFL_PowOp $x, $y)>;
def LegalizeRange : Pat<(TF_RangeOp $start, $limit, $delta),
(TFL_RangeOp $start, $limit, $delta)>;
def LegalizeRelu6 : Pat<(TF_Relu6Op $arg), (TFL_Relu6Op $arg)>;
def LegalizeRelu : Pat<(TF_ReluOp $arg), (TFL_ReluOp $arg)>;
def LegalizeReverseSequence : Pat<
(TF_ReverseSequenceOp $input, $seq_lengths, $seq_dim, $batch_dim),
(TFL_ReverseSequenceOp $input, $seq_lengths,
(convertIntAttrTo32Bit $seq_dim), (convertIntAttrTo32Bit $batch_dim))>;
def LegalizeRound : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
def LegalizeRsqrt : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids),
(TFL_SegmentSumOp $data, $segment_ids)>;
def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y),
(TFL_SelectOp $cond, $x, $y)>;
def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y),
(TFL_SelectOp $cond, $x, $y),
[(HasSameStaticShapes $src_op)]>;
def LegalizeSelectV2NotSameStaticShape : Pat<
(TF_SelectV2Op:$src_op $cond, $x, $y),
(TFL_SelectV2Op $cond, $x, $y),
[(HasNotSameStaticShapes $src_op)]>;
def LegalizeShape : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>;
def LegalizeSigmoid : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
def LegalizeSin : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
def LegalizeSlice : Pat<(TF_SliceOp $input, $begin, $size),
(TFL_SliceOp $input, $begin, $size)>;
def LegalizeSoftmax : Pat<(TF_SoftmaxOp $arg),
(TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
def LegalizeSoftPlus : Pat<(TF_SoftplusOp F32Tensor:$arg0),
(TFL_LogOp (TFL_AddOp (TFL_ExpOp $arg0),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">),
TFL_AF_None))>;
def LegalizeSqueeze : Pat<(TF_SqueezeOp $arg, $squeeze_dims),
(TFL_SqueezeOp $arg, $squeeze_dims)>;
def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm),
(TFL_TransposeOp $arg, $perm)>;
def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
//===----------------------------------------------------------------------===//
// Binary ops patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_LessOp $l, $r), (TFL_LessOp $l, $r)>;
def : Pat<(TF_GreaterOp $l, $r), (TFL_GreaterOp $l, $r)>;
def LegalizeLess : Pat<(TF_LessOp $l, $r), (TFL_LessOp $l, $r)>;
def LegalizeGreater : Pat<(TF_GreaterOp $l, $r), (TFL_GreaterOp $l, $r)>;
def : Pat<(TF_LessEqualOp $l, $r), (TFL_LessEqualOp $l, $r)>;
def : Pat<(TF_GreaterEqualOp $l, $r), (TFL_GreaterEqualOp $l, $r)>;
def LegalizeLessEqual : Pat<(TF_LessEqualOp $l, $r), (TFL_LessEqualOp $l, $r)>;
def LegalizeGreaterEqual : Pat<(TF_GreaterEqualOp $l, $r),
(TFL_GreaterEqualOp $l, $r)>;
// Gather in TF -> Gather in TFL with axis=0
// The 'validate_indices' attribute is deprecated.
def : Pat<(TF_GatherOp $params, $indices, $ignored_validate_indices),
(TFL_GatherOp $params, $indices, ConstantAttr<I32Attr, "0">)>;
def LegalizeGather: Pat<
(TF_GatherOp $params, $indices, $ignored_validate_indices),
(TFL_GatherOp $params, $indices, ConstantAttr<I32Attr, "0">)>;
def : Pat<(TF_GatherNdOp $params, $indices),
(TFL_GatherNdOp $params, $indices)>;
def LegalizeGatherNd : Pat<(TF_GatherNdOp $params, $indices),
(TFL_GatherNdOp $params, $indices)>;
def : Pat<(TF_GatherV2Op $params, $indices,
(ConstantOp ElementsAttr:$axis),
ConstantAttr<I64Attr, "0">:$batch_dims),
(TFL_GatherOp $params, $indices,
ExtractSingleElementAsInt32:$axis)>;
def LegalizeGatherV2 : Pat<
(TF_GatherV2Op $params, $indices, (ConstantOp ElementsAttr:$axis),
ConstantAttr<I64Attr, "0">:$batch_dims),
(TFL_GatherOp $params, $indices, ExtractSingleElementAsInt32:$axis)>;
def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
def LegalizeFloorDiv : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
def : Pat<(TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue),
(TFL_NotEqualOp $l, $r)>;
def LegalizeNotEqual : Pat<
(TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue),
(TFL_NotEqualOp $l, $r)>;
def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
def LegalizeLogicalAnd : Pat<(TF_LogicalAndOp $l, $r),
(TFL_LogicalAndOp $l, $r)>;
def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>;
def LegalizeLogicalOr : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>;
def LegalizeAdd : Pat<(TF_AddOp $lhs, $rhs),
(TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs),
(TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeBiasAdd : Pat<
(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format),
(TFL_AddOp $l, $r, TFL_AF_None)>;
def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs),
(TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs),
(TFL_MulOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeRealDiv : Pat<(TF_RealDivOp $lhs, $rhs),
(TFL_DivOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeDiv : Pat<(TF_DivOp $lhs, $rhs),
(TFL_DivOp $lhs, $rhs, TFL_AF_None)>;
def : Pat<(TF_AddOp $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
def : Pat<(TF_AddV2Op $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>;
// When batch size is known, TF BatchMatMul gets unfolded to TFL FullyConnected
// with additional ops. In the case of unknown batch size, the match will
// fall through to here and convert to TF Lite BatchMatMul.
def : Pat<(TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>;
def : Pat<(TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>;
def : Pat<(TF_SubOp $lhs, $rhs), (TFL_SubOp $lhs, $rhs, TFL_AF_None)>;
def : Pat<(TF_MulOp $lhs, $rhs), (TFL_MulOp $lhs, $rhs, TFL_AF_None)>;
def : Pat<(TF_RealDivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>;
def : Pat<(TF_DivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>;
def LegalizeBatchMatMulV2UnknownBatch : Pat<
(TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y),
(TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>;
def LegalizeBatchMatMulUnknownBatch : Pat<
(TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y),
(TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>;
def : Pat<(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r,
IsDataFormatNHWC:$data_format),
(TFL_AddOp $l, $r, TFL_AF_None)>;
// TODO(jpienaar): These should be handled by the pattern rewriter, find out
// why it isn't.
def : Pat<(TF_Relu6Op (TF_BiasAddOp F32Tensor:$l, F32Tensor:$r,
IsDataFormatNHWC:$data_format)),
(TFL_AddOp $l, $r, TFL_AF_Relu6)>;
def : Pat<(TF_FakeQuantWithMinMaxVarsOp $inputs,
(ConstantOp F32ElementsAttr:$min),
(ConstantOp F32ElementsAttr:$max),
$num_bits, $narrow_range),
(TFL_DequantizeOp
(TFL_QuantizeOp $inputs,
(ConvertToQuantTypeFromAttrs $inputs, $min, $max,
$num_bits, $narrow_range)))>;
def LegalizeFakeQuantWithMinMaxVars: Pat<
(TF_FakeQuantWithMinMaxVarsOp $inputs, (ConstantOp F32ElementsAttr:$min),
(ConstantOp F32ElementsAttr:$max), $num_bits, $narrow_range),
(TFL_DequantizeOp
(TFL_QuantizeOp $inputs, (ConvertToQuantTypeFromAttrs $inputs, $min, $max,
$num_bits, $narrow_range)))>;
// TODO(rocky): Not all of the attributes are handled correctly. Make this
// more general if there is a need.
def : Pat<(TF_QuantizeAndDequantizeV2Op $inputs,
(ConstantOp F32ElementsAttr:$min),
(ConstantOp F32ElementsAttr:$max),
$signed_input, $num_bits, $range_given, $round_mode,
$narrow_range, $axis),
(TFL_DequantizeOp
(TFL_QuantizeOp $inputs,
(ConvertToQuantTypeFromAttrs $inputs, $min, $max,
$num_bits, $narrow_range)))>;
def LegalizeQuantizeAndDequantizeV2 : Pat<
(TF_QuantizeAndDequantizeV2Op $inputs, (ConstantOp F32ElementsAttr:$min),
(ConstantOp F32ElementsAttr:$max),
$signed_input, $num_bits, $range_given, $round_mode, $narrow_range, $axis),
(TFL_DequantizeOp
(TFL_QuantizeOp $inputs, (ConvertToQuantTypeFromAttrs $inputs, $min, $max,
$num_bits, $narrow_range)))>;
def : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
def LegalizeRank : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
def : Pat<(TF_SquaredDifferenceOp $l, $r), (TFL_SquaredDifferenceOp $l, $r)>;
def LegalizeSquaredDifference : Pat<(TF_SquaredDifferenceOp $l, $r),
(TFL_SquaredDifferenceOp $l, $r)>;
// Note(ycling): We can eliminate Relu from Relu(SquaredDifference(x, y)),
// since the result of SquaredDifference is always non-negative.
// TFLite interpreter doesn't support Relu+int32 for now. So the test cases
// are failing without the following pattern to optimize Relu away fixes
// the problem.
def : Pat<(TF_ReluOp (TF_SquaredDifferenceOp $l, $r)),
(TFL_SquaredDifferenceOp $l, $r)>;
def LegalizeReverseV2 : Pat<(TF_ReverseV2Op $arg0, $arg1),
(TFL_ReverseV2Op $arg0, $arg1)>;
def : Pat<(TF_ReverseV2Op $arg0, $arg1), (TFL_ReverseV2Op $arg0, $arg1)>;
def LegalizeEqual : Pat<(TF_EqualOp $arg0, $arg1,
/*incompatible_shape_error=*/ConstBoolAttrTrue),
(TFL_EqualOp $arg0, $arg1)>;
def : Pat<(TF_EqualOp $arg0, $arg1, /*incompatible_shape_error=*/ConstBoolAttrTrue), (TFL_EqualOp $arg0, $arg1)>;
def LegalizePad : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>;
def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>;
def LegalizeTile : Pat<(TF_TileOp $arg0, $arg1), (TFL_TileOp $arg0, $arg1)>;
def : Pat<(TF_TileOp $arg0, $arg1), (TFL_TileOp $arg0, $arg1)>;
def LegalizePadV2 : Pat<(TF_PadV2Op $arg0, $arg1, $cst),
(TFL_PadV2Op $arg0, $arg1, $cst)>;
def : Pat<(TF_PadV2Op $arg0, $arg1, $cst), (TFL_PadV2Op $arg0, $arg1, $cst)>;
def LegalizeMean : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2),
(TFL_MeanOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), (TFL_SumOp $arg, $axes, $arg2)>;
def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2),
(TFL_SumOp $arg, $axes, $arg2)>;
// TopK in TFL is always sorted so we ignore that attribute here.
def : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>;
def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted),
(TFL_TopKV2Op $input, $k)>;
def : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
def LegalizeMin : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2),
(TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
def LegalizeMax : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2),
(TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
def LegalizeProd : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2),
(TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
def : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims),
(TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>;
def LegalizeAny : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims),
(TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>;
def : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
def LegalizeBatchToSpaceND : Pat<
(TF_BatchToSpaceNDOp $input, $block_shape, $crops),
(TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
def LegalizeSpaceToBatchND : Pat<
(TF_SpaceToBatchNDOp $input, $block_shape, $paddings),
(TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
def : Pat<(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format),
(TFL_SpaceToDepthOp $input, (convertIntAttrTo32Bit $block_size))>;
def LegalizeSpaceToDepth : Pat<
(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format),
(TFL_SpaceToDepthOp $input, (convertIntAttrTo32Bit $block_size))>;
def : Pat<(TF_DepthToSpaceOp $input, $block_size, IsDataFormatNHWC:$data_format),
(TFL_DepthToSpaceOp $input, (convertIntAttrTo32Bit $block_size))>;
def LegalizeDepthToSpace : Pat<
(TF_DepthToSpaceOp $input, $block_size, IsDataFormatNHWC:$data_format),
(TFL_DepthToSpaceOp $input, (convertIntAttrTo32Bit $block_size))>;
def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers)>;
def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers)>;
def LegalizeResizeBilinear : Pat<
(TF_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers),
(TFL_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers)>;
def LegalizeResizeNearestNeighbor : Pat<
(TF_ResizeNearestNeighborOp $images, $size, $align_corners,
$half_pixel_centers),
(TFL_ResizeNearestNeighborOp $images, $size, $align_corners,
$half_pixel_centers)>;
def : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), (TFL_MirrorPadOp $arg0, $arg1, $cst)>;
def LegalizeMirrorPad : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst),
(TFL_MirrorPadOp $arg0, $arg1, $cst)>;
def : Pat<(TF_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, $default_value, $validate_indices),
(TFL_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, $default_value)>;
def LegalizeSparseToDense : Pat<
(TF_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values,
$default_value, $validate_indices),
(TFL_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values,
$default_value)>;
def : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>;
def LegalizeUnique : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>;
def : Pat<(TF_FloorModOp $arg0, $arg1), (TFL_FloorModOp $arg0, $arg1)>;
def : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>;
def LegalizeFloorMod : Pat<(TF_FloorModOp $arg0, $arg1),
(TFL_FloorModOp $arg0, $arg1)>;
def LegalizeExp : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>;
def : Pat<(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), $bias, $alpha, $beta)>;
def LegalizeLRN : Pat<
(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta),
(TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius),
$bias, $alpha, $beta)>;
def : Pat<
(TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $pad_to_max_output_size),
(TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold)>;
def LegalizeNonMaxSuppressionV4 : Pat<
(TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold,
$score_threshold, $pad_to_max_output_size),
(TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold,
$score_threshold)>;
def : Pat<
(TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma, $pad_to_max_output_size),
(TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma)>;
def LegalizeNonMaxSuppressionV5 : Pat<
(TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold,
$score_threshold, $soft_nms_sigma, $pad_to_max_output_size),
(TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold,
$score_threshold, $soft_nms_sigma)>;
def : Pat<(TF_MatrixDiagOp $diagonal), (TFL_MatrixDiagOp $diagonal)>;
def LegalizeMatrixDiag : Pat<(TF_MatrixDiagOp $diagonal),
(TFL_MatrixDiagOp $diagonal)>;
class I32VectorElementsAttr<int len> : ElementsAttrBase<
CPred<"$_self.isa<DenseIntElementsAttr>() &&"
@ -356,7 +406,7 @@ class I32VectorElementsAttr<int len> : ElementsAttrBase<
"RankedTensorType::get({" # len # "}, $_builder.getIntegerType(32)), $0)";
}
def : Pat<
def LegalizeConv2DBackpropInput : Pat<
(TF_Conv2DBackpropInputOp $input_sizes, $filter, $out_backprop,
IsIntList1XY1:$strides,
BoolAttr:$use_cudnn_on_gpu,
@ -373,9 +423,10 @@ def : Pat<
/*stride_h=*/ ExtractI32At<1>:$strides,
/*stride_w=*/ ExtractI32At<2>:$strides)>;
def : Pat<
def LegalizeMatrixSetDiag : Pat<
(TF_MatrixSetDiagOp $input, $diagonal),
(TFL_MatrixSetDiagOp $input, $diagonal)>;
def : Pat<(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
def LegalizeScatterNd : Pat<
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;

View File

@ -485,4 +485,11 @@ foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
[(HasOneUse $first_output)]>;
}
// We can eliminate Relu from Relu(SquaredDifference(x, y)),
// since the result of SquaredDifference is always non-negative.
// TFLite interpreter doesn't support Relu+int32 for now. So the test cases
// are failing without the following pattern to optimize Relu away fixes
// the problem.
def OptimizeReluSquaredDifference : Pat<
(TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)),
(TFL_SquaredDifferenceOp $l, $r)>;