Implement Rfft2d as a custom ops.

PiperOrigin-RevId: 252874759
This commit is contained in:
Lu Wang 2019-06-12 12:18:09 -07:00 committed by TensorFlower Gardener
parent 73a5307689
commit 03195f1345
19 changed files with 823 additions and 44 deletions

View File

@ -40,7 +40,7 @@ readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/ar
# TODO (yongtang): Replace the following with 'http://mirror.tensorflow.org/github.com/google/re2/.*tar\.gz' once
# the archive has been propagated in mirror.tensorflow.org.
RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
FFT2D_URL="$(grep -o 'http.*fft2d\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)"
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
CUB_URL="$(grep -o 'https.*cub/archive.*zip' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)"

View File

@ -315,6 +315,7 @@ def generated_test_models():
"resolve_constant_strided_slice",
"reverse_sequence",
"reverse_v2",
"rfft2d",
"round",
"rsqrt",
"shape",

View File

@ -436,6 +436,23 @@ cc_library(
],
)
cc_library(
name = "custom_ops",
srcs = ["rfft2d.cc"],
hdrs = ["custom_ops_register.h"],
deps = [
":kernel_util",
":op_macros",
"//tensorflow/lite:context",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:tensor",
"//third_party/fft2d:fft2d_headers",
"@fft2d",
"@gemmlowp//:profiler",
],
)
cc_library(
name = "lstm_eval",
srcs = ["lstm_eval.cc"],
@ -1619,6 +1636,19 @@ cc_test(
],
)
cc_test(
name = "rfft2d_test",
size = "small",
srcs = ["rfft2d_test.cc"],
deps = [
":custom_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_CUSTOM_OPS_REGISTER_H_
#define TENSORFLOW_LITE_KERNELS_CUSTOM_OPS_REGISTER_H_
#include "tensorflow/lite/context.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_RFFT2D();
}
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_CUSTOM_OPS_REGISTER_H_

View File

@ -0,0 +1,426 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/fft2d/fft2d.h"
#include "profiling/instrumentation.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace custom {
namespace rfft2d {
using std::complex;
constexpr int kInputTensor = 0;
constexpr int kFftLengthTensor = 1;
constexpr int kOutputTensor = 0;
constexpr int kFftIntegerWorkingAreaTensor = 0;
constexpr int kFftDoubleWorkingAreaTensor = 1;
constexpr int kTensorNotAllocated = -1;
struct OpData {
// IDs are the arbitrary identifiers used by TF Lite to identify and access
// memory buffers.
int fft_integer_working_area_id = kTensorNotAllocated;
int fft_double_working_area_id = kTensorNotAllocated;
};
bool IsPowerOfTwo(uint32_t v) { return v && !(v & (v - 1)); }
static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
// The prepare function may be executed multiple times. But temporary tensors
// only need to be initiated once.
if (data->fft_integer_working_area_id != kTensorNotAllocated &&
data->fft_double_working_area_id != kTensorNotAllocated) {
return kTfLiteOk;
}
TfLiteIntArrayFree(node->temporaries);
// Create two temporary tensors.
node->temporaries = TfLiteIntArrayCreate(2);
int first_new_index;
TF_LITE_ENSURE_STATUS(context->AddTensors(context, 2, &first_new_index));
node->temporaries->data[kFftIntegerWorkingAreaTensor] = first_new_index;
data->fft_integer_working_area_id = first_new_index;
node->temporaries->data[kFftDoubleWorkingAreaTensor] = first_new_index + 1;
data->fft_double_working_area_id = first_new_index + 1;
// Set up FFT integer working area buffer.
TfLiteTensor* fft_integer_working_area =
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
fft_integer_working_area->type = kTfLiteInt32;
// If fft_length is not a constant tensor, fft_integer_working_area will be
// set to dynamic later in Prepare.
fft_integer_working_area->allocation_type = kTfLiteArenaRw;
// Set up FFT double working area buffer.
TfLiteTensor* fft_double_working_area =
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
// fft_double_working_area is a double tensor. Ideally, double should be
// added into tflite data types. However, since fft_double_working_area is a
// temporary tensor, and there are no ops having double input/output tensors
// in tflite at this point, adding double as a tflite data type may confuse
// users that double is supported. As a results, kTfLiteInt64 is used here
// for memory allocation. And it will be cast into double in Eval when being
// used.
fft_double_working_area->type = kTfLiteInt64;
// If fft_length is not a constant tensor, fft_double_working_area will be
// set to dynamic later in Prepare.
fft_double_working_area->allocation_type = kTfLiteArenaRw;
return kTfLiteOk;
}
TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const int num_dims = NumDimensions(input);
TF_LITE_ENSURE(context, num_dims >= 2);
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
// The lib, fft2d, can only handle fft_lengths of power of 2.
TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0]));
TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[1]));
int fft_height, fft_width;
fft_height = fft_length_data[0];
fft_width = fft_length_data[1];
int fft_working_length = std::max(fft_height, fft_width / 2);
int half_fft_working_length = fft_working_length / 2;
// Resize output tensor.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
output_shape->data[num_dims - 2] = fft_length_data[0];
output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1;
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
// Resize temporary tensors, fft_integer_working_area.
TfLiteTensor* fft_integer_working_area =
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1);
fft_integer_working_area_shape->data[0] =
2 + static_cast<int>(sqrt(fft_working_length));
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_integer_working_area,
fft_integer_working_area_shape));
// Resize temporary tensors, fft_double_working_area.
TfLiteTensor* fft_double_working_area =
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1);
fft_double_working_area_shape->data[0] =
half_fft_working_length + fft_width / 4;
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_double_working_area,
fft_double_working_area_shape));
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = new OpData;
return data;
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
// Check type and shape of the input tensor
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, NumDimensions(input) >= 2);
if (input->type != kTfLiteFloat32) {
context->ReportError(context,
"Type '%s' for input is not supported by rfft2d.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
// Check type and shape of the fft_length tensor
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
const RuntimeShape fft_length_shape = GetTensorShape(fft_length);
TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1);
TF_LITE_ENSURE_EQ(context, fft_length_shape.Dims(0), 2);
if (fft_length->type != kTfLiteInt32) {
context->ReportError(context,
"Type '%s' for fft_length is not supported by rfft2d.",
TfLiteTypeGetName(fft_length->type));
return kTfLiteError;
}
// Setup temporary tensors for fft computation.
TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node));
// Set output type
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
output->type = kTfLiteComplex64;
// Exit early if fft_length is a non-const tensor. Set output tensor and
// temporary tensors to dynamic, so that their tensor sizes can be determined
// in Eval.
if (!IsConstantTensor(fft_length)) {
TfLiteTensor* fft_integer_working_area =
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
TfLiteTensor* fft_double_working_area =
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
SetTensorToDynamic(fft_integer_working_area);
SetTensorToDynamic(fft_double_working_area);
SetTensorToDynamic(output);
return kTfLiteOk;
}
TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
return kTfLiteOk;
}
// Reorder the result so that it matches the pattern of tf.signal.rfft2d.
// In tf.signal.fft2d the frequency matrix of a 4x4 input is
// [[F(0, 0), F(0, 1/4), F(0, 2/4)],
// [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
// [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
// [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
// While in rdft2d, the frequency matrix of a 4x4 input is
// [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0],
// [ F(-1/4, 0), F(-1/4, -1/4), 0],
// [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
// [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]]
// Since real fft has the property that
// Real(u,v) = Real(-u, -v)
// Img(u,v) = - Img(-u, -v)
// Result of rdft2d can be reordered and match the pattern of tf.signal.rfft2d.
// For example,
// Real(-3/4, 0) = Real(1/4, 0) = Real(-1/4, 0)
// Img(-3/4, 0) = Img(1/4, 0) = -Img(-1/4, 0)
void Rfft2dReorder(int fft_height, int fft_width, double** fft_input_output) {
int fft_height_half;
gemmlowp::ScopedProfilingLabel label("Rfft2dReorder");
double real, img;
fft_height_half = fft_height >> 1;
// Use 4x4 input as an example, reorder the frequency matrix from
// [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0],
// [ F(-1/4, 0), F(-1/4, -1/4), 0],
// [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
// [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]]
// to
// [[F(0, 0), F(0, -1/4), F(0, -2/4)],
// [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
// [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
// [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
for (int i = fft_height_half + 1; i < fft_height; ++i) {
real = fft_input_output[i][0];
img = fft_input_output[i][1];
fft_input_output[i][fft_width] = img;
fft_input_output[i][fft_width + 1] = real;
fft_input_output[fft_height - i][fft_width] = img;
fft_input_output[fft_height - i][fft_width + 1] = -real;
fft_input_output[i][0] = fft_input_output[fft_height - i][0];
fft_input_output[i][1] = -fft_input_output[fft_height - i][1];
}
fft_input_output[0][fft_width] = fft_input_output[0][1];
fft_input_output[0][fft_width + 1] = 0;
fft_input_output[0][1] = 0;
fft_input_output[fft_height_half][fft_width] =
fft_input_output[fft_height_half][1];
fft_input_output[fft_height_half][fft_width + 1] = 0;
fft_input_output[fft_height_half][1] = 0;
// Reorder the frequency matrix from
// [[F(0, 0), F(0, -1/4), F(0, -2/4)],
// [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
// [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
// [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
// to
// [[F(0, 0), F(0, 1/4), F(0, 2/4)],
// [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
// [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
// [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
for (int i = 0; i < fft_height; ++i) {
for (int j = 1; j < fft_width + 2; j += 2) {
fft_input_output[i][j] = -fft_input_output[i][j];
}
}
}
void Rfft2dImpl(int fft_height, int fft_width, double** fft_input_output,
int* fft_integer_working_area_data,
double* fft_double_working_area_data) {
gemmlowp::ScopedProfilingLabel label("Rfft2dImpl");
// Working data areas for the FFT routines.
double* fft_dynamic_working_area = nullptr;
const int kForwardFft = 1;
rdft2d(fft_height, fft_width, kForwardFft, fft_input_output,
fft_dynamic_working_area, fft_integer_working_area_data,
fft_double_working_area_data);
Rfft2dReorder(fft_height, fft_width, fft_input_output);
}
void PrepareInputBuffer(const float* input_data, int input_height,
int input_width, int fft_height, int fft_width,
double** fft_input_output) {
int valid_input_height = std::min(input_height, fft_height);
int valid_input_width = std::min(input_width, fft_width);
for (int i = 0; i < valid_input_height; ++i) {
int in_pos = i * input_width;
for (int j = 0; j < valid_input_width; ++j) {
fft_input_output[i][j] = input_data[in_pos++];
}
// Zero-pad the rest of the input buffer
for (int j = valid_input_width; j < fft_width + 2; ++j) {
fft_input_output[i][j] = 0;
}
}
// Zero-pad input buffer, if fft_height is greater than valid_input_height.
for (int i = valid_input_height; i < fft_height; ++i) {
for (int j = 0; j < fft_width + 2; ++j) {
fft_input_output[i][j] = 0;
}
}
}
void PrepareOutputBuffer(complex<float>* output_data, int fft_height,
int fft_width, double** fft_input_output) {
int cnt = 0;
for (int i = 0; i < fft_height; ++i) {
for (int j = 0; j < fft_width / 2 + 1; ++j) {
output_data[cnt++] = complex<float>(fft_input_output[i][j * 2],
fft_input_output[i][j * 2 + 1]);
}
}
}
TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const float* input_data = GetTensorData<float>(input);
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
complex<float>* output_data = GetTensorData<complex<float>>(output);
int fft_height, fft_width;
fft_height = fft_length_data[0];
fft_width = fft_length_data[1];
// FFT is processed for every slice on the inner most 2 dimensions.
// Count the number of slices in the input tensor.
const RuntimeShape input_shape = GetTensorShape(input);
const int input_dims_count = input_shape.DimensionsCount();
const auto* input_dims_data = input_shape.DimsData();
int num_slices = 1;
for (int i = 0; i < input_dims_count - 2; ++i) {
num_slices *= input_dims_data[i];
}
int input_height = input_dims_data[input_dims_count - 2];
int input_width = input_dims_data[input_dims_count - 1];
int input_slice_size = input_height * input_width;
int output_slice_size = fft_height * (fft_width / 2 + 1);
// Create input/output buffer for FFT
double** fft_input_output = new double*[fft_height];
for (int i = 0; i < fft_height; ++i) {
fft_input_output[i] = new double[fft_width + 2];
}
// Get buffer for integer working area.
TfLiteTensor* fft_integer_working_area =
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
int* fft_integer_working_area_data =
GetTensorData<int>(fft_integer_working_area);
// Get buffer for double working area.
TfLiteTensor* fft_double_working_area =
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
// Get double value out of the memory of fft_double_working_area_data.
double* fft_double_working_area_data = reinterpret_cast<double*>(
GetTensorData<int64_t>(fft_double_working_area));
// Process evert slice in the input buffer
for (int i = 0; i < num_slices; ++i) {
PrepareInputBuffer(input_data, input_height, input_width, fft_height,
fft_width, fft_input_output);
memset(fft_integer_working_area_data, 0, fft_integer_working_area->bytes);
memset(fft_double_working_area_data, 0, fft_double_working_area->bytes);
Rfft2dImpl(fft_height, fft_width, fft_input_output,
fft_integer_working_area_data, fft_double_working_area_data);
PrepareOutputBuffer(output_data, fft_height, fft_width, fft_input_output);
input_data += input_slice_size;
output_data += output_slice_size;
}
// Delete the input buffer
for (int i = 0; i < fft_height; ++i) {
delete[] fft_input_output[i];
}
delete[] fft_input_output;
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type != kTfLiteComplex64) {
context->ReportError(context,
"Type '%s' for output is not supported by rfft2d.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
// Resize the output tensor if the fft_length tensor is not constant.
// Otherwise, check if the output shape is correct.
if (!IsConstantTensor(fft_length)) {
TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
} else {
int num_dims_output = NumDimensions(output);
const RuntimeShape output_shape = GetTensorShape(output);
TF_LITE_ENSURE_EQ(context, num_dims_output, NumDimensions(input));
TF_LITE_ENSURE(context, num_dims_output >= 2);
TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 2),
fft_length_data[0]);
TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 1),
fft_length_data[1] / 2 + 1);
}
return Rfft2dHelper(context, node);
}
} // namespace rfft2d
TfLiteRegistration* Register_RFFT2D() {
static TfLiteRegistration r = {rfft2d::Init, rfft2d::Free, rfft2d::Prepare,
rfft2d::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,154 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <initializer_list>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_RFFT2D();
namespace {
using std::complex;
using ::testing::ElementsAreArray;
class Rfft2dOpModel : public SingleOpModel {
public:
Rfft2dOpModel(const TensorData& input, const TensorData& fft_lengths) {
input_ = AddInput(input);
fft_lengths_ = AddInput(fft_lengths);
TensorType output_type = TensorType_COMPLEX64;
output_ = AddOutput({output_type, {}});
const std::vector<uint8_t> custom_option;
SetCustomOp("Rfft2d", custom_option, Register_RFFT2D);
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
int fft_lengths() { return fft_lengths_; }
std::vector<complex<float>> GetOutput() {
return ExtractVector<complex<float>>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_;
int fft_lengths_;
int output_;
};
TEST(Rfft2dOpTest, FftLengthMatchesInputSize) {
Rfft2dOpModel model({TensorType_FLOAT32, {4, 4}}, {TensorType_INT32, {2}});
// clang-format off
model.PopulateTensor<float>(model.input(),
{1, 2, 3, 4,
3, 8, 6, 3,
5, 2, 7, 6,
9, 5, 8, 3});
// clang-format on
model.PopulateTensor<int32_t>(model.fft_lengths(), {4, 4});
model.Invoke();
std::complex<float> expected_result[12] = {
{75, 0}, {-6, -1}, {9, 0}, {-10, 5}, {-3, 2}, {-6, 11},
{-15, 0}, {-2, 13}, {-5, 0}, {-10, -5}, {3, -6}, {-6, -11}};
EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result));
}
TEST(Rfft2dOpTest, FftLengthSmallerThanInputSize) {
Rfft2dOpModel model({TensorType_FLOAT32, {4, 5}}, {TensorType_INT32, {2}});
// clang-format off
model.PopulateTensor<float>(model.input(),
{1, 2, 3, 4, 0,
3, 8, 6, 3, 0,
5, 2, 7, 6, 0,
9, 5, 8, 3, 0});
// clang-format on
model.PopulateTensor<int32_t>(model.fft_lengths(), {4, 4});
model.Invoke();
std::complex<float> expected_result[12] = {
{75, 0}, {-6, -1}, {9, 0}, {-10, 5}, {-3, 2}, {-6, 11},
{-15, 0}, {-2, 13}, {-5, 0}, {-10, -5}, {3, -6}, {-6, -11}};
EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result));
}
TEST(Rfft2dOpTest, FftLengthGreaterThanInputSize) {
Rfft2dOpModel model({TensorType_FLOAT32, {3, 4}}, {TensorType_INT32, {2}});
// clang-format off
model.PopulateTensor<float>(model.input(),
{1, 2, 3, 4,
3, 8, 6, 3,
5, 2, 7, 6});
// clang-format on
model.PopulateTensor<int32_t>(model.fft_lengths(), {4, 8});
model.Invoke();
// clang-format off
std::complex<float> expected_result[20] = {
{50, 0}, {8.29289341, -33.6776695}, {-7, 1}, {9.70710659, -1.67766953},
{0, 0},
{-10, -20}, {-16.3639603, -1.12132037}, {-5, 1}, {-7.19238806, -2.05025244},
{-6, 2},
{10, 0}, {-4.7781744, -6.12132025}, {-1, 11}, {10.7781744, 1.87867963},
{4, 0},
{-10, 20}, {11.1923885, 11.9497471}, {5, -5}, {-3.63603902, -3.12132025},
{-6, -2}};
// clang-format on
EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result));
}
TEST(Rfft2dOpTest, InputDimsGreaterThan2) {
Rfft2dOpModel model({TensorType_FLOAT32, {2, 2, 4}}, {TensorType_INT32, {2}});
// clang-format off
model.PopulateTensor<float>(model.input(),
{1., 2., 3., 4.,
3., 8., 6., 3.,
5., 2., 7., 6.,
7., 3., 23., 5.});
// clang-format on
model.PopulateTensor<int32_t>(model.fft_lengths(), {2, 4});
model.Invoke();
// clang-format off
std::complex<float> expected_result[12] = {
{30., 0.}, {-5, -3.}, { -4., 0.},
{-10., 0.}, {1., 7.}, { 0., 0.},
{58., 0.}, {-18., 6.}, { 26., 0.},
{-18., 0.}, { 14., 2.}, {-18., 0.}};
// clang-format on
EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result));
}
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -188,6 +188,7 @@ cc_library(
"//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:custom_ops",
"//tensorflow/lite/kernels:reference_ops",
"@com_google_absl//absl/strings",
],

