diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index fb20e842a75..09c79d90e26 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -149,6 +149,9 @@ static StatusOr GetTFLiteType(Type type, if (ftype && ftype.isF32()) { return tflite::TensorType_COMPLEX64; } + if (ftype && ftype.isF64()) { + return tflite::TensorType_COMPLEX128; + } return Status(error::INVALID_ARGUMENT, "Unsupported type"); } case mlir::StandardTypes::Integer: { diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 4725eb1ac5f..a4e58123e05 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -123,6 +123,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { return DT_BOOL; case toco::IODataType::COMPLEX64: return DT_COMPLEX64; + case toco::IODataType::COMPLEX128: + return DT_COMPLEX128; default: return DT_INVALID; } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir index 50fe804f86c..a622c43c2f2 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir @@ -15,6 +15,13 @@ func @complex64() -> tensor<4xcomplex> { return %0 : tensor<4xcomplex> } +func @complex128() -> tensor<4xcomplex> { + // CHECK-LABEL: @complex128 + // CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex> + %0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex> } : () -> tensor<4xcomplex> + return %0 : tensor<4xcomplex> +} + // TODO(b/138847107) this should work but doesn't // func @f16() -> tensor<4xf16> { // %0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf16> } : () -> tensor<4xf16> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir new file mode 100644 index 00000000000..a5e6d4aabb5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir @@ -0,0 +1,66 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s + +func @main(tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex> { +^bb0(%arg0: tensor<4xcomplex>, %arg1: tensor<4xcomplex>): +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: builtin_code: CUSTOM, +// CHECK-NEXT: custom_code: "FlexAdd" +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: COMPLEX128, +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "arg0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: COMPLEX128, +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "arg1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: COMPLEX128, +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "add", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1 ], +// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 1 ], +// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 18, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 4 +// CHECK-NEXT: } ] +// CHECK-NEXT:} + + %0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex> loc("add") + return %0 : tensor<4xcomplex> +} diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 22283d7eace..6b3ad78a830 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -53,6 +53,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { return builder.getIntegerType(16); case tflite::TensorType_COMPLEX64: return mlir::ComplexType::get(builder.getF32Type()); + case tflite::TensorType_COMPLEX128: + return mlir::ComplexType::get(builder.getF64Type()); case tflite::TensorType_INT8: return builder.getIntegerType(8); } @@ -64,6 +66,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { return tensorflow::DT_BOOL; case tflite::TensorType_COMPLEX64: return tensorflow::DT_COMPLEX64; + case tflite::TensorType_COMPLEX128: + return tensorflow::DT_COMPLEX128; case tflite::TensorType_FLOAT16: return tensorflow::DT_HALF; case tflite::TensorType_FLOAT32: diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c index e6b47896528..0264f420b12 100644 --- a/tensorflow/lite/c/common.c +++ b/tensorflow/lite/c/common.c @@ -207,6 +207,8 @@ const char* TfLiteTypeGetName(TfLiteType type) { return "BOOL"; case kTfLiteComplex64: return "COMPLEX64"; + case kTfLiteComplex128: + return "COMPLEX128"; case kTfLiteString: return "STRING"; case kTfLiteFloat16: diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index 13e846406e6..cd6eeec4da2 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -238,6 +238,11 @@ typedef struct TfLiteComplex64 { float re, im; // real and imaginary parts, respectively. } TfLiteComplex64; +// Double-precision complex data type compatible with the C99 definition. +typedef struct TfLiteComplex128 { + double re, im; // real and imaginary parts, respectively. +} TfLiteComplex128; + // Half precision data type compatible with the C99 definition. typedef struct TfLiteFloat16 { uint16_t data; @@ -257,6 +262,7 @@ typedef enum { kTfLiteInt8 = 9, kTfLiteFloat16 = 10, kTfLiteFloat64 = 11, + kTfLiteComplex128 = 12, } TfLiteType; // Return the name of a given type, for error reporting purposes. @@ -313,12 +319,14 @@ typedef union TfLitePtrUnion { int64_t* i64; float* f; TfLiteFloat16* f16; + double* f64; char* raw; const char* raw_const; uint8_t* uint8; bool* b; int16_t* i16; TfLiteComplex64* c64; + TfLiteComplex128* c128; int8_t* int8; /* Only use this member. */ void* data; diff --git a/tensorflow/lite/c/common_test.cc b/tensorflow/lite/c/common_test.cc index 0421b50c05e..235c9c1b2cc 100644 --- a/tensorflow/lite/c/common_test.cc +++ b/tensorflow/lite/c/common_test.cc @@ -78,6 +78,7 @@ TEST(Types, TestTypeNames) { return std::string(TfLiteTypeGetName(t)); }; EXPECT_EQ(type_name(kTfLiteNoType), "NOTYPE"); + EXPECT_EQ(type_name(kTfLiteFloat64), "FLOAT64"); EXPECT_EQ(type_name(kTfLiteFloat32), "FLOAT32"); EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16"); EXPECT_EQ(type_name(kTfLiteInt16), "INT16"); @@ -87,6 +88,7 @@ TEST(Types, TestTypeNames) { EXPECT_EQ(type_name(kTfLiteInt64), "INT64"); EXPECT_EQ(type_name(kTfLiteBool), "BOOL"); EXPECT_EQ(type_name(kTfLiteComplex64), "COMPLEX64"); + EXPECT_EQ(type_name(kTfLiteComplex128), "COMPLEX128"); EXPECT_EQ(type_name(kTfLiteString), "STRING"); } diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index e5422697acc..0652c64f6c2 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -863,6 +863,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_COMPLEX64: *type = kTfLiteComplex64; return kTfLiteOk; + case TensorType_COMPLEX128: + *type = kTfLiteComplex128; + return kTfLiteOk; default: *type = kTfLiteNoType; TF_LITE_REPORT_ERROR(error_reporter, diff --git a/tensorflow/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc index 750de7397fa..11cf28073fa 100644 --- a/tensorflow/lite/delegates/flex/util.cc +++ b/tensorflow/lite/delegates/flex/util.cc @@ -76,6 +76,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { return TF_INT64; case kTfLiteComplex64: return TF_COMPLEX64; + case kTfLiteComplex128: + return TF_COMPLEX128; case kTfLiteString: return TF_STRING; case kTfLiteBool: @@ -89,6 +91,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) { return kTfLiteFloat32; case TF_HALF: return kTfLiteFloat16; + case TF_DOUBLE: + return kTfLiteFloat64; case TF_INT16: return kTfLiteInt16; case TF_INT32: @@ -101,6 +105,8 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) { return kTfLiteInt64; case TF_COMPLEX64: return kTfLiteComplex64; + case TF_COMPLEX128: + return kTfLiteComplex128; case TF_STRING: return kTfLiteString; case TF_BOOL: diff --git a/tensorflow/lite/delegates/flex/util_test.cc b/tensorflow/lite/delegates/flex/util_test.cc index 751289ef28f..0d4b50256f0 100644 --- a/tensorflow/lite/delegates/flex/util_test.cc +++ b/tensorflow/lite/delegates/flex/util_test.cc @@ -109,22 +109,28 @@ TEST(UtilTest, CopyShapeAndType) { TEST(UtilTest, TypeConversionsFromTFLite) { EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType)); EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32)); + EXPECT_EQ(TF_HALF, GetTensorFlowDataType(kTfLiteFloat16)); + EXPECT_EQ(TF_DOUBLE, GetTensorFlowDataType(kTfLiteFloat64)); EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16)); EXPECT_EQ(TF_INT32, GetTensorFlowDataType(kTfLiteInt32)); EXPECT_EQ(TF_UINT8, GetTensorFlowDataType(kTfLiteUInt8)); EXPECT_EQ(TF_INT64, GetTensorFlowDataType(kTfLiteInt64)); EXPECT_EQ(TF_COMPLEX64, GetTensorFlowDataType(kTfLiteComplex64)); + EXPECT_EQ(TF_COMPLEX128, GetTensorFlowDataType(kTfLiteComplex128)); EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString)); EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool)); } TEST(UtilTest, TypeConversionsFromTensorFlow) { + EXPECT_EQ(kTfLiteFloat16, GetTensorFlowLiteType(TF_HALF)); EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT)); + EXPECT_EQ(kTfLiteFloat64, GetTensorFlowLiteType(TF_DOUBLE)); EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16)); EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32)); EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8)); EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64)); EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64)); + EXPECT_EQ(kTfLiteComplex128, GetTensorFlowLiteType(TF_COMPLEX128)); EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING)); EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL)); EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE)); diff --git a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm index 34dd119885d..0ccafd71d1b 100644 --- a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm +++ b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm @@ -405,7 +405,9 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_ case kTfLiteNoType: case kTfLiteString: case kTfLiteComplex64: - // kTfLiteString and kTfLiteComplex64 are not supported in TensorFlow Lite Objc API. + case kTfLiteComplex128: + // kTfLiteString, kTfLiteComplex64 and kTfLiteComplex128 are not supported in TensorFlow Lite + // Objc API. return TFLTensorDataTypeNoType; } } diff --git a/tensorflow/lite/experimental/writer/enum_mapping.h b/tensorflow/lite/experimental/writer/enum_mapping.h index 5eabbcb2015..0847fb7893d 100644 --- a/tensorflow/lite/experimental/writer/enum_mapping.h +++ b/tensorflow/lite/experimental/writer/enum_mapping.h @@ -82,6 +82,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_INT16; case kTfLiteComplex64: return TensorType_COMPLEX64; + case kTfLiteComplex128: + return TensorType_COMPLEX128; } // TODO(aselle): consider an error } diff --git a/tensorflow/lite/micro/memory_helpers.cc b/tensorflow/lite/micro/memory_helpers.cc index bded4d6895a..0e8f335c049 100644 --- a/tensorflow/lite/micro/memory_helpers.cc +++ b/tensorflow/lite/micro/memory_helpers.cc @@ -72,6 +72,9 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) { case kTfLiteComplex64: *size = sizeof(float) * 2; break; + case kTfLiteComplex128: + *size = sizeof(double) * 2; + break; default: return kTfLiteError; } diff --git a/tensorflow/lite/micro/memory_helpers_test.cc b/tensorflow/lite/micro/memory_helpers_test.cc index 791e30c944e..82096c6890d 100644 --- a/tensorflow/lite/micro/memory_helpers_test.cc +++ b/tensorflow/lite/micro/memory_helpers_test.cc @@ -141,6 +141,10 @@ TF_LITE_MICRO_TEST(TestTypeSizeOf) { tflite::TfLiteTypeSizeOf(kTfLiteComplex64, &size)); TF_LITE_MICRO_EXPECT_EQ(sizeof(float) * 2, size); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + tflite::TfLiteTypeSizeOf(kTfLiteComplex128, &size)); + TF_LITE_MICRO_EXPECT_EQ(sizeof(double) * 2, size); + TF_LITE_MICRO_EXPECT_NE( kTfLiteOk, tflite::TfLiteTypeSizeOf(static_cast(-1), &size)); } diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index 08556a56a54..c16ede174aa 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -162,6 +162,9 @@ void MicroInterpreter::CorrectTensorEndianness(TfLiteTensor* tensorCorr) { case TfLiteType::kTfLiteComplex64: CorrectTensorDataEndianness(tensorCorr->data.c64, tensorSize); break; + case TfLiteType::kTfLiteComplex128: + CorrectTensorDataEndianness(tensorCorr->data.c128, tensorSize); + break; default: // Do nothing for other data types. break; diff --git a/tensorflow/lite/micro/micro_optional_debug_tools.cc b/tensorflow/lite/micro/micro_optional_debug_tools.cc index f94d67b5ee5..516def3ebe4 100644 --- a/tensorflow/lite/micro/micro_optional_debug_tools.cc +++ b/tensorflow/lite/micro/micro_optional_debug_tools.cc @@ -85,6 +85,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteInt16"; case kTfLiteComplex64: return "kTfLiteComplex64"; + case kTfLiteComplex128: + return "kTfLiteComplex128"; case kTfLiteFloat16: return "kTfLiteFloat16"; case kTfLiteFloat64: diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc index 2e25b0a17f7..8ee5c3b3f56 100644 --- a/tensorflow/lite/optional_debug_tools.cc +++ b/tensorflow/lite/optional_debug_tools.cc @@ -57,6 +57,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteInt16"; case kTfLiteComplex64: return "kTfLiteComplex64"; + case kTfLiteComplex128: + return "kTfLiteComplex128"; case kTfLiteFloat16: return "kTfLiteFloat16"; case kTfLiteFloat64: diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc index 00e5064e620..d2f308a74a2 100644 --- a/tensorflow/lite/python/interpreter_wrapper/numpy.cc +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc @@ -56,6 +56,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_BOOL; case kTfLiteComplex64: return NPY_COMPLEX64; + case kTfLiteComplex128: + return NPY_COMPLEX128; case kTfLiteNoType: return NPY_NOTYPE; // Avoid default so compiler errors created when new types are made. diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index 4e4584c0fd7..b608d529c85 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -86,6 +86,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_INT16; case kTfLiteComplex64: return TensorType_COMPLEX64; + case kTfLiteComplex128: + return TensorType_COMPLEX128; } // No default to get compiler error when new type is introduced. } diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index a69f59b2837..ff7caad0f88 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -51,6 +51,7 @@ _MAP_TF_TO_TFLITE_TYPES = { dtypes.int8: _types_pb2.INT8, dtypes.int16: _types_pb2.QUANTIZED_INT16, dtypes.complex64: _types_pb2.COMPLEX64, + dtypes.complex128: _types_pb2.COMPLEX128, dtypes.bool: _types_pb2.BOOL, } diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index b7f41c756e4..878acde1e16 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -41,6 +41,7 @@ enum TensorType : byte { COMPLEX64 = 8, INT8 = 9, FLOAT64 = 10, + COMPLEX128 = 11, } // Custom quantization parameters for experimenting with new quantization diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index b044acb4033..a6117dc72ab 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -379,11 +379,12 @@ enum TensorType { TensorType_COMPLEX64 = 8, TensorType_INT8 = 9, TensorType_FLOAT64 = 10, + TensorType_COMPLEX128 = 11, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_FLOAT64 + TensorType_MAX = TensorType_COMPLEX128 }; -inline const TensorType (&EnumValuesTensorType())[11] { +inline const TensorType (&EnumValuesTensorType())[12] { static const TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -395,13 +396,14 @@ inline const TensorType (&EnumValuesTensorType())[11] { TensorType_INT16, TensorType_COMPLEX64, TensorType_INT8, - TensorType_FLOAT64 + TensorType_FLOAT64, + TensorType_COMPLEX128 }; return values; } inline const char * const *EnumNamesTensorType() { - static const char * const names[12] = { + static const char * const names[13] = { "FLOAT32", "FLOAT16", "INT32", @@ -413,13 +415,14 @@ inline const char * const *EnumNamesTensorType() { "COMPLEX64", "INT8", "FLOAT64", + "COMPLEX128", nullptr }; return names; } inline const char *EnumNameTensorType(TensorType e) { - if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_FLOAT64)) return ""; + if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_COMPLEX128)) return ""; const size_t index = static_cast(e); return EnumNamesTensorType()[index]; } diff --git a/tensorflow/lite/testing/split.h b/tensorflow/lite/testing/split.h index d4e762164a4..6f7b9a68484 100644 --- a/tensorflow/lite/testing/split.h +++ b/tensorflow/lite/testing/split.h @@ -132,6 +132,26 @@ inline std::vector> Split(const string& s, return fields; } +template <> +inline std::vector> Split(const string& s, + const string& delimiter) { + std::vector> fields; + for (const auto& p : SplitToPos(s, delimiter)) { + std::string sc = s.substr(p.first, p.second - p.first); + std::string::size_type sz_real, sz_img; + double real = std::stod(sc, &sz_real); + double img = std::stod(sc.substr(sz_real), &sz_img); + if (sz_real + sz_img + 1 != sc.length()) { + std::cerr << "There were errors in parsing string, " << sc + << ", to complex value." << std::endl; + return fields; + } + std::complex c(real, img); + fields.push_back(c); + } + return fields; +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc index ae53be09889..ae352ce04c4 100644 --- a/tensorflow/lite/testing/tflite_driver.cc +++ b/tensorflow/lite/testing/tflite_driver.cc @@ -127,15 +127,37 @@ class TfLiteDriver::DataExpectation { return error_is_large; } + bool CompareTwoValuesHelper(double v1, double v2) { + double diff = std::abs(v1 - v2); + bool error_is_large = false; + // For very small numbers, try absolute error, otherwise go with + // relative. + if (std::abs(v2) < relative_threshold_) { + error_is_large = (diff > absolute_threshold_); + } else { + error_is_large = (diff > relative_threshold_ * std::abs(v2)); + } + return error_is_large; + } + bool CompareTwoValues(std::complex v1, std::complex v2) { return CompareTwoValues(v1.real(), v2.real()) || CompareTwoValues(v1.imag(), v2.imag()); } + bool CompareTwoValues(std::complex v1, std::complex v2) { + return CompareTwoValues(v1.real(), v2.real()) || + CompareTwoValues(v1.imag(), v2.imag()); + } + bool CompareTwoValues(float v1, float v2) { return CompareTwoValuesHelper(v1, v2); } + bool CompareTwoValues(double v1, double v2) { + return CompareTwoValuesHelper(v1, v2); + } + template bool TypedCheck(bool verbose, const TfLiteTensor& tensor) { size_t tensor_size = tensor.bytes / sizeof(T); @@ -315,6 +337,9 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose, case kTfLiteComplex64: return TypedCheck, std::complex>(verbose, tensor); + case kTfLiteComplex128: + return TypedCheck, std::complex>(verbose, + tensor); default: fprintf(stderr, "Unsupported type %d in Check\n", tensor.type); return false; @@ -527,6 +552,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) { case kTfLiteComplex64: expected_output_[id]->SetData>(csv_values); break; + case kTfLiteComplex128: + expected_output_[id]->SetData>(csv_values); + break; default: Invalidate(absl::StrCat("Unsupported tensor type ", TfLiteTypeGetName(tensor->type), diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 58397f5a3eb..b42fed6fbc1 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -236,6 +236,7 @@ enum class ArrayDataType : uint8 { kComplex64, kFloat16, kFloat64, + kComplex128, }; // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index bc12d49a115..794691f5724 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -51,6 +51,7 @@ namespace tflite { {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, {ArrayDataType::kString, ::tflite::TensorType_STRING}, {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64}, + {ArrayDataType::kComplex128, ::tflite::TensorType_COMPLEX128}, {ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16}, {ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}}; diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index be4cda8aa3d..d84763faee6 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -1769,6 +1769,8 @@ int ElementSize(ArrayDataType data_type) { return 8; case ArrayDataType::kComplex64: return 8; + case ArrayDataType::kComplex128: + return 16; case ArrayDataType::kFloat64: return 8; @@ -2313,6 +2315,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { return ArrayDataType::kString; case COMPLEX64: return ArrayDataType::kComplex64; + case COMPLEX128: + return ArrayDataType::kComplex128; case FLOAT16: return ArrayDataType::kFloat16; case FLOAT64: diff --git a/tensorflow/lite/toco/types.proto b/tensorflow/lite/toco/types.proto index 029a159321e..009891c3bcb 100644 --- a/tensorflow/lite/toco/types.proto +++ b/tensorflow/lite/toco/types.proto @@ -52,4 +52,7 @@ enum IODataType { // Double precision float, not quantized. FLOAT64 = 11; + + // Complex128, not quantized + COMPLEX128 = 12; } diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h index 13e846406e6..cd6eeec4da2 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h @@ -238,6 +238,11 @@ typedef struct TfLiteComplex64 { float re, im; // real and imaginary parts, respectively. } TfLiteComplex64; +// Double-precision complex data type compatible with the C99 definition. +typedef struct TfLiteComplex128 { + double re, im; // real and imaginary parts, respectively. +} TfLiteComplex128; + // Half precision data type compatible with the C99 definition. typedef struct TfLiteFloat16 { uint16_t data; @@ -257,6 +262,7 @@ typedef enum { kTfLiteInt8 = 9, kTfLiteFloat16 = 10, kTfLiteFloat64 = 11, + kTfLiteComplex128 = 12, } TfLiteType; // Return the name of a given type, for error reporting purposes. @@ -313,12 +319,14 @@ typedef union TfLitePtrUnion { int64_t* i64; float* f; TfLiteFloat16* f16; + double* f64; char* raw; const char* raw_const; uint8_t* uint8; bool* b; int16_t* i16; TfLiteComplex64* c64; + TfLiteComplex128* c128; int8_t* int8; /* Only use this member. */ void* data; diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc index 9befa7fd6f1..12b24e6f2d8 100644 --- a/tensorflow/lite/tools/verifier.cc +++ b/tensorflow/lite/tools/verifier.cc @@ -384,6 +384,9 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, case TensorType_COMPLEX64: bytes_required *= sizeof(std::complex); break; + case TensorType_COMPLEX128: + bytes_required *= sizeof(std::complex); + break; default: ReportError(error_reporter, "Tensor %s invalid type: %d", NameOrEmptyString(tensor.name()), tensor.type()); diff --git a/tensorflow/lite/type_to_tflitetype.h b/tensorflow/lite/type_to_tflitetype.h index 84cd54b5718..4ad36688bee 100644 --- a/tensorflow/lite/type_to_tflitetype.h +++ b/tensorflow/lite/type_to_tflitetype.h @@ -67,6 +67,10 @@ constexpr TfLiteType typeToTfLiteType>() { return kTfLiteComplex64; } template <> +constexpr TfLiteType typeToTfLiteType>() { + return kTfLiteComplex128; +} +template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteString; } diff --git a/tensorflow/lite/util.cc b/tensorflow/lite/util.cc index 09efaa77f15..9cfdaf4d695 100644 --- a/tensorflow/lite/util.cc +++ b/tensorflow/lite/util.cc @@ -102,6 +102,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type, case kTfLiteComplex64: *bytes = sizeof(std::complex); break; + case kTfLiteComplex128: + *bytes = sizeof(std::complex); + break; case kTfLiteInt16: *bytes = sizeof(int16_t); break;