Add complex<double> tensor support in TFLite

Even though we do not support complex<double> op kernels on mobile, it is
inevitable to support complex<double> tensors in order to enable TF
complex<double> ops via flex delegate.

This CL enables the complex<double> tensor type in MLIR converter only.

PiperOrigin-RevId: 321072365
Change-Id: I5ecd631339b3d5e00b3d999b9f2c6102b554cea5
This commit is contained in:
Jaesung Chung 2020-07-13 18:08:07 -07:00 committed by TensorFlower Gardener
parent afa9b26c2f
commit 3749694080
33 changed files with 217 additions and 6 deletions

View File

@ -149,6 +149,9 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
if (ftype && ftype.isF32()) {
return tflite::TensorType_COMPLEX64;
}
if (ftype && ftype.isF64()) {
return tflite::TensorType_COMPLEX128;
}
return Status(error::INVALID_ARGUMENT, "Unsupported type");
}
case mlir::StandardTypes::Integer: {

View File

@ -123,6 +123,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_BOOL;
case toco::IODataType::COMPLEX64:
return DT_COMPLEX64;
case toco::IODataType::COMPLEX128:
return DT_COMPLEX128;
default:
return DT_INVALID;
}

View File

@ -15,6 +15,13 @@ func @complex64() -> tensor<4xcomplex<f32>> {
return %0 : tensor<4xcomplex<f32>>
}
func @complex128() -> tensor<4xcomplex<f64>> {
// CHECK-LABEL: @complex128
// CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>>
%0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>> } : () -> tensor<4xcomplex<f64>>
return %0 : tensor<4xcomplex<f64>>
}
// TODO(b/138847107) this should work but doesn't
// func @f16() -> tensor<4xf16> {
// %0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf16> } : () -> tensor<4xf16>

View File

@ -0,0 +1,66 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s
func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> {
^bb0(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<4xcomplex<f64>>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "FlexAdd"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: COMPLEX128,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: COMPLEX128,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: COMPLEX128,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "add",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 18, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT:}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> loc("add")
return %0 : tensor<4xcomplex<f64>>
}

View File

@ -53,6 +53,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
return builder.getIntegerType(16);
case tflite::TensorType_COMPLEX64:
return mlir::ComplexType::get(builder.getF32Type());
case tflite::TensorType_COMPLEX128:
return mlir::ComplexType::get(builder.getF64Type());
case tflite::TensorType_INT8:
return builder.getIntegerType(8);
}
@ -64,6 +66,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
return tensorflow::DT_BOOL;
case tflite::TensorType_COMPLEX64:
return tensorflow::DT_COMPLEX64;
case tflite::TensorType_COMPLEX128:
return tensorflow::DT_COMPLEX128;
case tflite::TensorType_FLOAT16:
return tensorflow::DT_HALF;
case tflite::TensorType_FLOAT32:

View File

@ -207,6 +207,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
return "BOOL";
case kTfLiteComplex64:
return "COMPLEX64";
case kTfLiteComplex128:
return "COMPLEX128";
case kTfLiteString:
return "STRING";
case kTfLiteFloat16:

View File

