Add float64 tensor support in TFLite

PiperOrigin-RevId: 304756523
Change-Id: I6e9f3196a700b3b43cc9b9fb06b72938db651582
This commit is contained in:
Jaesung Chung 2020-04-03 23:16:32 -07:00 committed by TensorFlower Gardener
parent 0d61fc79f5
commit 845adc40a6
33 changed files with 183 additions and 12 deletions

View File

@ -138,6 +138,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
return tflite::TensorType_FLOAT32; return tflite::TensorType_FLOAT32;
case mlir::StandardTypes::F16: case mlir::StandardTypes::F16:
return tflite::TensorType_FLOAT16; return tflite::TensorType_FLOAT16;
case mlir::StandardTypes::F64:
return tflite::TensorType_FLOAT64;
case mlir::TF::TensorFlowTypes::STRING: case mlir::TF::TensorFlowTypes::STRING:
return tflite::TensorType_STRING; return tflite::TensorType_STRING;
case mlir::TF::TensorFlowTypes::QUINT8: case mlir::TF::TensorFlowTypes::QUINT8:

View File

@ -353,6 +353,22 @@ StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
} }
return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values)); return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values));
} }
case 64: {
assert(bytes_len % 8 == 0);
size_t elem_count = bytes_len / 8;
std::vector<double> values;
values.reserve(elem_count);
const char* data = reinterpret_cast<const char*>(buffer.data());
for (int i = 0; i < elem_count; i++) {
uint64_t bit_repr =
llvm::support::endian::readNext<uint64_t, llvm::support::little,
llvm::support::unaligned>(data);
values.push_back(absl::bit_cast<double>(bit_repr));
}
return DenseElementsAttr::get(shaped_type, ArrayRef<double>(values));
}
} }
return errors::InvalidArgument("unsupported bit width", elem_type.getWidth()); return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
} }

View File

@ -105,6 +105,10 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
switch (dtype) { switch (dtype) {
case toco::IODataType::FLOAT: case toco::IODataType::FLOAT:
return DT_FLOAT; return DT_FLOAT;
case toco::IODataType::FLOAT16:
return DT_HALF;
case toco::IODataType::FLOAT64:
return DT_DOUBLE;
case toco::IODataType::QUANTIZED_UINT8: case toco::IODataType::QUANTIZED_UINT8:
return DT_QUINT8; return DT_QUINT8;
case toco::IODataType::INT8: case toco::IODataType::INT8:

View File

@ -28,6 +28,13 @@ func @f32() -> tensor<4xf32> {
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
func @f64() -> tensor<4xf64> {
// CHECK-LABEL: @f64
// CHECK: value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64>
%0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> } : () -> tensor<4xf64>
return %0 : tensor<4xf64>
}
func @i8() -> tensor<4xi8> { func @i8() -> tensor<4xi8> {
// CHECK-LABEL: @i8 // CHECK-LABEL: @i8
// CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8> // CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8>

View File

@ -0,0 +1,66 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
^bb0(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "FlexAdd"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: FLOAT64,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: FLOAT64,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: FLOAT64,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "add",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 2, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT:}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add")
return %0 : tensor<4xf64>
}

View File

@ -32,10 +32,12 @@ namespace errors = tensorflow::errors;
mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
switch (type) { switch (type) {
case tflite::TensorType_FLOAT32:
return builder.getF32Type();
case tflite::TensorType_FLOAT16: case tflite::TensorType_FLOAT16:
return builder.getF16Type(); return builder.getF16Type();
case tflite::TensorType_FLOAT32:
return builder.getF32Type();
case tflite::TensorType_FLOAT64:
return builder.getF64Type();
case tflite::TensorType_INT32: case tflite::TensorType_INT32:
return builder.getIntegerType(32); return builder.getIntegerType(32);
case tflite::TensorType_UINT8: case tflite::TensorType_UINT8:
@ -65,6 +67,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
return tensorflow::DT_HALF; return tensorflow::DT_HALF;
case tflite::TensorType_FLOAT32: case tflite::TensorType_FLOAT32:
return tensorflow::DT_FLOAT; return tensorflow::DT_FLOAT;
case tflite::TensorType_FLOAT64:
return tensorflow::DT_DOUBLE;
case tflite::TensorType_INT8: case tflite::TensorType_INT8:
return tensorflow::DT_INT8; return tensorflow::DT_INT8;
case tflite::TensorType_INT16: case tflite::TensorType_INT16:

View File

@ -209,6 +209,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
return "STRING"; return "STRING";
case kTfLiteFloat16: case kTfLiteFloat16:
return "FLOAT16"; return "FLOAT16";
case kTfLiteFloat64:
return "FLOAT64";
} }
return "Unknown type"; return "Unknown type";
} }

View File

@ -236,6 +236,7 @@ typedef enum {
kTfLiteComplex64 = 8, kTfLiteComplex64 = 8,
kTfLiteInt8 = 9, kTfLiteInt8 = 9,
kTfLiteFloat16 = 10, kTfLiteFloat16 = 10,
kTfLiteFloat64 = 11,
} TfLiteType; } TfLiteType;
// Return the name of a given type, for error reporting purposes. // Return the name of a given type, for error reporting purposes.

