Implement Rfft2d as a custom ops.
PiperOrigin-RevId: 252874759
This commit is contained in:
parent
73a5307689
commit
03195f1345
@ -40,7 +40,7 @@ readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/ar
|
||||
# TODO (yongtang): Replace the following with 'http://mirror.tensorflow.org/github.com/google/re2/.*tar\.gz' once
|
||||
# the archive has been propagated in mirror.tensorflow.org.
|
||||
RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
|
||||
FFT2D_URL="$(grep -o 'http.*fft2d\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
|
||||
DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)"
|
||||
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
CUB_URL="$(grep -o 'https.*cub/archive.*zip' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)"
|
||||
|
@ -315,6 +315,7 @@ def generated_test_models():
|
||||
"resolve_constant_strided_slice",
|
||||
"reverse_sequence",
|
||||
"reverse_v2",
|
||||
"rfft2d",
|
||||
"round",
|
||||
"rsqrt",
|
||||
"shape",
|
||||
|
@ -436,6 +436,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_ops",
|
||||
srcs = ["rfft2d.cc"],
|
||||
hdrs = ["custom_ops_register.h"],
|
||||
deps = [
|
||||
":kernel_util",
|
||||
":op_macros",
|
||||
"//tensorflow/lite:context",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//third_party/fft2d:fft2d_headers",
|
||||
"@fft2d",
|
||||
"@gemmlowp//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lstm_eval",
|
||||
srcs = ["lstm_eval.cc"],
|
||||
@ -1619,6 +1636,19 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "rfft2d_test",
|
||||
size = "small",
|
||||
srcs = ["rfft2d_test.cc"],
|
||||
deps = [
|
||||
":custom_ops",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
30
tensorflow/lite/kernels/custom_ops_register.h
Normal file
30
tensorflow/lite/kernels/custom_ops_register.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_CUSTOM_OPS_REGISTER_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_CUSTOM_OPS_REGISTER_H_
|
||||
|
||||
#include "tensorflow/lite/context.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_RFFT2D();
|
||||
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_CUSTOM_OPS_REGISTER_H_
|
426
tensorflow/lite/kernels/rfft2d.cc
Normal file
426
tensorflow/lite/kernels/rfft2d.cc
Normal file
@ -0,0 +1,426 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "third_party/fft2d/fft2d.h"
|
||||
#include "profiling/instrumentation.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace rfft2d {
|
||||
|
||||
using std::complex;
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kFftLengthTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
constexpr int kFftIntegerWorkingAreaTensor = 0;
|
||||
constexpr int kFftDoubleWorkingAreaTensor = 1;
|
||||
constexpr int kTensorNotAllocated = -1;
|
||||
|
||||
struct OpData {
|
||||
// IDs are the arbitrary identifiers used by TF Lite to identify and access
|
||||
// memory buffers.
|
||||
int fft_integer_working_area_id = kTensorNotAllocated;
|
||||
int fft_double_working_area_id = kTensorNotAllocated;
|
||||
};
|
||||
|
||||
bool IsPowerOfTwo(uint32_t v) { return v && !(v & (v - 1)); }
|
||||
|
||||
static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
// The prepare function may be executed multiple times. But temporary tensors
|
||||
// only need to be initiated once.
|
||||
if (data->fft_integer_working_area_id != kTensorNotAllocated &&
|
||||
data->fft_double_working_area_id != kTensorNotAllocated) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
// Create two temporary tensors.
|
||||
node->temporaries = TfLiteIntArrayCreate(2);
|
||||
int first_new_index;
|
||||
TF_LITE_ENSURE_STATUS(context->AddTensors(context, 2, &first_new_index));
|
||||
node->temporaries->data[kFftIntegerWorkingAreaTensor] = first_new_index;
|
||||
data->fft_integer_working_area_id = first_new_index;
|
||||
node->temporaries->data[kFftDoubleWorkingAreaTensor] = first_new_index + 1;
|
||||
data->fft_double_working_area_id = first_new_index + 1;
|
||||
|
||||
// Set up FFT integer working area buffer.
|
||||
TfLiteTensor* fft_integer_working_area =
|
||||
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
||||
fft_integer_working_area->type = kTfLiteInt32;
|
||||
// If fft_length is not a constant tensor, fft_integer_working_area will be
|
||||
// set to dynamic later in Prepare.
|
||||
fft_integer_working_area->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
// Set up FFT double working area buffer.
|
||||
TfLiteTensor* fft_double_working_area =
|
||||
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
||||
// fft_double_working_area is a double tensor. Ideally, double should be
|
||||
// added into tflite data types. However, since fft_double_working_area is a
|
||||
// temporary tensor, and there are no ops having double input/output tensors
|
||||
// in tflite at this point, adding double as a tflite data type may confuse
|
||||
// users that double is supported. As a results, kTfLiteInt64 is used here
|
||||
// for memory allocation. And it will be cast into double in Eval when being
|
||||
// used.
|
||||
fft_double_working_area->type = kTfLiteInt64;
|
||||
// If fft_length is not a constant tensor, fft_double_working_area will be
|
||||
// set to dynamic later in Prepare.
|
||||
fft_double_working_area->allocation_type = kTfLiteArenaRw;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const int num_dims = NumDimensions(input);
|
||||
TF_LITE_ENSURE(context, num_dims >= 2);
|
||||
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
||||
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
|
||||
// The lib, fft2d, can only handle fft_lengths of power of 2.
|
||||
TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0]));
|
||||
TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[1]));
|
||||
|
||||
int fft_height, fft_width;
|
||||
fft_height = fft_length_data[0];
|
||||
fft_width = fft_length_data[1];
|
||||
int fft_working_length = std::max(fft_height, fft_width / 2);
|
||||
int half_fft_working_length = fft_working_length / 2;
|
||||
|
||||
// Resize output tensor.
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
|
||||
output_shape->data[num_dims - 2] = fft_length_data[0];
|
||||
output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1;
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
|
||||
|
||||
// Resize temporary tensors, fft_integer_working_area.
|
||||
TfLiteTensor* fft_integer_working_area =
|
||||
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
||||
TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1);
|
||||
fft_integer_working_area_shape->data[0] =
|
||||
2 + static_cast<int>(sqrt(fft_working_length));
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_integer_working_area,
|
||||
fft_integer_working_area_shape));
|
||||
|
||||
// Resize temporary tensors, fft_double_working_area.
|
||||
TfLiteTensor* fft_double_working_area =
|
||||
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
||||
TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1);
|
||||
fft_double_working_area_shape->data[0] =
|
||||
half_fft_working_length + fft_width / 4;
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_double_working_area,
|
||||
fft_double_working_area_shape));
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* data = new OpData;
|
||||
return data;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
// Check type and shape of the input tensor
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= 2);
|
||||
if (input->type != kTfLiteFloat32) {
|
||||
context->ReportError(context,
|
||||
"Type '%s' for input is not supported by rfft2d.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Check type and shape of the fft_length tensor
|
||||
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
||||
const RuntimeShape fft_length_shape = GetTensorShape(fft_length);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1);
|
||||
TF_LITE_ENSURE_EQ(context, fft_length_shape.Dims(0), 2);
|
||||
if (fft_length->type != kTfLiteInt32) {
|
||||
context->ReportError(context,
|
||||
"Type '%s' for fft_length is not supported by rfft2d.",
|
||||
TfLiteTypeGetName(fft_length->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Setup temporary tensors for fft computation.
|
||||
TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node));
|
||||
|
||||
// Set output type
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
output->type = kTfLiteComplex64;
|
||||
|
||||
// Exit early if fft_length is a non-const tensor. Set output tensor and
|
||||
// temporary tensors to dynamic, so that their tensor sizes can be determined
|
||||
// in Eval.
|
||||
if (!IsConstantTensor(fft_length)) {
|
||||
TfLiteTensor* fft_integer_working_area =
|
||||
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
||||
TfLiteTensor* fft_double_working_area =
|
||||
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
||||
SetTensorToDynamic(fft_integer_working_area);
|
||||
SetTensorToDynamic(fft_double_working_area);
|
||||
SetTensorToDynamic(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// Reorder the result so that it matches the pattern of tf.signal.rfft2d.
|
||||
// In tf.signal.fft2d the frequency matrix of a 4x4 input is
|
||||
// [[F(0, 0), F(0, 1/4), F(0, 2/4)],
|
||||
// [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
|
||||
// [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
|
||||
// [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
|
||||
// While in rdft2d, the frequency matrix of a 4x4 input is
|
||||
// [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0],
|
||||
// [ F(-1/4, 0), F(-1/4, -1/4), 0],
|
||||
// [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
|
||||
// [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]]
|
||||
// Since real fft has the property that
|
||||
// Real(u,v) = Real(-u, -v)
|
||||
// Img(u,v) = - Img(-u, -v)
|
||||
// Result of rdft2d can be reordered and match the pattern of tf.signal.rfft2d.
|
||||
// For example,
|
||||
// Real(-3/4, 0) = Real(1/4, 0) = Real(-1/4, 0)
|
||||
// Img(-3/4, 0) = Img(1/4, 0) = -Img(-1/4, 0)
|
||||
void Rfft2dReorder(int fft_height, int fft_width, double** fft_input_output) {
|
||||
int fft_height_half;
|
||||
gemmlowp::ScopedProfilingLabel label("Rfft2dReorder");
|
||||
double real, img;
|
||||
|
||||
fft_height_half = fft_height >> 1;
|
||||
// Use 4x4 input as an example, reorder the frequency matrix from
|
||||
// [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0],
|
||||
// [ F(-1/4, 0), F(-1/4, -1/4), 0],
|
||||
// [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
|
||||
// [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]]
|
||||
// to
|
||||
// [[F(0, 0), F(0, -1/4), F(0, -2/4)],
|
||||
// [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
|
||||
// [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
|
||||
// [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
|
||||
for (int i = fft_height_half + 1; i < fft_height; ++i) {
|
||||
real = fft_input_output[i][0];
|
||||
img = fft_input_output[i][1];
|
||||
fft_input_output[i][fft_width] = img;
|
||||
fft_input_output[i][fft_width + 1] = real;
|
||||
fft_input_output[fft_height - i][fft_width] = img;
|
||||
fft_input_output[fft_height - i][fft_width + 1] = -real;
|
||||
fft_input_output[i][0] = fft_input_output[fft_height - i][0];
|
||||
fft_input_output[i][1] = -fft_input_output[fft_height - i][1];
|
||||
}
|
||||
fft_input_output[0][fft_width] = fft_input_output[0][1];
|
||||
fft_input_output[0][fft_width + 1] = 0;
|
||||
fft_input_output[0][1] = 0;
|
||||
fft_input_output[fft_height_half][fft_width] =
|
||||
fft_input_output[fft_height_half][1];
|
||||
fft_input_output[fft_height_half][fft_width + 1] = 0;
|
||||
fft_input_output[fft_height_half][1] = 0;
|
||||
|
||||
// Reorder the frequency matrix from
|
||||
// [[F(0, 0), F(0, -1/4), F(0, -2/4)],
|
||||
// [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
|
||||
// [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
|
||||
// [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
|
||||
// to
|
||||
// [[F(0, 0), F(0, 1/4), F(0, 2/4)],
|
||||
// [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
|
||||
// [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
|
||||
// [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
|
||||
for (int i = 0; i < fft_height; ++i) {
|
||||
for (int j = 1; j < fft_width + 2; j += 2) {
|
||||
fft_input_output[i][j] = -fft_input_output[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Rfft2dImpl(int fft_height, int fft_width, double** fft_input_output,
|
||||
int* fft_integer_working_area_data,
|
||||
double* fft_double_working_area_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Rfft2dImpl");
|
||||
|
||||
// Working data areas for the FFT routines.
|
||||
double* fft_dynamic_working_area = nullptr;
|
||||
const int kForwardFft = 1;
|
||||
rdft2d(fft_height, fft_width, kForwardFft, fft_input_output,
|
||||
fft_dynamic_working_area, fft_integer_working_area_data,
|
||||
fft_double_working_area_data);
|
||||
Rfft2dReorder(fft_height, fft_width, fft_input_output);
|
||||
}
|
||||
|
||||
void PrepareInputBuffer(const float* input_data, int input_height,
|
||||
int input_width, int fft_height, int fft_width,
|
||||
double** fft_input_output) {
|
||||
int valid_input_height = std::min(input_height, fft_height);
|
||||
int valid_input_width = std::min(input_width, fft_width);
|
||||
for (int i = 0; i < valid_input_height; ++i) {
|
||||
int in_pos = i * input_width;
|
||||
for (int j = 0; j < valid_input_width; ++j) {
|
||||
fft_input_output[i][j] = input_data[in_pos++];
|
||||
}
|
||||
// Zero-pad the rest of the input buffer
|
||||
for (int j = valid_input_width; j < fft_width + 2; ++j) {
|
||||
fft_input_output[i][j] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Zero-pad input buffer, if fft_height is greater than valid_input_height.
|
||||
for (int i = valid_input_height; i < fft_height; ++i) {
|
||||
for (int j = 0; j < fft_width + 2; ++j) {
|
||||
fft_input_output[i][j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PrepareOutputBuffer(complex<float>* output_data, int fft_height,
|
||||
int fft_width, double** fft_input_output) {
|
||||
int cnt = 0;
|
||||
for (int i = 0; i < fft_height; ++i) {
|
||||
for (int j = 0; j < fft_width / 2 + 1; ++j) {
|
||||
output_data[cnt++] = complex<float>(fft_input_output[i][j * 2],
|
||||
fft_input_output[i][j * 2 + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const float* input_data = GetTensorData<float>(input);
|
||||
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
||||
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
complex<float>* output_data = GetTensorData<complex<float>>(output);
|
||||
|
||||
int fft_height, fft_width;
|
||||
fft_height = fft_length_data[0];
|
||||
fft_width = fft_length_data[1];
|
||||
|
||||
// FFT is processed for every slice on the inner most 2 dimensions.
|
||||
// Count the number of slices in the input tensor.
|
||||
const RuntimeShape input_shape = GetTensorShape(input);
|
||||
const int input_dims_count = input_shape.DimensionsCount();
|
||||
const auto* input_dims_data = input_shape.DimsData();
|
||||
int num_slices = 1;
|
||||
for (int i = 0; i < input_dims_count - 2; ++i) {
|
||||
num_slices *= input_dims_data[i];
|
||||
}
|
||||
|
||||
int input_height = input_dims_data[input_dims_count - 2];
|
||||
int input_width = input_dims_data[input_dims_count - 1];
|
||||
int input_slice_size = input_height * input_width;
|
||||
int output_slice_size = fft_height * (fft_width / 2 + 1);
|
||||
|
||||
// Create input/output buffer for FFT
|
||||
double** fft_input_output = new double*[fft_height];
|
||||
for (int i = 0; i < fft_height; ++i) {
|
||||
fft_input_output[i] = new double[fft_width + 2];
|
||||
}
|
||||
|
||||
// Get buffer for integer working area.
|
||||
TfLiteTensor* fft_integer_working_area =
|
||||
GetTemporary(context, node, kFftIntegerWorkingAreaTensor);
|
||||
int* fft_integer_working_area_data =
|
||||
GetTensorData<int>(fft_integer_working_area);
|
||||
|
||||
// Get buffer for double working area.
|
||||
TfLiteTensor* fft_double_working_area =
|
||||
GetTemporary(context, node, kFftDoubleWorkingAreaTensor);
|
||||
// Get double value out of the memory of fft_double_working_area_data.
|
||||
double* fft_double_working_area_data = reinterpret_cast<double*>(
|
||||
GetTensorData<int64_t>(fft_double_working_area));
|
||||
|
||||
// Process evert slice in the input buffer
|
||||
for (int i = 0; i < num_slices; ++i) {
|
||||
PrepareInputBuffer(input_data, input_height, input_width, fft_height,
|
||||
fft_width, fft_input_output);
|
||||
memset(fft_integer_working_area_data, 0, fft_integer_working_area->bytes);
|
||||
memset(fft_double_working_area_data, 0, fft_double_working_area->bytes);
|
||||
Rfft2dImpl(fft_height, fft_width, fft_input_output,
|
||||
fft_integer_working_area_data, fft_double_working_area_data);
|
||||
PrepareOutputBuffer(output_data, fft_height, fft_width, fft_input_output);
|
||||
input_data += input_slice_size;
|
||||
output_data += output_slice_size;
|
||||
}
|
||||
|
||||
// Delete the input buffer
|
||||
for (int i = 0; i < fft_height; ++i) {
|
||||
delete[] fft_input_output[i];
|
||||
}
|
||||
delete[] fft_input_output;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* fft_length = GetInput(context, node, kFftLengthTensor);
|
||||
const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
if (output->type != kTfLiteComplex64) {
|
||||
context->ReportError(context,
|
||||
"Type '%s' for output is not supported by rfft2d.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Resize the output tensor if the fft_length tensor is not constant.
|
||||
// Otherwise, check if the output shape is correct.
|
||||
if (!IsConstantTensor(fft_length)) {
|
||||
TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
|
||||
} else {
|
||||
int num_dims_output = NumDimensions(output);
|
||||
const RuntimeShape output_shape = GetTensorShape(output);
|
||||
TF_LITE_ENSURE_EQ(context, num_dims_output, NumDimensions(input));
|
||||
TF_LITE_ENSURE(context, num_dims_output >= 2);
|
||||
TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 2),
|
||||
fft_length_data[0]);
|
||||
TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 1),
|
||||
fft_length_data[1] / 2 + 1);
|
||||
}
|
||||
|
||||
return Rfft2dHelper(context, node);
|
||||
}
|
||||
|
||||
} // namespace rfft2d
|
||||
|
||||
TfLiteRegistration* Register_RFFT2D() {
|
||||
static TfLiteRegistration r = {rfft2d::Init, rfft2d::Free, rfft2d::Prepare,
|
||||
rfft2d::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
154
tensorflow/lite/kernels/rfft2d_test.cc
Normal file
154
tensorflow/lite/kernels/rfft2d_test.cc
Normal file
@ -0,0 +1,154 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_RFFT2D();
|
||||
|
||||
namespace {
|
||||
|
||||
using std::complex;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class Rfft2dOpModel : public SingleOpModel {
|
||||
public:
|
||||
Rfft2dOpModel(const TensorData& input, const TensorData& fft_lengths) {
|
||||
input_ = AddInput(input);
|
||||
fft_lengths_ = AddInput(fft_lengths);
|
||||
TensorType output_type = TensorType_COMPLEX64;
|
||||
output_ = AddOutput({output_type, {}});
|
||||
|
||||
const std::vector<uint8_t> custom_option;
|
||||
SetCustomOp("Rfft2d", custom_option, Register_RFFT2D);
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
int input() { return input_; }
|
||||
int fft_lengths() { return fft_lengths_; }
|
||||
|
||||
std::vector<complex<float>> GetOutput() {
|
||||
return ExtractVector<complex<float>>(output_);
|
||||
}
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int fft_lengths_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(Rfft2dOpTest, FftLengthMatchesInputSize) {
|
||||
Rfft2dOpModel model({TensorType_FLOAT32, {4, 4}}, {TensorType_INT32, {2}});
|
||||
// clang-format off
|
||||
model.PopulateTensor<float>(model.input(),
|
||||
{1, 2, 3, 4,
|
||||
3, 8, 6, 3,
|
||||
5, 2, 7, 6,
|
||||
9, 5, 8, 3});
|
||||
// clang-format on
|
||||
model.PopulateTensor<int32_t>(model.fft_lengths(), {4, 4});
|
||||
model.Invoke();
|
||||
|
||||
std::complex<float> expected_result[12] = {
|
||||
{75, 0}, {-6, -1}, {9, 0}, {-10, 5}, {-3, 2}, {-6, 11},
|
||||
{-15, 0}, {-2, 13}, {-5, 0}, {-10, -5}, {3, -6}, {-6, -11}};
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
TEST(Rfft2dOpTest, FftLengthSmallerThanInputSize) {
|
||||
Rfft2dOpModel model({TensorType_FLOAT32, {4, 5}}, {TensorType_INT32, {2}});
|
||||
// clang-format off
|
||||
model.PopulateTensor<float>(model.input(),
|
||||
{1, 2, 3, 4, 0,
|
||||
3, 8, 6, 3, 0,
|
||||
5, 2, 7, 6, 0,
|
||||
9, 5, 8, 3, 0});
|
||||
// clang-format on
|
||||
model.PopulateTensor<int32_t>(model.fft_lengths(), {4, 4});
|
||||
model.Invoke();
|
||||
|
||||
std::complex<float> expected_result[12] = {
|
||||
{75, 0}, {-6, -1}, {9, 0}, {-10, 5}, {-3, 2}, {-6, 11},
|
||||
{-15, 0}, {-2, 13}, {-5, 0}, {-10, -5}, {3, -6}, {-6, -11}};
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
TEST(Rfft2dOpTest, FftLengthGreaterThanInputSize) {
|
||||
Rfft2dOpModel model({TensorType_FLOAT32, {3, 4}}, {TensorType_INT32, {2}});
|
||||
// clang-format off
|
||||
model.PopulateTensor<float>(model.input(),
|
||||
{1, 2, 3, 4,
|
||||
3, 8, 6, 3,
|
||||
5, 2, 7, 6});
|
||||
// clang-format on
|
||||
model.PopulateTensor<int32_t>(model.fft_lengths(), {4, 8});
|
||||
model.Invoke();
|
||||
|
||||
// clang-format off
|
||||
std::complex<float> expected_result[20] = {
|
||||
{50, 0}, {8.29289341, -33.6776695}, {-7, 1}, {9.70710659, -1.67766953},
|
||||
{0, 0},
|
||||
{-10, -20}, {-16.3639603, -1.12132037}, {-5, 1}, {-7.19238806, -2.05025244},
|
||||
{-6, 2},
|
||||
{10, 0}, {-4.7781744, -6.12132025}, {-1, 11}, {10.7781744, 1.87867963},
|
||||
{4, 0},
|
||||
{-10, 20}, {11.1923885, 11.9497471}, {5, -5}, {-3.63603902, -3.12132025},
|
||||
{-6, -2}};
|
||||
// clang-format on
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
TEST(Rfft2dOpTest, InputDimsGreaterThan2) {
|
||||
Rfft2dOpModel model({TensorType_FLOAT32, {2, 2, 4}}, {TensorType_INT32, {2}});
|
||||
// clang-format off
|
||||
model.PopulateTensor<float>(model.input(),
|
||||
{1., 2., 3., 4.,
|
||||
3., 8., 6., 3.,
|
||||
5., 2., 7., 6.,
|
||||
7., 3., 23., 5.});
|
||||
// clang-format on
|
||||
model.PopulateTensor<int32_t>(model.fft_lengths(), {2, 4});
|
||||
model.Invoke();
|
||||
|
||||
// clang-format off
|
||||
std::complex<float> expected_result[12] = {
|
||||
{30., 0.}, {-5, -3.}, { -4., 0.},
|
||||
{-10., 0.}, {1., 7.}, { 0., 0.},
|
||||
{58., 0.}, {-18., 6.}, { 26., 0.},
|
||||
{-18., 0.}, { 14., 2.}, {-18., 0.}};
|
||||
// clang-format on
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -188,6 +188,7 @@ cc_library(
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/delegates/flex:delegate",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/kernels:custom_ops",
|
||||
"//tensorflow/lite/kernels:reference_ops",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -56,6 +56,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import graph_util as tf_graph_util
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import spectral_ops_test_util
|
||||
|
||||
|
||||
RANDOM_SEED = 342
|
||||
@ -5070,6 +5071,39 @@ def make_unfused_gru_tests(options):
|
||||
build_inputs,
|
||||
use_frozen_graph=True)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_rfft2d_tests(options):
|
||||
"""Make a set of tests to do rfft2d."""
|
||||
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32],
|
||||
"input_shape": [[8, 8], [3, 8, 8]],
|
||||
"fft_length": [
|
||||
None, [4, 4], [4, 8], [8, 4], [8, 8], [8, 16], [16, 8], [16, 16]
|
||||
]
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
input_value = tf.placeholder(
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
with spectral_ops_test_util.fft_kernel_label_map():
|
||||
outs = tf.signal.rfft2d(input_value, fft_length=parameters["fft_length"])
|
||||
return [input_value], [outs]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value = create_tensor_data(parameters["input_dtype"],
|
||||
parameters["input_shape"])
|
||||
return [input_value], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value])))
|
||||
|
||||
extra_toco_options = ExtraTocoOptions()
|
||||
extra_toco_options.allow_custom_ops = True
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs,
|
||||
extra_toco_options)
|
||||
|
||||
# Toco binary path provided by the generate rule.
|
||||
bin_path = None
|
||||
|
||||
|
@ -15,10 +15,14 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_TESTING_SPLIT_H_
|
||||
#define TENSORFLOW_LITE_TESTING_SPLIT_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/string.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -99,6 +103,26 @@ inline std::vector<bool> Split(const string& s, const string& delimiter) {
|
||||
return fields;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<std::complex<float>> Split(const string& s,
|
||||
const string& delimiter) {
|
||||
std::vector<std::complex<float>> fields;
|
||||
for (const auto& p : SplitToPos(s, delimiter)) {
|
||||
std::string sc = s.substr(p.first, p.second - p.first);
|
||||
std::string::size_type sz_real, sz_img;
|
||||
float real = std::stof(sc, &sz_real);
|
||||
float img = std::stof(sc.substr(sz_real), &sz_img);
|
||||
if (sz_real + sz_img + 1 != sc.length()) {
|
||||
std::cerr << "There were errors in parsing string, " << sc
|
||||
<< ", to complex value." << std::endl;
|
||||
return fields;
|
||||
}
|
||||
std::complex<float> c(real, img);
|
||||
fields.push_back(c);
|
||||
}
|
||||
return fields;
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -14,9 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/testing/tflite_driver.h"
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "tensorflow/lite/builtin_op_data.h"
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/register_ref.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
@ -58,6 +61,11 @@ bool Value(const TfLitePtrUnion& data, int index) {
|
||||
return data.b[index];
|
||||
}
|
||||
|
||||
template <>
|
||||
std::complex<float> Value(const TfLitePtrUnion& data, int index) {
|
||||
return std::complex<float>(data.c64[index].re, data.c64[index].im);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SetTensorData(const std::vector<T>& values, TfLitePtrUnion* data) {
|
||||
T* input_ptr = reinterpret_cast<T*>(data->raw);
|
||||
@ -123,7 +131,29 @@ class TfLiteDriver::Expectation {
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool CompareTwoValuesHelper(float v1, float v2) {
|
||||
float diff = std::abs(v1 - v2);
|
||||
bool error_is_large = false;
|
||||
// For very small numbers, try absolute error, otherwise go with
|
||||
// relative.
|
||||
if (std::abs(v2) < relative_threshold_) {
|
||||
error_is_large = (diff > absolute_threshold_);
|
||||
} else {
|
||||
error_is_large = (diff > relative_threshold_ * std::abs(v2));
|
||||
}
|
||||
return error_is_large;
|
||||
}
|
||||
|
||||
bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
|
||||
return CompareTwoValues(v1.real(), v2.real()) ||
|
||||
CompareTwoValues(v1.imag(), v2.imag());
|
||||
}
|
||||
|
||||
bool CompareTwoValues(float v1, float v2) {
|
||||
return CompareTwoValuesHelper(v1, v2);
|
||||
}
|
||||
|
||||
template <typename T, typename TS>
|
||||
bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
|
||||
size_t tensor_size = tensor.bytes / sizeof(T);
|
||||
|
||||
@ -136,18 +166,9 @@ class TfLiteDriver::Expectation {
|
||||
|
||||
bool good_output = true;
|
||||
for (int i = 0; i < tensor_size; ++i) {
|
||||
float computed = Value<T>(tensor.data, i);
|
||||
float reference = Value<T>(data_, i);
|
||||
float diff = std::abs(computed - reference);
|
||||
bool error_is_large = false;
|
||||
// For very small numbers, try absolute error, otherwise go with
|
||||
// relative.
|
||||
if (std::abs(reference) < relative_threshold_) {
|
||||
error_is_large = (diff > absolute_threshold_);
|
||||
} else {
|
||||
error_is_large = (diff > relative_threshold_ * std::abs(reference));
|
||||
}
|
||||
if (error_is_large) {
|
||||
TS computed = Value<T>(tensor.data, i);
|
||||
TS reference = Value<T>(data_, i);
|
||||
if (CompareTwoValues(computed, reference)) {
|
||||
good_output = false;
|
||||
if (verbose) {
|
||||
std::cerr << " index " << i << ": got " << computed
|
||||
@ -158,6 +179,8 @@ class TfLiteDriver::Expectation {
|
||||
return good_output;
|
||||
}
|
||||
|
||||
bool TypedCheckString(bool verbose, const TfLiteTensor& tensor);
|
||||
|
||||
TfLitePtrUnion data_;
|
||||
size_t num_elements_;
|
||||
double relative_threshold_;
|
||||
@ -171,9 +194,8 @@ void TfLiteDriver::Expectation::SetData<string>(const string& csv_values) {
|
||||
memcpy(data_.raw, s.data(), s.size());
|
||||
}
|
||||
|
||||
template <>
|
||||
bool TfLiteDriver::Expectation::TypedCheck<string>(bool verbose,
|
||||
const TfLiteTensor& tensor) {
|
||||
bool TfLiteDriver::Expectation::TypedCheckString(bool verbose,
|
||||
const TfLiteTensor& tensor) {
|
||||
if (tensor.data.raw == nullptr) {
|
||||
if (verbose) {
|
||||
std::cerr << " got empty string" << std::endl;
|
||||
@ -215,19 +237,22 @@ bool TfLiteDriver::Expectation::Check(bool verbose,
|
||||
const TfLiteTensor& tensor) {
|
||||
switch (tensor.type) {
|
||||
case kTfLiteFloat32:
|
||||
return TypedCheck<float>(verbose, tensor);
|
||||
return TypedCheck<float, float>(verbose, tensor);
|
||||
case kTfLiteInt32:
|
||||
return TypedCheck<int32_t>(verbose, tensor);
|
||||
return TypedCheck<int32_t, float>(verbose, tensor);
|
||||
case kTfLiteInt64:
|
||||
return TypedCheck<int64_t>(verbose, tensor);
|
||||
return TypedCheck<int64_t, float>(verbose, tensor);
|
||||
case kTfLiteUInt8:
|
||||
return TypedCheck<uint8_t>(verbose, tensor);
|
||||
return TypedCheck<uint8_t, float>(verbose, tensor);
|
||||
case kTfLiteInt8:
|
||||
return TypedCheck<int8_t>(verbose, tensor);
|
||||
return TypedCheck<int8_t, float>(verbose, tensor);
|
||||
case kTfLiteBool:
|
||||
return TypedCheck<bool>(verbose, tensor);
|
||||
return TypedCheck<bool, float>(verbose, tensor);
|
||||
case kTfLiteString:
|
||||
return TypedCheck<string>(verbose, tensor);
|
||||
return TypedCheckString(verbose, tensor);
|
||||
case kTfLiteComplex64:
|
||||
return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
|
||||
tensor);
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
|
||||
return false;
|
||||
@ -243,6 +268,10 @@ TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name,
|
||||
resolver_.reset(new ops::builtin::BuiltinRefOpResolver);
|
||||
} else {
|
||||
resolver_.reset(new ops::builtin::BuiltinOpResolver);
|
||||
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
|
||||
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
|
||||
buildinop_resolver_->AddCustom("RFFT2D",
|
||||
tflite::ops::custom::Register_RFFT2D());
|
||||
}
|
||||
|
||||
if (delegate_name == "FLEX") {
|
||||
@ -402,6 +431,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
|
||||
case kTfLiteString:
|
||||
expected_output_[id]->SetData<string>(csv_values);
|
||||
break;
|
||||
case kTfLiteComplex64:
|
||||
expected_output_[id]->SetData<std::complex<float>>(csv_values);
|
||||
break;
|
||||
default:
|
||||
Invalidate(absl::StrCat("Unsupported tensor type ",
|
||||
TfLiteTypeGetName(tensor->type),
|
||||
|
@ -36,7 +36,7 @@ ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_
|
||||
NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip"
|
||||
FARMHASH_URL="http://mirror.tensorflow.org/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz"
|
||||
FLATBUFFERS_URL="http://mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.11.0.tar.gz"
|
||||
FFT2D_URL="http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz"
|
||||
FFT2D_URL="http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz"
|
||||
|
||||
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
|
||||
# so work around it by patching the source.
|
||||
|
@ -181,6 +181,11 @@ tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/fft2d/fft2d.h
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
@ -188,10 +193,6 @@ tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
|
@ -148,7 +148,7 @@ genrule(
|
||||
"@double_conversion//:LICENSE",
|
||||
"@eigen_archive//:COPYING.MPL2",
|
||||
"@farmhash_archive//:COPYING",
|
||||
"@fft2d//:fft/readme.txt",
|
||||
"@fft2d//:fft2d/readme2d.txt",
|
||||
"@gemmlowp//:LICENSE",
|
||||
"@gif_archive//:COPYING",
|
||||
"@highwayhash//:LICENSE",
|
||||
@ -219,7 +219,7 @@ genrule(
|
||||
"@double_conversion//:LICENSE",
|
||||
"@eigen_archive//:COPYING.MPL2",
|
||||
"@farmhash_archive//:COPYING",
|
||||
"@fft2d//:fft/readme.txt",
|
||||
"@fft2d//:fft2d/readme2d.txt",
|
||||
"@gemmlowp//:LICENSE",
|
||||
"@gif_archive//:COPYING",
|
||||
"@highwayhash//:LICENSE",
|
||||
|
@ -169,7 +169,7 @@ filegroup(
|
||||
"@eigen_archive//:COPYING.MPL2",
|
||||
"@enum34_archive//:LICENSE",
|
||||
"@farmhash_archive//:COPYING",
|
||||
"@fft2d//:fft/readme.txt",
|
||||
"@fft2d//:fft2d/readme2d.txt",
|
||||
"@flatbuffers//:LICENSE.txt",
|
||||
"@gast_archive//:PKG-INFO",
|
||||
"@gemmlowp//:LICENSE",
|
||||
|
@ -580,10 +580,10 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
tf_http_archive(
|
||||
name = "fft2d",
|
||||
build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
|
||||
sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
|
||||
sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9",
|
||||
urls = [
|
||||
"http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
|
||||
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
|
||||
"http://mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
|
||||
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
|
||||
],
|
||||
)
|
||||
|
||||
|
17
third_party/fft2d/BUILD
vendored
17
third_party/fft2d/BUILD
vendored
@ -1,5 +1,5 @@
|
||||
# Headers for 2D Fast Fourier Transform package
|
||||
# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html
|
||||
# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html
|
||||
# This is a separate package because the original downloaded archive doesn't
|
||||
# contain any header files.
|
||||
|
||||
@ -15,18 +15,27 @@ exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "fft2d_headers",
|
||||
srcs = ["fft.h"],
|
||||
srcs = [
|
||||
"fft.h",
|
||||
"fft2d.h",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "fft2d_headersd_ios",
|
||||
srcs = ["fft.h"],
|
||||
srcs = [
|
||||
"fft.h",
|
||||
"fft2d.h",
|
||||
],
|
||||
)
|
||||
|
||||
# Export the source code so that it could be compiled for Andoid native apps.
|
||||
filegroup(
|
||||
name = "fft2d_headers_srcs",
|
||||
srcs = ["fft.h"],
|
||||
srcs = [
|
||||
"fft.h",
|
||||
"fft2d.h",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
2
third_party/fft2d/fft.h
vendored
2
third_party/fft2d/fft.h
vendored
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Declarations for 1D FFT routines in third_party/fft2d/fft.
|
||||
// Declarations for 1D FFT routines in third_party/fft2d/fft2d.
|
||||
|
||||
#ifndef FFT2D_FFT_H__
|
||||
#define FFT2D_FFT_H__
|
||||
|
9
third_party/fft2d/fft2d.BUILD
vendored
9
third_party/fft2d/fft2d.BUILD
vendored
@ -1,5 +1,5 @@
|
||||
# 2D Fast Fourier Transform package
|
||||
# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html
|
||||
# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
@ -8,10 +8,11 @@ package(
|
||||
# Unrestricted use; can only distribute original package.
|
||||
licenses(["notice"])
|
||||
|
||||
exports_files(["fft/readme.txt"])
|
||||
exports_files(["fft2d/readme2d.txt"])
|
||||
|
||||
FFT2D_SRCS = [
|
||||
"fft/fftsg.c",
|
||||
"fft2d/fftsg.c",
|
||||
"fft2d/fftsg2d.c",
|
||||
]
|
||||
|
||||
config_setting(
|
||||
@ -22,7 +23,7 @@ config_setting(
|
||||
# This is the main 2D FFT library. The 2D FFTs in this library call
|
||||
# 1D FFTs. In addition, fast DCTs are provided for the special case
|
||||
# of 8x8 and 16x16. This code in this library is referred to as
|
||||
# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html.
|
||||
# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html.
|
||||
cc_library(
|
||||
name = "fft2d",
|
||||
srcs = FFT2D_SRCS,
|
||||
|
36
third_party/fft2d/fft2d.h
vendored
Normal file
36
third_party/fft2d/fft2d.h
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Declarations for 2D FFT routines in third_party/fft2d/fft2d.
|
||||
|
||||
#ifndef FFT2D_FFT_H__
|
||||
#define FFT2D_FFT_H__
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern void cdft2d(int, int, int, double **, double *, int *, double *);
|
||||
extern void rdft2d(int, int, int, double **, double *, int *, double *);
|
||||
extern void ddct2d(int, int, int, double **, double *, int *, double *);
|
||||
extern void ddst2d(int, int, int, double **, double *, int *, double *);
|
||||
extern void ddct8x8s(int isgn, double **a);
|
||||
extern void ddct16x16s(int isgn, double **a);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // FFT2D_FFT_H__
|
Loading…
Reference in New Issue
Block a user