Transform cast_complex -> fft to rfft directly
PiperOrigin-RevId: 344934048 Change-Id: I419df9d316bbbd1e60c73e680f41c6c71ee8647b
This commit is contained in:
parent
cc6440591e
commit
4e17145925
@ -33,3 +33,22 @@ func @convaddv2mul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> {
|
||||
// CHECK-NEXT: %[[add:.*]] = "tf.AddV2"(%[[conv]], %[[cst_0]])
|
||||
// CHECK-NEXT: return %[[add]] : tensor<256x30x30x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_cast_fft_to_rfft
|
||||
func @fold_cast_fft_to_rfft(%arg0: tensor<10x20x30xf32>) -> tensor<10x20x30xcomplex<f32>> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<10x20x30xf32>) -> tensor<10x20x30xcomplex<f32>>
|
||||
%1 = "tf.FFT"(%0) : (tensor<10x20x30xcomplex<f32>>) -> tensor<10x20x30xcomplex<f32>>
|
||||
return %1: tensor<10x20x30xcomplex<f32>>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<30> : tensor<1xi32>
|
||||
// CHECK: %[[rff:.*]] = "tf.RFFT"(%arg0, %[[cst]]) : (tensor<10x20x30xf32>, tensor<1xi32>) -> tensor<10x20x30xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: not_fold_cast_fft_to_rfft
|
||||
func @not_fold_cast_fft_to_rfft(%arg0: tensor<10x20x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f32>> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<10x20x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f32>>
|
||||
%1 = "tf.FFT"(%0) : (tensor<10x20x30xcomplex<f32>>) -> tensor<10x20x30xcomplex<f32>>
|
||||
return %1: tensor<10x20x30xcomplex<f32>>
|
||||
|
||||
// CHECK: %[[fft:.*]] = "tf.FFT"(%0) : (tensor<10x20x30xcomplex<f32>>) -> tensor<10x20x30xcomplex<f32>>
|
||||
}
|
||||
|
@ -18,6 +18,23 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
|
||||
|
||||
// Get the last dimension size as a 1-d single element attr.
|
||||
def GetLastDimSizeAsI32 : NativeCodeCall<
|
||||
"DenseElementsAttr::get(RankedTensorType::get({1}, $_builder.getIntegerType(32)), "
|
||||
"static_cast<int32_t>($0.getType().cast<RankedTensorType>().getDimSize( "
|
||||
" $0.getType().cast<RankedTensorType>().getRank() - 1)))">;
|
||||
|
||||
// Check whether the tensor is ranked and whether its last dim is static.
|
||||
def IsRankedShapeLastDimStatic : Constraint<And<[
|
||||
CPred<"$0.getType().isa<RankedTensorType>()">,
|
||||
CPred<"!$0.getType().cast<ShapedType>().isDynamicDim( "
|
||||
" $0.getType().cast<RankedTensorType>().getRank() - 1)">]>>;
|
||||
|
||||
def IsNotComplexType : Constraint<And<[
|
||||
CPred<"$0.getType().isa<RankedTensorType>()">,
|
||||
CPred<"!$0.getType().cast<ShapedType>().getElementType().isa<ComplexType>()">
|
||||
]>>;
|
||||
|
||||
// Only fuse multiplier if all dimensions other than the channel dimension
|
||||
// are equal to 1.
|
||||
def CanFuseMulAndConv2D :
|
||||
@ -85,3 +102,10 @@ def PassthroughMulAndAddV2 :
|
||||
(TF_MulOp $input, (ConstantOp $value)),
|
||||
(TF_MulOp (ConstantOp $bias), (ConstantOp $value))),
|
||||
[(DefinedByConv2D $input), (HasOneUse $output)]>;
|
||||
|
||||
// input -> cast -> FFT => input -> RFFT
|
||||
def ConvertCastComplexFFTToRFFT: Pat<
|
||||
(TF_FFTOp (TF_CastOp $input, ConstBoolAttrFalse)),
|
||||
(TF_RFFTOp $input,
|
||||
(ConstantOp (GetLastDimSizeAsI32 $input))),
|
||||
[(IsRankedShapeLastDimStatic $input), (IsNotComplexType $input)]>;
|
||||
|
Loading…
Reference in New Issue
Block a user