Add ComplexAbs Op to TensorFlow Lite
Also promote Real and Image to builtin op. Converter support will be added in a follow-up cl. PiperOrigin-RevId: 356169843 Change-Id: I766ca461a99038f592fe2008b2d4cf86c9496acc
This commit is contained in:
parent
221c2d5d53
commit
ffe6762608
tensorflow/lite
@ -160,6 +160,9 @@ typedef enum {
|
||||
kTfLiteBuiltinBroadcastTo = 130,
|
||||
kTfLiteBuiltinRfft2d = 131,
|
||||
kTfLiteBuiltinConv3d = 132,
|
||||
kTfLiteBuiltinImag = 133,
|
||||
kTfLiteBuiltinReal = 134,
|
||||
kTfLiteBuiltinComplexAbs = 135,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -821,6 +821,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_SEGMENT_SUM:
|
||||
case BuiltinOperator_BROADCAST_TO:
|
||||
case BuiltinOperator_RFFT2D:
|
||||
case BuiltinOperator_IMAG:
|
||||
case BuiltinOperator_REAL:
|
||||
case BuiltinOperator_COMPLEX_ABS:
|
||||
return kTfLiteOk;
|
||||
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
|
||||
return kTfLiteError;
|
||||
|
@ -558,6 +558,7 @@ BUILTIN_KERNEL_SRCS = [
|
||||
"cast.cc",
|
||||
"ceil.cc",
|
||||
"comparisons.cc",
|
||||
"complex_support.cc",
|
||||
"concatenation.cc",
|
||||
"conv.cc",
|
||||
"conv3d.cc",
|
||||
@ -731,7 +732,6 @@ cc_test(
|
||||
cc_library(
|
||||
name = "custom_ops",
|
||||
srcs = [
|
||||
"complex_support.cc",
|
||||
"multinomial.cc",
|
||||
"random_standard_normal.cc",
|
||||
],
|
||||
@ -2378,7 +2378,6 @@ cc_test(
|
||||
name = "complex_support_test",
|
||||
srcs = ["complex_support_test.cc"],
|
||||
deps = [
|
||||
":custom_ops",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
|
@ -43,6 +43,7 @@ TfLiteRegistration* Register_BROADCAST_TO();
|
||||
TfLiteRegistration* Register_CALL_ONCE();
|
||||
TfLiteRegistration* Register_CAST();
|
||||
TfLiteRegistration* Register_CEIL();
|
||||
TfLiteRegistration* Register_COMPLEX_ABS();
|
||||
TfLiteRegistration* Register_CONCATENATION();
|
||||
TfLiteRegistration* Register_CONV_2D();
|
||||
TfLiteRegistration* Register_CONV_3D();
|
||||
@ -72,6 +73,7 @@ TfLiteRegistration* Register_GREATER_EQUAL();
|
||||
TfLiteRegistration* Register_HARD_SWISH();
|
||||
TfLiteRegistration* Register_HASHTABLE_LOOKUP();
|
||||
TfLiteRegistration* Register_IF();
|
||||
TfLiteRegistration* Register_IMAG();
|
||||
TfLiteRegistration* Register_L2_NORMALIZATION();
|
||||
TfLiteRegistration* Register_L2_POOL_2D();
|
||||
TfLiteRegistration* Register_LEAKY_RELU();
|
||||
@ -107,6 +109,7 @@ TfLiteRegistration* Register_PRELU();
|
||||
TfLiteRegistration* Register_QUANTIZE();
|
||||
TfLiteRegistration* Register_RANGE();
|
||||
TfLiteRegistration* Register_RANK();
|
||||
TfLiteRegistration* Register_REAL();
|
||||
TfLiteRegistration* Register_REDUCE_ANY();
|
||||
TfLiteRegistration* Register_REDUCE_MAX();
|
||||
TfLiteRegistration* Register_REDUCE_MIN();
|
||||
|
@ -20,12 +20,9 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
// TODO(b/165735381): Promote this op to builtin-op when we can add new builtin
|
||||
// ops.
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace builtin {
|
||||
namespace complex {
|
||||
|
||||
static const int kInputTensor = 0;
|
||||
@ -43,9 +40,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
if (input->type == kTfLiteComplex64) {
|
||||
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||
} else {
|
||||
TF_LITE_ENSURE(context, output->type = kTfLiteFloat64);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat64);
|
||||
}
|
||||
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
|
||||
@ -127,6 +124,37 @@ TfLiteStatus EvalImag(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalAbs(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteComplex64: {
|
||||
ExtractData<float>(
|
||||
input,
|
||||
static_cast<float (*)(const std::complex<float>&)>(std::abs<float>),
|
||||
output);
|
||||
break;
|
||||
}
|
||||
case kTfLiteComplex128: {
|
||||
ExtractData<double>(input,
|
||||
static_cast<double (*)(const std::complex<double>&)>(
|
||||
std::abs<double>),
|
||||
output);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Unsupported input type, ComplexAbs op only supports "
|
||||
"complex input, but got: ",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace complex
|
||||
|
||||
TfLiteRegistration* Register_REAL() {
|
||||
@ -141,6 +169,12 @@ TfLiteRegistration* Register_IMAG() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
TfLiteRegistration* Register_COMPLEX_ABS() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
|
||||
complex::Prepare, complex::EvalAbs};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -19,18 +19,11 @@ limitations under the License.
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_REAL();
|
||||
TfLiteRegistration* Register_IMAG();
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
@ -42,7 +35,7 @@ class RealOpModel : public SingleOpModel {
|
||||
output_ = AddOutput(output);
|
||||
|
||||
const std::vector<uint8_t> custom_option;
|
||||
SetCustomOp("Real", custom_option, Register_REAL);
|
||||
SetBuiltinOp(BuiltinOperator_REAL, BuiltinOptions_NONE, 0);
|
||||
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
@ -103,7 +96,7 @@ class ImagOpModel : public SingleOpModel {
|
||||
output_ = AddOutput(output);
|
||||
|
||||
const std::vector<uint8_t> custom_option;
|
||||
SetCustomOp("Imag", custom_option, Register_IMAG);
|
||||
SetBuiltinOp(BuiltinOperator_IMAG, BuiltinOptions_NONE, 0);
|
||||
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
@ -155,7 +148,86 @@ TEST(ImagOpTest, SimpleDoubleTest) {
|
||||
{7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class ComplexAbsOpModel : public SingleOpModel {
|
||||
public:
|
||||
ComplexAbsOpModel(const TensorData& input, const TensorData& output) {
|
||||
input_ = AddInput(input);
|
||||
|
||||
output_ = AddOutput(output);
|
||||
|
||||
const std::vector<uint8_t> custom_option;
|
||||
SetBuiltinOp(BuiltinOperator_COMPLEX_ABS, BuiltinOptions_NONE, 0);
|
||||
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
int input() { return input_; }
|
||||
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(ComplexAbsOpTest, IncompatibleType64Test) {
|
||||
EXPECT_DEATH_IF_SUPPORTED(
|
||||
ComplexAbsOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
|
||||
{TensorType_FLOAT64, {}}),
|
||||
"output->type != kTfLiteFloat32");
|
||||
}
|
||||
|
||||
TEST(ComplexAbsOpTest, IncompatibleType128Test) {
|
||||
EXPECT_DEATH_IF_SUPPORTED(
|
||||
ComplexAbsOpModel<float> m({TensorType_COMPLEX128, {2, 4}},
|
||||
{TensorType_FLOAT32, {}}),
|
||||
"output->type != kTfLiteFloat64");
|
||||
}
|
||||
|
||||
TEST(ComplexAbsOpTest, SimpleFloatTest) {
|
||||
ComplexAbsOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
|
||||
{TensorType_FLOAT32, {}});
|
||||
|
||||
m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
|
||||
{-6, -1},
|
||||
{9, 3.5},
|
||||
{-10, 5},
|
||||
{-3, 2},
|
||||
{-6, 11},
|
||||
{0, 0},
|
||||
{22.1, 33.3}});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), testing::ElementsAre(2, 4));
|
||||
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
|
||||
{75.32596f, 6.0827627f, 9.656604f, 11.18034f,
|
||||
3.6055512f, 12.529964f, 0.f, 39.966236f})));
|
||||
}
|
||||
|
||||
TEST(ComplexAbsOpTest, SimpleDoubleTest) {
|
||||
ComplexAbsOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
|
||||
{TensorType_FLOAT64, {}});
|
||||
|
||||
m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
|
||||
{-6, -1},
|
||||
{9, 3.5},
|
||||
{-10, 5},
|
||||
{-3, 2},
|
||||
{-6, 11},
|
||||
{0, 0},
|
||||
{22.1, 33.3}});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), testing::ElementsAre(2, 4));
|
||||
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
|
||||
{75.32596f, 6.0827627f, 9.656604f, 11.18034f,
|
||||
3.6055512f, 12.529964f, 0.f, 39.966236f})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -25,10 +25,8 @@ TfLiteRegistration* Register_HASHTABLE();
|
||||
TfLiteRegistration* Register_HASHTABLE_FIND();
|
||||
TfLiteRegistration* Register_HASHTABLE_IMPORT();
|
||||
TfLiteRegistration* Register_HASHTABLE_SIZE();
|
||||
TfLiteRegistration* Register_IMAG();
|
||||
TfLiteRegistration* Register_MULTINOMIAL();
|
||||
TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL();
|
||||
TfLiteRegistration* Register_REAL();
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
|
@ -313,6 +313,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
tflite::ops::builtin::Register_CALL_ONCE());
|
||||
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());
|
||||
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D());
|
||||
AddBuiltin(BuiltinOperator_IMAG, Register_IMAG());
|
||||
AddBuiltin(BuiltinOperator_REAL, Register_REAL());
|
||||
AddBuiltin(BuiltinOperator_COMPLEX_ABS, Register_COMPLEX_ABS());
|
||||
|
||||
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
@ -158,6 +158,9 @@ TfLiteRegistration* Register_SELECT_V2();
|
||||
TfLiteRegistration* Register_SEGMENT_SUM();
|
||||
TfLiteRegistration* Register_BROADCAST_TO();
|
||||
TfLiteRegistration* Register_CONV_3D();
|
||||
TfLiteRegistration* Register_IMAG();
|
||||
TfLiteRegistration* Register_REAL();
|
||||
TfLiteRegistration* Register_COMPLEX_ABS();
|
||||
|
||||
namespace {
|
||||
|
||||
@ -463,6 +466,9 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D());
|
||||
AddBuiltin(BuiltinOperator_IMAG, Register_IMAG());
|
||||
AddBuiltin(BuiltinOperator_REAL, Register_REAL());
|
||||
AddBuiltin(BuiltinOperator_COMPLEX_ABS, Register_COMPLEX_ABS());
|
||||
AddCustom("NumericVerify",
|
||||
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
|
@ -360,6 +360,9 @@ enum BuiltinOperator : int32 {
|
||||
BROADCAST_TO = 130,
|
||||
RFFT2D = 131,
|
||||
CONV_3D = 132,
|
||||
IMAG=133,
|
||||
REAL=134,
|
||||
COMPLEX_ABS=135,
|
||||
}
|
||||
// LINT.ThenChange(nnapi_linter/linter.proto)
|
||||
|
||||
|
@ -817,11 +817,14 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_BROADCAST_TO = 130,
|
||||
BuiltinOperator_RFFT2D = 131,
|
||||
BuiltinOperator_CONV_3D = 132,
|
||||
BuiltinOperator_IMAG = 133,
|
||||
BuiltinOperator_REAL = 134,
|
||||
BuiltinOperator_COMPLEX_ABS = 135,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_CONV_3D
|
||||
BuiltinOperator_MAX = BuiltinOperator_COMPLEX_ABS
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[133] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[136] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -955,13 +958,16 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[133] {
|
||||
BuiltinOperator_CALL_ONCE,
|
||||
BuiltinOperator_BROADCAST_TO,
|
||||
BuiltinOperator_RFFT2D,
|
||||
BuiltinOperator_CONV_3D
|
||||
BuiltinOperator_CONV_3D,
|
||||
BuiltinOperator_IMAG,
|
||||
BuiltinOperator_REAL,
|
||||
BuiltinOperator_COMPLEX_ABS
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesBuiltinOperator() {
|
||||
static const char * const names[134] = {
|
||||
static const char * const names[137] = {
|
||||
"ADD",
|
||||
"AVERAGE_POOL_2D",
|
||||
"CONCATENATION",
|
||||
@ -1095,13 +1101,16 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"BROADCAST_TO",
|
||||
"RFFT2D",
|
||||
"CONV_3D",
|
||||
"IMAG",
|
||||
"REAL",
|
||||
"COMPLEX_ABS",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_CONV_3D)) return "";
|
||||
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_COMPLEX_ABS)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOperator()[index];
|
||||
}
|
||||
|
@ -193,6 +193,9 @@ class OpOptionData {
|
||||
op_to_option_["RSQRT"] = "";
|
||||
op_to_option_["ELU"] = "";
|
||||
op_to_option_["REVERSE_SEQUENCE"] = "";
|
||||
op_to_option_["REAL"] = "";
|
||||
op_to_option_["IMAG"] = "";
|
||||
op_to_option_["COMPLEX_ABS"] = "";
|
||||
|
||||
// TODO(aselle): These are undesirable hacks. Consider changing C structs
|
||||
option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
|
||||
|
@ -343,6 +343,9 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_CALL_ONCE, 1}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_RFFT2D, 1}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_CONV_3D, 1}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_IMAG, 1}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_REAL, 1}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_COMPLEX_ABS, 1}, kPendingReleaseVersion},
|
||||
});
|
||||
|
||||
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};
|
||||
|
Loading…
Reference in New Issue
Block a user