Add complex<double> tensor support in TFLite
Even though we do not support complex<double> op kernels on mobile, it is inevitable to support complex<double> tensors in order to enable TF complex<double> ops via flex delegate. This CL enables the complex<double> tensor type in MLIR converter only. PiperOrigin-RevId: 321072365 Change-Id: I5ecd631339b3d5e00b3d999b9f2c6102b554cea5
This commit is contained in:
parent
afa9b26c2f
commit
3749694080
@ -149,6 +149,9 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
if (ftype && ftype.isF32()) {
|
||||
return tflite::TensorType_COMPLEX64;
|
||||
}
|
||||
if (ftype && ftype.isF64()) {
|
||||
return tflite::TensorType_COMPLEX128;
|
||||
}
|
||||
return Status(error::INVALID_ARGUMENT, "Unsupported type");
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
|
@ -123,6 +123,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
return DT_BOOL;
|
||||
case toco::IODataType::COMPLEX64:
|
||||
return DT_COMPLEX64;
|
||||
case toco::IODataType::COMPLEX128:
|
||||
return DT_COMPLEX128;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
|
@ -15,6 +15,13 @@ func @complex64() -> tensor<4xcomplex<f32>> {
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
||||
func @complex128() -> tensor<4xcomplex<f64>> {
|
||||
// CHECK-LABEL: @complex128
|
||||
// CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>>
|
||||
%0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>> } : () -> tensor<4xcomplex<f64>>
|
||||
return %0 : tensor<4xcomplex<f64>>
|
||||
}
|
||||
|
||||
// TODO(b/138847107) this should work but doesn't
|
||||
// func @f16() -> tensor<4xf16> {
|
||||
// %0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf16> } : () -> tensor<4xf16>
|
||||
|
@ -0,0 +1,66 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> {
|
||||
^bb0(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<4xcomplex<f64>>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "FlexAdd"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: COMPLEX128,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: COMPLEX128,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: COMPLEX128,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "add",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 18, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> loc("add")
|
||||
return %0 : tensor<4xcomplex<f64>>
|
||||
}
|
@ -53,6 +53,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
|
||||
return builder.getIntegerType(16);
|
||||
case tflite::TensorType_COMPLEX64:
|
||||
return mlir::ComplexType::get(builder.getF32Type());
|
||||
case tflite::TensorType_COMPLEX128:
|
||||
return mlir::ComplexType::get(builder.getF64Type());
|
||||
case tflite::TensorType_INT8:
|
||||
return builder.getIntegerType(8);
|
||||
}
|
||||
@ -64,6 +66,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
|
||||
return tensorflow::DT_BOOL;
|
||||
case tflite::TensorType_COMPLEX64:
|
||||
return tensorflow::DT_COMPLEX64;
|
||||
case tflite::TensorType_COMPLEX128:
|
||||
return tensorflow::DT_COMPLEX128;
|
||||
case tflite::TensorType_FLOAT16:
|
||||
return tensorflow::DT_HALF;
|
||||
case tflite::TensorType_FLOAT32:
|
||||
|
@ -207,6 +207,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
|
||||
return "BOOL";
|
||||
case kTfLiteComplex64:
|
||||
return "COMPLEX64";
|
||||
case kTfLiteComplex128:
|
||||
return "COMPLEX128";
|
||||
case kTfLiteString:
|
||||
return "STRING";
|
||||
case kTfLiteFloat16:
|
||||
|
@ -238,6 +238,11 @@ typedef struct TfLiteComplex64 {
|
||||
float re, im; // real and imaginary parts, respectively.
|
||||
} TfLiteComplex64;
|
||||
|
||||
// Double-precision complex data type compatible with the C99 definition.
|
||||
typedef struct TfLiteComplex128 {
|
||||
double re, im; // real and imaginary parts, respectively.
|
||||
} TfLiteComplex128;
|
||||
|
||||
// Half precision data type compatible with the C99 definition.
|
||||
typedef struct TfLiteFloat16 {
|
||||
uint16_t data;
|
||||
@ -257,6 +262,7 @@ typedef enum {
|
||||
kTfLiteInt8 = 9,
|
||||
kTfLiteFloat16 = 10,
|
||||
kTfLiteFloat64 = 11,
|
||||
kTfLiteComplex128 = 12,
|
||||
} TfLiteType;
|
||||
|
||||
// Return the name of a given type, for error reporting purposes.
|
||||
@ -313,12 +319,14 @@ typedef union TfLitePtrUnion {
|
||||
int64_t* i64;
|
||||
float* f;
|
||||
TfLiteFloat16* f16;
|
||||
double* f64;
|
||||
char* raw;
|
||||
const char* raw_const;
|
||||
uint8_t* uint8;
|
||||
bool* b;
|
||||
int16_t* i16;
|
||||
TfLiteComplex64* c64;
|
||||
TfLiteComplex128* c128;
|
||||
int8_t* int8;
|
||||
/* Only use this member. */
|
||||
void* data;
|
||||
|
@ -78,6 +78,7 @@ TEST(Types, TestTypeNames) {
|
||||
return std::string(TfLiteTypeGetName(t));
|
||||
};
|
||||
EXPECT_EQ(type_name(kTfLiteNoType), "NOTYPE");
|
||||
EXPECT_EQ(type_name(kTfLiteFloat64), "FLOAT64");
|
||||
EXPECT_EQ(type_name(kTfLiteFloat32), "FLOAT32");
|
||||
EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16");
|
||||
EXPECT_EQ(type_name(kTfLiteInt16), "INT16");
|
||||
@ -87,6 +88,7 @@ TEST(Types, TestTypeNames) {
|
||||
EXPECT_EQ(type_name(kTfLiteInt64), "INT64");
|
||||
EXPECT_EQ(type_name(kTfLiteBool), "BOOL");
|
||||
EXPECT_EQ(type_name(kTfLiteComplex64), "COMPLEX64");
|
||||
EXPECT_EQ(type_name(kTfLiteComplex128), "COMPLEX128");
|
||||
EXPECT_EQ(type_name(kTfLiteString), "STRING");
|
||||
}
|
||||
|
||||
|
@ -863,6 +863,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
||||
case TensorType_COMPLEX64:
|
||||
*type = kTfLiteComplex64;
|
||||
return kTfLiteOk;
|
||||
case TensorType_COMPLEX128:
|
||||
*type = kTfLiteComplex128;
|
||||
return kTfLiteOk;
|
||||
default:
|
||||
*type = kTfLiteNoType;
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
|
@ -76,6 +76,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
|
||||
return TF_INT64;
|
||||
case kTfLiteComplex64:
|
||||
return TF_COMPLEX64;
|
||||
case kTfLiteComplex128:
|
||||
return TF_COMPLEX128;
|
||||
case kTfLiteString:
|
||||
return TF_STRING;
|
||||
case kTfLiteBool:
|
||||
@ -89,6 +91,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
|
||||
return kTfLiteFloat32;
|
||||
case TF_HALF:
|
||||
return kTfLiteFloat16;
|
||||
case TF_DOUBLE:
|
||||
return kTfLiteFloat64;
|
||||
case TF_INT16:
|
||||
return kTfLiteInt16;
|
||||
case TF_INT32:
|
||||
@ -101,6 +105,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
|
||||
return kTfLiteInt64;
|
||||
case TF_COMPLEX64:
|
||||
return kTfLiteComplex64;
|
||||
case TF_COMPLEX128:
|
||||
return kTfLiteComplex128;
|
||||
case TF_STRING:
|
||||
return kTfLiteString;
|
||||
case TF_BOOL:
|
||||
|
@ -109,22 +109,28 @@ TEST(UtilTest, CopyShapeAndType) {
|
||||
TEST(UtilTest, TypeConversionsFromTFLite) {
|
||||
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType));
|
||||
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32));
|
||||
EXPECT_EQ(TF_HALF, GetTensorFlowDataType(kTfLiteFloat16));
|
||||
EXPECT_EQ(TF_DOUBLE, GetTensorFlowDataType(kTfLiteFloat64));
|
||||
EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16));
|
||||
EXPECT_EQ(TF_INT32, GetTensorFlowDataType(kTfLiteInt32));
|
||||
EXPECT_EQ(TF_UINT8, GetTensorFlowDataType(kTfLiteUInt8));
|
||||
EXPECT_EQ(TF_INT64, GetTensorFlowDataType(kTfLiteInt64));
|
||||
EXPECT_EQ(TF_COMPLEX64, GetTensorFlowDataType(kTfLiteComplex64));
|
||||
EXPECT_EQ(TF_COMPLEX128, GetTensorFlowDataType(kTfLiteComplex128));
|
||||
EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString));
|
||||
EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
|
||||
}
|
||||
|
||||
TEST(UtilTest, TypeConversionsFromTensorFlow) {
|
||||
EXPECT_EQ(kTfLiteFloat16, GetTensorFlowLiteType(TF_HALF));
|
||||
EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT));
|
||||
EXPECT_EQ(kTfLiteFloat64, GetTensorFlowLiteType(TF_DOUBLE));
|
||||
EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16));
|
||||
EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32));
|
||||
EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8));
|
||||
EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64));
|
||||
EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64));
|
||||
EXPECT_EQ(kTfLiteComplex128, GetTensorFlowLiteType(TF_COMPLEX128));
|
||||
EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));
|
||||
EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL));
|
||||
EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE));
|
||||
|
@ -405,7 +405,9 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
|
||||
case kTfLiteNoType:
|
||||
case kTfLiteString:
|
||||
case kTfLiteComplex64:
|
||||
// kTfLiteString and kTfLiteComplex64 are not supported in TensorFlow Lite Objc API.
|
||||
case kTfLiteComplex128:
|
||||
// kTfLiteString, kTfLiteComplex64 and kTfLiteComplex128 are not supported in TensorFlow Lite
|
||||
// Objc API.
|
||||
return TFLTensorDataTypeNoType;
|
||||
}
|
||||
}
|
||||
|
@ -82,6 +82,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
|
||||
return TensorType_INT16;
|
||||
case kTfLiteComplex64:
|
||||
return TensorType_COMPLEX64;
|
||||
case kTfLiteComplex128:
|
||||
return TensorType_COMPLEX128;
|
||||
}
|
||||
// TODO(aselle): consider an error
|
||||
}
|
||||
|
@ -72,6 +72,9 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {
|
||||
case kTfLiteComplex64:
|
||||
*size = sizeof(float) * 2;
|
||||
break;
|
||||
case kTfLiteComplex128:
|
||||
*size = sizeof(double) * 2;
|
||||
break;
|
||||
default:
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
@ -141,6 +141,10 @@ TF_LITE_MICRO_TEST(TestTypeSizeOf) {
|
||||
tflite::TfLiteTypeSizeOf(kTfLiteComplex64, &size));
|
||||
TF_LITE_MICRO_EXPECT_EQ(sizeof(float) * 2, size);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
|
||||
tflite::TfLiteTypeSizeOf(kTfLiteComplex128, &size));
|
||||
TF_LITE_MICRO_EXPECT_EQ(sizeof(double) * 2, size);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_NE(
|
||||
kTfLiteOk, tflite::TfLiteTypeSizeOf(static_cast<TfLiteType>(-1), &size));
|
||||
}
|
||||
|
@ -162,6 +162,9 @@ void MicroInterpreter::CorrectTensorEndianness(TfLiteTensor* tensorCorr) {
|
||||
case TfLiteType::kTfLiteComplex64:
|
||||
CorrectTensorDataEndianness(tensorCorr->data.c64, tensorSize);
|
||||
break;
|
||||
case TfLiteType::kTfLiteComplex128:
|
||||
CorrectTensorDataEndianness(tensorCorr->data.c128, tensorSize);
|
||||
break;
|
||||
default:
|
||||
// Do nothing for other data types.
|
||||
break;
|
||||
|
@ -85,6 +85,8 @@ const char* TensorTypeName(TfLiteType type) {
|
||||
return "kTfLiteInt16";
|
||||
case kTfLiteComplex64:
|
||||
return "kTfLiteComplex64";
|
||||
case kTfLiteComplex128:
|
||||
return "kTfLiteComplex128";
|
||||
case kTfLiteFloat16:
|
||||
return "kTfLiteFloat16";
|
||||
case kTfLiteFloat64:
|
||||
|
@ -57,6 +57,8 @@ const char* TensorTypeName(TfLiteType type) {
|
||||
return "kTfLiteInt16";
|
||||
case kTfLiteComplex64:
|
||||
return "kTfLiteComplex64";
|
||||
case kTfLiteComplex128:
|
||||
return "kTfLiteComplex128";
|
||||
case kTfLiteFloat16:
|
||||
return "kTfLiteFloat16";
|
||||
case kTfLiteFloat64:
|
||||
|
@ -56,6 +56,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
|
||||
return NPY_BOOL;
|
||||
case kTfLiteComplex64:
|
||||
return NPY_COMPLEX64;
|
||||
case kTfLiteComplex128:
|
||||
return NPY_COMPLEX128;
|
||||
case kTfLiteNoType:
|
||||
return NPY_NOTYPE;
|
||||
// Avoid default so compiler errors created when new types are made.
|
||||
|
@ -86,6 +86,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
|
||||
return TensorType_INT16;
|
||||
case kTfLiteComplex64:
|
||||
return TensorType_COMPLEX64;
|
||||
case kTfLiteComplex128:
|
||||
return TensorType_COMPLEX128;
|
||||
}
|
||||
// No default to get compiler error when new type is introduced.
|
||||
}
|
||||
|
@ -51,6 +51,7 @@ _MAP_TF_TO_TFLITE_TYPES = {
|
||||
dtypes.int8: _types_pb2.INT8,
|
||||
dtypes.int16: _types_pb2.QUANTIZED_INT16,
|
||||
dtypes.complex64: _types_pb2.COMPLEX64,
|
||||
dtypes.complex128: _types_pb2.COMPLEX128,
|
||||
dtypes.bool: _types_pb2.BOOL,
|
||||
}
|
||||
|
||||
|
@ -41,6 +41,7 @@ enum TensorType : byte {
|
||||
COMPLEX64 = 8,
|
||||
INT8 = 9,
|
||||
FLOAT64 = 10,
|
||||
COMPLEX128 = 11,
|
||||
}
|
||||
|
||||
// Custom quantization parameters for experimenting with new quantization
|
||||
|
@ -379,11 +379,12 @@ enum TensorType {
|
||||
TensorType_COMPLEX64 = 8,
|
||||
TensorType_INT8 = 9,
|
||||
TensorType_FLOAT64 = 10,
|
||||
TensorType_COMPLEX128 = 11,
|
||||
TensorType_MIN = TensorType_FLOAT32,
|
||||
TensorType_MAX = TensorType_FLOAT64
|
||||
TensorType_MAX = TensorType_COMPLEX128
|
||||
};
|
||||
|
||||
inline const TensorType (&EnumValuesTensorType())[11] {
|
||||
inline const TensorType (&EnumValuesTensorType())[12] {
|
||||
static const TensorType values[] = {
|
||||
TensorType_FLOAT32,
|
||||
TensorType_FLOAT16,
|
||||
@ -395,13 +396,14 @@ inline const TensorType (&EnumValuesTensorType())[11] {
|
||||
TensorType_INT16,
|
||||
TensorType_COMPLEX64,
|
||||
TensorType_INT8,
|
||||
TensorType_FLOAT64
|
||||
TensorType_FLOAT64,
|
||||
TensorType_COMPLEX128
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesTensorType() {
|
||||
static const char * const names[12] = {
|
||||
static const char * const names[13] = {
|
||||
"FLOAT32",
|
||||
"FLOAT16",
|
||||
"INT32",
|
||||
@ -413,13 +415,14 @@ inline const char * const *EnumNamesTensorType() {
|
||||
"COMPLEX64",
|
||||
"INT8",
|
||||
"FLOAT64",
|
||||
"COMPLEX128",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameTensorType(TensorType e) {
|
||||
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_FLOAT64)) return "";
|
||||
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_COMPLEX128)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesTensorType()[index];
|
||||
}
|
||||
|
@ -132,6 +132,26 @@ inline std::vector<std::complex<float>> Split(const string& s,
|
||||
return fields;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<std::complex<double>> Split(const string& s,
|
||||
const string& delimiter) {
|
||||
std::vector<std::complex<double>> fields;
|
||||
for (const auto& p : SplitToPos(s, delimiter)) {
|
||||
std::string sc = s.substr(p.first, p.second - p.first);
|
||||
std::string::size_type sz_real, sz_img;
|
||||
double real = std::stod(sc, &sz_real);
|
||||
double img = std::stod(sc.substr(sz_real), &sz_img);
|
||||
if (sz_real + sz_img + 1 != sc.length()) {
|
||||
std::cerr << "There were errors in parsing string, " << sc
|
||||
<< ", to complex value." << std::endl;
|
||||
return fields;
|
||||
}
|
||||
std::complex<double> c(real, img);
|
||||
fields.push_back(c);
|
||||
}
|
||||
return fields;
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -127,15 +127,37 @@ class TfLiteDriver::DataExpectation {
|
||||
return error_is_large;
|
||||
}
|
||||
|
||||
bool CompareTwoValuesHelper(double v1, double v2) {
|
||||
double diff = std::abs(v1 - v2);
|
||||
bool error_is_large = false;
|
||||
// For very small numbers, try absolute error, otherwise go with
|
||||
// relative.
|
||||
if (std::abs(v2) < relative_threshold_) {
|
||||
error_is_large = (diff > absolute_threshold_);
|
||||
} else {
|
||||
error_is_large = (diff > relative_threshold_ * std::abs(v2));
|
||||
}
|
||||
return error_is_large;
|
||||
}
|
||||
|
||||
bool CompareTwoValues(std::complex<float> v1, std::complex<float> v2) {
|
||||
return CompareTwoValues(v1.real(), v2.real()) ||
|
||||
CompareTwoValues(v1.imag(), v2.imag());
|
||||
}
|
||||
|
||||
bool CompareTwoValues(std::complex<double> v1, std::complex<double> v2) {
|
||||
return CompareTwoValues(v1.real(), v2.real()) ||
|
||||
CompareTwoValues(v1.imag(), v2.imag());
|
||||
}
|
||||
|
||||
bool CompareTwoValues(float v1, float v2) {
|
||||
return CompareTwoValuesHelper(v1, v2);
|
||||
}
|
||||
|
||||
bool CompareTwoValues(double v1, double v2) {
|
||||
return CompareTwoValuesHelper(v1, v2);
|
||||
}
|
||||
|
||||
template <typename T, typename TS>
|
||||
bool TypedCheck(bool verbose, const TfLiteTensor& tensor) {
|
||||
size_t tensor_size = tensor.bytes / sizeof(T);
|
||||
@ -315,6 +337,9 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
|
||||
case kTfLiteComplex64:
|
||||
return TypedCheck<std::complex<float>, std::complex<float>>(verbose,
|
||||
tensor);
|
||||
case kTfLiteComplex128:
|
||||
return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
|
||||
tensor);
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
|
||||
return false;
|
||||
@ -527,6 +552,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
|
||||
case kTfLiteComplex64:
|
||||
expected_output_[id]->SetData<std::complex<float>>(csv_values);
|
||||
break;
|
||||
case kTfLiteComplex128:
|
||||
expected_output_[id]->SetData<std::complex<double>>(csv_values);
|
||||
break;
|
||||
default:
|
||||
Invalidate(absl::StrCat("Unsupported tensor type ",
|
||||
TfLiteTypeGetName(tensor->type),
|
||||
|
@ -236,6 +236,7 @@ enum class ArrayDataType : uint8 {
|
||||
kComplex64,
|
||||
kFloat16,
|
||||
kFloat64,
|
||||
kComplex128,
|
||||
};
|
||||
|
||||
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
|
||||
|
@ -51,6 +51,7 @@ namespace tflite {
|
||||
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
|
||||
{ArrayDataType::kString, ::tflite::TensorType_STRING},
|
||||
{ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
|
||||
{ArrayDataType::kComplex128, ::tflite::TensorType_COMPLEX128},
|
||||
{ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16},
|
||||
{ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}};
|
||||
|
||||
|
@ -1769,6 +1769,8 @@ int ElementSize(ArrayDataType data_type) {
|
||||
return 8;
|
||||
case ArrayDataType::kComplex64:
|
||||
return 8;
|
||||
case ArrayDataType::kComplex128:
|
||||
return 16;
|
||||
case ArrayDataType::kFloat64:
|
||||
return 8;
|
||||
|
||||
@ -2313,6 +2315,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
|
||||
return ArrayDataType::kString;
|
||||
case COMPLEX64:
|
||||
return ArrayDataType::kComplex64;
|
||||
case COMPLEX128:
|
||||
return ArrayDataType::kComplex128;
|
||||
case FLOAT16:
|
||||
return ArrayDataType::kFloat16;
|
||||
case FLOAT64:
|
||||
|
@ -52,4 +52,7 @@ enum IODataType {
|
||||
|
||||
// Double precision float, not quantized.
|
||||
FLOAT64 = 11;
|
||||
|
||||
// Complex128, not quantized
|
||||
COMPLEX128 = 12;
|
||||
}
|
||||
|
@ -238,6 +238,11 @@ typedef struct TfLiteComplex64 {
|
||||
float re, im; // real and imaginary parts, respectively.
|
||||
} TfLiteComplex64;
|
||||
|
||||
// Double-precision complex data type compatible with the C99 definition.
|
||||
typedef struct TfLiteComplex128 {
|
||||
double re, im; // real and imaginary parts, respectively.
|
||||
} TfLiteComplex128;
|
||||
|
||||
// Half precision data type compatible with the C99 definition.
|
||||
typedef struct TfLiteFloat16 {
|
||||
uint16_t data;
|
||||
@ -257,6 +262,7 @@ typedef enum {
|
||||
kTfLiteInt8 = 9,
|
||||
kTfLiteFloat16 = 10,
|
||||
kTfLiteFloat64 = 11,
|
||||
kTfLiteComplex128 = 12,
|
||||
} TfLiteType;
|
||||
|
||||
// Return the name of a given type, for error reporting purposes.
|
||||
@ -313,12 +319,14 @@ typedef union TfLitePtrUnion {
|
||||
int64_t* i64;
|
||||
float* f;
|
||||
TfLiteFloat16* f16;
|
||||
double* f64;
|
||||
char* raw;
|
||||
const char* raw_const;
|
||||
uint8_t* uint8;
|
||||
bool* b;
|
||||
int16_t* i16;
|
||||
TfLiteComplex64* c64;
|
||||
TfLiteComplex128* c128;
|
||||
int8_t* int8;
|
||||
/* Only use this member. */
|
||||
void* data;
|
||||
|
@ -384,6 +384,9 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
|
||||
case TensorType_COMPLEX64:
|
||||
bytes_required *= sizeof(std::complex<float>);
|
||||
break;
|
||||
case TensorType_COMPLEX128:
|
||||
bytes_required *= sizeof(std::complex<double>);
|
||||
break;
|
||||
default:
|
||||
ReportError(error_reporter, "Tensor %s invalid type: %d",
|
||||
NameOrEmptyString(tensor.name()), tensor.type());
|
||||
|
@ -67,6 +67,10 @@ constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
|
||||
return kTfLiteComplex64;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<std::complex<double>>() {
|
||||
return kTfLiteComplex128;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<std::string>() {
|
||||
return kTfLiteString;
|
||||
}
|
||||
|
@ -102,6 +102,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
|
||||
case kTfLiteComplex64:
|
||||
*bytes = sizeof(std::complex<float>);
|
||||
break;
|
||||
case kTfLiteComplex128:
|
||||
*bytes = sizeof(std::complex<double>);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
*bytes = sizeof(int16_t);
|
||||
break;
|
||||
|
Loading…
x
Reference in New Issue
Block a user