Resource & variant type additions to TFLite schema
PiperOrigin-RevId: 354638976 Change-Id: I104c4de542b68e2660887a4e3ec45631a056dd74
This commit is contained in:
		
							parent
							
								
									c6de267c22
								
							
						
					
					
						commit
						6574fc4e08
					
				| @ -131,6 +131,10 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { | ||||
|       return DT_COMPLEX64; | ||||
|     case toco::IODataType::COMPLEX128: | ||||
|       return DT_COMPLEX128; | ||||
|     case toco::IODataType::RESOURCE: | ||||
|       return DT_RESOURCE; | ||||
|     case toco::IODataType::VARIANT: | ||||
|       return DT_VARIANT; | ||||
|     default: | ||||
|       return DT_INVALID; | ||||
|   } | ||||
|  | ||||
| @ -59,6 +59,10 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { | ||||
|       return builder.getIntegerType(8); | ||||
|     case tflite::TensorType_UINT64: | ||||
|       return builder.getIntegerType(64, /*isSigned=*/false); | ||||
|     case tflite::TensorType_RESOURCE: | ||||
|       return mlir::TF::ResourceType::get(builder.getContext()); | ||||
|     case tflite::TensorType_VARIANT: | ||||
|       return mlir::TF::VariantType::get(builder.getContext()); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| @ -90,6 +94,10 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { | ||||
|       return tensorflow::DT_UINT8; | ||||
|     case tflite::TensorType_UINT64: | ||||
|       return tensorflow::DT_UINT64; | ||||
|     case tflite::TensorType_RESOURCE: | ||||
|       return tensorflow::DT_RESOURCE; | ||||
|     case tflite::TensorType_VARIANT: | ||||
|       return tensorflow::DT_VARIANT; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| @ -99,10 +107,14 @@ StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) { | ||||
|       return tflite::TensorType_BOOL; | ||||
|     case tensorflow::DT_COMPLEX64: | ||||
|       return tflite::TensorType_COMPLEX64; | ||||
|     case tensorflow::DT_COMPLEX128: | ||||
|       return tflite::TensorType_COMPLEX128; | ||||
|     case tensorflow::DT_HALF: | ||||
|       return tflite::TensorType_FLOAT16; | ||||
|     case tensorflow::DT_FLOAT: | ||||
|       return tflite::TensorType_FLOAT32; | ||||
|     case tensorflow::DT_DOUBLE: | ||||
|       return tflite::TensorType_FLOAT64; | ||||
|     case tensorflow::DT_INT8: | ||||
|       return tflite::TensorType_INT8; | ||||
|     case tensorflow::DT_INT16: | ||||
| @ -111,10 +123,16 @@ StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) { | ||||
|       return tflite::TensorType_INT32; | ||||
|     case tensorflow::DT_INT64: | ||||
|       return tflite::TensorType_INT64; | ||||
|     case tensorflow::DT_UINT64: | ||||
|       return tflite::TensorType_UINT64; | ||||
|     case tensorflow::DT_STRING: | ||||
|       return tflite::TensorType_STRING; | ||||
|     case tensorflow::DT_UINT8: | ||||
|       return tflite::TensorType_UINT8; | ||||
|     case tensorflow::DT_RESOURCE: | ||||
|       return tflite::TensorType_RESOURCE; | ||||
|     case tensorflow::DT_VARIANT: | ||||
|       return tflite::TensorType_VARIANT; | ||||
|     default: | ||||
|       return errors::InvalidArgument("unsupported tensor data type", type); | ||||
|   } | ||||
|  | ||||
| @ -73,6 +73,8 @@ typedef enum { | ||||
|   kTfLiteFloat64 = 11, | ||||
|   kTfLiteComplex128 = 12, | ||||
|   kTfLiteUInt64 = 13, | ||||
|   kTfLiteResource = 14, | ||||
|   kTfLiteVariant = 15, | ||||
| } TfLiteType; | ||||
| 
 | ||||
| // Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
 | ||||
|  | ||||
| @ -219,6 +219,10 @@ const char* TfLiteTypeGetName(TfLiteType type) { | ||||
|       return "FLOAT16"; | ||||
|     case kTfLiteFloat64: | ||||
|       return "FLOAT64"; | ||||
|     case kTfLiteResource: | ||||
|       return "RESOURCE"; | ||||
|     case kTfLiteVariant: | ||||
|       return "VARIANT"; | ||||
|   } | ||||
|   return "Unknown type"; | ||||
| } | ||||
|  | ||||
| @ -91,6 +91,8 @@ TEST(Types, TestTypeNames) { | ||||
|   EXPECT_EQ(type_name(kTfLiteComplex64), "COMPLEX64"); | ||||
|   EXPECT_EQ(type_name(kTfLiteComplex128), "COMPLEX128"); | ||||
|   EXPECT_EQ(type_name(kTfLiteString), "STRING"); | ||||
|   EXPECT_EQ(type_name(kTfLiteResource), "RESOURCE"); | ||||
|   EXPECT_EQ(type_name(kTfLiteVariant), "VARIANT"); | ||||
| } | ||||
| 
 | ||||
