Transform cast_complex -> fft to rfft directly

PiperOrigin-RevId: 344934048
Change-Id: I419df9d316bbbd1e60c73e680f41c6c71ee8647b
This commit is contained in:
Renjie Liu 2020-11-30 19:16:48 -08:00 committed by TensorFlower Gardener
parent cc6440591e
commit 4e17145925
2 changed files with 43 additions and 0 deletions

View File

@ -33,3 +33,22 @@ func @convaddv2mul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> {
// CHECK-NEXT: %[[add:.*]] = "tf.AddV2"(%[[conv]], %[[cst_0]])
// CHECK-NEXT: return %[[add]] : tensor<256x30x30x16xf32>
}
// CHECK-LABEL: fold_cast_fft_to_rfft
func @fold_cast_fft_to_rfft(%arg0: tensor<10x20x30xf32>) -> tensor<10x20x30xcomplex<f32>> {
%0 = "tf.Cast"(%arg0) : (tensor<10x20x30xf32>) -> tensor<10x20x30xcomplex<f32>>
%1 = "tf.FFT"(%0) : (tensor<10x20x30xcomplex<f32>>) -> tensor<10x20x30xcomplex<f32>>
return %1: tensor<10x20x30xcomplex<f32>>
// CHECK: %[[cst:.*]] = constant dense<30> : tensor<1xi32>
// CHECK: %[[rff:.*]] = "tf.RFFT"(%arg0, %[[cst]]) : (tensor<10x20x30xf32>, tensor<1xi32>) -> tensor<10x20x30xcomplex<f32>>
}
// CHECK-LABEL: not_fold_cast_fft_to_rfft
func @not_fold_cast_fft_to_rfft(%arg0: tensor<10x20x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f32>> {
%0 = "tf.Cast"(%arg0) : (tensor<10x20x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f32>>
%1 = "tf.FFT"(%0) : (tensor<10x20x30xcomplex<f32>>) -> tensor<10x20x30xcomplex<f32>>
return %1: tensor<10x20x30xcomplex<f32>>
// CHECK: %[[fft:.*]] = "tf.FFT"(%0) : (tensor<10x20x30xcomplex<f32>>) -> tensor<10x20x30xcomplex<f32>>
}

View File

@ -18,6 +18,23 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
// Get the last dimension size as a 1-d single element attr.
def GetLastDimSizeAsI32 : NativeCodeCall<
"DenseElementsAttr::get(RankedTensorType::get({1}, $_builder.getIntegerType(32)), "
"static_cast<int32_t>($0.getType().cast<RankedTensorType>().getDimSize( "
" $0.getType().cast<RankedTensorType>().getRank() - 1)))">;
// Check whether the tensor is ranked and whether its last dim is static.
def IsRankedShapeLastDimStatic : Constraint<And<[
CPred<"$0.getType().isa<RankedTensorType>()">,
CPred<"!$0.getType().cast<ShapedType>().isDynamicDim( "
" $0.getType().cast<RankedTensorType>().getRank() - 1)">]>>;
def IsNotComplexType : Constraint<And<[
CPred<"$0.getType().isa<RankedTensorType>()">,
CPred<"!$0.getType().cast<ShapedType>().getElementType().isa<ComplexType>()">
]>>;
// Only fuse multiplier if all dimensions other than the channel dimension
// are equal to 1.
def CanFuseMulAndConv2D :
@ -85,3 +102,10 @@ def PassthroughMulAndAddV2 :
(TF_MulOp $input, (ConstantOp $value)),
(TF_MulOp (ConstantOp $bias), (ConstantOp $value))),
[(DefinedByConv2D $input), (HasOneUse $output)]>;
// input -> cast -> FFT => input -> RFFT
def ConvertCastComplexFFTToRFFT: Pat<
(TF_FFTOp (TF_CastOp $input, ConstBoolAttrFalse)),
(TF_RFFTOp $input,
(ConstantOp (GetLastDimSizeAsI32 $input))),
[(IsRankedShapeLastDimStatic $input), (IsNotComplexType $input)]>;