Resource & variant type additions to TFLite schema
PiperOrigin-RevId: 354638976 Change-Id: I104c4de542b68e2660887a4e3ec45631a056dd74
This commit is contained in:
parent
c6de267c22
commit
6574fc4e08
@ -131,6 +131,10 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
return DT_COMPLEX64;
|
||||
case toco::IODataType::COMPLEX128:
|
||||
return DT_COMPLEX128;
|
||||
case toco::IODataType::RESOURCE:
|
||||
return DT_RESOURCE;
|
||||
case toco::IODataType::VARIANT:
|
||||
return DT_VARIANT;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
|
@ -59,6 +59,10 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
|
||||
return builder.getIntegerType(8);
|
||||
case tflite::TensorType_UINT64:
|
||||
return builder.getIntegerType(64, /*isSigned=*/false);
|
||||
case tflite::TensorType_RESOURCE:
|
||||
return mlir::TF::ResourceType::get(builder.getContext());
|
||||
case tflite::TensorType_VARIANT:
|
||||
return mlir::TF::VariantType::get(builder.getContext());
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,6 +94,10 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
|
||||
return tensorflow::DT_UINT8;
|
||||
case tflite::TensorType_UINT64:
|
||||
return tensorflow::DT_UINT64;
|
||||
case tflite::TensorType_RESOURCE:
|
||||
return tensorflow::DT_RESOURCE;
|
||||
case tflite::TensorType_VARIANT:
|
||||
return tensorflow::DT_VARIANT;
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,10 +107,14 @@ StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
|
||||
return tflite::TensorType_BOOL;
|
||||
case tensorflow::DT_COMPLEX64:
|
||||
return tflite::TensorType_COMPLEX64;
|
||||
case tensorflow::DT_COMPLEX128:
|
||||
return tflite::TensorType_COMPLEX128;
|
||||
case tensorflow::DT_HALF:
|
||||
return tflite::TensorType_FLOAT16;
|
||||
case tensorflow::DT_FLOAT:
|
||||
return tflite::TensorType_FLOAT32;
|
||||
case tensorflow::DT_DOUBLE:
|
||||
return tflite::TensorType_FLOAT64;
|
||||
case tensorflow::DT_INT8:
|
||||
return tflite::TensorType_INT8;
|
||||
case tensorflow::DT_INT16:
|
||||
@ -111,10 +123,16 @@ StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
|
||||
return tflite::TensorType_INT32;
|
||||
case tensorflow::DT_INT64:
|
||||
return tflite::TensorType_INT64;
|
||||
case tensorflow::DT_UINT64:
|
||||
return tflite::TensorType_UINT64;
|
||||
case tensorflow::DT_STRING:
|
||||
return tflite::TensorType_STRING;
|
||||
case tensorflow::DT_UINT8:
|
||||
return tflite::TensorType_UINT8;
|
||||
case tensorflow::DT_RESOURCE:
|
||||
return tflite::TensorType_RESOURCE;
|
||||
case tensorflow::DT_VARIANT:
|
||||
return tflite::TensorType_VARIANT;
|
||||
default:
|
||||
return errors::InvalidArgument("unsupported tensor data type", type);
|
||||
}
|
||||
|
@ -73,6 +73,8 @@ typedef enum {
|
||||
kTfLiteFloat64 = 11,
|
||||
kTfLiteComplex128 = 12,
|
||||
kTfLiteUInt64 = 13,
|
||||
kTfLiteResource = 14,
|
||||
kTfLiteVariant = 15,
|
||||
} TfLiteType;
|
||||
|
||||
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
|
||||
|
@ -219,6 +219,10 @@ const char* TfLiteTypeGetName(TfLiteType type) {
|
||||
return "FLOAT16";
|
||||
case kTfLiteFloat64:
|
||||
return "FLOAT64";
|
||||
case kTfLiteResource:
|
||||
return "RESOURCE";
|
||||
case kTfLiteVariant:
|
||||
return "VARIANT";
|
||||
}
|
||||
return "Unknown type";
|
||||
}
|
||||
|
@ -91,6 +91,8 @@ TEST(Types, TestTypeNames) {
|
||||
EXPECT_EQ(type_name(kTfLiteComplex64), "COMPLEX64");
|
||||
EXPECT_EQ(type_name(kTfLiteComplex128), "COMPLEX128");
|
||||
EXPECT_EQ(type_name(kTfLiteString), "STRING");
|
||||
EXPECT_EQ(type_name(kTfLiteResource), "RESOURCE");
|
||||
EXPECT_EQ(type_name(kTfLiteVariant), "VARIANT");
|
||||
}
|
||||
|
||||
TEST(Quantization, TestQuantizationFree) {
|
||||
|
@ -872,6 +872,12 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
||||
case TensorType_COMPLEX128:
|
||||
*type = kTfLiteComplex128;
|
||||
return kTfLiteOk;
|
||||
case TensorType_RESOURCE:
|
||||
*type = kTfLiteResource;
|
||||
return kTfLiteOk;
|
||||
case TensorType_VARIANT:
|
||||
*type = kTfLiteVariant;
|
||||
return kTfLiteOk;
|
||||
default:
|
||||
*type = kTfLiteNoType;
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
|
@ -1210,7 +1210,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
|
||||
// ensure the buffer is large enough. However, we need to skip string tensors
|
||||
// and sparse tensors because their sizes change with the contents.
|
||||
// TODO(b/145615516): Extend BytesRequired to check sparse tensors.
|
||||
if (type != kTfLiteString && sparsity == nullptr) {
|
||||
if (type != kTfLiteString && type != kTfLiteResource &&
|
||||
type != kTfLiteVariant && sparsity == nullptr) {
|
||||
size_t required_bytes;
|
||||
TF_LITE_ENSURE_OK(&context_,
|
||||
BytesRequired(type, dims, rank, &required_bytes));
|
||||
@ -1262,7 +1263,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
|
||||
TF_LITE_ENSURE(&context_,
|
||||
tensor_index < context_.tensors_size && tensor_index >= 0);
|
||||
size_t required_bytes = 0;
|
||||
if (type != kTfLiteString) {
|
||||
if (type != kTfLiteString && type != kTfLiteResource &&
|
||||
type != kTfLiteVariant) {
|
||||
// These types will be allocated in our arena so we need to record how
|
||||
// many bytes we will need based on the dimensions. String tensors are
|
||||
// allocated dynamically and we can't know ahead of time how much space
|
||||
@ -1272,7 +1274,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
|
||||
}
|
||||
|
||||
TfLiteAllocationType allocation_type = kTfLiteArenaRw;
|
||||
if (type == kTfLiteString) {
|
||||
if (type == kTfLiteString || type == kTfLiteResource ||
|
||||
type == kTfLiteVariant) {
|
||||
if (is_variable) {
|
||||
// We don't have a real use case for string variable tensor.
|
||||
ReportError("String variable tensor isn't supported.");
|
||||
@ -1315,7 +1318,8 @@ TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor,
|
||||
tensor->allocation_type == kTfLiteCustom) {
|
||||
tensor_resized_since_op_invoke_ |=
|
||||
TfLiteIntArrayEqual(tensor->dims, new_size) == 0;
|
||||
if (tensor->type != kTfLiteString) {
|
||||
if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
|
||||
tensor->type != kTfLiteVariant) {
|
||||
size_t bytesRequired;
|
||||
TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
|
||||
new_size->size, &bytesRequired);
|
||||
|
@ -84,6 +84,10 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
|
||||
return TF_STRING;
|
||||
case kTfLiteBool:
|
||||
return TF_BOOL;
|
||||
case kTfLiteResource:
|
||||
return TF_RESOURCE;
|
||||
case kTfLiteVariant:
|
||||
return TF_VARIANT;
|
||||
}
|
||||
}
|
||||
|
||||
@ -115,6 +119,10 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
|
||||
return kTfLiteString;
|
||||
case TF_BOOL:
|
||||
return kTfLiteBool;
|
||||
case TF_RESOURCE:
|
||||
return kTfLiteResource;
|
||||
case TF_VARIANT:
|
||||
return kTfLiteVariant;
|
||||
default:
|
||||
return kTfLiteNoType;
|
||||
}
|
||||
|
@ -120,6 +120,8 @@ TEST(UtilTest, TypeConversionsFromTFLite) {
|
||||
EXPECT_EQ(TF_COMPLEX128, GetTensorFlowDataType(kTfLiteComplex128));
|
||||
EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString));
|
||||
EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
|
||||
EXPECT_EQ(TF_RESOURCE, GetTensorFlowDataType(kTfLiteResource));
|
||||
EXPECT_EQ(TF_VARIANT, GetTensorFlowDataType(kTfLiteVariant));
|
||||
}
|
||||
|
||||
TEST(UtilTest, TypeConversionsFromTensorFlow) {
|
||||
@ -135,8 +137,8 @@ TEST(UtilTest, TypeConversionsFromTensorFlow) {
|
||||
EXPECT_EQ(kTfLiteComplex128, GetTensorFlowLiteType(TF_COMPLEX128));
|
||||
EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));
|
||||
EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL));
|
||||
EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE));
|
||||
EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT));
|
||||
EXPECT_EQ(kTfLiteResource, GetTensorFlowLiteType(TF_RESOURCE));
|
||||
EXPECT_EQ(kTfLiteVariant, GetTensorFlowLiteType(TF_VARIANT));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -48,9 +48,15 @@ size_t AlignSizeUp(size_t size, size_t alignment) {
|
||||
|
||||
TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {
|
||||
switch (type) {
|
||||
case kTfLiteFloat16:
|
||||
*size = sizeof(int16_t);
|
||||
break;
|
||||
case kTfLiteFloat32:
|
||||
*size = sizeof(float);
|
||||
break;
|
||||
case kTfLiteFloat64:
|
||||
*size = sizeof(double);
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
*size = sizeof(int16_t);
|
||||
break;
|
||||
|
@ -115,10 +115,18 @@ TF_LITE_MICRO_TEST(TestAlignSizeUp) {
|
||||
|
||||
TF_LITE_MICRO_TEST(TestTypeSizeOf) {
|
||||
size_t size;
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
|
||||
tflite::TfLiteTypeSizeOf(kTfLiteFloat16, &size));
|
||||
TF_LITE_MICRO_EXPECT_EQ(sizeof(int16_t), size);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
|
||||
tflite::TfLiteTypeSizeOf(kTfLiteFloat32, &size));
|
||||
TF_LITE_MICRO_EXPECT_EQ(sizeof(float), size);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
|
||||
tflite::TfLiteTypeSizeOf(kTfLiteFloat64, &size));
|
||||
TF_LITE_MICRO_EXPECT_EQ(sizeof(double), size);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
|
||||
tflite::TfLiteTypeSizeOf(kTfLiteInt16, &size));
|
||||
TF_LITE_MICRO_EXPECT_EQ(sizeof(int16_t), size);
|
||||
|
@ -422,8 +422,11 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
|
||||
case kTfLiteComplex64:
|
||||
case kTfLiteComplex128:
|
||||
case kTfLiteUInt64:
|
||||
// kTfLiteString, kTfLiteUInt64, kTfLiteComplex64 and kTfLiteComplex128 are not supported in
|
||||
// TensorFlow Lite Objc API.
|
||||
case kTfLiteResource:
|
||||
case kTfLiteVariant:
|
||||
// kTfLiteString, kTfLiteUInt64, kTfLiteComplex64, kTfLiteComplex128,
|
||||
// kTfLiteResource and kTfLiteVariant are not supported in TensorFlow Lite
|
||||
// Objc API.
|
||||
return TFLTensorDataTypeNoType;
|
||||
}
|
||||
}
|
||||
|
@ -73,6 +73,10 @@ const char* TensorTypeName(TfLiteType type) {
|
||||
return "kTfLiteFloat16";
|
||||
case kTfLiteFloat64:
|
||||
return "kTfLiteFloat64";
|
||||
case kTfLiteResource:
|
||||
return "kTfLiteResource";
|
||||
case kTfLiteVariant:
|
||||
return "kTfLiteVariant";
|
||||
}
|
||||
return "(invalid)";
|
||||
}
|
||||
|
@ -619,7 +619,8 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
|
||||
|
||||
std::vector<npy_intp> dims(tensor->dims->data,
|
||||
tensor->dims->data + tensor->dims->size);
|
||||
if (tensor->type != kTfLiteString) {
|
||||
if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
|
||||
tensor->type != kTfLiteVariant) {
|
||||
// Make a buffer copy but we must tell Numpy It owns that data or else
|
||||
// it will leak.
|
||||
void* data = malloc(tensor->bytes);
|
||||
|
@ -60,6 +60,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
|
||||
return NPY_COMPLEX64;
|
||||
case kTfLiteComplex128:
|
||||
return NPY_COMPLEX128;
|
||||
case kTfLiteResource:
|
||||
case kTfLiteVariant:
|
||||
return NPY_OBJECT;
|
||||
case kTfLiteNoType:
|
||||
return NPY_NOTYPE;
|
||||
// Avoid default so compiler errors created when new types are made.
|
||||
|
@ -91,6 +91,10 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
|
||||
return TensorType_COMPLEX64;
|
||||
case kTfLiteComplex128:
|
||||
return TensorType_COMPLEX128;
|
||||
case kTfLiteResource:
|
||||
return TensorType_RESOURCE;
|
||||
case kTfLiteVariant:
|
||||
return TensorType_VARIANT;
|
||||
}
|
||||
// No default to get compiler error when new type is introduced.
|
||||
}
|
||||
|
@ -64,6 +64,8 @@ _MAP_TF_TO_TFLITE_TYPES = {
|
||||
dtypes.int8: _types_pb2.INT8,
|
||||
dtypes.float64: _types_pb2.FLOAT64,
|
||||
dtypes.complex128: _types_pb2.COMPLEX128,
|
||||
dtypes.resource: _types_pb2.RESOURCE,
|
||||
dtypes.variant: _types_pb2.VARIANT,
|
||||
}
|
||||
|
||||
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
|
||||
|
@ -45,6 +45,8 @@ enum TensorType : byte {
|
||||
FLOAT64 = 10,
|
||||
COMPLEX128 = 11,
|
||||
UINT64 = 12,
|
||||
RESOURCE = 13,
|
||||
VARIANT = 14,
|
||||
}
|
||||
|
||||
// Custom quantization parameters for experimenting with new quantization
|
||||
|
@ -402,11 +402,13 @@ enum TensorType {
|
||||
TensorType_FLOAT64 = 10,
|
||||
TensorType_COMPLEX128 = 11,
|
||||
TensorType_UINT64 = 12,
|
||||
TensorType_RESOURCE = 13,
|
||||
TensorType_VARIANT = 14,
|
||||
TensorType_MIN = TensorType_FLOAT32,
|
||||
TensorType_MAX = TensorType_UINT64
|
||||
TensorType_MAX = TensorType_VARIANT
|
||||
};
|
||||
|
||||
inline const TensorType (&EnumValuesTensorType())[13] {
|
||||
inline const TensorType (&EnumValuesTensorType())[15] {
|
||||
static const TensorType values[] = {
|
||||
TensorType_FLOAT32,
|
||||
TensorType_FLOAT16,
|
||||
@ -420,13 +422,15 @@ inline const TensorType (&EnumValuesTensorType())[13] {
|
||||
TensorType_INT8,
|
||||
TensorType_FLOAT64,
|
||||
TensorType_COMPLEX128,
|
||||
TensorType_UINT64
|
||||
TensorType_UINT64,
|
||||
TensorType_RESOURCE,
|
||||
TensorType_VARIANT
|
||||
};
|
||||
return values;
|
||||
}
|
||||
|
||||
inline const char * const *EnumNamesTensorType() {
|
||||
static const char * const names[14] = {
|
||||
static const char * const names[16] = {
|
||||
"FLOAT32",
|
||||
"FLOAT16",
|
||||
"INT32",
|
||||
@ -440,13 +444,15 @@ inline const char * const *EnumNamesTensorType() {
|
||||
"FLOAT64",
|
||||
"COMPLEX128",
|
||||
"UINT64",
|
||||
"RESOURCE",
|
||||
"VARIANT",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameTensorType(TensorType e) {
|
||||
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT64)) return "";
|
||||
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_VARIANT)) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesTensorType()[index];
|
||||
}
|
||||
|
@ -2323,6 +2323,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
|
||||
return ArrayDataType::kFloat16;
|
||||
case FLOAT64:
|
||||
return ArrayDataType::kFloat64;
|
||||
case RESOURCE:
|
||||
case VARIANT:
|
||||
default:
|
||||
return ArrayDataType::kNone;
|
||||
}
|
||||
|
@ -58,4 +58,10 @@ enum IODataType {
|
||||
|
||||
// Uint64, not quantized
|
||||
UINT64 = 13;
|
||||
|
||||
// Resource type
|
||||
RESOURCE = 14;
|
||||
|
||||
// Variant type
|
||||
VARIANT = 15;
|
||||
}
|
||||
|
@ -86,6 +86,10 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
|
||||
return TensorType_COMPLEX64;
|
||||
case kTfLiteComplex128:
|
||||
return TensorType_COMPLEX128;
|
||||
case kTfLiteResource:
|
||||
return TensorType_RESOURCE;
|
||||
case kTfLiteVariant:
|
||||
return TensorType_VARIANT;
|
||||
}
|
||||
// TODO(aselle): consider an error
|
||||
}
|
||||
|
@ -558,7 +558,9 @@ TEST(VerifyModel, TypedTensorShapeMatchesTensorBufferSize) {
|
||||
TfLiteFlatbufferModelBuilder builder;
|
||||
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
|
||||
++tensor_type) {
|
||||
if (tensor_type == TensorType_STRING) continue;
|
||||
if (tensor_type == TensorType_STRING ||
|
||||
tensor_type == TensorType_RESOURCE || tensor_type == TensorType_VARIANT)
|
||||
continue;
|
||||
TfLiteType lite_type = kTfLiteNoType;
|
||||
ASSERT_EQ(ConvertTensorType(static_cast<TensorType>(tensor_type),
|
||||
&lite_type, /*error_reporter=*/nullptr),
|
||||
|
@ -132,8 +132,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
|
||||
if (context) {
|
||||
context->ReportError(
|
||||
context,
|
||||
"Type %d is unsupported. Only float32, int8, int16, int32, int64, "
|
||||
"uint8, bool, complex64 supported currently.",
|
||||
"Type %d is unsupported. Only float16, float32, float64, int8, "
|
||||
"int16, int32, int64, uint8, uint64, bool, complex64 and "
|
||||
"complex128 supported currently.",
|
||||
type);
|
||||
}
|
||||
return kTfLiteError;
|
||||
|
Loading…
x
Reference in New Issue
Block a user