Add ComplexAbs Op to TensorFlow Lite

Also promote Real and Image to builtin op.
Converter support will be added in a follow-up cl.

PiperOrigin-RevId: 356169843
Change-Id: I766ca461a99038f592fe2008b2d4cf86c9496acc
This commit is contained in:
Thai Nguyen 2021-02-07 17:47:41 -08:00 committed by TensorFlower Gardener
parent 221c2d5d53
commit ffe6762608
13 changed files with 167 additions and 27 deletions

View File

@ -160,6 +160,9 @@ typedef enum {
kTfLiteBuiltinBroadcastTo = 130,
kTfLiteBuiltinRfft2d = 131,
kTfLiteBuiltinConv3d = 132,
kTfLiteBuiltinImag = 133,
kTfLiteBuiltinReal = 134,
kTfLiteBuiltinComplexAbs = 135,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -821,6 +821,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_SEGMENT_SUM:
case BuiltinOperator_BROADCAST_TO:
case BuiltinOperator_RFFT2D:
case BuiltinOperator_IMAG:
case BuiltinOperator_REAL:
case BuiltinOperator_COMPLEX_ABS:
return kTfLiteOk;
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
return kTfLiteError;

View File

@ -558,6 +558,7 @@ BUILTIN_KERNEL_SRCS = [
"cast.cc",
"ceil.cc",
"comparisons.cc",
"complex_support.cc",
"concatenation.cc",
"conv.cc",
"conv3d.cc",
@ -731,7 +732,6 @@ cc_test(
cc_library(
name = "custom_ops",
srcs = [
"complex_support.cc",
"multinomial.cc",
"random_standard_normal.cc",
],
@ -2378,7 +2378,6 @@ cc_test(
name = "complex_support_test",
srcs = ["complex_support_test.cc"],
deps = [
":custom_ops",
":test_main",
":test_util",
"//tensorflow/lite:framework",

View File

@ -43,6 +43,7 @@ TfLiteRegistration* Register_BROADCAST_TO();
TfLiteRegistration* Register_CALL_ONCE();
TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_CEIL();
TfLiteRegistration* Register_COMPLEX_ABS();
TfLiteRegistration* Register_CONCATENATION();
TfLiteRegistration* Register_CONV_2D();
TfLiteRegistration* Register_CONV_3D();
@ -72,6 +73,7 @@ TfLiteRegistration* Register_GREATER_EQUAL();
TfLiteRegistration* Register_HARD_SWISH();
TfLiteRegistration* Register_HASHTABLE_LOOKUP();
TfLiteRegistration* Register_IF();
TfLiteRegistration* Register_IMAG();
TfLiteRegistration* Register_L2_NORMALIZATION();
TfLiteRegistration* Register_L2_POOL_2D();
TfLiteRegistration* Register_LEAKY_RELU();
@ -107,6 +109,7 @@ TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_QUANTIZE();
TfLiteRegistration* Register_RANGE();
TfLiteRegistration* Register_RANK();
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_REDUCE_ANY();
TfLiteRegistration* Register_REDUCE_MAX();
TfLiteRegistration* Register_REDUCE_MIN();

View File

@ -20,12 +20,9 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
// TODO(b/165735381): Promote this op to builtin-op when we can add new builtin
// ops.
namespace tflite {
namespace ops {
namespace custom {
namespace builtin {
namespace complex {
static const int kInputTensor = 0;
@ -43,9 +40,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (input->type == kTfLiteComplex64) {
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
} else {
TF_LITE_ENSURE(context, output->type = kTfLiteFloat64);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat64);
}
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
@ -127,6 +124,37 @@ TfLiteStatus EvalImag(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus EvalAbs(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteComplex64: {
ExtractData<float>(
input,
static_cast<float (*)(const std::complex<float>&)>(std::abs<float>),
output);
break;
}
case kTfLiteComplex128: {
ExtractData<double>(input,
static_cast<double (*)(const std::complex<double>&)>(
std::abs<double>),
output);
break;
}
default: {
TF_LITE_KERNEL_LOG(context,
"Unsupported input type, ComplexAbs op only supports "
"complex input, but got: ",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}
} // namespace complex
TfLiteRegistration* Register_REAL() {
@ -141,6 +169,12 @@ TfLiteRegistration* Register_IMAG() {
return &r;
}
} // namespace custom
TfLiteRegistration* Register_COMPLEX_ABS() {
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
complex::Prepare, complex::EvalAbs};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -19,18 +19,11 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_IMAG();
namespace {
template <typename T>
@ -42,7 +35,7 @@ class RealOpModel : public SingleOpModel {
output_ = AddOutput(output);
const std::vector<uint8_t> custom_option;
SetCustomOp("Real", custom_option, Register_REAL);
SetBuiltinOp(BuiltinOperator_REAL, BuiltinOptions_NONE, 0);
BuildInterpreter({GetShape(input_)});
}
@ -103,7 +96,7 @@ class ImagOpModel : public SingleOpModel {
output_ = AddOutput(output);
const std::vector<uint8_t> custom_option;
SetCustomOp("Imag", custom_option, Register_IMAG);
SetBuiltinOp(BuiltinOperator_IMAG, BuiltinOptions_NONE, 0);
BuildInterpreter({GetShape(input_)});
}
@ -155,7 +148,86 @@ TEST(ImagOpTest, SimpleDoubleTest) {
{7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
}
template <typename T>
class ComplexAbsOpModel : public SingleOpModel {
public:
ComplexAbsOpModel(const TensorData& input, const TensorData& output) {
input_ = AddInput(input);
output_ = AddOutput(output);
const std::vector<uint8_t> custom_option;
SetBuiltinOp(BuiltinOperator_COMPLEX_ABS, BuiltinOptions_NONE, 0);
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_;
int output_;
};
TEST(ComplexAbsOpTest, IncompatibleType64Test) {
EXPECT_DEATH_IF_SUPPORTED(
ComplexAbsOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
{TensorType_FLOAT64, {}}),
"output->type != kTfLiteFloat32");
}
TEST(ComplexAbsOpTest, IncompatibleType128Test) {
EXPECT_DEATH_IF_SUPPORTED(
ComplexAbsOpModel<float> m({TensorType_COMPLEX128, {2, 4}},
{TensorType_FLOAT32, {}}),
"output->type != kTfLiteFloat64");
}
TEST(ComplexAbsOpTest, SimpleFloatTest) {
ComplexAbsOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
{TensorType_FLOAT32, {}});
m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
{-6, -1},
{9, 3.5},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), testing::ElementsAre(2, 4));
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{75.32596f, 6.0827627f, 9.656604f, 11.18034f,
3.6055512f, 12.529964f, 0.f, 39.966236f})));
}
TEST(ComplexAbsOpTest, SimpleDoubleTest) {
ComplexAbsOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
{TensorType_FLOAT64, {}});
m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
{-6, -1},
{9, 3.5},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), testing::ElementsAre(2, 4));
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{75.32596f, 6.0827627f, 9.656604f, 11.18034f,
3.6055512f, 12.529964f, 0.f, 39.966236f})));
}
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -25,10 +25,8 @@ TfLiteRegistration* Register_HASHTABLE();
TfLiteRegistration* Register_HASHTABLE_FIND();
TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
TfLiteRegistration* Register_IMAG();
TfLiteRegistration* Register_MULTINOMIAL();
TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL();
TfLiteRegistration* Register_REAL();
} // namespace custom
} // namespace ops

