Transform rfft to rfft2d & fix rfft2d kernel for height == 1 case.

PiperOrigin-RevId: 327368554
Change-Id: I625f2b75b2e1b762b0536380e4e57c7377eb5c59
This commit is contained in:
Renjie Liu 2020-08-18 21:38:44 -07:00 committed by TensorFlower Gardener
parent 3b9cb438e5
commit f04a2215fa
4 changed files with 118 additions and 4 deletions

View File

@ -615,4 +615,18 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
// CHECK-LABEL: lower_rfft_to_rfft2d
func @lower_rfft_to_rfft2d(%input: tensor<10x20x30xf32>, %fft_len: tensor<1xi32>) -> tensor<10x20x30xcomplex<f64>> {
%0 = "tf.RFFT"(%input, %fft_len) : (tensor<10x20x30xf32>, tensor<1xi32>) -> tensor<10x20x30xcomplex<f64>>
return %0: tensor<10x20x30xcomplex<f64>>
// CHECK: %[[CST:.*]] = constant dense<-2> : tensor<i32>
// CHECK: %[[CST0:.*]] = constant dense<1> : tensor<1xi32>
// CHECK: %[[CST1:.*]] = constant dense<0> : tensor<i32>
// CHECK: %[[EXP:.*]] = "tf.ExpandDims"(%arg0, %[[CST]]) : (tensor<10x20x30xf32>, tensor<i32>) -> tensor<10x20x1x30xf32>
// CHECK: %[[CON:.*]] = "tf.ConcatV2"(%[[CST0]], %arg1, %[[CST1]]) : (tensor<1xi32>, tensor<1xi32>, tensor<i32>) -> tensor<2xi32>
// CHECK: %[[RFF:.*]] = "tf.RFFT2D"(%[[EXP]], %[[CON]]) : (tensor<10x20x1x30xf32>, tensor<2xi32>) -> tensor<10x20x1x30xcomplex<f64>>
// CHECK: %[[SQE:.*]] = "tf.Squeeze"(%[[RFF]]) {squeeze_dims = [-2]} : (tensor<10x20x1x30xcomplex<f64>>) -> tensor<10x20x30xcomplex<f64>>
}
}

View File

@ -762,6 +762,102 @@ LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
return applyPartialConversion(func, target, patterns);
}
// Convert rfft to rfft2d.
// The transformation pattern looks like below:
//
// input fft_len
// \ /
// rfft
//
// ||
// \/
//
// input fft_len
// \ /
// expand_dim concat with [1] at the front
// \ /
// rfft_2d
// |
// squeeze
struct ConvertRfftToRfft2d : public RewritePattern {
explicit ConvertRfftToRfft2d(MLIRContext *context)
: RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto rfft_op = dyn_cast<TF::RFFTOp>(op);
auto input = rfft_op.input();
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
if (!input_type) return failure();
auto fft_len = rfft_op.fft_length();
auto fft_len_type = fft_len.getType().dyn_cast_or_null<ShapedType>();
if (!fft_len_type) return failure();
auto output_type =
rfft_op.getResult().getType().dyn_cast_or_null<RankedTensorType>();
if (!output_type) return failure();
// Expanded inputs.
// Insert at -2 location.
auto one_ele_type =
mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32));
auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
one_ele_type, -2);
SmallVector<int64_t, 4> expanded_input_shape;
SmallVector<int64_t, 4> expanded_output_shape;
int expanded_rank = input_type.getRank() + 1;
int r = 0;
for (int i = 0; i < expanded_rank; ++i) {
if (i == expanded_rank - 2) {
expanded_input_shape.push_back(1);
expanded_output_shape.push_back(1);
} else {
expanded_input_shape.push_back(input_type.getDimSize(r));
expanded_output_shape.push_back(output_type.getDimSize(r));
r++;
}
}
auto expaned_input_type = mlir::RankedTensorType::get(
expanded_input_shape, input_type.getElementType());
TF::ExpandDimsOp expanded_input = rewriter.create<TF::ExpandDimsOp>(
rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult());
// Expanded fft_len.
auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1});
auto one = rewriter.create<TF::ConstOp>(rfft_op.getLoc(), one_attr);
auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(),
one_ele_type, 0);
auto expanded_fft_len_type =
mlir::RankedTensorType::get({2}, fft_len_type.getElementType());
TF::ConcatV2Op expanded_fft_len = rewriter.create<TF::ConcatV2Op>(
rfft_op.getLoc(), expanded_fft_len_type,
SmallVector<Value, 2>({one.getResult(), fft_len}), zero->getResult());
// Insert the rfft_2d.
auto rfft2d_out_type = mlir::RankedTensorType::get(
expanded_output_shape, output_type.getElementType());
TF::RFFT2DOp rfft2d = rewriter.create<TF::RFFT2DOp>(
rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(),
expanded_fft_len.getResult());
// Insert the squeeze op.
auto squeeze_dim = rewriter.getI64ArrayAttr({-2});
TF::SqueezeOp squeeze = rewriter.create<TF::SqueezeOp>(
rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim);
rewriter.replaceOp(op, squeeze.getResult());
return success();
}
};
void PrepareTFPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
@ -811,7 +907,8 @@ void PrepareTFPass::runOnFunction() {
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
}
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFConv2D,
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice,
ConvertRfftToRfft2d>(ctx);
applyPatternsAndFoldGreedily(func, patterns);
}

View File

@ -248,13 +248,15 @@ void Rfft2dReorder(int fft_height, int fft_width, double** fft_input_output) {
fft_input_output[i][0] = fft_input_output[fft_height - i][0];
fft_input_output[i][1] = -fft_input_output[fft_height - i][1];
}
fft_input_output[0][fft_width] = fft_input_output[0][1];
double temp = fft_input_output[0][1];
fft_input_output[0][fft_width + 1] = 0;
fft_input_output[0][1] = 0;
fft_input_output[fft_height_half][fft_width] =
fft_input_output[fft_height_half][1];
fft_input_output[fft_height_half][fft_width + 1] = 0;
fft_input_output[fft_height_half][1] = 0;
fft_input_output[0][fft_width] = temp;
// Reorder the frequency matrix from
// [[F(0, 0), F(0, -1/4), F(0, -2/4)],

View File

@ -30,9 +30,10 @@ def make_rfft2d_tests(options):
test_parameters = [{
"input_dtype": [tf.float32],
"input_shape": [[8, 8], [3, 8, 8]],
"input_shape": [[8, 8], [3, 8, 8], [3, 1, 16]],
"fft_length": [
None, [4, 4], [4, 8], [8, 4], [8, 8], [8, 16], [16, 8], [16, 16]
None, [4, 4], [4, 8], [8, 4], [8, 8], [8, 16], [16, 8], [16, 16],
[1, 8], [1, 16]
]
}]