View File

@ -91,11 +91,14 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
ErrorReporter* error_reporter) { ErrorReporter* error_reporter) {
*type = kTfLiteNoType; *type = kTfLiteNoType;
switch (tensor_type) { switch (tensor_type) {
case TensorType_FLOAT16:
*type = kTfLiteFloat16;
break;
case TensorType_FLOAT32: case TensorType_FLOAT32:
*type = kTfLiteFloat32; *type = kTfLiteFloat32;
break; break;
case TensorType_FLOAT16: case TensorType_FLOAT64:
*type = kTfLiteFloat16; *type = kTfLiteFloat64;
break; break;
case TensorType_INT16: case TensorType_INT16:
*type = kTfLiteInt16; *type = kTfLiteInt16;

View File

@ -62,6 +62,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
return TF_FLOAT; return TF_FLOAT;
case kTfLiteFloat16: case kTfLiteFloat16:
return TF_HALF; return TF_HALF;
case kTfLiteFloat64:
return TF_DOUBLE;
case kTfLiteInt16: case kTfLiteInt16:
return TF_INT16; return TF_INT16;
case kTfLiteInt32: case kTfLiteInt32:

View File

@ -49,6 +49,9 @@ typedef NS_ENUM(NSUInteger, TFLTensorDataType) {
/** 8-bit signed integer. */ /** 8-bit signed integer. */
TFLTensorDataTypeInt8, TFLTensorDataTypeInt8,
/** 64-bit double precision floating point. */
TFLTensorDataTypeFloat64,
}; };
/** /**

View File

@ -373,6 +373,8 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
return TFLTensorDataTypeFloat32; return TFLTensorDataTypeFloat32;
case kTfLiteFloat16: case kTfLiteFloat16:
return TFLTensorDataTypeFloat16; return TFLTensorDataTypeFloat16;
case kTfLiteFloat64:
return TFLTensorDataTypeFloat64;
case kTfLiteInt32: case kTfLiteInt32:
return TFLTensorDataTypeInt32; return TFLTensorDataTypeInt32;
case kTfLiteUInt8: case kTfLiteUInt8:

View File

@ -73,6 +73,8 @@ extension Tensor {
case float16 case float16
/// A 32-bit single precision floating point. /// A 32-bit single precision floating point.
case float32 case float32
/// A 64-bit double precision floating point.
case float64
/// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported /// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported
/// or could not be determined because there was an error. /// or could not be determined because there was an error.
@ -94,6 +96,8 @@ extension Tensor {
self = .float16 self = .float16
case kTfLiteFloat32: case kTfLiteFloat32:
self = .float32 self = .float32
case kTfLiteFloat64:
self = .float64
case kTfLiteNoType: case kTfLiteNoType:
fallthrough fallthrough
default: default:

View File

@ -64,6 +64,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_FLOAT32; return TensorType_FLOAT32;
case kTfLiteFloat16: case kTfLiteFloat16:
return TensorType_FLOAT16; return TensorType_FLOAT16;
case kTfLiteFloat64:
return TensorType_FLOAT64;
case kTfLiteInt32: case kTfLiteInt32:
return TensorType_INT32; return TensorType_INT32;
case kTfLiteUInt8: case kTfLiteUInt8:

View File

@ -790,6 +790,7 @@ template <typename T>
TensorType GetTensorType() { TensorType GetTensorType() {
if (std::is_same<T, float>::value) return TensorType_FLOAT32; if (std::is_same<T, float>::value) return TensorType_FLOAT32;
if (std::is_same<T, TfLiteFloat16>::value) return TensorType_FLOAT16; if (std::is_same<T, TfLiteFloat16>::value) return TensorType_FLOAT16;
if (std::is_same<T, double>::value) return TensorType_FLOAT64;
if (std::is_same<T, int8_t>::value) return TensorType_INT8; if (std::is_same<T, int8_t>::value) return TensorType_INT8;
if (std::is_same<T, int16_t>::value) return TensorType_INT16; if (std::is_same<T, int16_t>::value) return TensorType_INT16;
if (std::is_same<T, int32_t>::value) return TensorType_INT32; if (std::is_same<T, int32_t>::value) return TensorType_INT32;

View File

@ -77,6 +77,8 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteComplex64"; return "kTfLiteComplex64";
case kTfLiteFloat16: case kTfLiteFloat16:
return "kTfLiteFloat16"; return "kTfLiteFloat16";
case kTfLiteFloat64:
return "kTfLiteFloat64";
} }
return "(invalid)"; return "(invalid)";
} }

View File

@ -59,6 +59,8 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteComplex64"; return "kTfLiteComplex64";
case kTfLiteFloat16: case kTfLiteFloat16:
return "kTfLiteFloat16"; return "kTfLiteFloat16";
case kTfLiteFloat64:
return "kTfLiteFloat64";
} }
return "(invalid)"; return "(invalid)";
} }

View File

@ -38,6 +38,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
return NPY_FLOAT32; return NPY_FLOAT32;
case kTfLiteFloat16: case kTfLiteFloat16:
return NPY_FLOAT16; return NPY_FLOAT16;
case kTfLiteFloat64:
return NPY_FLOAT64;
case kTfLiteInt32: case kTfLiteInt32:
return NPY_INT32; return NPY_INT32;
case kTfLiteInt16: case kTfLiteInt16:

View File

@ -68,6 +68,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_FLOAT32; return TensorType_FLOAT32;
case kTfLiteFloat16: case kTfLiteFloat16:
return TensorType_FLOAT16; return TensorType_FLOAT16;
case kTfLiteFloat64:
return TensorType_FLOAT64;
case kTfLiteInt32: case kTfLiteInt32:
return TensorType_INT32; return TensorType_INT32;
case kTfLiteUInt8: case kTfLiteUInt8:

View File

@ -43,6 +43,7 @@ from tensorflow.python.training.saver import export_meta_graph as _export_meta_g
_MAP_TF_TO_TFLITE_TYPES = { _MAP_TF_TO_TFLITE_TYPES = {
dtypes.float32: _types_pb2.FLOAT, dtypes.float32: _types_pb2.FLOAT,
dtypes.float16: _types_pb2.FLOAT16, dtypes.float16: _types_pb2.FLOAT16,
dtypes.float64: _types_pb2.FLOAT64,
dtypes.int32: _types_pb2.INT32, dtypes.int32: _types_pb2.INT32,
dtypes.int64: _types_pb2.INT64, dtypes.int64: _types_pb2.INT64,
dtypes.string: _types_pb2.STRING, dtypes.string: _types_pb2.STRING,

View File

@ -40,6 +40,7 @@ enum TensorType : byte {
INT16 = 7, INT16 = 7,
COMPLEX64 = 8, COMPLEX64 = 8,
INT8 = 9, INT8 = 9,
FLOAT64 = 10,
} }
// Custom quantization parameters for experimenting with new quantization // Custom quantization parameters for experimenting with new quantization

View File

@ -378,11 +378,12 @@ enum TensorType {
TensorType_INT16 = 7, TensorType_INT16 = 7,
TensorType_COMPLEX64 = 8, TensorType_COMPLEX64 = 8,
TensorType_INT8 = 9, TensorType_INT8 = 9,
TensorType_FLOAT64 = 10,
TensorType_MIN = TensorType_FLOAT32, TensorType_MIN = TensorType_FLOAT32,
TensorType_MAX = TensorType_INT8 TensorType_MAX = TensorType_FLOAT64
}; };
inline const TensorType (&EnumValuesTensorType())[10] { inline const TensorType (&EnumValuesTensorType())[11] {
static const TensorType values[] = { static const TensorType values[] = {
TensorType_FLOAT32, TensorType_FLOAT32,
TensorType_FLOAT16, TensorType_FLOAT16,
@ -393,13 +394,14 @@ inline const TensorType (&EnumValuesTensorType())[10] {
TensorType_BOOL, TensorType_BOOL,
TensorType_INT16, TensorType_INT16,
TensorType_COMPLEX64, TensorType_COMPLEX64,
TensorType_INT8 TensorType_INT8,
TensorType_FLOAT64
}; };
return values; return values;
} }
inline const char * const *EnumNamesTensorType() { inline const char * const *EnumNamesTensorType() {
static const char * const names[11] = { static const char * const names[12] = {
"FLOAT32", "FLOAT32",
"FLOAT16", "FLOAT16",
"INT32", "INT32",
@ -410,13 +412,14 @@ inline const char * const *EnumNamesTensorType() {
"INT16", "INT16",
"COMPLEX64", "COMPLEX64",
"INT8", "INT8",
"FLOAT64",
nullptr nullptr
}; };
return names; return names;
} }
inline const char *EnumNameTensorType(TensorType e) { inline const char *EnumNameTensorType(TensorType e) {
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_INT8)) return ""; if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_FLOAT64)) return "";
const size_t index = static_cast<size_t>(e); const size_t index = static_cast<size_t>(e);
return EnumNamesTensorType()[index]; return EnumNamesTensorType()[index];
} }

View File

@ -114,6 +114,18 @@ def make_binary_op_tests(options,
}, },
] ]
# float64 types are supported via flex only.
if options.run_with_flex and options.use_experimental_converter:
test_parameters = test_parameters + [
{
"dtype": [tf.float64],
"input_shape_1": [[7]],
"input_shape_2": [[7]],
"activation": [False],
"fully_quantize": [False],
},
]
# test_parameters include fully_quantize option only when # test_parameters include fully_quantize option only when
# allow_fully_quantize is True. # allow_fully_quantize is True.
if not allow_fully_quantize: if not allow_fully_quantize:

View File

@ -75,6 +75,7 @@ RANDOM_SEED = 342
TF_TYPE_INFO = { TF_TYPE_INFO = {
tf.float32: (np.float32, "FLOAT"), tf.float32: (np.float32, "FLOAT"),
tf.float16: (np.float16, "FLOAT"), tf.float16: (np.float16, "FLOAT"),
tf.float64: (np.double, "FLOAT64"),
tf.int32: (np.int32, "INT32"), tf.int32: (np.int32, "INT32"),
tf.uint8: (np.uint8, "QUANTIZED_UINT8"), tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
tf.int16: (np.int16, "QUANTIZED_INT16"), tf.int16: (np.int16, "QUANTIZED_INT16"),
@ -108,7 +109,7 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
if dtype in TF_TYPE_INFO: if dtype in TF_TYPE_INFO:
dtype = TF_TYPE_INFO[dtype][0] dtype = TF_TYPE_INFO[dtype][0]
if dtype in (tf.float32, tf.float16): if dtype in (tf.float32, tf.float16, tf.float64):
value = (max_value - min_value) * np.random.random_sample(shape) + min_value value = (max_value - min_value) * np.random.random_sample(shape) + min_value
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
value = np.random.randint(min_value, max_value + 1, shape) value = np.random.randint(min_value, max_value + 1, shape)
@ -128,7 +129,7 @@ def create_scalar_data(dtype, min_value=-100, max_value=100):
if dtype in TF_TYPE_INFO: if dtype in TF_TYPE_INFO:
dtype = TF_TYPE_INFO[dtype][0] dtype = TF_TYPE_INFO[dtype][0]
if dtype in (tf.float32, tf.float16): if dtype in (tf.float32, tf.float16, tf.float64):
value = (max_value - min_value) * np.random.random() + min_value value = (max_value - min_value) * np.random.random() + min_value
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
value = np.random.randint(min_value, max_value + 1) value = np.random.randint(min_value, max_value + 1)

View File

@ -234,6 +234,7 @@ enum class ArrayDataType : uint8 {
kString, kString,
kComplex64, kComplex64,
kFloat16, kFloat16,
kFloat64,
}; };
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type

View File

@ -51,7 +51,8 @@ namespace tflite {
{ArrayDataType::kInt64, ::tflite::TensorType_INT64}, {ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kString, ::tflite::TensorType_STRING}, {ArrayDataType::kString, ::tflite::TensorType_STRING},
{ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64}, {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
{ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16}}; {ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16},
{ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}};
auto it = tensor_type_map.find(type); auto it = tensor_type_map.find(type);
if (it != tensor_type_map.end()) { if (it != tensor_type_map.end()) {

View File

@ -1766,6 +1766,8 @@ int ElementSize(ArrayDataType data_type) {
return 8; return 8;
case ArrayDataType::kComplex64: case ArrayDataType::kComplex64:
return 8; return 8;
case ArrayDataType::kFloat64:
return 8;
// Usually not critical limitation because strings are only input and/or // Usually not critical limitation because strings are only input and/or
// output. // output.
@ -2307,6 +2309,10 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
return ArrayDataType::kString; return ArrayDataType::kString;
case COMPLEX64: case COMPLEX64:
return ArrayDataType::kComplex64; return ArrayDataType::kComplex64;
case FLOAT16:
return ArrayDataType::kFloat16;
case FLOAT64:
return ArrayDataType::kFloat64;
default: default:
return ArrayDataType::kNone; return ArrayDataType::kNone;
} }

View File

@ -49,4 +49,7 @@ enum IODataType {
// Half precision float, not quantized. // Half precision float, not quantized.
FLOAT16 = 10; FLOAT16 = 10;
// Double precision float, not quantized.
FLOAT64 = 11;
} }

View File

@ -497,6 +497,10 @@ BenchmarkTfLiteModel::CreateRandomTensorData(const TfLiteTensor& t,
#endif // TFLITE_ENABLE_FP16_CPU_BENCHMARKS #endif // TFLITE_ENABLE_FP16_CPU_BENCHMARKS
break; break;
} }
case kTfLiteFloat64: {
return CreateInputTensorData<double>(
num_elements, std::uniform_real_distribution<double>(-0.5, 0.5));
}
case kTfLiteInt64: { case kTfLiteInt64: {
int low = has_value_range ? low_range : 0; int low = has_value_range ? low_range : 0;
int high = has_value_range ? high_range : 99; int high = has_value_range ? high_range : 99;

View File

@ -236,6 +236,7 @@ typedef enum {
kTfLiteComplex64 = 8, kTfLiteComplex64 = 8,
kTfLiteInt8 = 9, kTfLiteInt8 = 9,
kTfLiteFloat16 = 10, kTfLiteFloat16 = 10,
kTfLiteFloat64 = 11,
} TfLiteType; } TfLiteType;
// Return the name of a given type, for error reporting purposes. // Return the name of a given type, for error reporting purposes.

View File

@ -350,6 +350,9 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
case TensorType_FLOAT16: case TensorType_FLOAT16:
bytes_required *= sizeof(uint16_t); bytes_required *= sizeof(uint16_t);
break; break;
case TensorType_FLOAT64:
bytes_required *= sizeof(double);
break;
case TensorType_INT32: case TensorType_INT32:
bytes_required *= sizeof(int32_t); bytes_required *= sizeof(int32_t);
break; break;

View File

@ -74,5 +74,9 @@ template <>
constexpr TfLiteType typeToTfLiteType<TfLiteFloat16>() { constexpr TfLiteType typeToTfLiteType<TfLiteFloat16>() {
return kTfLiteFloat16; return kTfLiteFloat16;
} }
template <>
constexpr TfLiteType typeToTfLiteType<double>() {
return kTfLiteFloat64;
}
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_LITE_TYPE_TO_TFLITETYPE_H_ #endif // TENSORFLOW_LITE_TYPE_TO_TFLITETYPE_H_

View File

@ -103,6 +103,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
case kTfLiteFloat16: case kTfLiteFloat16:
*bytes = sizeof(TfLiteFloat16); *bytes = sizeof(TfLiteFloat16);
break; break;
case kTfLiteFloat64:
*bytes = sizeof(double);
break;
default: default:
if (context) { if (context) {
context->ReportError( context->ReportError(