Add nominal support for unsigned 32-bit integer tensor types.

PiperOrigin-RevId: 358028374
Change-Id: I89ebe8f549c279d87da74ca4fedc6b49f04ff506
This commit is contained in:
Shlomi Regev 2021-02-17 14:05:09 -08:00 committed by TensorFlower Gardener
parent 2c332fbc4b
commit e43be76009
41 changed files with 221 additions and 14 deletions

View File

@ -100,6 +100,9 @@
function for a given signaturedef.
* Add int8 support for `ReshapeV2`.
* Add experimental support for optimization with sparsity.
* Add nominal support for unsigned 32-bit integer tensor types. Note that
very few TFLite kernels support this type natively, so its use in mobile
ML authoring is generally discouraged.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with

View File

@ -165,7 +165,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
case 16:
return tflite::TensorType_INT16;
case 32:
return tflite::TensorType_INT32;
return itype.isUnsigned() ? tflite::TensorType_UINT32
: tflite::TensorType_INT32;
case 64:
return itype.isUnsigned() ? tflite::TensorType_UINT64
: tflite::TensorType_INT64;

View File

@ -79,6 +79,8 @@ static std::string TfLiteTensorString(const TfLiteTensor& tensor) {
switch (tensor.type) {
case kTfLiteInt32:
return TfLiteTypedTensorString<int32_t>(tensor);
case kTfLiteUInt32:
return TfLiteTypedTensorString<uint32_t>(tensor);
case kTfLiteInt64:
return TfLiteTypedTensorString<int64_t>(tensor);
case kTfLiteFloat32:

View File

@ -119,6 +119,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_INT16;
case toco::IODataType::INT32:
return DT_INT32;
case toco::IODataType::UINT32:
return DT_UINT32;
case toco::IODataType::INT64:
return DT_INT64;
case toco::IODataType::UINT64:

View File

@ -41,6 +41,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
return builder.getF64Type();
case tflite::TensorType_INT32:
return builder.getIntegerType(32);
case tflite::TensorType_UINT32:
return builder.getIntegerType(32, /*isSigned=*/false);
case tflite::TensorType_UINT8:
return builder.getIntegerType(8, /*isSigned=*/false);
case tflite::TensorType_INT64:
@ -86,6 +88,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
return tensorflow::DT_INT16;
case tflite::TensorType_INT32:
return tensorflow::DT_INT32;
case tflite::TensorType_UINT32:
return tensorflow::DT_UINT32;
case tflite::TensorType_INT64:
return tensorflow::DT_INT64;
case tflite::TensorType_STRING:
@ -121,6 +125,8 @@ StatusOr<tflite::TensorType> TfTypeToTflType(tensorflow::DataType type) {
return tflite::TensorType_INT16;
case tensorflow::DT_INT32:
return tflite::TensorType_INT32;
case tensorflow::DT_UINT32:
return tflite::TensorType_UINT32;
case tensorflow::DT_INT64:
return tflite::TensorType_INT64;
case tensorflow::DT_UINT64:

View File

@ -75,6 +75,7 @@ typedef enum {
kTfLiteUInt64 = 13,
kTfLiteResource = 14,
kTfLiteVariant = 15,
kTfLiteUInt32 = 16,
} TfLiteType;
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.

View File

@ -199,6 +199,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
return "INT16";
case kTfLiteInt32:
return "INT32";
case kTfLiteUInt32:
return "UINT32";
case kTfLiteUInt8:
return "UINT8";
case kTfLiteInt8:

View File

@ -296,6 +296,7 @@ typedef union TfLitePtrUnion {
* GetTensorData<TYPE>(tensor) instead, otherwise only access .data, as other
* members are deprecated. */
int32_t* i32;
uint32_t* u32;
int64_t* i64;
uint64_t* u64;
float* f;

View File

@ -83,6 +83,7 @@ TEST(Types, TestTypeNames) {
EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16");
EXPECT_EQ(type_name(kTfLiteInt16), "INT16");
EXPECT_EQ(type_name(kTfLiteInt32), "INT32");
EXPECT_EQ(type_name(kTfLiteUInt32), "UINT32");
EXPECT_EQ(type_name(kTfLiteUInt8), "UINT8");
EXPECT_EQ(type_name(kTfLiteUInt64), "UINT64");
EXPECT_EQ(type_name(kTfLiteInt8), "INT8");

View File

@ -851,6 +851,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
case TensorType_INT32:
*type = kTfLiteInt32;
return kTfLiteOk;
case TensorType_UINT32:
*type = kTfLiteUInt32;
return kTfLiteOk;
case TensorType_UINT8:
*type = kTfLiteUInt8;
return kTfLiteOk;

View File

@ -68,6 +68,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
return TF_INT16;
case kTfLiteInt32:
return TF_INT32;
case kTfLiteUInt32:
return TF_UINT32;
case kTfLiteUInt8:
return TF_UINT8;
case kTfLiteInt8:

View File

@ -189,9 +189,10 @@ TEST(BasicInterpreter, CheckAllocate) {
TfLiteType type;
size_t size;
} cases[] = {
{kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)},
{kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)},
{kTfLiteInt16, sizeof(int16_t)}, {kTfLiteFloat16, sizeof(TfLiteFloat16)},
{kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)},
{kTfLiteUInt32, sizeof(uint32_t)}, {kTfLiteUInt8, sizeof(uint8_t)},
{kTfLiteInt64, sizeof(int64_t)}, {kTfLiteInt16, sizeof(int16_t)},
{kTfLiteFloat16, sizeof(TfLiteFloat16)},
};
for (auto test : cases) {
@ -261,6 +262,7 @@ TEST(BasicInterpreter, CheckQuantization) {
TEST(BasicInterpreter, CheckResize) {
const float floats[] = {-3., -4.};
const int32_t int32s[] = {-3, -4};
const uint32_t uint32s[] = {3, 4};
const uint8_t uint8s[] = {3, 4};
const int64_t int64s[] = {6, -7};
const int16_t int16s[] = {8, -9};
@ -274,6 +276,7 @@ TEST(BasicInterpreter, CheckResize) {
} cases[] = {
{kTfLiteFloat32, sizeof(float), reinterpret_cast<const char*>(floats)},
{kTfLiteInt32, sizeof(int32_t), reinterpret_cast<const char*>(int32s)},
{kTfLiteUInt32, sizeof(uint32_t), reinterpret_cast<const char*>(uint32s)},
{kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)},
{kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)},
{kTfLiteInt16, sizeof(int16_t), reinterpret_cast<const char*>(int16s)},
@ -313,8 +316,9 @@ TEST(BasicInterpreter, CheckResize) {
TEST(BasicInterpreter, CheckAlignment) {
struct {
TfLiteType type;
} cases[] = {{kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8},
{kTfLiteInt64}, {kTfLiteInt16}, {kTfLiteFloat16}};
} cases[] = {{kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt32},
{kTfLiteUInt8}, {kTfLiteInt64}, {kTfLiteInt16},
{kTfLiteFloat16}};
for (auto test : cases) {
Interpreter interpreter;

View File

@ -486,6 +486,9 @@ int TfLiteTypeGetSize(TfLiteType type) {
case kTfLiteInt32:
TF_LITE_ASSERT_EQ(sizeof(int32_t), 4);
return 4;
case kTfLiteUInt32:
TF_LITE_ASSERT_EQ(sizeof(uint32_t), 4);
return 4;
case kTfLiteInt64:
TF_LITE_ASSERT_EQ(sizeof(int64_t), 8);
return 8;

View File

@ -915,6 +915,7 @@ TensorType GetTensorType() {
if (std::is_same<T, int8_t>::value) return TensorType_INT8;
if (std::is_same<T, int16_t>::value) return TensorType_INT16;
if (std::is_same<T, int32_t>::value) return TensorType_INT32;
if (std::is_same<T, uint32_t>::value) return TensorType_UINT32;
if (std::is_same<T, int64_t>::value) return TensorType_INT64;
if (std::is_same<T, uint8_t>::value) return TensorType_UINT8;
if (std::is_same<T, string>::value) return TensorType_STRING;
@ -955,6 +956,16 @@ struct TypeUnion<int32_t> {
typedef int32_t ScalarType;
};
template <>
struct TypeUnion<uint32_t> {
public:
// NOLINTNEXTLINE
static constexpr TensorType tensor_type = TensorType::TensorType_UINT32;
// NOLINTNEXTLINE
static constexpr TfLiteType tflite_type = TfLiteType::kTfLiteUInt32;
typedef uint32_t ScalarType;
};
template <>
struct TypeUnion<int16_t> {
public:

View File

@ -63,6 +63,9 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {
case kTfLiteInt32:
*size = sizeof(int32_t);
break;
case kTfLiteUInt32:
*size = sizeof(uint32_t);
break;
case kTfLiteUInt8:
*size = sizeof(uint8_t);
break;

View File

@ -136,6 +136,10 @@ TF_LITE_MICRO_TEST(TestTypeSizeOf) {
tflite::TfLiteTypeSizeOf(kTfLiteInt32, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(int32_t), size);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::TfLiteTypeSizeOf(kTfLiteUInt32, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(uint32_t), size);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::TfLiteTypeSizeOf(kTfLiteUInt8, &size));
TF_LITE_MICRO_EXPECT_EQ(sizeof(uint8_t), size);

View File

@ -421,6 +421,7 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
case kTfLiteString:
case kTfLiteComplex64:
case kTfLiteComplex128:
case kTfLiteUInt32:
case kTfLiteUInt64:
case kTfLiteResource:
case kTfLiteVariant:

View File

@ -51,6 +51,8 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteFloat32";
case kTfLiteInt32:
return "kTfLiteInt32";
case kTfLiteUInt32:
return "kTfLiteUInt32";
case kTfLiteUInt8:
return "kTfLiteUInt8";
case kTfLiteInt8:

View File

@ -59,6 +59,7 @@ struct TfLiteTypeToType {}; // Specializations below
// No string mapping is included here, since the TF Lite packed representation
// doesn't correspond to a C++ type well.
MATCH_TYPE_AND_TFLITE_TYPE(int32_t, kTfLiteInt32);
MATCH_TYPE_AND_TFLITE_TYPE(uint32_t, kTfLiteUInt32);
MATCH_TYPE_AND_TFLITE_TYPE(int16_t, kTfLiteInt16);
MATCH_TYPE_AND_TFLITE_TYPE(int64_t, kTfLiteInt64);
MATCH_TYPE_AND_TFLITE_TYPE(float, kTfLiteFloat32);

View File

@ -42,6 +42,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
return NPY_FLOAT64;
case kTfLiteInt32:
return NPY_INT32;
case kTfLiteUInt32:
return NPY_UINT32;
case kTfLiteInt16:
return NPY_INT16;
case kTfLiteUInt8:
@ -80,6 +82,8 @@ TfLiteType TfLiteTypeFromPyType(int py_type) {
return kTfLiteFloat64;
case NPY_INT32:
return kTfLiteInt32;
case NPY_UINT32:
return kTfLiteUInt32;
case NPY_INT16:
return kTfLiteInt16;
case NPY_UINT8:

View File

@ -73,6 +73,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_FLOAT64;
case kTfLiteInt32:
return TensorType_INT32;
case kTfLiteUInt32:
return TensorType_UINT32;
case kTfLiteUInt8:
return TensorType_UINT8;
case kTfLiteInt8:

View File

@ -66,6 +66,7 @@ _MAP_TF_TO_TFLITE_TYPES = {
dtypes.complex128: _types_pb2.COMPLEX128,
dtypes.resource: _types_pb2.RESOURCE,
dtypes.variant: _types_pb2.VARIANT,
dtypes.uint32: _types_pb2.UINT32,
}
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
@ -81,6 +82,7 @@ _MAP_TFLITE_ENUM_TO_TF_TYPES = {
9: dtypes.int8,
10: dtypes.float64,
11: dtypes.complex128,
16: dtypes.uint32,
}
_TFLITE_FILE_IDENTIFIER = b"TFL3"

View File

@ -47,6 +47,8 @@ class UtilTest(test_util.TensorFlowTestCase):
util.convert_dtype_to_tflite_type(dtypes.float16), _types_pb2.FLOAT16)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.int32), _types_pb2.INT32)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.uint32), _types_pb2.UINT32)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.uint8),
_types_pb2.QUANTIZED_UINT8)
@ -89,13 +91,16 @@ class UtilTest(test_util.TensorFlowTestCase):
util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(11), dtypes.complex128)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(16), dtypes.uint32)
with self.assertRaises(ValueError) as error:
util._convert_tflite_enum_type_to_tf_type(20)
self.assertEqual(
"Unsupported enum 20. The valid map of enum to tf types is : "
"{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, "
"5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, "
"10: tf.float64, 11: tf.complex128}", str(error.exception))
"10: tf.float64, 11: tf.complex128, 16: tf.uint32}",
str(error.exception))
def testTensorName(self):
with ops.Graph().as_default():
@ -108,6 +113,30 @@ class UtilTest(test_util.TensorFlowTestCase):
got_name = util.get_tensor_name(out_tensors[i])
self.assertEqual(got_name, expect_names[i])
def testUint32PassThrough(self):
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(4,), dtype=tf.uint32),
tf.keras.layers.Reshape(target_shape=(2, 2))
])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]
self.assertEqual(input_details["dtype"], np.uint32)
self.assertEqual(output_details["dtype"], np.uint32)
in_array = np.array([[1, 1, 1, 1]], dtype="uint32") * ((1 << 32) - 1)
expected_out = np.reshape(in_array, (2, 2))
interpreter.set_tensor(input_details["index"], in_array)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details["index"])[0]
self.assertAllEqual(expected_out, output_data)
@test_util.enable_control_flow_v2
def testRemoveLowerUsingSwitchMerge(self):
with ops.Graph().as_default():