View File

@ -56,6 +56,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.ops import rnn
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import spectral_ops_test_util
RANDOM_SEED = 342
@ -5070,6 +5071,39 @@ def make_unfused_gru_tests(options):
build_inputs,
use_frozen_graph=True)
@register_make_test_function()
def make_rfft2d_tests(options):
"""Make a set of tests to do rfft2d."""
test_parameters = [{
"input_dtype": [tf.float32],
"input_shape": [[8, 8], [3, 8, 8]],
"fft_length": [
None, [4, 4], [4, 8], [8, 4], [8, 8], [8, 16], [16, 8], [16, 16]
]
}]
def build_graph(parameters):
input_value = tf.placeholder(
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
with spectral_ops_test_util.fft_kernel_label_map():
outs = tf.signal.rfft2d(input_value, fft_length=parameters["fft_length"])
return [input_value], [outs]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
return [input_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value])))
extra_toco_options = ExtraTocoOptions()
extra_toco_options.allow_custom_ops = True
make_zip_of_tests(options, test_parameters, build_graph, build_inputs,
extra_toco_options)
# Toco binary path provided by the generate rule.
bin_path = None

View File

@ -15,10 +15,14 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_TESTING_SPLIT_H_
#define TENSORFLOW_LITE_TESTING_SPLIT_H_
#include <algorithm>
#include <complex>
#include <cstdlib>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/lite/string.h"
namespace tflite {
@ -99,6 +103,26 @@ inline std::vector<bool> Split(const string& s, const string& delimiter) {
return fields;
}
template <>
inline std::vector<std::complex<float>> Split(const string& s,
const string& delimiter) {
std::vector<std::complex<float>> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
std::string sc = s.substr(p.first, p.second - p.first);
std::string::size_type sz_real, sz_img;
float real = std::stof(sc, &sz_real);
float img = std::stof(sc.substr(sz_real), &sz_img);
if (sz_real + sz_img + 1 != sc.length()) {
std::cerr << "There were errors in parsing string, " << sc
<< ", to complex value." << std::endl;
return fields;
}
std::complex<float> c(real, img);
fields.push_back(c);
}
return fields;
}
} // namespace testing
} // namespace tflite

