Add real/imag custom ops. The ops will be migrated to builtin ops soon.

PiperOrigin-RevId: 327742515
Change-Id: I0699f469c98270bf895cb1b8826fcc5a2c6fdd46
This commit is contained in:
Renjie Liu 2020-08-20 19:36:30 -07:00 committed by TensorFlower Gardener
parent c37ef0c195
commit 89cbc0882f
4 changed files with 335 additions and 5 deletions

View File

@ -697,16 +697,16 @@ cc_test(
cc_library(
name = "custom_ops",
srcs = ["rfft2d.cc"],
srcs = [
"complex_support.cc",
"rfft2d.cc",
],
hdrs = ["custom_ops_register.h"],
copts = tflite_copts(),
deps = [
":kernel_util",
":op_macros",
"//tensorflow/lite:context",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:optimized_base",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:types",
"//third_party/fft2d:fft2d_headers",
@ -2288,4 +2288,19 @@ cc_test(
],
)
cc_test(
name = "complex_support_test",
srcs = ["complex_support_test.cc"],
deps = [
":custom_ops",
":test_main",
":test_util",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
)
tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})

View File

@ -0,0 +1,146 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
// TODO(b/165735381): Promote this op to builtin-op when we can add new builtin
// ops.
namespace tflite {
namespace ops {
namespace custom {
namespace complex {
static const int kInputTensor = 0;
static const int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input->type == kTfLiteComplex64 ||
input->type == kTfLiteComplex128);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (input->type == kTfLiteComplex64) {
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
} else {
TF_LITE_ENSURE(context, output->type = kTfLiteFloat64);
}
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
return context->ResizeTensor(context, output, output_shape);
}
template <typename T, typename ExtractF>
void ExtractData(const TfLiteTensor* input, ExtractF extract_func,
TfLiteTensor* output) {
const std::complex<T>* input_data = GetTensorData<std::complex<T>>(input);
T* output_data = GetTensorData<T>(output);
const int input_size = NumElements(input);
for (int i = 0; i < input_size; ++i) {
*output_data++ = extract_func(*input_data++);
}
}
TfLiteStatus EvalReal(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteComplex64: {
ExtractData<float>(
input,
static_cast<float (*)(const std::complex<float>&)>(std::real<float>),
output);
break;
}
case kTfLiteComplex128: {
ExtractData<double>(input,
static_cast<double (*)(const std::complex<double>&)>(
std::real<double>),
output);
break;
}
default: {
TF_LITE_KERNEL_LOG(context,
"Unsupported input type, Real op only supports "
"complex input, but got: ",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}
TfLiteStatus EvalImag(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteComplex64: {
ExtractData<float>(
input,
static_cast<float (*)(const std::complex<float>&)>(std::imag<float>),
output);
break;
}
case kTfLiteComplex128: {
ExtractData<double>(input,
static_cast<double (*)(const std::complex<double>&)>(
std::imag<double>),
output);
break;
}
default: {
TF_LITE_KERNEL_LOG(context,
"Unsupported input type, Imag op only supports "
"complex input, but got: ",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}
} // namespace complex
TfLiteRegistration* Register_REAL() {
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
complex::Prepare, complex::EvalReal};
return &r;
}
TfLiteRegistration* Register_IMAG() {
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
complex::Prepare, complex::EvalImag};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,167 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_IMAG();
namespace {
template <typename T>
class RealOpModel : public SingleOpModel {
public:
RealOpModel(const TensorData& input, const TensorData& output) {
input_ = AddInput(input);
output_ = AddOutput(output);
const std::vector<uint8_t> custom_option;
SetCustomOp("Real", custom_option, Register_REAL);
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
private:
int input_;
int output_;
};
TEST(RealOpTest, SimpleFloatTest) {
RealOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
{TensorType_FLOAT32, {}});
m.PopulateTensor<std::complex<float>>(m.input(), {{75, 0},
{-6, -1},
{9, 0},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{75, -6, 9, -10, -3, -6, 0, 22.1f})));
}
TEST(RealOpTest, SimpleDoubleTest) {
RealOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
{TensorType_FLOAT64, {}});
m.PopulateTensor<std::complex<double>>(m.input(), {{75, 0},
{-6, -1},
{9, 0},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{75, -6, 9, -10, -3, -6, 0, 22.1f})));
}
template <typename T>
class ImagOpModel : public SingleOpModel {
public:
ImagOpModel(const TensorData& input, const TensorData& output) {
input_ = AddInput(input);
output_ = AddOutput(output);
const std::vector<uint8_t> custom_option;
SetCustomOp("Imag", custom_option, Register_IMAG);
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
private:
int input_;
int output_;
};
TEST(ImagOpTest, SimpleFloatTest) {
ImagOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
{TensorType_FLOAT32, {}});
m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
{-6, -1},
{9, 3.5},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
}
TEST(ImagOpTest, SimpleDoubleTest) {
ImagOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
{TensorType_FLOAT64, {}});
m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
{-6, -1},
{9, 3.5},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
}
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -26,6 +26,8 @@ TfLiteRegistration* Register_HASHTABLE();
TfLiteRegistration* Register_HASHTABLE_FIND();
TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_IMAG();
}
} // namespace ops
} // namespace tflite