View File

@ -47,6 +47,7 @@ enum TensorType : byte {
UINT64 = 12,
RESOURCE = 13,
VARIANT = 14,
UINT32 = 15,
}
// Custom quantization parameters for experimenting with new quantization

View File

@ -404,11 +404,12 @@ enum TensorType {
TensorType_UINT64 = 12,
TensorType_RESOURCE = 13,
TensorType_VARIANT = 14,
TensorType_UINT32 = 15,
TensorType_MIN = TensorType_FLOAT32,
TensorType_MAX = TensorType_VARIANT
TensorType_MAX = TensorType_UINT32
};
inline const TensorType (&EnumValuesTensorType())[15] {
inline const TensorType (&EnumValuesTensorType())[16] {
static const TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
@ -424,13 +425,14 @@ inline const TensorType (&EnumValuesTensorType())[15] {
TensorType_COMPLEX128,
TensorType_UINT64,
TensorType_RESOURCE,
TensorType_VARIANT
TensorType_VARIANT,
TensorType_UINT32
};
return values;
}
inline const char * const *EnumNamesTensorType() {
static const char * const names[16] = {
static const char * const names[17] = {
"FLOAT32",
"FLOAT16",
"INT32",
@ -446,13 +448,14 @@ inline const char * const *EnumNamesTensorType() {
"UINT64",
"RESOURCE",
"VARIANT",
"UINT32",
nullptr
};
return names;
}
inline const char *EnumNameTensorType(TensorType e) {
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_VARIANT)) return "";
if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT32)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesTensorType()[index];
}

