diff --git a/tensorflow/lite/delegates/gpu/cl/serialization.cc b/tensorflow/lite/delegates/gpu/cl/serialization.cc index 812782c60bc..4f20c41e546 100644 --- a/tensorflow/lite/delegates/gpu/cl/serialization.cc +++ b/tensorflow/lite/delegates/gpu/cl/serialization.cc @@ -53,7 +53,25 @@ data::DataType ToFB(DataType type) { return data::DataType::FLOAT16; case DataType::FLOAT32: return data::DataType::FLOAT32; - default: + case DataType::FLOAT64: + return data::DataType::FLOAT64; + case DataType::UINT8: + return data::DataType::UINT8; + case DataType::INT8: + return data::DataType::INT8; + case DataType::UINT16: + return data::DataType::UINT16; + case DataType::INT16: + return data::DataType::INT16; + case DataType::UINT32: + return data::DataType::UINT32; + case DataType::INT32: + return data::DataType::INT32; + case DataType::UINT64: + return data::DataType::UINT64; + case DataType::INT64: + return data::DataType::INT64; + case DataType::UNKNOWN: return data::DataType::UNKNOWN; } } @@ -118,7 +136,25 @@ DataType ToEnum(data::DataType type) { return DataType::FLOAT16; case data::DataType::FLOAT32: return DataType::FLOAT32; - default: + case data::DataType::FLOAT64: + return DataType::FLOAT64; + case data::DataType::UINT8: + return DataType::UINT8; + case data::DataType::INT8: + return DataType::INT8; + case data::DataType::UINT16: + return DataType::UINT16; + case data::DataType::INT16: + return DataType::INT16; + case data::DataType::UINT32: + return DataType::UINT32; + case data::DataType::INT32: + return DataType::INT32; + case data::DataType::UINT64: + return DataType::UINT64; + case data::DataType::INT64: + return DataType::INT64; + case data::DataType::UNKNOWN: return DataType::UNKNOWN; } } diff --git a/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs b/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs index 8d9434786a2..5b1918d26ec 100644 --- a/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs +++ b/tensorflow/lite/delegates/gpu/common/task/serialization_base.fbs @@ -68,8 +68,17 @@ table HalfValue { enum DataType : byte { UNKNOWN = 0, - FLOAT32 = 1, - FLOAT16 = 2, + FLOAT16 = 1, + FLOAT32 = 2, + FLOAT64 = 3, + UINT8 = 4, + INT8 = 5, + UINT16 = 6, + INT16 = 7, + UINT32 = 8, + INT32 = 9, + UINT64 = 10, + INT64 = 11, } enum MemoryType : byte { diff --git a/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h b/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h index e3c3a5c33df..82fadccb0b6 100644 --- a/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h +++ b/tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h @@ -122,33 +122,39 @@ inline const char *EnumNameAccessType(AccessType e) { enum class DataType : int8_t { UNKNOWN = 0, - FLOAT32 = 1, - FLOAT16 = 2, + FLOAT16 = 1, + FLOAT32 = 2, + FLOAT64 = 3, + UINT8 = 4, + INT8 = 5, + UINT16 = 6, + INT16 = 7, + UINT32 = 8, + INT32 = 9, + UINT64 = 10, + INT64 = 11, MIN = UNKNOWN, - MAX = FLOAT16 + MAX = INT64 }; -inline const DataType (&EnumValuesDataType())[3] { +inline const DataType (&EnumValuesDataType())[12] { static const DataType values[] = { - DataType::UNKNOWN, - DataType::FLOAT32, - DataType::FLOAT16 - }; + DataType::UNKNOWN, DataType::FLOAT16, DataType::FLOAT32, + DataType::FLOAT64, DataType::UINT8, DataType::INT8, + DataType::UINT16, DataType::INT16, DataType::UINT32, + DataType::INT32, DataType::UINT64, DataType::INT64}; return values; } inline const char * const *EnumNamesDataType() { - static const char * const names[4] = { - "UNKNOWN", - "FLOAT32", - "FLOAT16", - nullptr - }; + static const char *const names[13] = { + "UNKNOWN", "FLOAT16", "FLOAT32", "FLOAT64", "UINT8", "INT8", "UINT16", + "INT16", "UINT32", "INT32", "UINT64", "INT64", nullptr}; return names; } inline const char *EnumNameDataType(DataType e) { - if (flatbuffers::IsOutRange(e, DataType::UNKNOWN, DataType::FLOAT16)) return ""; + if (flatbuffers::IsOutRange(e, DataType::UNKNOWN, DataType::INT64)) return ""; const size_t index = static_cast(e); return EnumNamesDataType()[index]; }