Resource & variant type additions to TFLite schema
PiperOrigin-RevId: 354638976 Change-Id: I104c4de542b68e2660887a4e3ec45631a056dd74
This commit is contained in:
		
							parent
							
								
									c6de267c22
								
							
						
					
					
						commit
						6574fc4e08
					
				@ -131,6 +131,10 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
 | 
			
		||||
      return DT_COMPLEX64;
 | 
			
		||||
    case toco::IODataType::COMPLEX128:
 | 
			
		||||
      return DT_COMPLEX128;
 | 
			
		||||
    case toco::IODataType::RESOURCE:
 | 
			
		||||
      return DT_RESOURCE;
 | 
			
		||||
    case toco::IODataType::VARIANT:
 | 
			
		||||
      return DT_VARIANT;
 | 
			
		||||
    default:
 | 
			
		||||
      return DT_INVALID;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -59,6 +59,10 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
 | 
			
		||||
      return builder.getIntegerType(8);
 | 
			
		||||
    case tflite::TensorType_UINT64:
 | 
			
		||||
      return builder.getIntegerType(64, /*isSigned=*/false);
 | 
			
		||||
    case tflite::TensorType_RESOURCE:
 | 
			
		||||
      return mlir::TF::ResourceType::get(builder.getContext());
 | 
			
		||||
    case tflite::TensorType_VARIANT:
 | 
			
		||||
      return mlir::TF::VariantType::get(builder.getContext());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -90,6 +94,10 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
 | 
			
		||||
      return tensorflow::DT_UINT8;
 | 
			
		||||
    case tflite::TensorType_UINT64:
 | 
			
		||||
      return tensorflow::DT_UINT64;
 | 
			
		||||
    case tflite::TensorType_RESOURCE:
 | 
			
		||||
      return tensorflow::DT_RESOURCE;
 | 
			
		||||
    case tflite::TensorType_VARIANT:
 | 
			
		||||
      return tensorflow::DT_VARIANT;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -99,10 +107,14 @@ StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
 | 
			
		||||
      return tflite::TensorType_BOOL;
 | 
			
		||||
    case tensorflow::DT_COMPLEX64:
 | 
			
		||||
      return tflite::TensorType_COMPLEX64;
 | 
			
		||||
    case tensorflow::DT_COMPLEX128:
 | 
			
		||||
      return tflite::TensorType_COMPLEX128;
 | 
			
		||||
    case tensorflow::DT_HALF:
 | 
			
		||||
      return tflite::TensorType_FLOAT16;
 | 
			
		||||
    case tensorflow::DT_FLOAT:
 | 
			
		||||
      return tflite::TensorType_FLOAT32;
 | 
			
		||||
    case tensorflow::DT_DOUBLE:
 | 
			
		||||
      return tflite::TensorType_FLOAT64;
 | 
			
		||||
    case tensorflow::DT_INT8:
 | 
			
		||||
      return tflite::TensorType_INT8;
 | 
			
		||||
    case tensorflow::DT_INT16:
 | 
			
		||||
@ -111,10 +123,16 @@ StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
 | 
			
		||||
      return tflite::TensorType_INT32;
 | 
			
		||||
    case tensorflow::DT_INT64:
 | 
			
		||||
      return tflite::TensorType_INT64;
 | 
			
		||||
    case tensorflow::DT_UINT64:
 | 
			
		||||
      return tflite::TensorType_UINT64;
 | 
			
		||||
    case tensorflow::DT_STRING:
 | 
			
		||||
      return tflite::TensorType_STRING;
 | 
			
		||||
    case tensorflow::DT_UINT8:
 | 
			
		||||
      return tflite::TensorType_UINT8;
 | 
			
		||||
    case tensorflow::DT_RESOURCE:
 | 
			
		||||
      return tflite::TensorType_RESOURCE;
 | 
			
		||||
    case tensorflow::DT_VARIANT:
 | 
			
		||||
      return tflite::TensorType_VARIANT;
 | 
			
		||||
    default:
 | 
			
		||||
      return errors::InvalidArgument("unsupported tensor data type", type);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -73,6 +73,8 @@ typedef enum {
 | 
			
		||||
  kTfLiteFloat64 = 11,
 | 
			
		||||
  kTfLiteComplex128 = 12,
 | 
			
		||||
  kTfLiteUInt64 = 13,
 | 
			
		||||
  kTfLiteResource = 14,
 | 
			
		||||
  kTfLiteVariant = 15,
 | 
			
		||||
} TfLiteType;
 | 
			
		||||
 | 
			
		||||
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
 | 
			
		||||
 | 
			
		||||
