Resource & variant type additions to TFLite schema

PiperOrigin-RevId: 354638976
Change-Id: I104c4de542b68e2660887a4e3ec45631a056dd74
This commit is contained in:
Jaesung Chung 2021-01-29 17:02:50 -08:00 committed by TensorFlower Gardener
parent c6de267c22
commit 6574fc4e08
24 changed files with 121 additions and 17 deletions

View File

@ -131,6 +131,10 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_COMPLEX64;
case toco::IODataType::COMPLEX128:
return DT_COMPLEX128;
case toco::IODataType::RESOURCE:
return DT_RESOURCE;
case toco::IODataType::VARIANT:
return DT_VARIANT;
default:
return DT_INVALID;
}

View File

@ -59,6 +59,10 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
return builder.getIntegerType(8);
case tflite::TensorType_UINT64:
return builder.getIntegerType(64, /*isSigned=*/false);
case tflite::TensorType_RESOURCE:
return mlir::TF::ResourceType::get(builder.getContext());
case tflite::TensorType_VARIANT:
return mlir::TF::VariantType::get(builder.getContext());
}
}
@ -90,6 +94,10 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
return tensorflow::DT_UINT8;
case tflite::TensorType_UINT64:
return tensorflow::DT_UINT64;
case tflite::TensorType_RESOURCE:
return tensorflow::DT_RESOURCE;
case tflite::TensorType_VARIANT:
return tensorflow::DT_VARIANT;
}
}
@ -99,10 +107,14 @@ StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
return tflite::TensorType_BOOL;
case tensorflow::DT_COMPLEX64:
return tflite::TensorType_COMPLEX64;
case tensorflow::DT_COMPLEX128:
return tflite::TensorType_COMPLEX128;
case tensorflow::DT_HALF:
return tflite::TensorType_FLOAT16;
case tensorflow::DT_FLOAT:
return tflite::TensorType_FLOAT32;
case tensorflow::DT_DOUBLE:
return tflite::TensorType_FLOAT64;
case tensorflow::DT_INT8:
return tflite::TensorType_INT8;
case tensorflow::DT_INT16:
@ -111,10 +123,16 @@ StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
return tflite::TensorType_INT32;
case tensorflow::DT_INT64:
return tflite::TensorType_INT64;
case tensorflow::DT_UINT64:
return tflite::TensorType_UINT64;
case tensorflow::DT_STRING:
return tflite::TensorType_STRING;
case tensorflow::DT_UINT8:
return tflite::TensorType_UINT8;
case tensorflow::DT_RESOURCE:
return tflite::TensorType_RESOURCE;
case tensorflow::DT_VARIANT:
return tflite::TensorType_VARIANT;
default:
return errors::InvalidArgument("unsupported tensor data type", type);
}

View File

@ -73,6 +73,8 @@ typedef enum {
kTfLiteFloat64 = 11,
kTfLiteComplex128 = 12,
kTfLiteUInt64 = 13,
kTfLiteResource = 14,
kTfLiteVariant = 15,
} TfLiteType;
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.

View File

@ -219,6 +219,10 @@ const char* TfLiteTypeGetName(TfLiteType type) {
return "FLOAT16";
case kTfLiteFloat64:
return "FLOAT64";
case kTfLiteResource:
return "RESOURCE";
case kTfLiteVariant:
return "VARIANT";
}
return "Unknown type";
}

View File