View File

@ -14,9 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/tflite_driver.h"
#include <complex>
#include "absl/strings/escaping.h"
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register_ref.h"
#include "tensorflow/lite/string_util.h"
@ -58,6 +61,11 @@ bool Value(const TfLitePtrUnion& data, int index) {
return data.b[index];
}
template <>
std::complex<float> Value(const TfLitePtrUnion& data, int index) {
return std::complex<float>(data.c64[index].re, data.c64[index].im);
}
template <typename T>
void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
T* input_ptr = reinterpret_cast<T*>(data->raw);
@ -123,7 +131,29 @@ class TfLiteDriver::Expectation {
}
private:
template <typename T>
bool CompareTwoValuesHelper(float v1, float v2) {
float diff = std::abs(v1 - v2);
bool error_is_large = false;
// For very small numbers, try absolute error, otherwise go with
// relative.
if (std::abs(v2) < relative_threshold_) {
error_is_large = (diff > absolute_threshold_);
} else {
error_is_large = (diff > relative_threshold_ * std::abs(v2));
}
return error_is_large;
}
bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
return CompareTwoValues(v1.real(), v2.real()) ||
CompareTwoValues(v1.imag(), v2.imag());
}
bool CompareTwoValues(float v1, float v2) {
return CompareTwoValuesHelper(v1, v2);
}
template <typename T, typename TS>
bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
size_t tensor_size = tensor.bytes / sizeof(T);
@ -136,18 +166,9 @@ class TfLiteDriver::Expectation {
bool good_output = true;
for (int i = 0; i < tensor_size; ++i) {
float computed = Value<T>(tensor.data, i);
float reference = Value<T>(data_, i);
float diff = std::abs(computed - reference);
bool error_is_large = false;
// For very small numbers, try absolute error, otherwise go with
// relative.
if (std::abs(reference) < relative_threshold_) {
error_is_large = (diff > absolute_threshold_);
} else {
error_is_large = (diff > relative_threshold_ * std::abs(reference));
}
if (error_is_large) {
TS computed = Value<T>(tensor.data, i);
TS reference = Value<T>(data_, i);
if (CompareTwoValues(computed, reference)) {
good_output = false;
if (verbose) {
std::cerr << " index " << i << ": got " << computed
@ -158,6 +179,8 @@ class TfLiteDriver::Expectation {
return good_output;
}
bool TypedCheckString(bool verbose, const TfLiteTensor& tensor);
TfLitePtrUnion data_;
size_t num_elements_;
double relative_threshold_;
@ -171,9 +194,8 @@ void TfLiteDriver::Expectation::SetData<string>(const string& csv_values) {
memcpy(data_.raw, s.data(), s.size());
}
template <>
bool TfLiteDriver::Expectation::TypedCheck<string>(bool verbose,
const TfLiteTensor& tensor) {
bool TfLiteDriver::Expectation::TypedCheckString(bool verbose,
const TfLiteTensor& tensor) {
if (tensor.data.raw == nullptr) {
if (verbose) {
std::cerr << " got empty string" << std::endl;
@ -215,19 +237,22 @@ bool TfLiteDriver::Expectation::Check(bool verbose,
const TfLiteTensor& tensor) {
switch (tensor.type) {
case kTfLiteFloat32:
return TypedCheck<float>(verbose, tensor);
return TypedCheck<float, float>(verbose, tensor);
case kTfLiteInt32:
return TypedCheck<int32_t>(verbose, tensor);
return TypedCheck<int32_t, float>(verbose, tensor);
case kTfLiteInt64:
return TypedCheck<int64_t>(verbose, tensor);
return TypedCheck<int64_t, float>(verbose, tensor);
case kTfLiteUInt8:
return TypedCheck<uint8_t>(verbose, tensor);
return TypedCheck<uint8_t, float>(verbose, tensor);
case kTfLiteInt8:
return TypedCheck<int8_t>(verbose, tensor);
return TypedCheck<int8_t, float>(verbose, tensor);
case kTfLiteBool:
return TypedCheck<bool>(verbose, tensor);
return TypedCheck<bool, float>(verbose, tensor);
case kTfLiteString:
return TypedCheck<string>(verbose, tensor);
return TypedCheckString(verbose, tensor);
case kTfLiteComplex64:
return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
tensor);
default:
fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
return false;
@ -243,6 +268,10 @@ TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name,
resolver_.reset(new ops::builtin::BuiltinRefOpResolver);
} else {
resolver_.reset(new ops::builtin::BuiltinOpResolver);
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
buildinop_resolver_->AddCustom("RFFT2D",
tflite::ops::custom::Register_RFFT2D());
}
if (delegate_name == "FLEX") {
@ -402,6 +431,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
case kTfLiteString:
expected_output_[id]->SetData<string>(csv_values);
break;
case kTfLiteComplex64:
expected_output_[id]->SetData<std::complex<float>>(csv_values);
break;
default:
Invalidate(absl::StrCat("Unsupported tensor type ",
TfLiteTypeGetName(tensor->type),

View File

@ -36,7 +36,7 @@ ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_
NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip"
FARMHASH_URL="http://mirror.tensorflow.org/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz"
FLATBUFFERS_URL="http://mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.11.0.tar.gz"
FFT2D_URL="http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz"
FFT2D_URL="http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz"
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
# so work around it by patching the source.

View File

@ -181,6 +181,11 @@ tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/icu/udata.patch
tensorflow/third_party/fft2d/fft2d.h
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/system.BUILD.tpl
@ -188,10 +193,6 @@ tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/boringssl/BUILD
tensorflow/third_party/mpi/.gitignore
tensorflow/third_party/mpi/BUILD

View File

@ -148,7 +148,7 @@ genrule(
"@double_conversion//:LICENSE",
"@eigen_archive//:COPYING.MPL2",
"@farmhash_archive//:COPYING",
"@fft2d//:fft/readme.txt",
"@fft2d//:fft2d/readme2d.txt",
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
@ -219,7 +219,7 @@ genrule(
"@double_conversion//:LICENSE",
"@eigen_archive//:COPYING.MPL2",
"@farmhash_archive//:COPYING",
"@fft2d//:fft/readme.txt",
"@fft2d//:fft2d/readme2d.txt",
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",

View File

@ -169,7 +169,7 @@ filegroup(
"@eigen_archive//:COPYING.MPL2",
"@enum34_archive//:LICENSE",
"@farmhash_archive//:COPYING",
"@fft2d//:fft/readme.txt",
"@fft2d//:fft2d/readme2d.txt",
"@flatbuffers//:LICENSE.txt",
"@gast_archive//:PKG-INFO",
"@gemmlowp//:LICENSE",

View File

@ -580,10 +580,10 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "fft2d",
build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9",
urls = [
"http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
"http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
],
)

View File

@ -1,5 +1,5 @@
# Headers for 2D Fast Fourier Transform package
# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html
# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html
# This is a separate package because the original downloaded archive doesn't
# contain any header files.
@ -15,18 +15,27 @@ exports_files(["LICENSE"])
cc_library(
name = "fft2d_headers",
srcs = ["fft.h"],
srcs = [
"fft.h",
"fft2d.h",
],
)
objc_library(
name = "fft2d_headersd_ios",
srcs = ["fft.h"],
srcs = [
"fft.h",
"fft2d.h",
],
)
# Export the source code so that it could be compiled for Andoid native apps.
filegroup(
name = "fft2d_headers_srcs",
srcs = ["fft.h"],
srcs = [
"fft.h",
"fft2d.h",
],
)
filegroup(

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Declarations for 1D FFT routines in third_party/fft2d/fft.
// Declarations for 1D FFT routines in third_party/fft2d/fft2d.
#ifndef FFT2D_FFT_H__
#define FFT2D_FFT_H__

View File

@ -1,5 +1,5 @@
# 2D Fast Fourier Transform package
# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html
# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html
package(
default_visibility = ["//visibility:public"],
@ -8,10 +8,11 @@ package(
# Unrestricted use; can only distribute original package.
licenses(["notice"])
exports_files(["fft/readme.txt"])
exports_files(["fft2d/readme2d.txt"])
FFT2D_SRCS = [
"fft/fftsg.c",
"fft2d/fftsg.c",
"fft2d/fftsg2d.c",
]
config_setting(
@ -22,7 +23,7 @@ config_setting(
# This is the main 2D FFT library. The 2D FFTs in this library call
# 1D FFTs. In addition, fast DCTs are provided for the special case
# of 8x8 and 16x16. This code in this library is referred to as
# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html.
# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html.
cc_library(
name = "fft2d",
srcs = FFT2D_SRCS,

36
third_party/fft2d/fft2d.h vendored Normal file
View File

@ -0,0 +1,36 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Declarations for 2D FFT routines in third_party/fft2d/fft2d.
#ifndef FFT2D_FFT_H__
#define FFT2D_FFT_H__
#ifdef __cplusplus
extern "C" {
#endif
extern void cdft2d(int, int, int, double **, double *, int *, double *);
extern void rdft2d(int, int, int, double **, double *, int *, double *);
extern void ddct2d(int, int, int, double **, double *, int *, double *);
extern void ddst2d(int, int, int, double **, double *, int *, double *);
extern void ddct8x8s(int isgn, double **a);
extern void ddct16x16s(int isgn, double **a);
#ifdef __cplusplus
}
#endif
#endif // FFT2D_FFT_H__