435 lines
17 KiB
C++
435 lines
17 KiB
C++
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include <math.h>
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
#include <string.h>
|
|
|
|
#include <algorithm>
|
|
#include <complex>
|
|
|
|
#include "third_party/fft2d/fft2d.h"
|
|
#include "ruy/profiler/instrumentation.h" // from @ruy
|
|
#include "tensorflow/lite/c/common.h"
|
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
|
#include "tensorflow/lite/kernels/internal/types.h"
|
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
|
|
|
namespace tflite {
|
|
namespace ops {
|
|
namespace custom {
|
|
namespace rfft2d {
|
|
|
|
using std::complex;
|
|
|
|
constexpr int kInputTensor = 0;
|
|
constexpr int kFftLengthTensor = 1;
|
|
constexpr int kOutputTensor = 0;
|
|
constexpr int kFftIntegerWorkingAreaTensor = 0;
|
|
constexpr int kFftDoubleWorkingAreaTensor = 1;
|
|
constexpr int kTensorNotAllocated = -1;
|
|
|
|
struct OpData {
|
|
// IDs are the arbitrary identifiers used by TF Lite to identify and access
|
|
// memory buffers.
|
|
int fft_integer_working_area_id = kTensorNotAllocated;
|
|
int fft_double_working_area_id = kTensorNotAllocated;
|
|
};
|
|
|
|
bool IsPowerOfTwo(uint32_t v) { return v && !(v & (v - 1)); }
|
|
|
|
static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
|
|
TfLiteNode* node) {
|
|
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
|
// The prepare function may be executed multiple times. But temporary tensors
|
|
// only need to be initiated once.
|
|
if (data->fft_integer_working_area_id != kTensorNotAllocated &&
|
|
data->fft_double_working_area_id != kTensorNotAllocated) {
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TfLiteIntArrayFree(node->temporaries);
|
|
// Create two temporary tensors.
|
|
node->temporaries = TfLiteIntArrayCreate(2);
|
|
int first_new_index;
|
|
TF_LITE_ENSURE_STATUS(context->AddTensors(context, 2, &first_new_index));
|
|
node->temporaries->data[kFftIntegerWorkingAreaTensor] = first_new_index;
|
|
data->fft_integer_working_area_id = first_new_index;
|
|
node->temporaries->data[kFftDoubleWorkingAreaTensor] = first_new_index + 1;
|
|
data->fft_double_working_area_id = first_new_index + 1;
|
|
|
|
// Set up FFT integer working area buffer.
|
|
TfLiteTensor* fft_integer_working_area =
|
|
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
|
fft_integer_working_area->type = kTfLiteInt32;
|
|
// If fft_length is not a constant tensor, fft_integer_working_area will be
|
|
// set to dynamic later in Prepare.
|
|
fft_integer_working_area->allocation_type = kTfLiteArenaRw;
|
|
|
|
// Set up FFT double working area buffer.
|
|
TfLiteTensor* fft_double_working_area =
|
|
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
|
// fft_double_working_area is a double tensor. Ideally, double should be
|
|
// added into tflite data types. However, since fft_double_working_area is a
|
|
// temporary tensor, and there are no ops having double input/output tensors
|
|
// in tflite at this point, adding double as a tflite data type may confuse
|
|
// users that double is supported. As a results, kTfLiteInt64 is used here
|
|
// for memory allocation. And it will be cast into double in Eval when being
|
|
// used.
|
|
fft_double_working_area->type = kTfLiteInt64;
|
|
// If fft_length is not a constant tensor, fft_double_working_area will be
|
|
// set to dynamic later in Prepare.
|
|
fft_double_working_area->allocation_type = kTfLiteArenaRw;
|
|
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
|
|
TfLiteNode* node) {
|
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
|
const int num_dims = NumDimensions(input);
|
|
TF_LITE_ENSURE(context, num_dims >= 2);
|
|
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
|
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
|
|
// The lib, fft2d, can only handle fft_lengths of power of 2.
|
|
TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0]));
|
|
TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[1]));
|
|
|
|
int fft_height, fft_width;
|
|
fft_height = fft_length_data[0];
|
|
fft_width = fft_length_data[1];
|
|
int fft_working_length = std::max(fft_height, fft_width / 2);
|
|
int half_fft_working_length = fft_working_length / 2;
|
|
|
|
// Resize output tensor.
|
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
|
|
output_shape->data[num_dims - 2] = fft_length_data[0];
|
|
output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1;
|
|
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
|
|
|
|
// Resize temporary tensors, fft_integer_working_area.
|
|
TfLiteTensor* fft_integer_working_area =
|
|
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
|
TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1);
|
|
fft_integer_working_area_shape->data[0] =
|
|
2 + static_cast<int>(sqrt(fft_working_length));
|
|
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_integer_working_area,
|
|
fft_integer_working_area_shape));
|
|
|
|
// Resize temporary tensors, fft_double_working_area.
|
|
TfLiteTensor* fft_double_working_area =
|
|
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
|
TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1);
|
|
fft_double_working_area_shape->data[0] =
|
|
half_fft_working_length + fft_width / 4;
|
|
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_double_working_area,
|
|
fft_double_working_area_shape));
|
|
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|
auto* data = new OpData;
|
|
return data;
|
|
}
|
|
|
|
void Free(TfLiteContext* context, void* buffer) {
|
|
delete reinterpret_cast<OpData*>(buffer);
|
|
}
|
|
|
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
|
|
|
// Check type and shape of the input tensor
|
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
|
TF_LITE_ENSURE(context, NumDimensions(input) >= 2);
|
|
if (input->type != kTfLiteFloat32) {
|
|
context->ReportError(context,
|
|
"Type '%s' for input is not supported by rfft2d.",
|
|
TfLiteTypeGetName(input->type));
|
|
return kTfLiteError;
|
|
}
|
|
|
|
// Check type and shape of the fft_length tensor
|
|
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
|
const RuntimeShape fft_length_shape = GetTensorShape(fft_length);
|
|
|
|
TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1);
|
|
TF_LITE_ENSURE_EQ(context, fft_length_shape.Dims(0), 2);
|
|
if (fft_length->type != kTfLiteInt32) {
|
|
context->ReportError(context,
|
|
"Type '%s' for fft_length is not supported by rfft2d.",
|
|
TfLiteTypeGetName(fft_length->type));
|
|
return kTfLiteError;
|
|
}
|
|
|
|
// Setup temporary tensors for fft computation.
|
|
TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node));
|
|
|
|
// Set output type
|
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
output->type = kTfLiteComplex64;
|
|
|
|
// Exit early if fft_length is a non-const tensor. Set output tensor and
|
|
// temporary tensors to dynamic, so that their tensor sizes can be determined
|
|
// in Eval.
|
|
if (!IsConstantTensor(fft_length)) {
|
|
TfLiteTensor* fft_integer_working_area =
|
|
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
|
TfLiteTensor* fft_double_working_area =
|
|
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
|
SetTensorToDynamic(fft_integer_working_area);
|
|
SetTensorToDynamic(fft_double_working_area);
|
|
SetTensorToDynamic(output);
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
// Reorder the result so that it matches the pattern of tf.signal.rfft2d.
|
|
// In tf.signal.fft2d the frequency matrix of a 4x4 input is
|
|
// [[F(0, 0), F(0, 1/4), F(0, 2/4)],
|
|
// [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
|
|
// [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
|
|
// [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
|
|
// While in rdft2d, the frequency matrix of a 4x4 input is
|
|
// [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0],
|
|
// [ F(-1/4, 0), F(-1/4, -1/4), 0],
|
|
// [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
|
|
// [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]]
|
|
// Since real fft has the property that
|
|
// Real(u,v) = Real(-u, -v)
|
|
// Img(u,v) = - Img(-u, -v)
|
|
// Result of rdft2d can be reordered and match the pattern of tf.signal.rfft2d.
|
|
// For example,
|
|
// Real(-3/4, 0) = Real(1/4, 0) = Real(-1/4, 0)
|
|
// Img(-3/4, 0) = Img(1/4, 0) = -Img(-1/4, 0)
|
|
void Rfft2dReorder(int fft_height, int fft_width, double** fft_input_output) {
|
|
int fft_height_half;
|
|
ruy::profiler::ScopeLabel label("Rfft2dReorder");
|
|
double real, img;
|
|
|
|
fft_height_half = fft_height >> 1;
|
|
// Use 4x4 input as an example, reorder the frequency matrix from
|
|
// [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0],
|
|
// [ F(-1/4, 0), F(-1/4, -1/4), 0],
|
|
// [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
|
|
// [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]]
|
|
// to
|
|
// [[F(0, 0), F(0, -1/4), F(0, -2/4)],
|
|
// [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
|
|
// [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
|
|
// [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
|
|
for (int i = fft_height_half + 1; i < fft_height; ++i) {
|
|
real = fft_input_output[i][0];
|
|
img = fft_input_output[i][1];
|
|
fft_input_output[i][fft_width] = img;
|
|
fft_input_output[i][fft_width + 1] = real;
|
|
fft_input_output[fft_height - i][fft_width] = img;
|
|
fft_input_output[fft_height - i][fft_width + 1] = -real;
|
|
fft_input_output[i][0] = fft_input_output[fft_height - i][0];
|
|
fft_input_output[i][1] = -fft_input_output[fft_height - i][1];
|
|
}
|
|
fft_input_output[0][fft_width] = fft_input_output[0][1];
|
|
fft_input_output[0][fft_width + 1] = 0;
|
|
fft_input_output[0][1] = 0;
|
|
fft_input_output[fft_height_half][fft_width] =
|
|
fft_input_output[fft_height_half][1];
|
|
fft_input_output[fft_height_half][fft_width + 1] = 0;
|
|
fft_input_output[fft_height_half][1] = 0;
|
|
|
|
// Reorder the frequency matrix from
|
|
// [[F(0, 0), F(0, -1/4), F(0, -2/4)],
|
|
// [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
|
|
// [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
|
|
// [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
|
|
// to
|
|
// [[F(0, 0), F(0, 1/4), F(0, 2/4)],
|
|
// [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
|
|
// [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
|
|
// [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
|
|
for (int i = 0; i < fft_height; ++i) {
|
|
for (int j = 1; j < fft_width + 2; j += 2) {
|
|
fft_input_output[i][j] = -fft_input_output[i][j];
|
|
}
|
|
}
|
|
}
|
|
|
|
void Rfft2dImpl(int fft_height, int fft_width, double** fft_input_output,
|
|
int* fft_integer_working_area_data,
|
|
double* fft_double_working_area_data) {
|
|
ruy::profiler::ScopeLabel label("Rfft2dImpl");
|
|
|
|
// Working data areas for the FFT routines.
|
|
double* fft_dynamic_working_area = nullptr;
|
|
const int kForwardFft = 1;
|
|
rdft2d(fft_height, fft_width, kForwardFft, fft_input_output,
|
|
fft_dynamic_working_area, fft_integer_working_area_data,
|
|
fft_double_working_area_data);
|
|
Rfft2dReorder(fft_height, fft_width, fft_input_output);
|
|
}
|
|
|
|
void PrepareInputBuffer(const float* input_data, int input_height,
|
|
int input_width, int fft_height, int fft_width,
|
|
double** fft_input_output) {
|
|
int valid_input_height = std::min(input_height, fft_height);
|
|
int valid_input_width = std::min(input_width, fft_width);
|
|
for (int i = 0; i < valid_input_height; ++i) {
|
|
int in_pos = i * input_width;
|
|
for (int j = 0; j < valid_input_width; ++j) {
|
|
fft_input_output[i][j] = input_data[in_pos++];
|
|
}
|
|
// Zero-pad the rest of the input buffer
|
|
for (int j = valid_input_width; j < fft_width + 2; ++j) {
|
|
fft_input_output[i][j] = 0;
|
|
}
|
|
}
|
|
|
|
// Zero-pad input buffer, if fft_height is greater than valid_input_height.
|
|
for (int i = valid_input_height; i < fft_height; ++i) {
|
|
for (int j = 0; j < fft_width + 2; ++j) {
|
|
fft_input_output[i][j] = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
void PrepareOutputBuffer(complex<float>* output_data, int fft_height,
|
|
int fft_width, double** fft_input_output) {
|
|
int cnt = 0;
|
|
for (int i = 0; i < fft_height; ++i) {
|
|
for (int j = 0; j < fft_width / 2 + 1; ++j) {
|
|
output_data[cnt++] = complex<float>(fft_input_output[i][j * 2],
|
|
fft_input_output[i][j * 2 + 1]);
|
|
}
|
|
}
|
|
}
|
|
|
|
TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
|
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
|
const float* input_data = GetTensorData<float>(input);
|
|
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
|
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
|
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
complex<float>* output_data = GetTensorData<complex<float>>(output);
|
|
|
|
int fft_height, fft_width;
|
|
fft_height = fft_length_data[0];
|
|
fft_width = fft_length_data[1];
|
|
|
|
// FFT is processed for every slice on the inner most 2 dimensions.
|
|
// Count the number of slices in the input tensor.
|
|
const RuntimeShape input_shape = GetTensorShape(input);
|
|
const int input_dims_count = input_shape.DimensionsCount();
|
|
const auto* input_dims_data = input_shape.DimsData();
|
|
int num_slices = 1;
|
|
for (int i = 0; i < input_dims_count - 2; ++i) {
|
|
num_slices *= input_dims_data[i];
|
|
}
|
|
|
|
int input_height = input_dims_data[input_dims_count - 2];
|
|
int input_width = input_dims_data[input_dims_count - 1];
|
|
int input_slice_size = input_height * input_width;
|
|
int output_slice_size = fft_height * (fft_width / 2 + 1);
|
|
|
|
// Create input/output buffer for FFT
|
|
double** fft_input_output = new double*[fft_height];
|
|
for (int i = 0; i < fft_height; ++i) {
|
|
fft_input_output[i] = new double[fft_width + 2];
|
|
}
|
|
|
|
// Get buffer for integer working area.
|
|
TfLiteTensor* fft_integer_working_area =
|
|
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
|
int* fft_integer_working_area_data =
|
|
GetTensorData<int>(fft_integer_working_area);
|
|
|
|
// Get buffer for double working area.
|
|
TfLiteTensor* fft_double_working_area =
|
|
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
|
// Get double value out of the memory of fft_double_working_area_data.
|
|
double* fft_double_working_area_data = reinterpret_cast<double*>(
|
|
GetTensorData<int64_t>(fft_double_working_area));
|
|
|
|
// Process every slice in the input buffer
|
|
for (int i = 0; i < num_slices; ++i) {
|
|
PrepareInputBuffer(input_data, input_height, input_width, fft_height,
|
|
fft_width, fft_input_output);
|
|
memset(fft_integer_working_area_data, 0, fft_integer_working_area->bytes);
|
|
memset(fft_double_working_area_data, 0, fft_double_working_area->bytes);
|
|
Rfft2dImpl(fft_height, fft_width, fft_input_output,
|
|
fft_integer_working_area_data, fft_double_working_area_data);
|
|
PrepareOutputBuffer(output_data, fft_height, fft_width, fft_input_output);
|
|
input_data += input_slice_size;
|
|
output_data += output_slice_size;
|
|
}
|
|
|
|
// Delete the input buffer
|
|
for (int i = 0; i < fft_height; ++i) {
|
|
delete[] fft_input_output[i];
|
|
}
|
|
delete[] fft_input_output;
|
|
|
|
return kTfLiteOk;
|
|
}
|
|
|
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
|
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
|
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
|
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
|
|
|
if (output->type != kTfLiteComplex64) {
|
|
context->ReportError(context,
|
|
"Type '%s' for output is not supported by rfft2d.",
|
|
TfLiteTypeGetName(output->type));
|
|
return kTfLiteError;
|
|
}
|
|
|
|
// Resize the output tensor if the fft_length tensor is not constant.
|
|
// Otherwise, check if the output shape is correct.
|
|
if (!IsConstantTensor(fft_length)) {
|
|
TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
|
|
} else {
|
|
int num_dims_output = NumDimensions(output);
|
|
const RuntimeShape output_shape = GetTensorShape(output);
|
|
TF_LITE_ENSURE_EQ(context, num_dims_output, NumDimensions(input));
|
|
TF_LITE_ENSURE(context, num_dims_output >= 2);
|
|
TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 2),
|
|
fft_length_data[0]);
|
|
TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 1),
|
|
fft_length_data[1] / 2 + 1);
|
|
}
|
|
|
|
return Rfft2dHelper(context, node);
|
|
}
|
|
|
|
} // namespace rfft2d
|
|
|
|
TfLiteRegistration* Register_RFFT2D() {
|
|
static TfLiteRegistration r = {rfft2d::Init, rfft2d::Free, rfft2d::Prepare,
|
|
rfft2d::Eval};
|
|
return &r;
|
|
}
|
|
|
|
} // namespace custom
|
|
} // namespace ops
|
|
} // namespace tflite
|