@ -238,6 +238,11 @@ typedef struct TfLiteComplex64 {
float re, im; // real and imaginary parts, respectively.
} TfLiteComplex64;
// Double-precision complex data type compatible with the C99 definition.
typedef struct TfLiteComplex128 {
double re, im; // real and imaginary parts, respectively.
} TfLiteComplex128;
// Half precision data type compatible with the C99 definition.
typedef struct TfLiteFloat16 {
uint16_t data;
@ -257,6 +262,7 @@ typedef enum {
kTfLiteInt8 = 9,
kTfLiteFloat16 = 10,
kTfLiteFloat64 = 11,
kTfLiteComplex128 = 12,
} TfLiteType;
// Return the name of a given type, for error reporting purposes.
@ -313,12 +319,14 @@ typedef union TfLitePtrUnion {
int64_t* i64;
float* f;
TfLiteFloat16* f16;
double* f64;
char* raw;
const char* raw_const;
uint8_t* uint8;
bool* b;
int16_t* i16;
TfLiteComplex64* c64;
TfLiteComplex128* c128;
int8_t* int8;
/* Only use this member. */
void* data;

View File

@ -78,6 +78,7 @@ TEST(Types, TestTypeNames) {
return std::string(TfLiteTypeGetName(t));
};
EXPECT_EQ(type_name(kTfLiteNoType), "NOTYPE");
EXPECT_EQ(type_name(kTfLiteFloat64), "FLOAT64");
EXPECT_EQ(type_name(kTfLiteFloat32), "FLOAT32");
EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16");
EXPECT_EQ(type_name(kTfLiteInt16), "INT16");
@ -87,6 +88,7 @@ TEST(Types, TestTypeNames) {
EXPECT_EQ(type_name(kTfLiteInt64), "INT64");
EXPECT_EQ(type_name(kTfLiteBool), "BOOL");
EXPECT_EQ(type_name(kTfLiteComplex64), "COMPLEX64");
EXPECT_EQ(type_name(kTfLiteComplex128), "COMPLEX128");
EXPECT_EQ(type_name(kTfLiteString), "STRING");
}

View File

@ -863,6 +863,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
case TensorType_COMPLEX64:
*type = kTfLiteComplex64;
return kTfLiteOk;
case TensorType_COMPLEX128:
*type = kTfLiteComplex128;
return kTfLiteOk;
default:
*type = kTfLiteNoType;
TF_LITE_REPORT_ERROR(error_reporter,

View File

@ -76,6 +76,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
return TF_INT64;
case kTfLiteComplex64:
return TF_COMPLEX64;
case kTfLiteComplex128:
return TF_COMPLEX128;
case kTfLiteString:
return TF_STRING;
case kTfLiteBool:
@ -89,6 +91,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
return kTfLiteFloat32;
case TF_HALF:
return kTfLiteFloat16;
case TF_DOUBLE:
return kTfLiteFloat64;
case TF_INT16:
return kTfLiteInt16;
case TF_INT32:
@ -101,6 +105,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
return kTfLiteInt64;
case TF_COMPLEX64:
return kTfLiteComplex64;
case TF_COMPLEX128:
return kTfLiteComplex128;
case TF_STRING:
return kTfLiteString;
case TF_BOOL:

View File

@ -109,22 +109,28 @@ TEST(UtilTest, CopyShapeAndType) {
TEST(UtilTest, TypeConversionsFromTFLite) {
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType));
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32));
EXPECT_EQ(TF_HALF, GetTensorFlowDataType(kTfLiteFloat16));
EXPECT_EQ(TF_DOUBLE, GetTensorFlowDataType(kTfLiteFloat64));
EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16));
EXPECT_EQ(TF_INT32, GetTensorFlowDataType(kTfLiteInt32));
EXPECT_EQ(TF_UINT8, GetTensorFlowDataType(kTfLiteUInt8));
EXPECT_EQ(TF_INT64, GetTensorFlowDataType(kTfLiteInt64));
EXPECT_EQ(TF_COMPLEX64, GetTensorFlowDataType(kTfLiteComplex64));
EXPECT_EQ(TF_COMPLEX128, GetTensorFlowDataType(kTfLiteComplex128));
EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString));
EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
}
TEST(UtilTest, TypeConversionsFromTensorFlow) {
EXPECT_EQ(kTfLiteFloat16, GetTensorFlowLiteType(TF_HALF));
EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT));
EXPECT_EQ(kTfLiteFloat64, GetTensorFlowLiteType(TF_DOUBLE));
EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16));
EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32));
EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8));
EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64));
EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64));
EXPECT_EQ(kTfLiteComplex128, GetTensorFlowLiteType(TF_COMPLEX128));
EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));
EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL));
EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE));

View File

@ -405,7 +405,9 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
case kTfLiteNoType:
case kTfLiteString:
case kTfLiteComplex64:
// kTfLiteString and kTfLiteComplex64 are not supported in TensorFlow Lite Objc API.
case kTfLiteComplex128:
// kTfLiteString, kTfLiteComplex64 and kTfLiteComplex128 are not supported in TensorFlow Lite
// Objc API.
return TFLTensorDataTypeNoType;
}
}

View File

@ -82,6 +82,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_INT16;
case kTfLiteComplex64:
return TensorType_COMPLEX64;
case kTfLiteComplex128:
return TensorType_COMPLEX128;
}
// TODO(aselle): consider an error
}

View File

@ -72,6 +72,9 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {
case kTfLiteComplex64:
*size = sizeof(float) * 2;
break;
case kTfLiteComplex128:
*size = sizeof(double) * 2;
break;
default:
return kTfLiteError;
}

View File

@ -141,6 +141,10 @@ TF_LITE_MICRO_TEST(TestTypeSizeOf) {
tflite::TfLiteTypeSizeOf(kTfLiteComplex64, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(float) * 2, size);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::TfLiteTypeSizeOf(kTfLiteComplex128, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(double) * 2, size);
TF_LITE_MICRO_EXPECT_NE(
kTfLiteOk, tflite::TfLiteTypeSizeOf(static_cast<TfLiteType>(-1), &size));
}

View File