View File

@ -58,6 +58,16 @@ inline std::vector<int> Split(const string& s, const string& delimiter) {
return fields;
}
template <>
inline std::vector<uint32_t> Split(const string& s, const string& delimiter) {
std::vector<uint32_t> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
// NOLINTNEXTLINE(runtime/deprecated_fn)
fields.push_back(strtol(s.data() + p.first, nullptr, 10));
}
return fields;
}
template <>
inline std::vector<int64_t> Split(const string& s, const string& delimiter) {
std::vector<int64_t> fields;

View File

@ -162,6 +162,10 @@ void TfDriver::SetInput(const string& values_as_string,
num_values_available =
FillTensorWithData<int32_t>(tensor, values_as_string);
break;
case tensorflow::DT_UINT32:
num_values_available =
FillTensorWithData<uint32_t>(tensor, values_as_string);
break;
case tensorflow::DT_UINT8:
num_values_available =
FillTensorWithData<uint8_t>(tensor, values_as_string);
@ -224,6 +228,8 @@ string TfDriver::ReadOutput(const tensorflow::Tensor& tensor) {
return TensorDataToCsvString<float>(tensor);
case tensorflow::DT_INT32:
return TensorDataToCsvString<int32_t>(tensor);
case tensorflow::DT_UINT32:
return TensorDataToCsvString<uint32_t>(tensor);
case tensorflow::DT_INT64:
return TensorDataToCsvString<tensorflow::int64>(tensor);
case tensorflow::DT_UINT8:

View File

@ -333,6 +333,8 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
return TypedCheck<float, float>(verbose, tensor);
case kTfLiteInt32:
return TypedCheck<int32_t, float>(verbose, tensor);
case kTfLiteUInt32:
return TypedCheck<uint32_t, float>(verbose, tensor);
case kTfLiteInt64:
return TypedCheck<int64_t, float>(verbose, tensor);
case kTfLiteUInt64:
@ -485,6 +487,12 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) {
SetTensorData(values, tensor->data.raw);
break;
}
case kTfLiteUInt32: {
const auto& values = testing::Split<uint32_t>(csv_values, ",");
if (!CheckSizes<uint32_t>(tensor->bytes, values.size())) return;
SetTensorData(values, tensor->data.raw);
break;
}
case kTfLiteInt64: {
const auto& values = testing::Split<int64_t>(csv_values, ",");
if (!CheckSizes<int64_t>(tensor->bytes, values.size())) return;
@ -586,6 +594,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
case kTfLiteInt32:
expected_output_[id]->SetData<int32_t>(csv_values);
break;
case kTfLiteUInt32:
expected_output_[id]->SetData<uint32_t>(csv_values);
break;
case kTfLiteInt64:
expected_output_[id]->SetData<int64_t>(csv_values);
break;
@ -692,6 +703,8 @@ string TfLiteDriver::ReadOutput(int id) {
return JoinDefault(tensor->data.f, num_elements, ",");
case kTfLiteInt32:
return JoinDefault(tensor->data.i32, num_elements, ",");
case kTfLiteUInt32:
return JoinDefault(tensor->data.u32, num_elements, ",");
case kTfLiteInt64:
return JoinDefault(tensor->data.i64, num_elements, ",");
case kTfLiteUInt64:

View File

@ -41,6 +41,7 @@ using tensorflow::DT_FLOAT;
using tensorflow::DT_INT16;
using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_UINT32;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::TensorProto;
@ -59,6 +60,8 @@ tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type,
return tensorflow::DT_UINT8;
case ArrayDataType::kInt32:
return tensorflow::DT_INT32;
case ArrayDataType::kUint32:
return tensorflow::DT_UINT32;
case ArrayDataType::kInt64:
return tensorflow::DT_INT64;
case ArrayDataType::kString:
@ -2438,6 +2441,9 @@ void AddPlaceholder(const std::string& name, ArrayDataType type,
case ArrayDataType::kInt32:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
break;
case ArrayDataType::kUint32:
(*placeholder->mutable_attr())["dtype"].set_type(DT_UINT32);
break;
case ArrayDataType::kInt64:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
break;

View File

@ -57,6 +57,7 @@ using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_QUINT8;
using tensorflow::DT_STRING;
using tensorflow::DT_UINT32;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::NodeDef;
@ -185,6 +186,8 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
return ArrayDataType::kBool;
else if (dtype == DT_INT32)
return ArrayDataType::kInt32;
else if (dtype == DT_UINT32)
return ArrayDataType::kUint32;
else if (dtype == DT_INT64)
return ArrayDataType::kInt64;
else if (dtype == DT_STRING)
@ -295,6 +298,18 @@ struct TensorTraits<int32> {
}
};
template <>
struct TensorTraits<uint32> {
static int size(const TensorProto& p) { return p.uint32_val_size(); }
static int32 get(const TensorProto& p, int i) { return p.uint32_val(i); }
static std::string accessor_name() { return "uint32_val"; }
static std::string type_name() { return "uint32"; }
static void CopyFromContent(const TensorProto& p, std::vector<uint32>* data) {
toco::port::CopyToBuffer(p.tensor_content(),
reinterpret_cast<char*>(data->data()));
}
};
template <>
struct TensorTraits<int64> {
static int size(const TensorProto& p) { return p.int64_val_size(); }
@ -432,6 +447,23 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
&output_int_data);
}
tensorflow::Status ImportUint32Array(const TensorProto& input_tensor,
Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_UINT32);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 6);
int input_flat_size;
auto status = ImportShape(input_shape.dim(), &input_flat_size,
output_array->mutable_shape());
if (!status.ok()) return status;
auto& output_int_data =
output_array->GetMutableBuffer<ArrayDataType::kUint32>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
return ImportTensorData<uint32>(input_tensor, input_flat_size,
&output_int_data);
}
tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT64);
@ -757,6 +789,10 @@ tensorflow::Status ConvertConstOperator(
array.data_type = ArrayDataType::kInt32;
status = ImportInt32Array(tensor, &array);
break;
case DT_UINT32:
array.data_type = ArrayDataType::kUint32;
status = ImportUint32Array(tensor, &array);
break;
case DT_QUINT8:
array.data_type = ArrayDataType::kUint8;
status = ImportQuint8Array(tensor, &array);
@ -1473,7 +1509,6 @@ tensorflow::Status ConditionallyConvertConstOperator(
model);
}
}
switch (GetDataTypeAttr(node, "dtype")) {
case DT_FLOAT:
case DT_INT32:

View File

@ -37,6 +37,7 @@ using tensorflow::DT_INT64;
using tensorflow::DT_INVALID;
using tensorflow::DT_QUINT8;
using tensorflow::DT_STRING;
using tensorflow::DT_UINT32;
using tensorflow::NodeDef;
using tensorflow::Status;
using ::testing::ElementsAre;
@ -127,6 +128,11 @@ void BuildConstNode(std::initializer_list<int64_t> shape,
t.add_int_val(i % std::numeric_limits<int>::max() + 1);
}
break;
case DT_UINT32:
for (int64_t i = 0; i < num_elements; ++i) {
t.add_int_val(i % std::numeric_limits<uint32_t>::max() + 1);
}
break;
case DT_QUINT8:
for (int64_t i = 0; i < num_elements; ++i) {
t.add_int_val(i % std::numeric_limits<uint8_t>::max() + 1);

View File

@ -48,6 +48,7 @@ namespace tflite {
{ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
{ArrayDataType::kInt16, ::tflite::TensorType_INT16},
{ArrayDataType::kInt32, ::tflite::TensorType_INT32},
{ArrayDataType::kUint32, ::tflite::TensorType_UINT32},
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kUint64, ::tflite::TensorType_UINT64},
{ArrayDataType::kString, ::tflite::TensorType_STRING},

View File

@ -92,6 +92,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
return ::tflite::TensorType_INT16;
case ArrayDataType::kInt32:
return ::tflite::TensorType_INT32;
case ArrayDataType::kUint32:
return ::tflite::TensorType_UINT32;
case ArrayDataType::kInt64:
return ::tflite::TensorType_INT64;
case ArrayDataType::kUint8:
@ -117,6 +119,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) {
return ArrayDataType::kInt16;
case ::tflite::TensorType_INT32:
return ArrayDataType::kInt32;
case ::tflite::TensorType_UINT32:
return ArrayDataType::kUint32;
case ::tflite::TensorType_INT64:
return ArrayDataType::kInt64;
case ::tflite::TensorType_STRING:
@ -143,6 +147,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
return CopyBuffer<ArrayDataType::kInt16>(array, builder);
case ArrayDataType::kInt32:
return CopyBuffer<ArrayDataType::kInt32>(array, builder);
case ArrayDataType::kUint32:
return CopyBuffer<ArrayDataType::kUint32>(array, builder);
case ArrayDataType::kInt64:
return CopyBuffer<ArrayDataType::kInt64>(array, builder);
case ArrayDataType::kString:
@ -170,6 +176,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
case ::tflite::TensorType_INT32:
return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
case ::tflite::TensorType_UINT32:
return CopyBuffer<ArrayDataType::kUint32>(buffer, array);
case ::tflite::TensorType_INT64:
return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
case ::tflite::TensorType_STRING:

View File

@ -71,6 +71,7 @@ TEST(DataType, SupportedTypes) {
std::vector<std::pair<ArrayDataType, ::tflite::TensorType>> testdata = {
{ArrayDataType::kUint8, ::tflite::TensorType_UINT8},
{ArrayDataType::kInt32, ::tflite::TensorType_INT32},
{ArrayDataType::kUint32, ::tflite::TensorType_UINT32},
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
{ArrayDataType::kBool, ::tflite::TensorType_BOOL},
@ -154,6 +155,12 @@ TEST(DataBuffer, Int32) {
::testing::ElementsAre(1, 1 << 30));
}
TEST(DataBuffer, Uint32) {
Array recovered = ToFlatBufferAndBack<ArrayDataType::kUint32>({1, 1U << 31});
EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kUint32>().data,
::testing::ElementsAre(1, 1U << 31));
}
TEST(DataBuffer, Int16) {
Array recovered = ToFlatBufferAndBack<ArrayDataType::kInt16>({1, 1 << 14});
EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kInt16>().data,

View File

@ -2307,6 +2307,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
return ArrayDataType::kInt16;
case INT32:
return ArrayDataType::kInt32;
case UINT32:
return ArrayDataType::kUint32;
case INT64:
return ArrayDataType::kInt64;
case UINT64:

View File

@ -64,4 +64,7 @@ enum IODataType {
// Variant type
VARIANT = 15;
// Uint32
UINT32 = 16;
}

View File

@ -485,6 +485,12 @@ BenchmarkTfLiteModel::CreateRandomTensorData(const TfLiteTensor& t,
return CreateInputTensorData<int32_t>(
num_elements, std::uniform_int_distribution<int32_t>(low, high));
}
case kTfLiteUInt32: {
int low = has_value_range ? low_range : 0;
int high = has_value_range ? high_range : 99;
return CreateInputTensorData<uint32_t>(
num_elements, std::uniform_int_distribution<uint32_t>(low, high));
}
case kTfLiteInt16: {
int low = has_value_range ? low_range : 0;
int high = has_value_range ? high_range : 99;

View File

@ -68,6 +68,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
return TensorType_FLOAT64;
case kTfLiteInt32:
return TensorType_INT32;
case kTfLiteUInt32:
return TensorType_UINT32;
case kTfLiteUInt8:
return TensorType_UINT8;
case kTfLiteInt8:

View File

@ -422,6 +422,9 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
case TensorType_INT32:
bytes_required *= sizeof(int32_t);
break;
case TensorType_UINT32:
bytes_required *= sizeof(uint32_t);
break;
case TensorType_UINT8:
bytes_required *= sizeof(uint8_t);
break;

View File

@ -30,6 +30,8 @@ TEST(TypeToTfLiteType, TypeMapsAreInverseOfEachOther) {
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt16>::Type>());
EXPECT_EQ(kTfLiteInt32,
typeToTfLiteType<TfLiteTypeToType<kTfLiteInt32>::Type>());
EXPECT_EQ(kTfLiteUInt32,
typeToTfLiteType<TfLiteTypeToType<kTfLiteUInt32>::Type>());
EXPECT_EQ(kTfLiteFloat32,
typeToTfLiteType<TfLiteTypeToType<kTfLiteFloat32>::Type>());
EXPECT_EQ(kTfLiteUInt8,

View File

@ -96,7 +96,10 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
*bytes = sizeof(float);
break;
case kTfLiteInt32:
*bytes = sizeof(int);
*bytes = sizeof(int32_t);
break;
case kTfLiteUInt32:
*bytes = sizeof(uint32_t);
break;
case kTfLiteUInt8:
*bytes = sizeof(uint8_t);