Transform cast_complex -> fft to rfft directly
PiperOrigin-RevId: 344934048 Change-Id: I419df9d316bbbd1e60c73e680f41c6c71ee8647b
This commit is contained in:
		
							parent
							
								
									cc6440591e
								
							
						
					
					
						commit
						4e17145925
					
				@ -33,3 +33,22 @@ func @convaddv2mul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> {
 | 
			
		||||
// CHECK-NEXT: %[[add:.*]] = "tf.AddV2"(%[[conv]], %[[cst_0]])
 | 
			
		||||
// CHECK-NEXT: return %[[add]] : tensor<256x30x30x16xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: fold_cast_fft_to_rfft
 | 
			
		||||
func @fold_cast_fft_to_rfft(%arg0: tensor<10x20x30xf32>) -> tensor<10x20x30xcomplex<f32>> {
 | 
			
		||||
  %0 = "tf.Cast"(%arg0) : (tensor<10x20x30xf32>) -> tensor<10x20x30xcomplex<f32>>
 | 
			
		||||
  %1 = "tf.FFT"(%0) : (tensor<10x20x30xcomplex<f32>>) -> tensor<10x20x30xcomplex<f32>>
 | 
			
		||||
  return %1: tensor<10x20x30xcomplex<f32>>
 | 
			
		||||
 | 
			
		||||
// CHECK:  %[[cst:.*]] = constant dense<30> : tensor<1xi32>
 | 
			
		||||
// CHECK:  %[[rff:.*]] = "tf.RFFT"(%arg0, %[[cst]]) : (tensor<10x20x30xf32>, tensor<1xi32>) -> tensor<10x20x30xcomplex<f32>>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CHECK-LABEL: not_fold_cast_fft_to_rfft
 | 
			
		||||
func @not_fold_cast_fft_to_rfft(%arg0: tensor<10x20x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f32>> {
 | 
			
		||||
  %0 = "tf.Cast"(%arg0) : (tensor<10x20x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f32>>
 | 
			
		||||
  %1 = "tf.FFT"(%0) : (tensor<10x20x30xcomplex<f32>>) -> tensor<10x20x30xcomplex<f32>>
 | 
			
		||||
  return %1: tensor<10x20x30xcomplex<f32>>
 | 
			
		||||
 | 
			
		||||
// CHECK: %[[fft:.*]] = "tf.FFT"(%0) : (tensor<10x20x30xcomplex<f32>>) -> tensor<10x20x30xcomplex<f32>>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -18,6 +18,23 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 | 
			
		||||
 | 
			
		||||
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
 | 
			
		||||
 | 
			
		||||
// Get the last dimension size as a 1-d single element attr.
 | 
			
		||||
def GetLastDimSizeAsI32 : NativeCodeCall<
 | 
			
		||||
  "DenseElementsAttr::get(RankedTensorType::get({1}, $_builder.getIntegerType(32)), "
 | 
			
		||||
  "static_cast<int32_t>($0.getType().cast<RankedTensorType>().getDimSize(  "
 | 
			
		||||
  "  $0.getType().cast<RankedTensorType>().getRank() - 1)))">;
 | 
			
		||||
 | 
			
		||||
// Check whether the tensor is ranked and whether its last dim is static.
 | 
			
		||||
def IsRankedShapeLastDimStatic : Constraint<And<[
 | 
			
		||||
  CPred<"$0.getType().isa<RankedTensorType>()">,
 | 
			
		||||
  CPred<"!$0.getType().cast<ShapedType>().isDynamicDim( "
 | 
			
		||||
  "  $0.getType().cast<RankedTensorType>().getRank() - 1)">]>>;
 | 
			
		||||
 | 
			
		||||
def IsNotComplexType : Constraint<And<[
 | 
			
		||||
  CPred<"$0.getType().isa<RankedTensorType>()">,
 | 
			
		||||
  CPred<"!$0.getType().cast<ShapedType>().getElementType().isa<ComplexType>()">
 | 
			
		||||
]>>;
 | 
			
		||||
 | 
			
		||||
// Only fuse multiplier if all dimensions other than the channel dimension
 | 
			
		||||
// are equal to 1.
 | 
			
		||||
def CanFuseMulAndConv2D :
 | 
			
		||||
@ -85,3 +102,10 @@ def PassthroughMulAndAddV2 :
 | 
			
		||||
          (TF_MulOp $input, (ConstantOp $value)),
 | 
			
		||||
          (TF_MulOp (ConstantOp $bias), (ConstantOp $value))),
 | 
			
		||||
      [(DefinedByConv2D $input), (HasOneUse $output)]>;
 | 
			
		||||
 | 
			
		||||
// input -> cast -> FFT  => input -> RFFT
 | 
			
		||||
def ConvertCastComplexFFTToRFFT: Pat<
 | 
			
		||||
  (TF_FFTOp (TF_CastOp $input, ConstBoolAttrFalse)),
 | 
			
		||||
  (TF_RFFTOp $input,
 | 
			
		||||
             (ConstantOp (GetLastDimSizeAsI32 $input))),
 | 
			
		||||
  [(IsRankedShapeLastDimStatic $input), (IsNotComplexType $input)]>;
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user