@ -91,6 +91,8 @@ TEST(Types, TestTypeNames) {
EXPECT_EQ(type_name(kTfLiteComplex64), "COMPLEX64");
EXPECT_EQ(type_name(kTfLiteComplex128), "COMPLEX128");
EXPECT_EQ(type_name(kTfLiteString), "STRING");
EXPECT_EQ(type_name(kTfLiteResource), "RESOURCE");
EXPECT_EQ(type_name(kTfLiteVariant), "VARIANT");
}
TEST(Quantization, TestQuantizationFree) {

View File

@ -872,6 +872,12 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
case TensorType_COMPLEX128:
*type = kTfLiteComplex128;
return kTfLiteOk;
case TensorType_RESOURCE:
*type = kTfLiteResource;
return kTfLiteOk;
case TensorType_VARIANT:
*type = kTfLiteVariant;
return kTfLiteOk;
default:
*type = kTfLiteNoType;
TF_LITE_REPORT_ERROR(error_reporter,

View File

@ -1210,7 +1210,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
// ensure the buffer is large enough. However, we need to skip string tensors
// and sparse tensors because their sizes change with the contents.
// TODO(b/145615516): Extend BytesRequired to check sparse tensors.
if (type != kTfLiteString && sparsity == nullptr) {
if (type != kTfLiteString && type != kTfLiteResource &&
type != kTfLiteVariant && sparsity == nullptr) {
size_t required_bytes;
TF_LITE_ENSURE_OK(&context_,
BytesRequired(type, dims, rank, &required_bytes));
@ -1262,7 +1263,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
size_t required_bytes = 0;
if (type != kTfLiteString) {
if (type != kTfLiteString && type != kTfLiteResource &&
type != kTfLiteVariant) {
// These types will be allocated in our arena so we need to record how
// many bytes we will need based on the dimensions. String tensors are
// allocated dynamically and we can't know ahead of time how much space
@ -1272,7 +1274,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
}
TfLiteAllocationType allocation_type = kTfLiteArenaRw;
if (type == kTfLiteString) {
if (type == kTfLiteString || type == kTfLiteResource ||
type == kTfLiteVariant) {
if (is_variable) {
// We don't have a real use case for string variable tensor.
ReportError("String variable tensor isn't supported.");
@ -1315,7 +1318,8 @@ TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor,
tensor->allocation_type == kTfLiteCustom) {
tensor_resized_since_op_invoke_ |=
TfLiteIntArrayEqual(tensor->dims, new_size) == 0;
if (tensor->type != kTfLiteString) {
if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
tensor->type != kTfLiteVariant) {
size_t bytesRequired;
TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
new_size->size, &bytesRequired);

View File

@ -84,6 +84,10 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
return TF_STRING;
case kTfLiteBool:
return TF_BOOL;
case kTfLiteResource:
return TF_RESOURCE;
case kTfLiteVariant:
return TF_VARIANT;
}
}
@ -115,6 +119,10 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
return kTfLiteString;
case TF_BOOL:
return kTfLiteBool;
case TF_RESOURCE:
return kTfLiteResource;
case TF_VARIANT:
return kTfLiteVariant;
default:
return kTfLiteNoType;
}

View File

@ -120,6 +120,8 @@ TEST(UtilTest, TypeConversionsFromTFLite) {
EXPECT_EQ(TF_COMPLEX128, GetTensorFlowDataType(kTfLiteComplex128));
EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString));
EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
EXPECT_EQ(TF_RESOURCE, GetTensorFlowDataType(kTfLiteResource));
EXPECT_EQ(TF_VARIANT, GetTensorFlowDataType(kTfLiteVariant));
}
TEST(UtilTest, TypeConversionsFromTensorFlow) {
@ -135,8 +137,8 @@ TEST(UtilTest, TypeConversionsFromTensorFlow) {
EXPECT_EQ(kTfLiteComplex128, GetTensorFlowLiteType(TF_COMPLEX128));
EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));
EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL));
EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE));
EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT));
EXPECT_EQ(kTfLiteResource, GetTensorFlowLiteType(TF_RESOURCE));
EXPECT_EQ(kTfLiteVariant, GetTensorFlowLiteType(TF_VARIANT));
}
} // namespace

View File

@ -48,9 +48,15 @@ size_t AlignSizeUp(size_t size, size_t alignment) {
TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {
switch (type) {
case kTfLiteFloat16:
*size = sizeof(int16_t);
break;
case kTfLiteFloat32:
*size = sizeof(float);
break;
case kTfLiteFloat64:
*size = sizeof(double);
break;
case kTfLiteInt16:
*size = sizeof(int16_t);
break;

View File

@ -115,10 +115,18 @@ TF_LITE_MICRO_TEST(TestAlignSizeUp) {
TF_LITE_MICRO_TEST(TestTypeSizeOf) {
size_t size;
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::TfLiteTypeSizeOf(kTfLiteFloat16, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(int16_t), size);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::TfLiteTypeSizeOf(kTfLiteFloat32, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(float), size);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::TfLiteTypeSizeOf(kTfLiteFloat64, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(double), size);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::TfLiteTypeSizeOf(kTfLiteInt16, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(int16_t), size);

View File

@ -422,8 +422,11 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
case kTfLiteComplex64:
case kTfLiteComplex128:
case kTfLiteUInt64:
// kTfLiteString, kTfLiteUInt64, kTfLiteComplex64 and kTfLiteComplex128 are not supported in
// TensorFlow Lite Objc API.
case kTfLiteResource:
case kTfLiteVariant:
// kTfLiteString, kTfLiteUInt64, kTfLiteComplex64, kTfLiteComplex128,
// kTfLiteResource and kTfLiteVariant are not supported in TensorFlow Lite
// Objc API.
return TFLTensorDataTypeNoType;
}
}

View File

@ -73,6 +73,10 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteFloat16";
case kTfLiteFloat64:
return "kTfLiteFloat64";
case kTfLiteResource:
return "kTfLiteResource";
case kTfLiteVariant:
return "kTfLiteVariant";
}
return "(invalid)";
}

View File

@ -619,7 +619,8 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
if (tensor->type != kTfLiteString) {
if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
tensor->type != kTfLiteVariant) {
// Make a buffer copy but we must tell Numpy It owns that data or else
// it will leak.
void* data = malloc(tensor->bytes);

View File

@ -60,6 +60,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
return NPY_COMPLEX64;
case kTfLiteComplex128:
return NPY_COMPLEX128;
case kTfLiteResource:
case kTfLiteVariant:
return NPY_OBJECT;
case kTfLiteNoType:
return NPY_NOTYPE;
// Avoid default so compiler errors created when new types are made.

View File

@ -91,6 +91,10 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_COMPLEX64;
case kTfLiteComplex128:
return TensorType_COMPLEX128;
case kTfLiteResource:
return TensorType_RESOURCE;
case kTfLiteVariant:
return TensorType_VARIANT;
}
// No default to get compiler error when new type is introduced.
}