@ -219,6 +219,10 @@ const char* TfLiteTypeGetName(TfLiteType type) {
 | 
			
		||||
      return "FLOAT16";
 | 
			
		||||
    case kTfLiteFloat64:
 | 
			
		||||
      return "FLOAT64";
 | 
			
		||||
    case kTfLiteResource:
 | 
			
		||||
      return "RESOURCE";
 | 
			
		||||
    case kTfLiteVariant:
 | 
			
		||||
      return "VARIANT";
 | 
			
		||||
  }
 | 
			
		||||
  return "Unknown type";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -91,6 +91,8 @@ TEST(Types, TestTypeNames) {
 | 
			
		||||
  EXPECT_EQ(type_name(kTfLiteComplex64), "COMPLEX64");
 | 
			
		||||
  EXPECT_EQ(type_name(kTfLiteComplex128), "COMPLEX128");
 | 
			
		||||
  EXPECT_EQ(type_name(kTfLiteString), "STRING");
 | 
			
		||||
  EXPECT_EQ(type_name(kTfLiteResource), "RESOURCE");
 | 
			
		||||
  EXPECT_EQ(type_name(kTfLiteVariant), "VARIANT");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(Quantization, TestQuantizationFree) {
 | 
			
		||||
 | 
			
		||||
@ -872,6 +872,12 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
 | 
			
		||||
    case TensorType_COMPLEX128:
 | 
			
		||||
      *type = kTfLiteComplex128;
 | 
			
		||||
      return kTfLiteOk;
 | 
			
		||||
    case TensorType_RESOURCE:
 | 
			
		||||
      *type = kTfLiteResource;
 | 
			
		||||
      return kTfLiteOk;
 | 
			
		||||
    case TensorType_VARIANT:
 | 
			
		||||
      *type = kTfLiteVariant;
 | 
			
		||||
      return kTfLiteOk;
 | 
			
		||||
    default:
 | 
			
		||||
      *type = kTfLiteNoType;
 | 
			
		||||
      TF_LITE_REPORT_ERROR(error_reporter,
 | 
			
		||||
 | 
			
		||||
@ -1210,7 +1210,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadOnly(
 | 
			
		||||
  // ensure the buffer is large enough. However, we need to skip string tensors
 | 
			
		||||
  // and sparse tensors because their sizes change with the contents.
 | 
			
		||||
  // TODO(b/145615516): Extend BytesRequired to check sparse tensors.
 | 
			
		||||
  if (type != kTfLiteString && sparsity == nullptr) {
 | 
			
		||||
  if (type != kTfLiteString && type != kTfLiteResource &&
 | 
			
		||||
      type != kTfLiteVariant && sparsity == nullptr) {
 | 
			
		||||
    size_t required_bytes;
 | 
			
		||||
    TF_LITE_ENSURE_OK(&context_,
 | 
			
		||||
                      BytesRequired(type, dims, rank, &required_bytes));
 | 
			
		||||
@ -1262,7 +1263,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
 | 
			
		||||
  TF_LITE_ENSURE(&context_,
 | 
			
		||||
                 tensor_index < context_.tensors_size && tensor_index >= 0);
 | 
			
		||||
  size_t required_bytes = 0;
 | 
			
		||||
  if (type != kTfLiteString) {
 | 
			
		||||
  if (type != kTfLiteString && type != kTfLiteResource &&
 | 
			
		||||
      type != kTfLiteVariant) {
 | 
			
		||||
    // These types will be allocated in our arena so we need to record how
 | 
			
		||||
    // many bytes we will need based on the dimensions. String tensors are
 | 
			
		||||
    // allocated dynamically and we can't know ahead of time how much space
 | 
			
		||||
@ -1272,7 +1274,8 @@ TfLiteStatus Subgraph::SetTensorParametersReadWrite(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TfLiteAllocationType allocation_type = kTfLiteArenaRw;
 | 
			
		||||
  if (type == kTfLiteString) {
 | 
			
		||||
  if (type == kTfLiteString || type == kTfLiteResource ||
 | 
			
		||||
      type == kTfLiteVariant) {
 | 
			
		||||
    if (is_variable) {
 | 
			
		||||
      // We don't have a real use case for string variable tensor.
 | 
			
		||||
      ReportError("String variable tensor isn't supported.");
 | 
			
		||||
@ -1315,7 +1318,8 @@ TfLiteStatus Subgraph::ResizeTensorImpl(TfLiteTensor* tensor,
 | 
			
		||||
      tensor->allocation_type == kTfLiteCustom) {
 | 
			
		||||
    tensor_resized_since_op_invoke_ |=
 | 
			
		||||
        TfLiteIntArrayEqual(tensor->dims, new_size) == 0;
 | 
			
		||||
    if (tensor->type != kTfLiteString) {
 | 
			
		||||
    if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
 | 
			
		||||
        tensor->type != kTfLiteVariant) {
 | 
			
		||||
      size_t bytesRequired;
 | 
			
		||||
      TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
 | 
			
		||||
                                          new_size->size, &bytesRequired);
 | 
			
		||||
 | 
			
		||||
@ -84,6 +84,10 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
 | 
			
		||||
      return TF_STRING;
 | 
			
		||||
    case kTfLiteBool:
 | 
			
		||||
      return TF_BOOL;
 | 
			
		||||
    case kTfLiteResource:
 | 
			
		||||
      return TF_RESOURCE;
 | 
			
		||||
    case kTfLiteVariant:
 | 
			
		||||
      return TF_VARIANT;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -115,6 +119,10 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
 | 
			
		||||
      return kTfLiteString;
 | 
			
		||||
    case TF_BOOL:
 | 
			
		||||
      return kTfLiteBool;
 | 
			
		||||
    case TF_RESOURCE:
 | 
			
		||||
      return kTfLiteResource;
 | 
			
		||||
    case TF_VARIANT:
 | 
			
		||||
      return kTfLiteVariant;
 | 
			
		||||
    default:
 | 
			
		||||
      return kTfLiteNoType;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -120,6 +120,8 @@ TEST(UtilTest, TypeConversionsFromTFLite) {
 | 
			
		||||
  EXPECT_EQ(TF_COMPLEX128, GetTensorFlowDataType(kTfLiteComplex128));
 | 
			
		||||
  EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString));
 | 
			
		||||
  EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
 | 
			
		||||
  EXPECT_EQ(TF_RESOURCE, GetTensorFlowDataType(kTfLiteResource));
 | 
			
		||||
  EXPECT_EQ(TF_VARIANT, GetTensorFlowDataType(kTfLiteVariant));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(UtilTest, TypeConversionsFromTensorFlow) {
 | 
			
		||||
@ -135,8 +137,8 @@ TEST(UtilTest, TypeConversionsFromTensorFlow) {
 | 
			
		||||
  EXPECT_EQ(kTfLiteComplex128, GetTensorFlowLiteType(TF_COMPLEX128));
 | 
			
		||||
  EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));
 | 
			
		||||
  EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL));
 | 
			
		||||
  EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE));
 | 
			
		||||
  EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT));
 | 
			
		||||
  EXPECT_EQ(kTfLiteResource, GetTensorFlowLiteType(TF_RESOURCE));
 | 
			
		||||
  EXPECT_EQ(kTfLiteVariant, GetTensorFlowLiteType(TF_VARIANT));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -48,9 +48,15 @@ size_t AlignSizeUp(size_t size, size_t alignment) {
 | 
			
		||||
 | 
			
		||||
TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {
 | 
			
		||||
  switch (type) {
 | 
			
		||||
    case kTfLiteFloat16:
 | 
			
		||||
      *size = sizeof(int16_t);
 | 
			
		||||
      break;
 | 
			
		||||
    case kTfLiteFloat32:
 | 
			
		||||
      *size = sizeof(float);
 | 
			
		||||
      break;
 | 
			
		||||
    case kTfLiteFloat64:
 | 
			
		||||
      *size = sizeof(double);
 | 
			
		||||
      break;
 | 
			
		||||
    case kTfLiteInt16:
 | 
			
		||||
      *size = sizeof(int16_t);
 | 
			
		||||
      break;
 | 
			
		||||
 | 
			
		||||
@ -115,10 +115,18 @@ TF_LITE_MICRO_TEST(TestAlignSizeUp) {
 | 
			
		||||
 | 
			
		||||
TF_LITE_MICRO_TEST(TestTypeSizeOf) {
 | 
			
		||||
  size_t size;
 | 
			
		||||
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
 | 
			
		||||
                          tflite::TfLiteTypeSizeOf(kTfLiteFloat16, &size));
 | 
			
		||||
  TF_LITE_MICRO_EXPECT_EQ(sizeof(int16_t), size);
 | 
			
		||||
 | 
			
		||||
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
 | 
			
		||||
                          tflite::TfLiteTypeSizeOf(kTfLiteFloat32, &size));
 | 
			
		||||
  TF_LITE_MICRO_EXPECT_EQ(sizeof(float), size);
 | 
			
		||||
 | 
			
		||||
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
 | 
			
		||||
                          tflite::TfLiteTypeSizeOf(kTfLiteFloat64, &size));
 | 
			
		||||
  TF_LITE_MICRO_EXPECT_EQ(sizeof(double), size);
 | 
			
		||||
 | 
			
		||||
  TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
 | 
			
		||||
                          tflite::TfLiteTypeSizeOf(kTfLiteInt16, &size));
 | 
			
		||||
  TF_LITE_MICRO_EXPECT_EQ(sizeof(int16_t), size);
 | 
			
		||||
 | 
			
		||||
@ -422,8 +422,11 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
 | 
			
		||||
    case kTfLiteComplex64:
 | 
			
		||||
    case kTfLiteComplex128:
 | 
			
		||||
    case kTfLiteUInt64:
 | 
			
		||||
      // kTfLiteString, kTfLiteUInt64, kTfLiteComplex64 and kTfLiteComplex128 are not supported in
 | 
			
		||||
      // TensorFlow Lite Objc API.
 | 
			
		||||
    case kTfLiteResource:
 | 
			
		||||
    case kTfLiteVariant:
 | 
			
		||||
      // kTfLiteString, kTfLiteUInt64, kTfLiteComplex64, kTfLiteComplex128,
 | 
			
		||||
      // kTfLiteResource and kTfLiteVariant are not supported in TensorFlow Lite
 | 
			
		||||
      // Objc API.
 | 
			
		||||
      return TFLTensorDataTypeNoType;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -73,6 +73,10 @@ const char* TensorTypeName(TfLiteType type) {
 | 
			
		||||
      return "kTfLiteFloat16";
 | 
			
		||||
    case kTfLiteFloat64:
 | 
			
		||||
      return "kTfLiteFloat64";
 | 
			
		||||
    case kTfLiteResource:
 | 
			
		||||
      return "kTfLiteResource";
 | 
			
		||||
    case kTfLiteVariant:
 | 
			
		||||
      return "kTfLiteVariant";
 | 
			
		||||
  }
 | 
			
		||||
  return "(invalid)";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -619,7 +619,8 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
 | 
			
		||||
 | 
			
		||||
  std::vector<npy_intp> dims(tensor->dims->data,
 | 
			
		||||
                             tensor->dims->data + tensor->dims->size);
 | 
			
		||||
  if (tensor->type != kTfLiteString) {
 | 
			
		||||
  if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
 | 
			
		||||
      tensor->type != kTfLiteVariant) {
 | 
			
		||||
    // Make a buffer copy but we must tell Numpy It owns that data or else
 | 
			
		||||
    // it will leak.
 | 
			
		||||
    void* data = malloc(tensor->bytes);
 | 
			
		||||
 | 
			
		||||
@ -60,6 +60,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
 | 
			
		||||
      return NPY_COMPLEX64;
 | 
			
		||||
    case kTfLiteComplex128:
 | 
			
		||||
      return NPY_COMPLEX128;
 | 
			
		||||
    case kTfLiteResource:
 | 
			
		||||
    case kTfLiteVariant:
 | 
			
		||||
      return NPY_OBJECT;
 | 
			
		||||
    case kTfLiteNoType:
 | 
			
		||||
      return NPY_NOTYPE;
 | 
			
		||||
      // Avoid default so compiler errors created when new types are made.
 | 
			
		||||
 | 
			
		||||
@ -91,6 +91,10 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
 | 
			
		||||
      return TensorType_COMPLEX64;
 | 
			
		||||
    case kTfLiteComplex128:
 | 
			
		||||
      return TensorType_COMPLEX128;
 | 
			
		||||
    case kTfLiteResource:
 | 
			
		||||
      return TensorType_RESOURCE;
 | 
			
		||||
    case kTfLiteVariant:
 | 
			
		||||
      return TensorType_VARIANT;
 | 
			
		||||
  }
 | 
			
		||||
  // No default to get compiler error when new type is introduced.
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -64,6 +64,8 @@ _MAP_TF_TO_TFLITE_TYPES = {
 | 
			
		||||
    dtypes.int8: _types_pb2.INT8,
 | 
			
		||||
    dtypes.float64: _types_pb2.FLOAT64,
 | 
			
		||||
    dtypes.complex128: _types_pb2.COMPLEX128,
 | 
			
		||||
    dtypes.resource: _types_pb2.RESOURCE,
 | 
			
		||||
    dtypes.variant: _types_pb2.VARIANT,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
 | 
			
		||||
 | 
			
		||||
@ -45,6 +45,8 @@ enum TensorType : byte {
 | 
			
		||||
  FLOAT64 = 10,
 | 
			
		||||
  COMPLEX128 = 11,
 | 
			
		||||
  UINT64 = 12,
 | 
			
		||||
  RESOURCE = 13,
 | 
			
		||||
  VARIANT = 14,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Custom quantization parameters for experimenting with new quantization
 | 
			
		||||
 | 
			
		||||
@ -402,11 +402,13 @@ enum TensorType {
 | 
			
		||||
  TensorType_FLOAT64 = 10,
 | 
			
		||||
  TensorType_COMPLEX128 = 11,
 | 
			
		||||
  TensorType_UINT64 = 12,
 | 
			
		||||
  TensorType_RESOURCE = 13,
 | 
			
		||||
  TensorType_VARIANT = 14,
 | 
			
		||||
  TensorType_MIN = TensorType_FLOAT32,
 | 
			
		||||
  TensorType_MAX = TensorType_UINT64
 | 
			
		||||
  TensorType_MAX = TensorType_VARIANT
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
inline const TensorType (&EnumValuesTensorType())[13] {
 | 
			
		||||
inline const TensorType (&EnumValuesTensorType())[15] {
 | 
			
		||||
  static const TensorType values[] = {
 | 
			
		||||
    TensorType_FLOAT32,
 | 
			
		||||
    TensorType_FLOAT16,
 | 
			
		||||
@ -420,13 +422,15 @@ inline const TensorType (&EnumValuesTensorType())[13] {
 | 
			
		||||
    TensorType_INT8,
 | 
			
		||||
    TensorType_FLOAT64,
 | 
			
		||||
    TensorType_COMPLEX128,
 | 
			
		||||
    TensorType_UINT64
 | 
			
		||||
    TensorType_UINT64,
 | 
			
		||||
    TensorType_RESOURCE,
 | 
			
		||||
    TensorType_VARIANT
 | 
			
		||||
  };
 | 
			
		||||
  return values;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline const char * const *EnumNamesTensorType() {
 | 
			
		||||
  static const char * const names[14] = {
 | 
			
		||||
  static const char * const names[16] = {
 | 
			
		||||
    "FLOAT32",
 | 
			
		||||
    "FLOAT16",
 | 
			
		||||
    "INT32",
 | 
			
		||||
@ -440,13 +444,15 @@ inline const char * const *EnumNamesTensorType() {
 | 
			
		||||
    "FLOAT64",
 | 
			
		||||
    "COMPLEX128",
 | 
			
		||||
    "UINT64",
 | 
			
		||||
    "RESOURCE",
 | 
			
		||||
    "VARIANT",
 | 
			
		||||
    nullptr
 | 
			
		||||
  };
 | 
			
		||||
  return names;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline const char *EnumNameTensorType(TensorType e) {
 | 
			
		||||
  if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT64)) return "";
 | 
			
		||||
  if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_VARIANT)) return "";
 | 
			
		||||
  const size_t index = static_cast<size_t>(e);
 | 
			
		||||
  return EnumNamesTensorType()[index];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -2323,6 +2323,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
 | 
			
		||||
      return ArrayDataType::kFloat16;
 | 
			
		||||
    case FLOAT64:
 | 
			
		||||
      return ArrayDataType::kFloat64;
 | 
			
		||||
    case RESOURCE:
 | 
			
		||||
    case VARIANT:
 | 
			
		||||
    default:
 | 
			
		||||
      return ArrayDataType::kNone;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -58,4 +58,10 @@ enum IODataType {
 | 
			
		||||
 | 
			
		||||
  // Uint64, not quantized
 | 
			
		||||
  UINT64 = 13;
 | 
			
		||||
 | 
			
		||||
  // Resource type
 | 
			
		||||
  RESOURCE = 14;
 | 
			
		||||
 | 
			
		||||
  // Variant type
 | 
			
		||||
  VARIANT = 15;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -86,6 +86,10 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
 | 
			
		||||
      return TensorType_COMPLEX64;
 | 
			
		||||
    case kTfLiteComplex128:
 | 
			
		||||
      return TensorType_COMPLEX128;
 | 
			
		||||
    case kTfLiteResource:
 | 
			
		||||
      return TensorType_RESOURCE;
 | 
			
		||||
    case kTfLiteVariant:
 | 
			
		||||
      return TensorType_VARIANT;
 | 
			
		||||
  }
 | 
			
		||||
  // TODO(aselle): consider an error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -558,7 +558,9 @@ TEST(VerifyModel, TypedTensorShapeMatchesTensorBufferSize) {
 | 
			
		||||
  TfLiteFlatbufferModelBuilder builder;
 | 
			
		||||
  for (int tensor_type = TensorType_MIN; tensor_type <= TensorType_MAX;
 | 
			
		||||
       ++tensor_type) {
 | 
			
		||||
    if (tensor_type == TensorType_STRING) continue;
 | 
			
		||||
    if (tensor_type == TensorType_STRING ||
 | 
			
		||||
        tensor_type == TensorType_RESOURCE || tensor_type == TensorType_VARIANT)
 | 
			
		||||
      continue;
 | 
			
		||||
    TfLiteType lite_type = kTfLiteNoType;
 | 
			
		||||
    ASSERT_EQ(ConvertTensorType(static_cast<TensorType>(tensor_type),
 | 
			
		||||
                                &lite_type, /*error_reporter=*/nullptr),
 | 
			
		||||
 | 
			
		||||
@ -132,8 +132,9 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
 | 
			
		||||
      if (context) {
 | 
			
		||||
        context->ReportError(
 | 
			
		||||
            context,
 | 
			
		||||
            "Type %d is unsupported. Only float32, int8, int16, int32, int64, "
 | 
			
		||||
            "uint8, bool, complex64 supported currently.",
 | 
			
		||||
            "Type %d is unsupported. Only float16, float32, float64, int8, "
 | 
			
		||||
            "int16, int32, int64, uint8, uint64, bool, complex64 and "
 | 
			
		||||
            "complex128 supported currently.",
 | 
			
		||||
            type);
 | 
			
		||||
      }
 | 
			
		||||
      return kTfLiteError;
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user