@ -162,6 +162,9 @@ void MicroInterpreter::CorrectTensorEndianness(TfLiteTensor* tensorCorr) {
case TfLiteType::kTfLiteComplex64:
CorrectTensorDataEndianness(tensorCorr->data.c64, tensorSize);
break;
case TfLiteType::kTfLiteComplex128:
CorrectTensorDataEndianness(tensorCorr->data.c128, tensorSize);
break;
default:
// Do nothing for other data types.
break;

View File

@ -85,6 +85,8 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteInt16";
case kTfLiteComplex64:
return "kTfLiteComplex64";
case kTfLiteComplex128:
return "kTfLiteComplex128";
case kTfLiteFloat16:
return "kTfLiteFloat16";
case kTfLiteFloat64:

View File

@ -57,6 +57,8 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteInt16";
case kTfLiteComplex64:
return "kTfLiteComplex64";
case kTfLiteComplex128:
return "kTfLiteComplex128";
case kTfLiteFloat16:
return "kTfLiteFloat16";
case kTfLiteFloat64:

View File

@ -56,6 +56,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
return NPY_BOOL;
case kTfLiteComplex64:
return NPY_COMPLEX64;
case kTfLiteComplex128:
return NPY_COMPLEX128;
case kTfLiteNoType:
return NPY_NOTYPE;
// Avoid default so compiler errors created when new types are made.

View File

@ -86,6 +86,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_INT16;
case kTfLiteComplex64:
return TensorType_COMPLEX64;
case kTfLiteComplex128:
return TensorType_COMPLEX128;
}
// No default to get compiler error when new type is introduced.
}

View File

@ -51,6 +51,7 @@ _MAP_TF_TO_TFLITE_TYPES = {
dtypes.int8: _types_pb2.INT8,
dtypes.int16: _types_pb2.QUANTIZED_INT16,
dtypes.complex64: _types_pb2.COMPLEX64,
dtypes.complex128: _types_pb2.COMPLEX128,
dtypes.bool: _types_pb2.BOOL,
}

View File

@ -41,6 +41,7 @@ enum TensorType : byte {
COMPLEX64 = 8,
INT8 = 9,
FLOAT64 = 10,
COMPLEX128 = 11,
}
// Custom quantization parameters for experimenting with new quantization

View File

@ -379,11 +379,12 @@ enum TensorType {
TensorType_COMPLEX64 = 8,
TensorType_INT8 = 9,
TensorType_FLOAT64 = 10,
TensorType_COMPLEX128 = 11,
TensorType_MIN = TensorType_FLOAT32,
TensorType_MAX = TensorType_FLOAT64
TensorType_MAX = TensorType_COMPLEX128
};
inline const TensorType (&EnumValuesTensorType())[11] {
inline const TensorType (&EnumValuesTensorType())[12] {
static const TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
@ -395,13 +396,14 @@ inline const TensorType (&EnumValuesTensorType())[11] {
TensorType_INT16,
TensorType_COMPLEX64,
TensorType_INT8,
TensorType_FLOAT64
TensorType_FLOAT64,
TensorType_COMPLEX128
};
return values;
}
inline const char * const *EnumNamesTensorType() {
static const char * const names[12] = {
static const char * const names[13] = {
"FLOAT32",
"FLOAT16",
"INT32",
@ -413,13 +415,14 @@ inline const char * const *EnumNamesTensorType() {
"COMPLEX64",
"INT8",
"FLOAT64",
"COMPLEX128",
nullptr
};
return names;
}
inline const char *EnumNameTensorType(TensorType e) {
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_FLOAT64)) return "";
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_COMPLEX128)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesTensorType()[index];
}

View File

@ -132,6 +132,26 @@ inline std::vector<std::complex<float>> Split(const string& s,
return fields;
}
template <>
inline std::vector<std::complex<double>> Split(const string& s,
const string& delimiter) {
std::vector<std::complex<double>> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
std::string sc = s.substr(p.first, p.second - p.first);
std::string::size_type sz_real, sz_img;
double real = std::stod(sc, &sz_real);
double img = std::stod(sc.substr(sz_real), &sz_img);
if (sz_real + sz_img + 1 != sc.length()) {
std::cerr << "There were errors in parsing string, " << sc
<< ", to complex value." << std::endl;
return fields;
}
std::complex<double> c(real, img);
fields.push_back(c);
}
return fields;
}
} // namespace testing
} // namespace tflite

View File

