Add float64 tensor support in TFLite
PiperOrigin-RevId: 304756523 Change-Id: I6e9f3196a700b3b43cc9b9fb06b72938db651582
This commit is contained in:
parent
0d61fc79f5
commit
845adc40a6
|
@ -138,6 +138,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
|||
return tflite::TensorType_FLOAT32;
|
||||
case mlir::StandardTypes::F16:
|
||||
return tflite::TensorType_FLOAT16;
|
||||
case mlir::StandardTypes::F64:
|
||||
return tflite::TensorType_FLOAT64;
|
||||
case mlir::TF::TensorFlowTypes::STRING:
|
||||
return tflite::TensorType_STRING;
|
||||
case mlir::TF::TensorFlowTypes::QUINT8:
|
||||
|
|
|
@ -353,6 +353,22 @@ StatusOr<mlir::ElementsAttr> ConvertFloatBuffer(
|
|||
}
|
||||
return DenseElementsAttr::get(shaped_type, ArrayRef<float>(values));
|
||||
}
|
||||
case 64: {
|
||||
assert(bytes_len % 8 == 0);
|
||||
size_t elem_count = bytes_len / 8;
|
||||
std::vector<double> values;
|
||||
values.reserve(elem_count);
|
||||
|
||||
const char* data = reinterpret_cast<const char*>(buffer.data());
|
||||
|
||||
for (int i = 0; i < elem_count; i++) {
|
||||
uint64_t bit_repr =
|
||||
llvm::support::endian::readNext<uint64_t, llvm::support::little,
|
||||
llvm::support::unaligned>(data);
|
||||
values.push_back(absl::bit_cast<double>(bit_repr));
|
||||
}
|
||||
return DenseElementsAttr::get(shaped_type, ArrayRef<double>(values));
|
||||
}
|
||||
}
|
||||
return errors::InvalidArgument("unsupported bit width", elem_type.getWidth());
|
||||
}
|
||||
|
|
|
@ -105,6 +105,10 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
|||
switch (dtype) {
|
||||
case toco::IODataType::FLOAT:
|
||||
return DT_FLOAT;
|
||||
case toco::IODataType::FLOAT16:
|
||||
return DT_HALF;
|
||||
case toco::IODataType::FLOAT64:
|
||||
return DT_DOUBLE;
|
||||
case toco::IODataType::QUANTIZED_UINT8:
|
||||
return DT_QUINT8;
|
||||
case toco::IODataType::INT8:
|
||||
|
|
|
@ -28,6 +28,13 @@ func @f32() -> tensor<4xf32> {
|
|||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
func @f64() -> tensor<4xf64> {
|
||||
// CHECK-LABEL: @f64
|
||||
// CHECK: value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64>
|
||||
%0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> } : () -> tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
||||
func @i8() -> tensor<4xi8> {
|
||||
// CHECK-LABEL: @i8
|
||||
// CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8>
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
|
||||
^bb0(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "FlexAdd"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: FLOAT64,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: FLOAT64,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: FLOAT64,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "add",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ],
|
||||
// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 2, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add")
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
|
@ -32,10 +32,12 @@ namespace errors = tensorflow::errors;
|
|||
|
||||
mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
|
||||
switch (type) {
|
||||
case tflite::TensorType_FLOAT32:
|
||||
return builder.getF32Type();
|
||||
case tflite::TensorType_FLOAT16:
|
||||
return builder.getF16Type();
|
||||
case tflite::TensorType_FLOAT32:
|
||||
return builder.getF32Type();
|
||||
case tflite::TensorType_FLOAT64:
|
||||
return builder.getF64Type();
|
||||
case tflite::TensorType_INT32:
|
||||
return builder.getIntegerType(32);
|
||||
case tflite::TensorType_UINT8:
|
||||
|
@ -65,6 +67,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
|
|||
return tensorflow::DT_HALF;
|
||||
case tflite::TensorType_FLOAT32:
|
||||
return tensorflow::DT_FLOAT;
|
||||
case tflite::TensorType_FLOAT64:
|
||||
return tensorflow::DT_DOUBLE;
|
||||
case tflite::TensorType_INT8:
|
||||
return tensorflow::DT_INT8;
|
||||
case tflite::TensorType_INT16:
|
||||
|
|
|
@ -209,6 +209,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
|
|||
return "STRING";
|
||||
case kTfLiteFloat16:
|
||||
return "FLOAT16";
|
||||
case kTfLiteFloat64:
|
||||
return "FLOAT64";
|
||||
}
|
||||
return "Unknown type";
|
||||
}
|
||||
|
|
|
@ -236,6 +236,7 @@ typedef enum {
|
|||
kTfLiteComplex64 = 8,
|
||||
kTfLiteInt8 = 9,
|
||||
kTfLiteFloat16 = 10,
|
||||
kTfLiteFloat64 = 11,
|
||||
} TfLiteType;
|
||||
|
||||
// Return the name of a given type, for error reporting purposes.
|
||||
|
|
|
@ -91,11 +91,14 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
|||
ErrorReporter* error_reporter) {
|
||||
*type = kTfLiteNoType;
|
||||
switch (tensor_type) {
|
||||
case TensorType_FLOAT16:
|
||||
*type = kTfLiteFloat16;
|
||||
break;
|
||||
case TensorType_FLOAT32:
|
||||
*type = kTfLiteFloat32;
|
||||
break;
|
||||
case TensorType_FLOAT16:
|
||||
*type = kTfLiteFloat16;
|
||||
case TensorType_FLOAT64:
|
||||
*type = kTfLiteFloat64;
|
||||
break;
|
||||
case TensorType_INT16:
|
||||
*type = kTfLiteInt16;
|
||||
|
|
|
@ -62,6 +62,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
|
|||
return TF_FLOAT;
|
||||
case kTfLiteFloat16:
|
||||
return TF_HALF;
|
||||
case kTfLiteFloat64:
|
||||
return TF_DOUBLE;
|
||||
case kTfLiteInt16:
|
||||
return TF_INT16;
|
||||
case kTfLiteInt32:
|
||||
|
|
|
@ -49,6 +49,9 @@ typedef NS_ENUM(NSUInteger, TFLTensorDataType) {
|
|||
|
||||
/** 8-bit signed integer. */
|
||||
TFLTensorDataTypeInt8,
|
||||
|
||||
/** 64-bit double precision floating point. */
|
||||
TFLTensorDataTypeFloat64,
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -373,6 +373,8 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
|
|||
return TFLTensorDataTypeFloat32;
|
||||
case kTfLiteFloat16:
|
||||
return TFLTensorDataTypeFloat16;
|
||||
case kTfLiteFloat64:
|
||||
return TFLTensorDataTypeFloat64;
|
||||
case kTfLiteInt32:
|
||||
return TFLTensorDataTypeInt32;
|
||||
case kTfLiteUInt8:
|
||||
|
|
|
@ -73,6 +73,8 @@ extension Tensor {
|
|||
case float16
|
||||
/// A 32-bit single precision floating point.
|
||||
case float32
|
||||
/// A 64-bit double precision floating point.
|
||||
case float64
|
||||
|
||||
/// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported
|
||||
/// or could not be determined because there was an error.
|
||||
|
@ -94,6 +96,8 @@ extension Tensor {
|
|||
self = .float16
|
||||
case kTfLiteFloat32:
|
||||
self = .float32
|
||||
case kTfLiteFloat64:
|
||||
self = .float64
|
||||
case kTfLiteNoType:
|
||||
fallthrough
|
||||
default:
|
||||
|
|
|
@ -64,6 +64,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
|
|||
return TensorType_FLOAT32;
|
||||
case kTfLiteFloat16:
|
||||
return TensorType_FLOAT16;
|
||||
case kTfLiteFloat64:
|
||||
return TensorType_FLOAT64;
|
||||
case kTfLiteInt32:
|
||||
return TensorType_INT32;
|
||||
case kTfLiteUInt8:
|
||||
|
|
|
@ -790,6 +790,7 @@ template <typename T>
|
|||
TensorType GetTensorType() {
|
||||
if (std::is_same<T, float>::value) return TensorType_FLOAT32;
|
||||
if (std::is_same<T, TfLiteFloat16>::value) return TensorType_FLOAT16;
|
||||
if (std::is_same<T, double>::value) return TensorType_FLOAT64;
|
||||
if (std::is_same<T, int8_t>::value) return TensorType_INT8;
|
||||
if (std::is_same<T, int16_t>::value) return TensorType_INT16;
|
||||
if (std::is_same<T, int32_t>::value) return TensorType_INT32;
|
||||
|
|
|
@ -77,6 +77,8 @@ const char* TensorTypeName(TfLiteType type) {
|
|||
return "kTfLiteComplex64";
|
||||
case kTfLiteFloat16:
|
||||
return "kTfLiteFloat16";
|
||||
case kTfLiteFloat64:
|
||||
return "kTfLiteFloat64";
|
||||
}
|
||||
return "(invalid)";
|
||||
}
|
||||
|
|
|
@ -59,6 +59,8 @@ const char* TensorTypeName(TfLiteType type) {
|
|||
return "kTfLiteComplex64";
|
||||
case kTfLiteFloat16:
|
||||
return "kTfLiteFloat16";
|
||||
case kTfLiteFloat64:
|
||||
return "kTfLiteFloat64";
|
||||
}
|
||||
return "(invalid)";
|
||||
}
|
||||
|
|
|
@ -38,6 +38,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
|
|||
return NPY_FLOAT32;
|
||||
case kTfLiteFloat16:
|
||||
return NPY_FLOAT16;
|
||||
case kTfLiteFloat64:
|
||||
return NPY_FLOAT64;
|
||||
case kTfLiteInt32:
|
||||
return NPY_INT32;
|
||||
case kTfLiteInt16:
|
||||
|
|
|
@ -68,6 +68,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
|
|||
return TensorType_FLOAT32;
|
||||
case kTfLiteFloat16:
|
||||
return TensorType_FLOAT16;
|
||||
case kTfLiteFloat64:
|
||||
return TensorType_FLOAT64;
|
||||
case kTfLiteInt32:
|
||||
return TensorType_INT32;
|
||||
case kTfLiteUInt8:
|
||||
|
|
|
@ -43,6 +43,7 @@ from tensorflow.python.training.saver import export_meta_graph as _export_meta_g
|
|||
_MAP_TF_TO_TFLITE_TYPES = {
|
||||
dtypes.float32: _types_pb2.FLOAT,
|
||||
dtypes.float16: _types_pb2.FLOAT16,
|
||||
dtypes.float64: _types_pb2.FLOAT64,
|
||||
dtypes.int32: _types_pb2.INT32,
|
||||
dtypes.int64: _types_pb2.INT64,
|
||||
dtypes.string: _types_pb2.STRING,
|
||||
|
|
|
@ -40,6 +40,7 @@ enum TensorType : byte {
|
|||
INT16 = 7,
|
||||
COMPLEX64 = 8,
|
||||
INT8 = 9,
|
||||
FLOAT64 = 10,
|
||||
}
|
||||
|
||||
// Custom quantization parameters for experimenting with new quantization
|
||||
|
|
|
@ -378,11 +378,12 @@ enum TensorType {
|
|||
TensorType_INT16 = 7,
|
||||
TensorType_COMPLEX64 = 8,
|
||||
TensorType_INT8 = 9,
|
||||
TensorType_FLOAT64 = 10,
|
||||
TensorType_MIN = TensorType_FLOAT32,
|
||||
TensorType_MAX = TensorType_INT8
|
||||
TensorType_MAX = TensorType_FLOAT64
|
||||
};
|
||||
|
||||
inline const TensorType (&EnumValuesTensorType())[10] {
|
||||
inline const TensorType (&EnumValuesTensorType())[11] {
|
||||
static const TensorType values[] = {
|
||||
TensorType_FLOAT32,
|
||||
TensorType_FLOAT16,
|
||||
|
@ -393,13 +394,14 @@ inline const TensorType (&EnumValuesTensorType())[10] {
|
|||
TensorType_BOOL,
|
||||
TensorType_INT16,
|
||||
TensorType_COMPLEX64,
|
||||
TensorType_INT8
|
||||
TensorType_INT8,
|
||||
TensorType_FLOAT64
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesTensorType() {
|
||||
static const char * const names[11] = {
|
||||
static const char * const names[12] = {
|
||||
"FLOAT32",
|
||||
"FLOAT16",
|
||||
"INT32",
|
||||
|
@ -410,13 +412,14 @@ inline const char * const *EnumNamesTensorType() {
|
|||
"INT16",
|
||||
"COMPLEX64",
|
||||
"INT8",
|
||||
"FLOAT64",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameTensorType(TensorType e) {
|
||||
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_INT8)) return "";
|
||||
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_FLOAT64)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesTensorType()[index];
|
||||
}
|
||||
|
|
|
@ -114,6 +114,18 @@ def make_binary_op_tests(options,
|
|||
},
|
||||
]
|
||||
|
||||
# float64 types are supported via flex only.
|
||||
if options.run_with_flex and options.use_experimental_converter:
|
||||
test_parameters = test_parameters + [
|
||||
{
|
||||
"dtype": [tf.float64],
|
||||
"input_shape_1": [[7]],
|
||||
"input_shape_2": [[7]],
|
||||
"activation": [False],
|
||||
"fully_quantize": [False],
|
||||
},
|
||||
]
|
||||
|
||||
# test_parameters include fully_quantize option only when
|
||||
# allow_fully_quantize is True.
|
||||
if not allow_fully_quantize:
|
||||
|
|
|
@ -75,6 +75,7 @@ RANDOM_SEED = 342
|
|||
TF_TYPE_INFO = {
|
||||
tf.float32: (np.float32, "FLOAT"),
|
||||
tf.float16: (np.float16, "FLOAT"),
|
||||
tf.float64: (np.double, "FLOAT64"),
|
||||
tf.int32: (np.int32, "INT32"),
|
||||
tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
|
||||
tf.int16: (np.int16, "QUANTIZED_INT16"),
|
||||
|
@ -108,7 +109,7 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
|
|||
if dtype in TF_TYPE_INFO:
|
||||
dtype = TF_TYPE_INFO[dtype][0]
|
||||
|
||||
if dtype in (tf.float32, tf.float16):
|
||||
if dtype in (tf.float32, tf.float16, tf.float64):
|
||||
value = (max_value - min_value) * np.random.random_sample(shape) + min_value
|
||||
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
|
||||
value = np.random.randint(min_value, max_value + 1, shape)
|
||||
|
@ -128,7 +129,7 @@ def create_scalar_data(dtype, min_value=-100, max_value=100):
|
|||
if dtype in TF_TYPE_INFO:
|
||||
dtype = TF_TYPE_INFO[dtype][0]
|
||||
|
||||
if dtype in (tf.float32, tf.float16):
|
||||
if dtype in (tf.float32, tf.float16, tf.float64):
|
||||
value = (max_value - min_value) * np.random.random() + min_value
|
||||
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
|
||||
value = np.random.randint(min_value, max_value + 1)
|
||||
|
|
|
@ -234,6 +234,7 @@ enum class ArrayDataType : uint8 {
|
|||
kString,
|
||||
kComplex64,
|
||||
kFloat16,
|
||||
kFloat64,
|
||||
};
|
||||
|
||||
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
|
||||
|
|
|
@ -51,7 +51,8 @@ namespace tflite {
|
|||
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
|
||||
{ArrayDataType::kString, ::tflite::TensorType_STRING},
|
||||
{ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64},
|
||||
{ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16}};
|
||||
{ArrayDataType::kFloat16, ::tflite::TensorType_FLOAT16},
|
||||
{ArrayDataType::kFloat64, ::tflite::TensorType_FLOAT64}};
|
||||
|
||||
auto it = tensor_type_map.find(type);
|
||||
if (it != tensor_type_map.end()) {
|
||||
|
|
|
@ -1766,6 +1766,8 @@ int ElementSize(ArrayDataType data_type) {
|
|||
return 8;
|
||||
case ArrayDataType::kComplex64:
|
||||
return 8;
|
||||
case ArrayDataType::kFloat64:
|
||||
return 8;
|
||||
|
||||
// Usually not critical limitation because strings are only input and/or
|
||||
// output.
|
||||
|
@ -2307,6 +2309,10 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
|
|||
return ArrayDataType::kString;
|
||||
case COMPLEX64:
|
||||
return ArrayDataType::kComplex64;
|
||||
case FLOAT16:
|
||||
return ArrayDataType::kFloat16;
|
||||
case FLOAT64:
|
||||
return ArrayDataType::kFloat64;
|
||||
default:
|
||||
return ArrayDataType::kNone;
|
||||
}
|
||||
|
|
|
@ -49,4 +49,7 @@ enum IODataType {
|
|||
|
||||
// Half precision float, not quantized.
|
||||
FLOAT16 = 10;
|
||||
|
||||
// Double precision float, not quantized.
|
||||
FLOAT64 = 11;
|
||||
}
|
||||
|
|
|
@ -497,6 +497,10 @@ BenchmarkTfLiteModel::CreateRandomTensorData(const TfLiteTensor& t,
|
|||
#endif // TFLITE_ENABLE_FP16_CPU_BENCHMARKS
|
||||
break;
|
||||
}
|
||||
case kTfLiteFloat64: {
|
||||
return CreateInputTensorData<double>(
|
||||
num_elements, std::uniform_real_distribution<double>(-0.5, 0.5));
|
||||
}
|
||||
case kTfLiteInt64: {
|
||||
int low = has_value_range ? low_range : 0;
|
||||
int high = has_value_range ? high_range : 99;
|
||||
|
|
|
@ -236,6 +236,7 @@ typedef enum {
|
|||
kTfLiteComplex64 = 8,
|
||||
kTfLiteInt8 = 9,
|
||||
kTfLiteFloat16 = 10,
|
||||
kTfLiteFloat64 = 11,
|
||||
} TfLiteType;
|
||||
|
||||
// Return the name of a given type, for error reporting purposes.
|
||||
|
|
|
@ -350,6 +350,9 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
|
|||
case TensorType_FLOAT16:
|
||||
bytes_required *= sizeof(uint16_t);
|
||||
break;
|
||||
case TensorType_FLOAT64:
|
||||
bytes_required *= sizeof(double);
|
||||
break;
|
||||
case TensorType_INT32:
|
||||
bytes_required *= sizeof(int32_t);
|
||||
break;
|
||||
|
|
|
@ -74,5 +74,9 @@ template <>
|
|||
constexpr TfLiteType typeToTfLiteType<TfLiteFloat16>() {
|
||||
return kTfLiteFloat16;
|
||||
}
|
||||
template <>
|
||||
constexpr TfLiteType typeToTfLiteType<double>() {
|
||||
return kTfLiteFloat64;
|
||||
}
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_TYPE_TO_TFLITETYPE_H_
|
||||
|
|
|
@ -103,6 +103,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
|
|||
case kTfLiteFloat16:
|
||||
*bytes = sizeof(TfLiteFloat16);
|
||||
break;
|
||||
case kTfLiteFloat64:
|
||||
*bytes = sizeof(double);
|
||||
break;
|
||||
default:
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
|
|
Loading…
Reference in New Issue