View File

@ -64,6 +64,8 @@ _MAP_TF_TO_TFLITE_TYPES = {
dtypes.int8: _types_pb2.INT8,
dtypes.float64: _types_pb2.FLOAT64,
dtypes.complex128: _types_pb2.COMPLEX128,
dtypes.resource: _types_pb2.RESOURCE,
dtypes.variant: _types_pb2.VARIANT,
}
_MAP_TFLITE_ENUM_TO_TF_TYPES = {

View File

@ -45,6 +45,8 @@ enum TensorType : byte {
FLOAT64 = 10,
COMPLEX128 = 11,
UINT64 = 12,
RESOURCE = 13,
VARIANT = 14,
}
// Custom quantization parameters for experimenting with new quantization

View File

@ -402,11 +402,13 @@ enum TensorType {
TensorType_FLOAT64 = 10,
TensorType_COMPLEX128 = 11,
TensorType_UINT64 = 12,
TensorType_RESOURCE = 13,
TensorType_VARIANT = 14,
TensorType_MIN = TensorType_FLOAT32,
TensorType_MAX = TensorType_UINT64
TensorType_MAX = TensorType_VARIANT
};
inline const TensorType (&EnumValuesTensorType())[13] {
inline const TensorType (&EnumValuesTensorType())[15] {
static const TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
@ -420,13 +422,15 @@ inline const TensorType (&EnumValuesTensorType())[13] {
TensorType_INT8,
TensorType_FLOAT64,
TensorType_COMPLEX128,
TensorType_UINT64
TensorType_UINT64,
TensorType_RESOURCE,
TensorType_VARIANT
};
return values;
}
inline const char * const *EnumNamesTensorType() {
static const char * const names[14] = {
static const char * const names[16] = {
"FLOAT32",
"FLOAT16",
"INT32",
@ -440,13 +444,15 @@ inline const char * const *EnumNamesTensorType() {
"FLOAT64",
"COMPLEX128",
"UINT64",
"RESOURCE",
"VARIANT",
nullptr
};
return names;
}
inline const char *EnumNameTensorType(TensorType e) {
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT64)) return "";
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_VARIANT)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesTensorType()[index];
}

View File

@ -2323,6 +2323,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
return ArrayDataType::kFloat16;
case FLOAT64:
return ArrayDataType::kFloat64;
case RESOURCE:
case VARIANT:
default:
return ArrayDataType::kNone;
}

View File

@ -58,4 +58,10 @@ enum IODataType {
// Uint64, not quantized
UINT64 = 13;
// Resource type
RESOURCE = 14;
// Variant type
VARIANT = 15;
}

View File

@ -86,6 +86,10 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_COMPLEX64;
case kTfLiteComplex128:
return TensorType_COMPLEX128;
case kTfLiteResource:
return TensorType_RESOURCE;
case kTfLiteVariant:
return TensorType_VARIANT;
}
// TODO(aselle): consider an error
}

View File

@ -558,7 +558,9 @@ TEST(VerifyModel, TypedTensorShapeMatchesTensorBufferSize) {
TfLiteFlatbufferModelBuilder builder;
for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
++tensor_type) {
if (tensor_type == TensorType_STRING) continue;
if (tensor_type == TensorType_STRING ||
tensor_type == TensorType_RESOURCE || tensor_type == TensorType_VARIANT)
continue;
TfLiteType lite_type = kTfLiteNoType;
ASSERT_EQ(ConvertTensorType(static_cast<TensorType>(tensor_type),
&lite_type, /*error_reporter=*/nullptr),

View File

@ -132,8 +132,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
if (context) {
context->ReportError(
context,
"Type %d is unsupported. Only float32, int8, int16, int32, int64, "
"uint8, bool, complex64 supported currently.",
"Type %d is unsupported. Only float16, float32, float64, int8, "
"int16, int32, int64, uint8, uint64, bool, complex64 and "
"complex128 supported currently.",
type);
}
return kTfLiteError;