| TEST(Quantization, TestQuantizationFree) { | ||||
|  | ||||
| @ -872,6 +872,12 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, | ||||
|     case TensorType_COMPLEX128: | ||||
|       *type = kTfLiteComplex128; | ||||
|       return kTfLiteOk; | ||||
|     case TensorType_RESOURCE: | ||||
|       *type = kTfLiteResource; | ||||
|       return kTfLiteOk; | ||||
|     case TensorType_VARIANT: | ||||
|       *type = kTfLiteVariant; | ||||
|       return kTfLiteOk; | ||||
|     default: | ||||
|       *type = kTfLiteNoType; | ||||
|       TF_LITE_REPORT_ERROR(error_reporter, | ||||
|  | ||||
| @ -1210,7 +1210,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly( | ||||
|   // ensure the buffer is large enough. However, we need to skip string tensors
 | ||||
|   // and sparse tensors because their sizes change with the contents.
 | ||||
|   // TODO(b/145615516): Extend BytesRequired to check sparse tensors.
 | ||||
|   if (type != kTfLiteString && sparsity == nullptr) { | ||||
|   if (type != kTfLiteString && type != kTfLiteResource && | ||||
|       type != kTfLiteVariant && sparsity == nullptr) { | ||||
|     size_t required_bytes; | ||||
|     TF_LITE_ENSURE_OK(&context_, | ||||
|                       BytesRequired(type, dims, rank, &required_bytes)); | ||||
| @ -1262,7 +1263,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite( | ||||
|   TF_LITE_ENSURE(&context_, | ||||
|                  tensor_index < context_.tensors_size && tensor_index >= 0); | ||||
|   size_t required_bytes = 0; | ||||
|   if (type != kTfLiteString) { | ||||
|   if (type != kTfLiteString && type != kTfLiteResource && | ||||
|       type != kTfLiteVariant) { | ||||
|     // These types will be allocated in our arena so we need to record how
 | ||||
|     // many bytes we will need based on the dimensions. String tensors are
 | ||||
|     // allocated dynamically and we can't know ahead of time how much space
 | ||||
| @ -1272,7 +1274,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite( | ||||
|   } | ||||
| 
 | ||||
|   TfLiteAllocationType allocation_type = kTfLiteArenaRw; | ||||
|   if (type == kTfLiteString) { | ||||
|   if (type == kTfLiteString || type == kTfLiteResource || | ||||
|       type == kTfLiteVariant) { | ||||
|     if (is_variable) { | ||||
|       // We don't have a real use case for string variable tensor.
 | ||||
|       ReportError("String variable tensor isn't supported."); | ||||
| @ -1315,7 +1318,8 @@ TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor, | ||||
|       tensor->allocation_type == kTfLiteCustom) { | ||||
|     tensor_resized_since_op_invoke_ |= | ||||
|         TfLiteIntArrayEqual(tensor->dims, new_size) == 0; | ||||
|     if (tensor->type != kTfLiteString) { | ||||
|     if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource && | ||||
|         tensor->type != kTfLiteVariant) { | ||||
|       size_t bytesRequired; | ||||
|       TfLiteStatus status = BytesRequired(tensor->type, new_size->data, | ||||
|                                           new_size->size, &bytesRequired); | ||||
|  | ||||
| @ -84,6 +84,10 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { | ||||
|       return TF_STRING; | ||||
|     case kTfLiteBool: | ||||
|       return TF_BOOL; | ||||
|     case kTfLiteResource: | ||||
|       return TF_RESOURCE; | ||||
|     case kTfLiteVariant: | ||||
|       return TF_VARIANT; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| @ -115,6 +119,10 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) { | ||||
|       return kTfLiteString; | ||||
|     case TF_BOOL: | ||||
|       return kTfLiteBool; | ||||
|     case TF_RESOURCE: | ||||
|       return kTfLiteResource; | ||||
|     case TF_VARIANT: | ||||
|       return kTfLiteVariant; | ||||
|     default: | ||||
|       return kTfLiteNoType; | ||||
|   } | ||||
|  | ||||
| @ -120,6 +120,8 @@ TEST(UtilTest, TypeConversionsFromTFLite) { | ||||
|   EXPECT_EQ(TF_COMPLEX128, GetTensorFlowDataType(kTfLiteComplex128)); | ||||
|   EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString)); | ||||
|   EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool)); | ||||
|   EXPECT_EQ(TF_RESOURCE, GetTensorFlowDataType(kTfLiteResource)); | ||||
|   EXPECT_EQ(TF_VARIANT, GetTensorFlowDataType(kTfLiteVariant)); | ||||
| } | ||||
| 
 | ||||
| TEST(UtilTest, TypeConversionsFromTensorFlow) { | ||||
| @ -135,8 +137,8 @@ TEST(UtilTest, TypeConversionsFromTensorFlow) { | ||||
|   EXPECT_EQ(kTfLiteComplex128, GetTensorFlowLiteType(TF_COMPLEX128)); | ||||
|   EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING)); | ||||
|   EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL)); | ||||
|   EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE)); | ||||
|   EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT)); | ||||
|   EXPECT_EQ(kTfLiteResource, GetTensorFlowLiteType(TF_RESOURCE)); | ||||
|   EXPECT_EQ(kTfLiteVariant, GetTensorFlowLiteType(TF_VARIANT)); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
|  | ||||
| @ -48,9 +48,15 @@ size_t AlignSizeUp(size_t size, size_t alignment) { | ||||
| 
 | ||||
| TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) { | ||||
|   switch (type) { | ||||
|     case kTfLiteFloat16: | ||||
|       *size = sizeof(int16_t); | ||||
|       break; | ||||
|     case kTfLiteFloat32: | ||||
|       *size = sizeof(float); | ||||
|       break; | ||||
|     case kTfLiteFloat64: | ||||
|       *size = sizeof(double); | ||||
|       break; | ||||
|     case kTfLiteInt16: | ||||
|       *size = sizeof(int16_t); | ||||
|       break; | ||||
|  | ||||
| @ -115,10 +115,18 @@ TF_LITE_MICRO_TEST(TestAlignSizeUp) { | ||||
| 
 | ||||
| TF_LITE_MICRO_TEST(TestTypeSizeOf) { | ||||
|   size_t size; | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, | ||||
|                           tflite::TfLiteTypeSizeOf(kTfLiteFloat16, &size)); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(sizeof(int16_t), size); | ||||
| 
 | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, | ||||
|                           tflite::TfLiteTypeSizeOf(kTfLiteFloat32, &size)); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(sizeof(float), size); | ||||
| 
 | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, | ||||
|                           tflite::TfLiteTypeSizeOf(kTfLiteFloat64, &size)); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(sizeof(double), size); | ||||
| 
 | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, | ||||
|                           tflite::TfLiteTypeSizeOf(kTfLiteInt16, &size)); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(sizeof(int16_t), size); | ||||
|  | ||||
| @ -422,8 +422,11 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_ | ||||
|     case kTfLiteComplex64: | ||||
|     case kTfLiteComplex128: | ||||
|     case kTfLiteUInt64: | ||||
|       // kTfLiteString, kTfLiteUInt64, kTfLiteComplex64 and kTfLiteComplex128 are not supported in | ||||
|       // TensorFlow Lite Objc API. | ||||
|     case kTfLiteResource: | ||||
|     case kTfLiteVariant: | ||||
|       // kTfLiteString, kTfLiteUInt64, kTfLiteComplex64, kTfLiteComplex128, | ||||
|       // kTfLiteResource and kTfLiteVariant are not supported in TensorFlow Lite | ||||
|       // Objc API. | ||||
|       return TFLTensorDataTypeNoType; | ||||
|   } | ||||
| } | ||||
|  | ||||
| @ -73,6 +73,10 @@ const char* TensorTypeName(TfLiteType type) { | ||||
|       return "kTfLiteFloat16"; | ||||
|     case kTfLiteFloat64: | ||||
|       return "kTfLiteFloat64"; | ||||
|     case kTfLiteResource: | ||||
|       return "kTfLiteResource"; | ||||
|     case kTfLiteVariant: | ||||
|       return "kTfLiteVariant"; | ||||
|   } | ||||
|   return "(invalid)"; | ||||
| } | ||||
|  | ||||
| @ -619,7 +619,8 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { | ||||
| 
 | ||||
|   std::vector<npy_intp> dims(tensor->dims->data, | ||||
|                              tensor->dims->data + tensor->dims->size); | ||||
|   if (tensor->type != kTfLiteString) { | ||||
|   if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource && | ||||
|       tensor->type != kTfLiteVariant) { | ||||
|     // Make a buffer copy but we must tell Numpy It owns that data or else
 | ||||
|     // it will leak.
 | ||||
|     void* data = malloc(tensor->bytes); | ||||
|  | ||||
| @ -60,6 +60,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { | ||||
|       return NPY_COMPLEX64; | ||||
|     case kTfLiteComplex128: | ||||
|       return NPY_COMPLEX128; | ||||
|     case kTfLiteResource: | ||||
|     case kTfLiteVariant: | ||||
|       return NPY_OBJECT; | ||||
|     case kTfLiteNoType: | ||||
|       return NPY_NOTYPE; | ||||
|       // Avoid default so compiler errors created when new types are made.
 | ||||
|  | ||||
| @ -91,6 +91,10 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { | ||||
|       return TensorType_COMPLEX64; | ||||
|     case kTfLiteComplex128: | ||||
|       return TensorType_COMPLEX128; | ||||
|     case kTfLiteResource: | ||||
|       return TensorType_RESOURCE; | ||||
|     case kTfLiteVariant: | ||||
|       return TensorType_VARIANT; | ||||
|   } | ||||
|   // No default to get compiler error when new type is introduced.
 | ||||
| } | ||||
|  | ||||
| @ -64,6 +64,8 @@ _MAP_TF_TO_TFLITE_TYPES = { | ||||
|     dtypes.int8: _types_pb2.INT8, | ||||
|     dtypes.float64: _types_pb2.FLOAT64, | ||||
|     dtypes.complex128: _types_pb2.COMPLEX128, | ||||
|     dtypes.resource: _types_pb2.RESOURCE, | ||||
|     dtypes.variant: _types_pb2.VARIANT, | ||||
| } | ||||
| 
 | ||||
| _MAP_TFLITE_ENUM_TO_TF_TYPES = { | ||||
|  | ||||
| @ -45,6 +45,8 @@ enum TensorType : byte { | ||||
|   FLOAT64 = 10, | ||||
|   COMPLEX128 = 11, | ||||
|   UINT64 = 12, | ||||
|   RESOURCE = 13, | ||||
|   VARIANT = 14, | ||||
| } | ||||
| 
 | ||||
| // Custom quantization parameters for experimenting with new quantization | ||||
|  | ||||
| @ -402,11 +402,13 @@ enum TensorType { | ||||
|   TensorType_FLOAT64 = 10, | ||||
|   TensorType_COMPLEX128 = 11, | ||||
|   TensorType_UINT64 = 12, | ||||
|   TensorType_RESOURCE = 13, | ||||
|   TensorType_VARIANT = 14, | ||||
|   TensorType_MIN = TensorType_FLOAT32, | ||||
|   TensorType_MAX = TensorType_UINT64 | ||||
|   TensorType_MAX = TensorType_VARIANT | ||||
| }; | ||||
| 
 | ||||
| inline const TensorType (&EnumValuesTensorType())[13] { | ||||
| inline const TensorType (&EnumValuesTensorType())[15] { | ||||
|   static const TensorType values[] = { | ||||
|     TensorType_FLOAT32, | ||||
|     TensorType_FLOAT16, | ||||
| @ -420,13 +422,15 @@ inline const TensorType (&EnumValuesTensorType())[13] { | ||||
|     TensorType_INT8, | ||||
|     TensorType_FLOAT64, | ||||
|     TensorType_COMPLEX128, | ||||
|     TensorType_UINT64 | ||||
|     TensorType_UINT64, | ||||
|     TensorType_RESOURCE, | ||||
|     TensorType_VARIANT | ||||
|   }; | ||||
|   return values; | ||||
| } | ||||
| 
 | ||||
| inline const char * const *EnumNamesTensorType() { | ||||
|   static const char * const names[14] = { | ||||
|   static const char * const names[16] = { | ||||
|     "FLOAT32", | ||||
|     "FLOAT16", | ||||
|     "INT32", | ||||
| @ -440,13 +444,15 @@ inline const char * const *EnumNamesTensorType() { | ||||
|     "FLOAT64", | ||||
|     "COMPLEX128", | ||||
|     "UINT64", | ||||
|     "RESOURCE", | ||||
|     "VARIANT", | ||||
|     nullptr | ||||
|   }; | ||||
|   return names; | ||||
| } | ||||
| 
 | ||||
| inline const char *EnumNameTensorType(TensorType e) { | ||||
|   if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT64)) return ""; | ||||
|   if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_VARIANT)) return ""; | ||||
|   const size_t index = static_cast<size_t>(e); | ||||
|   return EnumNamesTensorType()[index]; | ||||
| } | ||||
|  | ||||
| @ -2323,6 +2323,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { | ||||
|       return ArrayDataType::kFloat16; | ||||
|     case FLOAT64: | ||||
|       return ArrayDataType::kFloat64; | ||||
|     case RESOURCE: | ||||
|     case VARIANT: | ||||
|     default: | ||||
|       return ArrayDataType::kNone; | ||||
|   } | ||||
|  | ||||
| @ -58,4 +58,10 @@ enum IODataType { | ||||
| 
 | ||||
|   // Uint64, not quantized | ||||
|   UINT64 = 13; | ||||
| 
 | ||||
|   // Resource type | ||||
|   RESOURCE = 14; | ||||
| 
 | ||||
|   // Variant type | ||||
|   VARIANT = 15; | ||||
| } | ||||
|  | ||||
| @ -86,6 +86,10 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { | ||||
|       return TensorType_COMPLEX64; | ||||
|     case kTfLiteComplex128: | ||||
|       return TensorType_COMPLEX128; | ||||
|     case kTfLiteResource: | ||||
|       return TensorType_RESOURCE; | ||||
|     case kTfLiteVariant: | ||||
|       return TensorType_VARIANT; | ||||
|   } | ||||
|   // TODO(aselle): consider an error
 | ||||
| } | ||||
|  | ||||
| @ -558,7 +558,9 @@ TEST(VerifyModel, TypedTensorShapeMatchesTensorBufferSize) { | ||||
|   TfLiteFlatbufferModelBuilder builder; | ||||
|   for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX; | ||||
|        ++tensor_type) { | ||||
|     if (tensor_type == TensorType_STRING) continue; | ||||
|     if (tensor_type == TensorType_STRING || | ||||
|         tensor_type == TensorType_RESOURCE || tensor_type == TensorType_VARIANT) | ||||
|       continue; | ||||
|     TfLiteType lite_type = kTfLiteNoType; | ||||
|     ASSERT_EQ(ConvertTensorType(static_cast<TensorType>(tensor_type), | ||||
|                                 &lite_type, /*error_reporter=*/nullptr), | ||||
|  | ||||
| @ -132,8 +132,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type, | ||||
|       if (context) { | ||||
|         context->ReportError( | ||||
|             context, | ||||
|             "Type %d is unsupported. Only float32, int8, int16, int32, int64, " | ||||
|             "uint8, bool, complex64 supported currently.", | ||||
|             "Type %d is unsupported. Only float16, float32, float64, int8, " | ||||
|             "int16, int32, int64, uint8, uint64, bool, complex64 and " | ||||
|             "complex128 supported currently.", | ||||
|             type); | ||||
|       } | ||||
|       return kTfLiteError; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user