Transform rfft to rfft2d & fix rfft2d kernel for height == 1 case.
PiperOrigin-RevId: 327368554 Change-Id: I625f2b75b2e1b762b0536380e4e57c7377eb5c59
This commit is contained in:
parent
3b9cb438e5
commit
f04a2215fa
@ -615,4 +615,18 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: lower_rfft_to_rfft2d
|
||||
func @lower_rfft_to_rfft2d(%input: tensor<10x20x30xf32>, %fft_len: tensor<1xi32>) -> tensor<10x20x30xcomplex<f64>> {
|
||||
%0 = "tf.RFFT"(%input, %fft_len) : (tensor<10x20x30xf32>, tensor<1xi32>) -> tensor<10x20x30xcomplex<f64>>
|
||||
return %0: tensor<10x20x30xcomplex<f64>>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<-2> : tensor<i32>
|
||||
// CHECK: %[[CST0:.*]] = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: %[[CST1:.*]] = constant dense<0> : tensor<i32>
|
||||
// CHECK: %[[EXP:.*]] = "tf.ExpandDims"(%arg0, %[[CST]]) : (tensor<10x20x30xf32>, tensor<i32>) -> tensor<10x20x1x30xf32>
|
||||
// CHECK: %[[CON:.*]] = "tf.ConcatV2"(%[[CST0]], %arg1, %[[CST1]]) : (tensor<1xi32>, tensor<1xi32>, tensor<i32>) -> tensor<2xi32>
|
||||
// CHECK: %[[RFF:.*]] = "tf.RFFT2D"(%[[EXP]], %[[CON]]) : (tensor<10x20x1x30xf32>, tensor<2xi32>) -> tensor<10x20x1x30xcomplex<f64>>
|
||||
// CHECK: %[[SQE:.*]] = "tf.Squeeze"(%[[RFF]]) {squeeze_dims = [-2]} : (tensor<10x20x1x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f64>>
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -762,6 +762,102 @@ LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
|
||||
return applyPartialConversion(func, target, patterns);
|
||||
}
|
||||
|
||||
// Convert rfft to rfft2d.
|
||||
// The transformation pattern looks like below:
|
||||
//
|
||||
// input fft_len
|
||||
// \ /
|
||||
// rfft
|
||||
//
|
||||
// ||
|
||||
// \/
|
||||
//
|
||||
// input fft_len
|
||||
// \ /
|
||||
// expand_dim concat with [1] at the front
|
||||
// \ /
|
||||
// rfft_2d
|
||||
// |
|
||||
// squeeze
|
||||
struct ConvertRfftToRfft2d : public RewritePattern {
|
||||
explicit ConvertRfftToRfft2d(MLIRContext *context)
|
||||
: RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto rfft_op = dyn_cast<TF::RFFTOp>(op);
|
||||
|
||||
auto input = rfft_op.input();
|
||||
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!input_type) return failure();
|
||||
auto fft_len = rfft_op.fft_length();
|
||||
auto fft_len_type = fft_len.getType().dyn_cast_or_null<ShapedType>();
|
||||
if (!fft_len_type) return failure();
|
||||
|
||||
auto output_type =
|
||||
rfft_op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!output_type) return failure();
|
||||
|
||||
// Expanded inputs.
|
||||
// Insert at -2 location.
|
||||
auto one_ele_type =
|
||||
mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32));
|
||||
auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
|
||||
one_ele_type, -2);
|
||||
|
||||
SmallVector<int64_t, 4> expanded_input_shape;
|
||||
SmallVector<int64_t, 4> expanded_output_shape;
|
||||
int expanded_rank = input_type.getRank() + 1;
|
||||
int r = 0;
|
||||
for (int i = 0; i < expanded_rank; ++i) {
|
||||
if (i == expanded_rank - 2) {
|
||||
expanded_input_shape.push_back(1);
|
||||
expanded_output_shape.push_back(1);
|
||||
} else {
|
||||
expanded_input_shape.push_back(input_type.getDimSize(r));
|
||||
expanded_output_shape.push_back(output_type.getDimSize(r));
|
||||
r++;
|
||||
}
|
||||
}
|
||||
|
||||
auto expaned_input_type = mlir::RankedTensorType::get(
|
||||
expanded_input_shape, input_type.getElementType());
|
||||
TF::ExpandDimsOp expanded_input = rewriter.create<TF::ExpandDimsOp>(
|
||||
rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult());
|
||||
|
||||
// Expanded fft_len.
|
||||
auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1});
|
||||
|
||||
auto one = rewriter.create<TF::ConstOp>(rfft_op.getLoc(), one_attr);
|
||||
|
||||
auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
|
||||
one_ele_type, 0);
|
||||
|
||||
auto expanded_fft_len_type =
|
||||
mlir::RankedTensorType::get({2}, fft_len_type.getElementType());
|
||||
|
||||
TF::ConcatV2Op expanded_fft_len = rewriter.create<TF::ConcatV2Op>(
|
||||
rfft_op.getLoc(), expanded_fft_len_type,
|
||||
SmallVector<Value, 2>({one.getResult(), fft_len}), zero->getResult());
|
||||
|
||||
// Insert the rfft_2d.
|
||||
auto rfft2d_out_type = mlir::RankedTensorType::get(
|
||||
expanded_output_shape, output_type.getElementType());
|
||||
TF::RFFT2DOp rfft2d = rewriter.create<TF::RFFT2DOp>(
|
||||
rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(),
|
||||
expanded_fft_len.getResult());
|
||||
|
||||
// Insert the squeeze op.
|
||||
auto squeeze_dim = rewriter.getI64ArrayAttr({-2});
|
||||
TF::SqueezeOp squeeze = rewriter.create<TF::SqueezeOp>(
|
||||
rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim);
|
||||
|
||||
rewriter.replaceOp(op, squeeze.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void PrepareTFPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
@ -811,7 +907,8 @@ void PrepareTFPass::runOnFunction() {
|
||||
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
|
||||
}
|
||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFConv2D,
|
||||
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
||||
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice,
|
||||
ConvertRfftToRfft2d>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
|
@ -248,13 +248,15 @@ void Rfft2dReorder(int fft_height, int fft_width, double** fft_input_output) {
|
||||
fft_input_output[i][0] = fft_input_output[fft_height - i][0];
|
||||
fft_input_output[i][1] = -fft_input_output[fft_height - i][1];
|
||||
}
|
||||
fft_input_output[0][fft_width] = fft_input_output[0][1];
|
||||
|
||||
double temp = fft_input_output[0][1];
|
||||
fft_input_output[0][fft_width + 1] = 0;
|
||||
fft_input_output[0][1] = 0;
|
||||
fft_input_output[fft_height_half][fft_width] =
|
||||
fft_input_output[fft_height_half][1];
|
||||
fft_input_output[fft_height_half][fft_width + 1] = 0;
|
||||
fft_input_output[fft_height_half][1] = 0;
|
||||
fft_input_output[0][fft_width] = temp;
|
||||
|
||||
// Reorder the frequency matrix from
|
||||
// [[F(0, 0), F(0, -1/4), F(0, -2/4)],
|
||||
|
@ -30,9 +30,10 @@ def make_rfft2d_tests(options):
|
||||
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32],
|
||||
"input_shape": [[8, 8], [3, 8, 8]],
|
||||
"input_shape": [[8, 8], [3, 8, 8], [3, 1, 16]],
|
||||
"fft_length": [
|
||||
None, [4, 4], [4, 8], [8, 4], [8, 8], [8, 16], [16, 8], [16, 16]
|
||||
None, [4, 4], [4, 8], [8, 4], [8, 8], [8, 16], [16, 8], [16, 16],
|
||||
[1, 8], [1, 16]
|
||||
]
|
||||
}]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user