@ -127,15 +127,37 @@ class TfLiteDriver::DataExpectation {
return error_is_large;
}
bool CompareTwoValuesHelper(double v1, double v2) {
double diff = std::abs(v1 - v2);
bool error_is_large = false;
// For very small numbers, try absolute error, otherwise go with
// relative.
if (std::abs(v2) < relative_threshold_) {
error_is_large = (diff > absolute_threshold_);
} else {
error_is_large = (diff > relative_threshold_ * std::abs(v2));
}
return error_is_large;
}
bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
return CompareTwoValues(v1.real(), v2.real()) ||
CompareTwoValues(v1.imag(), v2.imag());
}
bool CompareTwoValues(std::complex<double> v1, std::complex<double> v2) {
return CompareTwoValues(v1.real(), v2.real()) ||
CompareTwoValues(v1.imag(), v2.imag());
}
bool CompareTwoValues(float v1, float v2) {
return CompareTwoValuesHelper(v1, v2);
}
bool CompareTwoValues(double v1, double v2) {
return CompareTwoValuesHelper(v1, v2);
}
template <typename T, typename TS>
bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
size_t tensor_size = tensor.bytes / sizeof(T);
@ -315,6 +337,9 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
case kTfLiteComplex64:
return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
tensor);
case kTfLiteComplex128:
return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
tensor);
default:
fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
return false;
@ -527,6 +552,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
case kTfLiteComplex64:
expected_output_[id]->SetData<std::complex<float>>(csv_values);
break;
case kTfLiteComplex128:
expected_output_[id]->SetData<std::complex<double>>(csv_values);
break;
default:
Invalidate(absl::StrCat("Unsupported tensor type ",
TfLiteTypeGetName(tensor->type),

View File

@ -236,6 +236,7 @@ enum class ArrayDataType : uint8 {
kComplex64,
kFloat16,
kFloat64,
kComplex128,
};
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type

View File

@ -51,6 +51,7 @@ namespace tflite {
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kString, ::tflite::TensorType_STRING},
{ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
{ArrayDataType::kComplex128, ::tflite::TensorType_COMPLEX128},
{ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16},
{ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}};

View File

@ -1769,6 +1769,8 @@ int ElementSize(ArrayDataType data_type) {
return 8;
case ArrayDataType::kComplex64:
return 8;
case ArrayDataType::kComplex128:
return 16;
case ArrayDataType::kFloat64:
return 8;
@ -2313,6 +2315,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
return ArrayDataType::kString;
case COMPLEX64:
return ArrayDataType::kComplex64;
case COMPLEX128:
return ArrayDataType::kComplex128;
case FLOAT16:
return ArrayDataType::kFloat16;
case FLOAT64:

View File

@ -52,4 +52,7 @@ enum IODataType {
// Double precision float, not quantized.
FLOAT64 = 11;
// Complex128, not quantized
COMPLEX128 = 12;
}

View File

@ -238,6 +238,11 @@ typedef struct TfLiteComplex64 {
float re, im; // real and imaginary parts, respectively.
} TfLiteComplex64;
// Double-precision complex data type compatible with the C99 definition.
typedef struct TfLiteComplex128 {
double re, im; // real and imaginary parts, respectively.
} TfLiteComplex128;
// Half precision data type compatible with the C99 definition.
typedef struct TfLiteFloat16 {
uint16_t data;
@ -257,6 +262,7 @@ typedef enum {
kTfLiteInt8 = 9,
kTfLiteFloat16 = 10,
kTfLiteFloat64 = 11,
kTfLiteComplex128 = 12,
} TfLiteType;
// Return the name of a given type, for error reporting purposes.
@ -313,12 +319,14 @@ typedef union TfLitePtrUnion {
int64_t* i64;
float* f;
TfLiteFloat16* f16;
double* f64;
char* raw;
const char* raw_const;
uint8_t* uint8;
bool* b;
int16_t* i16;
TfLiteComplex64* c64;
TfLiteComplex128* c128;
int8_t* int8;
/* Only use this member. */
void* data;

View File

@ -384,6 +384,9 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
case TensorType_COMPLEX64:
bytes_required *= sizeof(std::complex<float>);
break;
case TensorType_COMPLEX128:
bytes_required *= sizeof(std::complex<double>);
break;
default:
ReportError(error_reporter, "Tensor %s invalid type: %d",
NameOrEmptyString(tensor.name()), tensor.type());

View File

@ -67,6 +67,10 @@ constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
return kTfLiteComplex64;
}
template <>
constexpr TfLiteType typeToTfLiteType<std::complex<double>>() {
return kTfLiteComplex128;
}
template <>
constexpr TfLiteType typeToTfLiteType<std::string>() {
return kTfLiteString;
}

View File

@ -102,6 +102,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
case kTfLiteComplex64:
*bytes = sizeof(std::complex<float>);
break;
case kTfLiteComplex128:
*bytes = sizeof(std::complex<double>);
break;
case kTfLiteInt16:
*bytes = sizeof(int16_t);
break;