From 03195f13456354deea8b81c9e583621b1337b952 Mon Sep 17 00:00:00 2001 From: Lu Wang Date: Wed, 12 Jun 2019 12:18:09 -0700 Subject: [PATCH] Implement Rfft2d as a custom ops. PiperOrigin-RevId: 252874759 --- .../contrib/makefile/download_dependencies.sh | 2 +- tensorflow/lite/build_def.bzl | 1 + tensorflow/lite/kernels/BUILD | 30 ++ tensorflow/lite/kernels/custom_ops_register.h | 30 ++ tensorflow/lite/kernels/rfft2d.cc | 426 ++++++++++++++++++ tensorflow/lite/kernels/rfft2d_test.cc | 154 +++++++ tensorflow/lite/testing/BUILD | 1 + .../lite/testing/generate_examples_lib.py | 34 ++ tensorflow/lite/testing/split.h | 24 + tensorflow/lite/testing/tflite_driver.cc | 78 +++- .../lite/tools/make/download_dependencies.sh | 2 +- tensorflow/opensource_only.files | 9 +- tensorflow/tools/lib_package/BUILD | 4 +- tensorflow/tools/pip_package/BUILD | 2 +- tensorflow/workspace.bzl | 6 +- third_party/fft2d/BUILD | 17 +- third_party/fft2d/fft.h | 2 +- third_party/fft2d/fft2d.BUILD | 9 +- third_party/fft2d/fft2d.h | 36 ++ 19 files changed, 823 insertions(+), 44 deletions(-) create mode 100644 tensorflow/lite/kernels/custom_ops_register.h create mode 100644 tensorflow/lite/kernels/rfft2d.cc create mode 100644 tensorflow/lite/kernels/rfft2d_test.cc create mode 100644 third_party/fft2d/fft2d.h diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index c41513a9096..1feca44f6e5 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -40,7 +40,7 @@ readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/ar # TODO (yongtang): Replace the following with 'http://mirror.tensorflow.org/github.com/google/re2/.*tar\.gz' once # the archive has been propagated in mirror.tensorflow.org. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +FFT2D_URL="$(grep -o 'http.*fft2d\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" CUB_URL="$(grep -o 'https.*cub/archive.*zip' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 15d7d99f074..128b6d33c70 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -315,6 +315,7 @@ def generated_test_models(): "resolve_constant_strided_slice", "reverse_sequence", "reverse_v2", + "rfft2d", "round", "rsqrt", "shape", diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index ea423a15970..de41561b91c 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -436,6 +436,23 @@ cc_library( ], ) +cc_library( + name = "custom_ops", + srcs = ["rfft2d.cc"], + hdrs = ["custom_ops_register.h"], + deps = [ + ":kernel_util", + ":op_macros", + "//tensorflow/lite:context", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels/internal:kernel_utils", + "//tensorflow/lite/kernels/internal:tensor", + "//third_party/fft2d:fft2d_headers", + "@fft2d", + "@gemmlowp//:profiler", + ], +) + cc_library( name = "lstm_eval", srcs = ["lstm_eval.cc"], @@ -1619,6 +1636,19 @@ cc_test( ], ) +cc_test( + name = "rfft2d_test", + size = "small", + srcs = ["rfft2d_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/lite/kernels/custom_ops_register.h b/tensorflow/lite/kernels/custom_ops_register.h new file mode 100644 index 00000000000..31d62d66c0d --- /dev/null +++ b/tensorflow/lite/kernels/custom_ops_register.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_CUSTOM_OPS_REGISTER_H_ +#define TENSORFLOW_LITE_KERNELS_CUSTOM_OPS_REGISTER_H_ + +#include "tensorflow/lite/context.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_RFFT2D(); + +} +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_CUSTOM_OPS_REGISTER_H_ diff --git a/tensorflow/lite/kernels/rfft2d.cc b/tensorflow/lite/kernels/rfft2d.cc new file mode 100644 index 00000000000..f3b5b0bf696 --- /dev/null +++ b/tensorflow/lite/kernels/rfft2d.cc @@ -0,0 +1,426 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/fft2d/fft2d.h" +#include "profiling/instrumentation.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace rfft2d { + +using std::complex; + +constexpr int kInputTensor = 0; +constexpr int kFftLengthTensor = 1; +constexpr int kOutputTensor = 0; +constexpr int kFftIntegerWorkingAreaTensor = 0; +constexpr int kFftDoubleWorkingAreaTensor = 1; +constexpr int kTensorNotAllocated = -1; + +struct OpData { + // IDs are the arbitrary identifiers used by TF Lite to identify and access + // memory buffers. + int fft_integer_working_area_id = kTensorNotAllocated; + int fft_double_working_area_id = kTensorNotAllocated; +}; + +bool IsPowerOfTwo(uint32_t v) { return v && !(v & (v - 1)); } + +static TfLiteStatus InitTemporaryTensors(TfLiteContext* context, + TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + // The prepare function may be executed multiple times. But temporary tensors + // only need to be initiated once. + if (data->fft_integer_working_area_id != kTensorNotAllocated && + data->fft_double_working_area_id != kTensorNotAllocated) { + return kTfLiteOk; + } + + TfLiteIntArrayFree(node->temporaries); + // Create two temporary tensors. + node->temporaries = TfLiteIntArrayCreate(2); + int first_new_index; + TF_LITE_ENSURE_STATUS(context->AddTensors(context, 2, &first_new_index)); + node->temporaries->data[kFftIntegerWorkingAreaTensor] = first_new_index; + data->fft_integer_working_area_id = first_new_index; + node->temporaries->data[kFftDoubleWorkingAreaTensor] = first_new_index + 1; + data->fft_double_working_area_id = first_new_index + 1; + + // Set up FFT integer working area buffer. + TfLiteTensor* fft_integer_working_area = + GetTemporary(context, node, kFftIntegerWorkingAreaTensor); + fft_integer_working_area->type = kTfLiteInt32; + // If fft_length is not a constant tensor, fft_integer_working_area will be + // set to dynamic later in Prepare. + fft_integer_working_area->allocation_type = kTfLiteArenaRw; + + // Set up FFT double working area buffer. + TfLiteTensor* fft_double_working_area = + GetTemporary(context, node, kFftDoubleWorkingAreaTensor); + // fft_double_working_area is a double tensor. Ideally, double should be + // added into tflite data types. However, since fft_double_working_area is a + // temporary tensor, and there are no ops having double input/output tensors + // in tflite at this point, adding double as a tflite data type may confuse + // users that double is supported. As a results, kTfLiteInt64 is used here + // for memory allocation. And it will be cast into double in Eval when being + // used. + fft_double_working_area->type = kTfLiteInt64; + // If fft_length is not a constant tensor, fft_double_working_area will be + // set to dynamic later in Prepare. + fft_double_working_area->allocation_type = kTfLiteArenaRw; + + return kTfLiteOk; +} + +TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context, + TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const int num_dims = NumDimensions(input); + TF_LITE_ENSURE(context, num_dims >= 2); + const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor); + const int32_t* fft_length_data = GetTensorData(fft_length); + // The lib, fft2d, can only handle fft_lengths of power of 2. + TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0])); + TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[1])); + + int fft_height, fft_width; + fft_height = fft_length_data[0]; + fft_width = fft_length_data[1]; + int fft_working_length = std::max(fft_height, fft_width / 2); + int half_fft_working_length = fft_working_length / 2; + + // Resize output tensor. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims); + output_shape->data[num_dims - 2] = fft_length_data[0]; + output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1; + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); + + // Resize temporary tensors, fft_integer_working_area. + TfLiteTensor* fft_integer_working_area = + GetTemporary(context, node, kFftIntegerWorkingAreaTensor); + TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1); + fft_integer_working_area_shape->data[0] = + 2 + static_cast(sqrt(fft_working_length)); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_integer_working_area, + fft_integer_working_area_shape)); + + // Resize temporary tensors, fft_double_working_area. + TfLiteTensor* fft_double_working_area = + GetTemporary(context, node, kFftDoubleWorkingAreaTensor); + TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1); + fft_double_working_area_shape->data[0] = + half_fft_working_length + fft_width / 4; + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_double_working_area, + fft_double_working_area_shape)); + + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Check type and shape of the input tensor + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, NumDimensions(input) >= 2); + if (input->type != kTfLiteFloat32) { + context->ReportError(context, + "Type '%s' for input is not supported by rfft2d.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + + // Check type and shape of the fft_length tensor + const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor); + const RuntimeShape fft_length_shape = GetTensorShape(fft_length); + + TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1); + TF_LITE_ENSURE_EQ(context, fft_length_shape.Dims(0), 2); + if (fft_length->type != kTfLiteInt32) { + context->ReportError(context, + "Type '%s' for fft_length is not supported by rfft2d.", + TfLiteTypeGetName(fft_length->type)); + return kTfLiteError; + } + + // Setup temporary tensors for fft computation. + TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node)); + + // Set output type + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + output->type = kTfLiteComplex64; + + // Exit early if fft_length is a non-const tensor. Set output tensor and + // temporary tensors to dynamic, so that their tensor sizes can be determined + // in Eval. + if (!IsConstantTensor(fft_length)) { + TfLiteTensor* fft_integer_working_area = + GetTemporary(context, node, kFftIntegerWorkingAreaTensor); + TfLiteTensor* fft_double_working_area = + GetTemporary(context, node, kFftDoubleWorkingAreaTensor); + SetTensorToDynamic(fft_integer_working_area); + SetTensorToDynamic(fft_double_working_area); + SetTensorToDynamic(output); + return kTfLiteOk; + } + + TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node)); + return kTfLiteOk; +} + +// Reorder the result so that it matches the pattern of tf.signal.rfft2d. +// In tf.signal.fft2d the frequency matrix of a 4x4 input is +// [[F(0, 0), F(0, 1/4), F(0, 2/4)], +// [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)], +// [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)], +// [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]] +// While in rdft2d, the frequency matrix of a 4x4 input is +// [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0], +// [ F(-1/4, 0), F(-1/4, -1/4), 0], +// [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0], +// [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]] +// Since real fft has the property that +// Real(u,v) = Real(-u, -v) +// Img(u,v) = - Img(-u, -v) +// Result of rdft2d can be reordered and match the pattern of tf.signal.rfft2d. +// For example, +// Real(-3/4, 0) = Real(1/4, 0) = Real(-1/4, 0) +// Img(-3/4, 0) = Img(1/4, 0) = -Img(-1/4, 0) +void Rfft2dReorder(int fft_height, int fft_width, double** fft_input_output) { + int fft_height_half; + gemmlowp::ScopedProfilingLabel label("Rfft2dReorder"); + double real, img; + + fft_height_half = fft_height >> 1; + // Use 4x4 input as an example, reorder the frequency matrix from + // [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0], + // [ F(-1/4, 0), F(-1/4, -1/4), 0], + // [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0], + // [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]] + // to + // [[F(0, 0), F(0, -1/4), F(0, -2/4)], + // [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)], + // [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)], + // [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]] + for (int i = fft_height_half + 1; i < fft_height; ++i) { + real = fft_input_output[i][0]; + img = fft_input_output[i][1]; + fft_input_output[i][fft_width] = img; + fft_input_output[i][fft_width + 1] = real; + fft_input_output[fft_height - i][fft_width] = img; + fft_input_output[fft_height - i][fft_width + 1] = -real; + fft_input_output[i][0] = fft_input_output[fft_height - i][0]; + fft_input_output[i][1] = -fft_input_output[fft_height - i][1]; + } + fft_input_output[0][fft_width] = fft_input_output[0][1]; + fft_input_output[0][fft_width + 1] = 0; + fft_input_output[0][1] = 0; + fft_input_output[fft_height_half][fft_width] = + fft_input_output[fft_height_half][1]; + fft_input_output[fft_height_half][fft_width + 1] = 0; + fft_input_output[fft_height_half][1] = 0; + + // Reorder the frequency matrix from + // [[F(0, 0), F(0, -1/4), F(0, -2/4)], + // [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)], + // [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)], + // [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]] + // to + // [[F(0, 0), F(0, 1/4), F(0, 2/4)], + // [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)], + // [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)], + // [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]] + for (int i = 0; i < fft_height; ++i) { + for (int j = 1; j < fft_width + 2; j += 2) { + fft_input_output[i][j] = -fft_input_output[i][j]; + } + } +} + +void Rfft2dImpl(int fft_height, int fft_width, double** fft_input_output, + int* fft_integer_working_area_data, + double* fft_double_working_area_data) { + gemmlowp::ScopedProfilingLabel label("Rfft2dImpl"); + + // Working data areas for the FFT routines. + double* fft_dynamic_working_area = nullptr; + const int kForwardFft = 1; + rdft2d(fft_height, fft_width, kForwardFft, fft_input_output, + fft_dynamic_working_area, fft_integer_working_area_data, + fft_double_working_area_data); + Rfft2dReorder(fft_height, fft_width, fft_input_output); +} + +void PrepareInputBuffer(const float* input_data, int input_height, + int input_width, int fft_height, int fft_width, + double** fft_input_output) { + int valid_input_height = std::min(input_height, fft_height); + int valid_input_width = std::min(input_width, fft_width); + for (int i = 0; i < valid_input_height; ++i) { + int in_pos = i * input_width; + for (int j = 0; j < valid_input_width; ++j) { + fft_input_output[i][j] = input_data[in_pos++]; + } + // Zero-pad the rest of the input buffer + for (int j = valid_input_width; j < fft_width + 2; ++j) { + fft_input_output[i][j] = 0; + } + } + + // Zero-pad input buffer, if fft_height is greater than valid_input_height. + for (int i = valid_input_height; i < fft_height; ++i) { + for (int j = 0; j < fft_width + 2; ++j) { + fft_input_output[i][j] = 0; + } + } +} + +void PrepareOutputBuffer(complex* output_data, int fft_height, + int fft_width, double** fft_input_output) { + int cnt = 0; + for (int i = 0; i < fft_height; ++i) { + for (int j = 0; j < fft_width / 2 + 1; ++j) { + output_data[cnt++] = complex(fft_input_output[i][j * 2], + fft_input_output[i][j * 2 + 1]); + } + } +} + +TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const float* input_data = GetTensorData(input); + const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor); + const int32_t* fft_length_data = GetTensorData(fft_length); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + complex* output_data = GetTensorData>(output); + + int fft_height, fft_width; + fft_height = fft_length_data[0]; + fft_width = fft_length_data[1]; + + // FFT is processed for every slice on the inner most 2 dimensions. + // Count the number of slices in the input tensor. + const RuntimeShape input_shape = GetTensorShape(input); + const int input_dims_count = input_shape.DimensionsCount(); + const auto* input_dims_data = input_shape.DimsData(); + int num_slices = 1; + for (int i = 0; i < input_dims_count - 2; ++i) { + num_slices *= input_dims_data[i]; + } + + int input_height = input_dims_data[input_dims_count - 2]; + int input_width = input_dims_data[input_dims_count - 1]; + int input_slice_size = input_height * input_width; + int output_slice_size = fft_height * (fft_width / 2 + 1); + + // Create input/output buffer for FFT + double** fft_input_output = new double*[fft_height]; + for (int i = 0; i < fft_height; ++i) { + fft_input_output[i] = new double[fft_width + 2]; + } + + // Get buffer for integer working area. + TfLiteTensor* fft_integer_working_area = + GetTemporary(context, node, kFftIntegerWorkingAreaTensor); + int* fft_integer_working_area_data = + GetTensorData(fft_integer_working_area); + + // Get buffer for double working area. + TfLiteTensor* fft_double_working_area = + GetTemporary(context, node, kFftDoubleWorkingAreaTensor); + // Get double value out of the memory of fft_double_working_area_data. + double* fft_double_working_area_data = reinterpret_cast( + GetTensorData(fft_double_working_area)); + + // Process evert slice in the input buffer + for (int i = 0; i < num_slices; ++i) { + PrepareInputBuffer(input_data, input_height, input_width, fft_height, + fft_width, fft_input_output); + memset(fft_integer_working_area_data, 0, fft_integer_working_area->bytes); + memset(fft_double_working_area_data, 0, fft_double_working_area->bytes); + Rfft2dImpl(fft_height, fft_width, fft_input_output, + fft_integer_working_area_data, fft_double_working_area_data); + PrepareOutputBuffer(output_data, fft_height, fft_width, fft_input_output); + input_data += input_slice_size; + output_data += output_slice_size; + } + + // Delete the input buffer + for (int i = 0; i < fft_height; ++i) { + delete[] fft_input_output[i]; + } + delete[] fft_input_output; + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor); + const int32_t* fft_length_data = GetTensorData(fft_length); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type != kTfLiteComplex64) { + context->ReportError(context, + "Type '%s' for output is not supported by rfft2d.", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + + // Resize the output tensor if the fft_length tensor is not constant. + // Otherwise, check if the output shape is correct. + if (!IsConstantTensor(fft_length)) { + TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node)); + } else { + int num_dims_output = NumDimensions(output); + const RuntimeShape output_shape = GetTensorShape(output); + TF_LITE_ENSURE_EQ(context, num_dims_output, NumDimensions(input)); + TF_LITE_ENSURE(context, num_dims_output >= 2); + TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 2), + fft_length_data[0]); + TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 1), + fft_length_data[1] / 2 + 1); + } + + return Rfft2dHelper(context, node); +} + +} // namespace rfft2d + +TfLiteRegistration* Register_RFFT2D() { + static TfLiteRegistration r = {rfft2d::Init, rfft2d::Free, rfft2d::Prepare, + rfft2d::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/rfft2d_test.cc b/tensorflow/lite/kernels/rfft2d_test.cc new file mode 100644 index 00000000000..d4b6a0a9d83 --- /dev/null +++ b/tensorflow/lite/kernels/rfft2d_test.cc @@ -0,0 +1,154 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/custom_ops_register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_RFFT2D(); + +namespace { + +using std::complex; +using ::testing::ElementsAreArray; + +class Rfft2dOpModel : public SingleOpModel { + public: + Rfft2dOpModel(const TensorData& input, const TensorData& fft_lengths) { + input_ = AddInput(input); + fft_lengths_ = AddInput(fft_lengths); + TensorType output_type = TensorType_COMPLEX64; + output_ = AddOutput({output_type, {}}); + + const std::vector custom_option; + SetCustomOp("Rfft2d", custom_option, Register_RFFT2D); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + int fft_lengths() { return fft_lengths_; } + + std::vector> GetOutput() { + return ExtractVector>(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int fft_lengths_; + int output_; +}; + +TEST(Rfft2dOpTest, FftLengthMatchesInputSize) { + Rfft2dOpModel model({TensorType_FLOAT32, {4, 4}}, {TensorType_INT32, {2}}); + // clang-format off + model.PopulateTensor(model.input(), + {1, 2, 3, 4, + 3, 8, 6, 3, + 5, 2, 7, 6, + 9, 5, 8, 3}); + // clang-format on + model.PopulateTensor(model.fft_lengths(), {4, 4}); + model.Invoke(); + + std::complex expected_result[12] = { + {75, 0}, {-6, -1}, {9, 0}, {-10, 5}, {-3, 2}, {-6, 11}, + {-15, 0}, {-2, 13}, {-5, 0}, {-10, -5}, {3, -6}, {-6, -11}}; + EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result)); +} + +TEST(Rfft2dOpTest, FftLengthSmallerThanInputSize) { + Rfft2dOpModel model({TensorType_FLOAT32, {4, 5}}, {TensorType_INT32, {2}}); + // clang-format off + model.PopulateTensor(model.input(), + {1, 2, 3, 4, 0, + 3, 8, 6, 3, 0, + 5, 2, 7, 6, 0, + 9, 5, 8, 3, 0}); + // clang-format on + model.PopulateTensor(model.fft_lengths(), {4, 4}); + model.Invoke(); + + std::complex expected_result[12] = { + {75, 0}, {-6, -1}, {9, 0}, {-10, 5}, {-3, 2}, {-6, 11}, + {-15, 0}, {-2, 13}, {-5, 0}, {-10, -5}, {3, -6}, {-6, -11}}; + EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result)); +} + +TEST(Rfft2dOpTest, FftLengthGreaterThanInputSize) { + Rfft2dOpModel model({TensorType_FLOAT32, {3, 4}}, {TensorType_INT32, {2}}); + // clang-format off + model.PopulateTensor(model.input(), + {1, 2, 3, 4, + 3, 8, 6, 3, + 5, 2, 7, 6}); + // clang-format on + model.PopulateTensor(model.fft_lengths(), {4, 8}); + model.Invoke(); + + // clang-format off + std::complex expected_result[20] = { + {50, 0}, {8.29289341, -33.6776695}, {-7, 1}, {9.70710659, -1.67766953}, + {0, 0}, + {-10, -20}, {-16.3639603, -1.12132037}, {-5, 1}, {-7.19238806, -2.05025244}, + {-6, 2}, + {10, 0}, {-4.7781744, -6.12132025}, {-1, 11}, {10.7781744, 1.87867963}, + {4, 0}, + {-10, 20}, {11.1923885, 11.9497471}, {5, -5}, {-3.63603902, -3.12132025}, + {-6, -2}}; + // clang-format on + EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result)); +} + +TEST(Rfft2dOpTest, InputDimsGreaterThan2) { + Rfft2dOpModel model({TensorType_FLOAT32, {2, 2, 4}}, {TensorType_INT32, {2}}); + // clang-format off + model.PopulateTensor(model.input(), + {1., 2., 3., 4., + 3., 8., 6., 3., + 5., 2., 7., 6., + 7., 3., 23., 5.}); + // clang-format on + model.PopulateTensor(model.fft_lengths(), {2, 4}); + model.Invoke(); + + // clang-format off + std::complex expected_result[12] = { + {30., 0.}, {-5, -3.}, { -4., 0.}, + {-10., 0.}, {1., 7.}, { 0., 0.}, + {58., 0.}, {-18., 6.}, { 26., 0.}, + {-18., 0.}, { 14., 2.}, {-18., 0.}}; + // clang-format on + EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result)); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index 2608f28e48e..e2eb79d713d 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -188,6 +188,7 @@ cc_library( "//tensorflow/lite:string_util", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:custom_ops", "//tensorflow/lite/kernels:reference_ops", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index a0d51100a1a..2a4e07412b8 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -56,6 +56,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.ops import rnn from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops_test_util RANDOM_SEED = 342 @@ -5070,6 +5071,39 @@ def make_unfused_gru_tests(options): build_inputs, use_frozen_graph=True) + +@register_make_test_function() +def make_rfft2d_tests(options): + """Make a set of tests to do rfft2d.""" + + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[8, 8], [3, 8, 8]], + "fft_length": [ + None, [4, 4], [4, 8], [8, 4], [8, 8], [8, 16], [16, 8], [16, 16] + ] + }] + + def build_graph(parameters): + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + with spectral_ops_test_util.fft_kernel_label_map(): + outs = tf.signal.rfft2d(input_value, fft_length=parameters["fft_length"]) + return [input_value], [outs] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + extra_toco_options = ExtraTocoOptions() + extra_toco_options.allow_custom_ops = True + make_zip_of_tests(options, test_parameters, build_graph, build_inputs, + extra_toco_options) + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/lite/testing/split.h b/tensorflow/lite/testing/split.h index 9732147dc51..b3ffab793af 100644 --- a/tensorflow/lite/testing/split.h +++ b/tensorflow/lite/testing/split.h @@ -15,10 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TESTING_SPLIT_H_ #define TENSORFLOW_LITE_TESTING_SPLIT_H_ +#include +#include #include +#include #include #include #include + #include "tensorflow/lite/string.h" namespace tflite { @@ -99,6 +103,26 @@ inline std::vector Split(const string& s, const string& delimiter) { return fields; } +template <> +inline std::vector> Split(const string& s, + const string& delimiter) { + std::vector> fields; + for (const auto& p : SplitToPos(s, delimiter)) { + std::string sc = s.substr(p.first, p.second - p.first); + std::string::size_type sz_real, sz_img; + float real = std::stof(sc, &sz_real); + float img = std::stof(sc.substr(sz_real), &sz_img); + if (sz_real + sz_img + 1 != sc.length()) { + std::cerr << "There were errors in parsing string, " << sc + << ", to complex value." << std::endl; + return fields; + } + std::complex c(real, img); + fields.push_back(c); + } + return fields; +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc index 0515bf90814..50981c5f101 100644 --- a/tensorflow/lite/testing/tflite_driver.cc +++ b/tensorflow/lite/testing/tflite_driver.cc @@ -14,9 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/testing/tflite_driver.h" +#include + #include "absl/strings/escaping.h" #include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/delegates/flex/delegate.h" +#include "tensorflow/lite/kernels/custom_ops_register.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register_ref.h" #include "tensorflow/lite/string_util.h" @@ -58,6 +61,11 @@ bool Value(const TfLitePtrUnion& data, int index) { return data.b[index]; } +template <> +std::complex Value(const TfLitePtrUnion& data, int index) { + return std::complex(data.c64[index].re, data.c64[index].im); +} + template void SetTensorData(const std::vector& values, TfLitePtrUnion* data) { T* input_ptr = reinterpret_cast(data->raw); @@ -123,7 +131,29 @@ class TfLiteDriver::Expectation { } private: - template + bool CompareTwoValuesHelper(float v1, float v2) { + float diff = std::abs(v1 - v2); + bool error_is_large = false; + // For very small numbers, try absolute error, otherwise go with + // relative. + if (std::abs(v2) < relative_threshold_) { + error_is_large = (diff > absolute_threshold_); + } else { + error_is_large = (diff > relative_threshold_ * std::abs(v2)); + } + return error_is_large; + } + + bool CompareTwoValues(std::complex v1, std::complex v2) { + return CompareTwoValues(v1.real(), v2.real()) || + CompareTwoValues(v1.imag(), v2.imag()); + } + + bool CompareTwoValues(float v1, float v2) { + return CompareTwoValuesHelper(v1, v2); + } + + template bool TypedCheck(bool verbose, const TfLiteTensor& tensor) { size_t tensor_size = tensor.bytes / sizeof(T); @@ -136,18 +166,9 @@ class TfLiteDriver::Expectation { bool good_output = true; for (int i = 0; i < tensor_size; ++i) { - float computed = Value(tensor.data, i); - float reference = Value(data_, i); - float diff = std::abs(computed - reference); - bool error_is_large = false; - // For very small numbers, try absolute error, otherwise go with - // relative. - if (std::abs(reference) < relative_threshold_) { - error_is_large = (diff > absolute_threshold_); - } else { - error_is_large = (diff > relative_threshold_ * std::abs(reference)); - } - if (error_is_large) { + TS computed = Value(tensor.data, i); + TS reference = Value(data_, i); + if (CompareTwoValues(computed, reference)) { good_output = false; if (verbose) { std::cerr << " index " << i << ": got " << computed @@ -158,6 +179,8 @@ class TfLiteDriver::Expectation { return good_output; } + bool TypedCheckString(bool verbose, const TfLiteTensor& tensor); + TfLitePtrUnion data_; size_t num_elements_; double relative_threshold_; @@ -171,9 +194,8 @@ void TfLiteDriver::Expectation::SetData(const string& csv_values) { memcpy(data_.raw, s.data(), s.size()); } -template <> -bool TfLiteDriver::Expectation::TypedCheck(bool verbose, - const TfLiteTensor& tensor) { +bool TfLiteDriver::Expectation::TypedCheckString(bool verbose, + const TfLiteTensor& tensor) { if (tensor.data.raw == nullptr) { if (verbose) { std::cerr << " got empty string" << std::endl; @@ -215,19 +237,22 @@ bool TfLiteDriver::Expectation::Check(bool verbose, const TfLiteTensor& tensor) { switch (tensor.type) { case kTfLiteFloat32: - return TypedCheck(verbose, tensor); + return TypedCheck(verbose, tensor); case kTfLiteInt32: - return TypedCheck(verbose, tensor); + return TypedCheck(verbose, tensor); case kTfLiteInt64: - return TypedCheck(verbose, tensor); + return TypedCheck(verbose, tensor); case kTfLiteUInt8: - return TypedCheck(verbose, tensor); + return TypedCheck(verbose, tensor); case kTfLiteInt8: - return TypedCheck(verbose, tensor); + return TypedCheck(verbose, tensor); case kTfLiteBool: - return TypedCheck(verbose, tensor); + return TypedCheck(verbose, tensor); case kTfLiteString: - return TypedCheck(verbose, tensor); + return TypedCheckString(verbose, tensor); + case kTfLiteComplex64: + return TypedCheck, std::complex>(verbose, + tensor); default: fprintf(stderr, "Unsupported type %d in Check\n", tensor.type); return false; @@ -243,6 +268,10 @@ TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name, resolver_.reset(new ops::builtin::BuiltinRefOpResolver); } else { resolver_.reset(new ops::builtin::BuiltinOpResolver); + ops::builtin::BuiltinOpResolver* buildinop_resolver_ = + reinterpret_cast(resolver_.get()); + buildinop_resolver_->AddCustom("RFFT2D", + tflite::ops::custom::Register_RFFT2D()); } if (delegate_name == "FLEX") { @@ -402,6 +431,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) { case kTfLiteString: expected_output_[id]->SetData(csv_values); break; + case kTfLiteComplex64: + expected_output_[id]->SetData>(csv_values); + break; default: Invalidate(absl::StrCat("Unsupported tensor type ", TfLiteTypeGetName(tensor->type), diff --git a/tensorflow/lite/tools/make/download_dependencies.sh b/tensorflow/lite/tools/make/download_dependencies.sh index 8d3b114c8c3..1b0df57624f 100755 --- a/tensorflow/lite/tools/make/download_dependencies.sh +++ b/tensorflow/lite/tools/make/download_dependencies.sh @@ -36,7 +36,7 @@ ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_ NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" FARMHASH_URL="http://mirror.tensorflow.org/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" FLATBUFFERS_URL="http://mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.11.0.tar.gz" -FFT2D_URL="http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz" +FFT2D_URL="http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 9f354f4d547..ae3ad19281a 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -181,6 +181,11 @@ tensorflow/third_party/llvm/expand_cmake_vars.py tensorflow/third_party/llvm/llvm.autogenerated.BUILD tensorflow/third_party/llvm/llvm.bzl tensorflow/third_party/icu/udata.patch +tensorflow/third_party/fft2d/fft2d.h +tensorflow/third_party/fft2d/BUILD +tensorflow/third_party/fft2d/fft.h +tensorflow/third_party/fft2d/LICENSE +tensorflow/third_party/fft2d/fft2d.BUILD tensorflow/third_party/nccl/archive.BUILD tensorflow/third_party/nccl/LICENSE tensorflow/third_party/nccl/system.BUILD.tpl @@ -188,10 +193,6 @@ tensorflow/third_party/nccl/nccl_configure.bzl tensorflow/third_party/nccl/build_defs.bzl.tpl tensorflow/third_party/nccl/archive.patch tensorflow/third_party/nccl/BUILD -tensorflow/third_party/fft2d/BUILD -tensorflow/third_party/fft2d/fft.h -tensorflow/third_party/fft2d/LICENSE -tensorflow/third_party/fft2d/fft2d.BUILD tensorflow/third_party/boringssl/BUILD tensorflow/third_party/mpi/.gitignore tensorflow/third_party/mpi/BUILD diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 8563d44625b..b96b412baec 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -148,7 +148,7 @@ genrule( "@double_conversion//:LICENSE", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", - "@fft2d//:fft/readme.txt", + "@fft2d//:fft2d/readme2d.txt", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", "@highwayhash//:LICENSE", @@ -219,7 +219,7 @@ genrule( "@double_conversion//:LICENSE", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", - "@fft2d//:fft/readme.txt", + "@fft2d//:fft2d/readme2d.txt", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", "@highwayhash//:LICENSE", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 457ec631919..e0e190343db 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -169,7 +169,7 @@ filegroup( "@eigen_archive//:COPYING.MPL2", "@enum34_archive//:LICENSE", "@farmhash_archive//:COPYING", - "@fft2d//:fft/readme.txt", + "@fft2d//:fft2d/readme2d.txt", "@flatbuffers//:LICENSE.txt", "@gast_archive//:PKG-INFO", "@gemmlowp//:LICENSE", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 70e5bc7931a..b84687f5d1c 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -580,10 +580,10 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "fft2d", build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"), - sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296", + sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9", urls = [ - "http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", - "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", + "http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz", + "http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz", ], ) diff --git a/third_party/fft2d/BUILD b/third_party/fft2d/BUILD index 81354424826..987019121a1 100644 --- a/third_party/fft2d/BUILD +++ b/third_party/fft2d/BUILD @@ -1,5 +1,5 @@ # Headers for 2D Fast Fourier Transform package -# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html +# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html # This is a separate package because the original downloaded archive doesn't # contain any header files. @@ -15,18 +15,27 @@ exports_files(["LICENSE"]) cc_library( name = "fft2d_headers", - srcs = ["fft.h"], + srcs = [ + "fft.h", + "fft2d.h", + ], ) objc_library( name = "fft2d_headersd_ios", - srcs = ["fft.h"], + srcs = [ + "fft.h", + "fft2d.h", + ], ) # Export the source code so that it could be compiled for Andoid native apps. filegroup( name = "fft2d_headers_srcs", - srcs = ["fft.h"], + srcs = [ + "fft.h", + "fft2d.h", + ], ) filegroup( diff --git a/third_party/fft2d/fft.h b/third_party/fft2d/fft.h index 31b4935089d..36d838b7f62 100644 --- a/third_party/fft2d/fft.h +++ b/third_party/fft2d/fft.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Declarations for 1D FFT routines in third_party/fft2d/fft. +// Declarations for 1D FFT routines in third_party/fft2d/fft2d. #ifndef FFT2D_FFT_H__ #define FFT2D_FFT_H__ diff --git a/third_party/fft2d/fft2d.BUILD b/third_party/fft2d/fft2d.BUILD index 74dd3112fce..a6a455d8523 100644 --- a/third_party/fft2d/fft2d.BUILD +++ b/third_party/fft2d/fft2d.BUILD @@ -1,5 +1,5 @@ # 2D Fast Fourier Transform package -# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html +# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html package( default_visibility = ["//visibility:public"], @@ -8,10 +8,11 @@ package( # Unrestricted use; can only distribute original package. licenses(["notice"]) -exports_files(["fft/readme.txt"]) +exports_files(["fft2d/readme2d.txt"]) FFT2D_SRCS = [ - "fft/fftsg.c", + "fft2d/fftsg.c", + "fft2d/fftsg2d.c", ] config_setting( @@ -22,7 +23,7 @@ config_setting( # This is the main 2D FFT library. The 2D FFTs in this library call # 1D FFTs. In addition, fast DCTs are provided for the special case # of 8x8 and 16x16. This code in this library is referred to as -# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html. +# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html. cc_library( name = "fft2d", srcs = FFT2D_SRCS, diff --git a/third_party/fft2d/fft2d.h b/third_party/fft2d/fft2d.h new file mode 100644 index 00000000000..d587b3b441c --- /dev/null +++ b/third_party/fft2d/fft2d.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Declarations for 2D FFT routines in third_party/fft2d/fft2d. + +#ifndef FFT2D_FFT_H__ +#define FFT2D_FFT_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +extern void cdft2d(int, int, int, double **, double *, int *, double *); +extern void rdft2d(int, int, int, double **, double *, int *, double *); +extern void ddct2d(int, int, int, double **, double *, int *, double *); +extern void ddst2d(int, int, int, double **, double *, int *, double *); +extern void ddct8x8s(int isgn, double **a); +extern void ddct16x16s(int isgn, double **a); + +#ifdef __cplusplus +} +#endif + +#endif // FFT2D_FFT_H__