View File

@ -313,6 +313,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
tflite::ops::builtin::Register_CALL_ONCE());
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D());
AddBuiltin(BuiltinOperator_IMAG, Register_IMAG());
AddBuiltin(BuiltinOperator_REAL, Register_REAL());
AddBuiltin(BuiltinOperator_COMPLEX_ABS, Register_COMPLEX_ABS());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -158,6 +158,9 @@ TfLiteRegistration* Register_SELECT_V2();
TfLiteRegistration* Register_SEGMENT_SUM();
TfLiteRegistration* Register_BROADCAST_TO();
TfLiteRegistration* Register_CONV_3D();
TfLiteRegistration* Register_IMAG();
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_COMPLEX_ABS();
namespace {
@ -463,6 +466,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D());
AddBuiltin(BuiltinOperator_IMAG, Register_IMAG());
AddBuiltin(BuiltinOperator_REAL, Register_REAL());
AddBuiltin(BuiltinOperator_COMPLEX_ABS, Register_COMPLEX_ABS());
AddCustom("NumericVerify",
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that

View File

@ -360,6 +360,9 @@ enum BuiltinOperator : int32 {
BROADCAST_TO = 130,
RFFT2D = 131,
CONV_3D = 132,
IMAG=133,
REAL=134,
COMPLEX_ABS=135,
}
// LINT.ThenChange(nnapi_linter/linter.proto)

View File

@ -817,11 +817,14 @@ enum BuiltinOperator {
BuiltinOperator_BROADCAST_TO = 130,
BuiltinOperator_RFFT2D = 131,
BuiltinOperator_CONV_3D = 132,
BuiltinOperator_IMAG = 133,
BuiltinOperator_REAL = 134,
BuiltinOperator_COMPLEX_ABS = 135,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_CONV_3D
BuiltinOperator_MAX = BuiltinOperator_COMPLEX_ABS
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[133] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[136] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -955,13 +958,16 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[133] {
BuiltinOperator_CALL_ONCE,
BuiltinOperator_BROADCAST_TO,
BuiltinOperator_RFFT2D,
BuiltinOperator_CONV_3D
BuiltinOperator_CONV_3D,
BuiltinOperator_IMAG,
BuiltinOperator_REAL,
BuiltinOperator_COMPLEX_ABS
};
return values;
}
inline const char * const *EnumNamesBuiltinOperator() {
static const char * const names[134] = {
static const char * const names[137] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
@ -1095,13 +1101,16 @@ inline const char * const *EnumNamesBuiltinOperator() {
"BROADCAST_TO",
"RFFT2D",
"CONV_3D",
"IMAG",
"REAL",
"COMPLEX_ABS",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_CONV_3D)) return "";
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_COMPLEX_ABS)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}

View File

@ -193,6 +193,9 @@ class OpOptionData {
op_to_option_["RSQRT"] = "";
op_to_option_["ELU"] = "";
op_to_option_["REVERSE_SEQUENCE"] = "";
op_to_option_["REAL"] = "";
op_to_option_["IMAG"] = "";
op_to_option_["COMPLEX_ABS"] = "";
// TODO(aselle): These are undesirable hacks. Consider changing C structs
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";

View File

@ -343,6 +343,9 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_CALL_ONCE, 1}, kPendingReleaseVersion},
{{BuiltinOperator_RFFT2D, 1}, kPendingReleaseVersion},
{{BuiltinOperator_CONV_3D, 1}, kPendingReleaseVersion},
{{BuiltinOperator_IMAG, 1}, kPendingReleaseVersion},
{{BuiltinOperator_REAL, 1}, kPendingReleaseVersion},
{{BuiltinOperator_COMPLEX_ABS, 1}, kPendingReleaseVersion},
});
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};