diff --git a/RELEASE.md b/RELEASE.md index f764b5ae543..9566460462e 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -100,6 +100,9 @@ function for a given signaturedef. * Add int8 support for `ReshapeV2`. * Add experimental support for optimization with sparsity. + * Add nominal support for unsigned 32-bit integer tensor types. Note that + very few TFLite kernels support this type natively, so its use in mobile + ML authoring is generally discouraged. * TF Core: * Corrected higher-order gradients of control flow constructs (`tf.cond`, `tf.while_loop`, and compositions like `tf.foldl`) computed with diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 23d5d9612f7..4ab877bac01 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -165,7 +165,8 @@ static StatusOr GetTFLiteType(Type type, case 16: return tflite::TensorType_INT16; case 32: - return tflite::TensorType_INT32; + return itype.isUnsigned() ? tflite::TensorType_UINT32 + : tflite::TensorType_INT32; case 64: return itype.isUnsigned() ? tflite::TensorType_UINT64 : tflite::TensorType_INT64; diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index 6e5af0889c5..af9fda26597 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -79,6 +79,8 @@ static std::string TfLiteTensorString(const TfLiteTensor& tensor) { switch (tensor.type) { case kTfLiteInt32: return TfLiteTypedTensorString(tensor); + case kTfLiteUInt32: + return TfLiteTypedTensorString(tensor); case kTfLiteInt64: return TfLiteTypedTensorString(tensor); case kTfLiteFloat32: diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 735e8be269f..213186f23c3 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -119,6 +119,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { return DT_INT16; case toco::IODataType::INT32: return DT_INT32; + case toco::IODataType::UINT32: + return DT_UINT32; case toco::IODataType::INT64: return DT_INT64; case toco::IODataType::UINT64: diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 811796bcbcd..733a5a33a3f 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -41,6 +41,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { return builder.getF64Type(); case tflite::TensorType_INT32: return builder.getIntegerType(32); + case tflite::TensorType_UINT32: + return builder.getIntegerType(32, /*isSigned=*/false); case tflite::TensorType_UINT8: return builder.getIntegerType(8, /*isSigned=*/false); case tflite::TensorType_INT64: @@ -86,6 +88,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { return tensorflow::DT_INT16; case tflite::TensorType_INT32: return tensorflow::DT_INT32; + case tflite::TensorType_UINT32: + return tensorflow::DT_UINT32; case tflite::TensorType_INT64: return tensorflow::DT_INT64; case tflite::TensorType_STRING: @@ -121,6 +125,8 @@ StatusOr TfTypeToTflType(tensorflow::DataType type) { return tflite::TensorType_INT16; case tensorflow::DT_INT32: return tflite::TensorType_INT32; + case tensorflow::DT_UINT32: + return tflite::TensorType_UINT32; case tensorflow::DT_INT64: return tflite::TensorType_INT64; case tensorflow::DT_UINT64: diff --git a/tensorflow/lite/c/c_api_types.h b/tensorflow/lite/c/c_api_types.h index d6dc5141aa9..01284778711 100644 --- a/tensorflow/lite/c/c_api_types.h +++ b/tensorflow/lite/c/c_api_types.h @@ -75,6 +75,7 @@ typedef enum { kTfLiteUInt64 = 13, kTfLiteResource = 14, kTfLiteVariant = 15, + kTfLiteUInt32 = 16, } TfLiteType; // Legacy. Will be deprecated in favor of TfLiteAffineQuantization. diff --git a/tensorflow/lite/c/common.c b/tensorflow/lite/c/common.c index d47ec4e4bb2..aaa98a98ebe 100644 --- a/tensorflow/lite/c/common.c +++ b/tensorflow/lite/c/common.c @@ -199,6 +199,8 @@ const char* TfLiteTypeGetName(TfLiteType type) { return "INT16"; case kTfLiteInt32: return "INT32"; + case kTfLiteUInt32: + return "UINT32"; case kTfLiteUInt8: return "UINT8"; case kTfLiteInt8: diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index 59ad97761f9..e7d97edc995 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -296,6 +296,7 @@ typedef union TfLitePtrUnion { * GetTensorData(tensor) instead, otherwise only access .data, as other * members are deprecated. */ int32_t* i32; + uint32_t* u32; int64_t* i64; uint64_t* u64; float* f; diff --git a/tensorflow/lite/c/common_test.cc b/tensorflow/lite/c/common_test.cc index e8425f110a9..7a45db15aba 100644 --- a/tensorflow/lite/c/common_test.cc +++ b/tensorflow/lite/c/common_test.cc @@ -83,6 +83,7 @@ TEST(Types, TestTypeNames) { EXPECT_EQ(type_name(kTfLiteFloat16), "FLOAT16"); EXPECT_EQ(type_name(kTfLiteInt16), "INT16"); EXPECT_EQ(type_name(kTfLiteInt32), "INT32"); + EXPECT_EQ(type_name(kTfLiteUInt32), "UINT32"); EXPECT_EQ(type_name(kTfLiteUInt8), "UINT8"); EXPECT_EQ(type_name(kTfLiteUInt64), "UINT64"); EXPECT_EQ(type_name(kTfLiteInt8), "INT8"); diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 9eb9048f3bf..40b0c9be792 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -851,6 +851,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_INT32: *type = kTfLiteInt32; return kTfLiteOk; + case TensorType_UINT32: + *type = kTfLiteUInt32; + return kTfLiteOk; case TensorType_UINT8: *type = kTfLiteUInt8; return kTfLiteOk; diff --git a/tensorflow/lite/delegates/flex/util.cc b/tensorflow/lite/delegates/flex/util.cc index 2ba9161bb30..ffb5bc210d8 100644 --- a/tensorflow/lite/delegates/flex/util.cc +++ b/tensorflow/lite/delegates/flex/util.cc @@ -68,6 +68,8 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { return TF_INT16; case kTfLiteInt32: return TF_INT32; + case kTfLiteUInt32: + return TF_UINT32; case kTfLiteUInt8: return TF_UINT8; case kTfLiteInt8: diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc index 010beeb8f76..0b5b6c232c6 100644 --- a/tensorflow/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -189,9 +189,10 @@ TEST(BasicInterpreter, CheckAllocate) { TfLiteType type; size_t size; } cases[] = { - {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)}, - {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)}, - {kTfLiteInt16, sizeof(int16_t)}, {kTfLiteFloat16, sizeof(TfLiteFloat16)}, + {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)}, + {kTfLiteUInt32, sizeof(uint32_t)}, {kTfLiteUInt8, sizeof(uint8_t)}, + {kTfLiteInt64, sizeof(int64_t)}, {kTfLiteInt16, sizeof(int16_t)}, + {kTfLiteFloat16, sizeof(TfLiteFloat16)}, }; for (auto test : cases) { @@ -261,6 +262,7 @@ TEST(BasicInterpreter, CheckQuantization) { TEST(BasicInterpreter, CheckResize) { const float floats[] = {-3., -4.}; const int32_t int32s[] = {-3, -4}; + const uint32_t uint32s[] = {3, 4}; const uint8_t uint8s[] = {3, 4}; const int64_t int64s[] = {6, -7}; const int16_t int16s[] = {8, -9}; @@ -274,6 +276,7 @@ TEST(BasicInterpreter, CheckResize) { } cases[] = { {kTfLiteFloat32, sizeof(float), reinterpret_cast(floats)}, {kTfLiteInt32, sizeof(int32_t), reinterpret_cast(int32s)}, + {kTfLiteUInt32, sizeof(uint32_t), reinterpret_cast(uint32s)}, {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, {kTfLiteInt16, sizeof(int16_t), reinterpret_cast(int16s)}, @@ -313,8 +316,9 @@ TEST(BasicInterpreter, CheckResize) { TEST(BasicInterpreter, CheckAlignment) { struct { TfLiteType type; - } cases[] = {{kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8}, - {kTfLiteInt64}, {kTfLiteInt16}, {kTfLiteFloat16}}; + } cases[] = {{kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt32}, + {kTfLiteUInt8}, {kTfLiteInt64}, {kTfLiteInt16}, + {kTfLiteFloat16}}; for (auto test : cases) { Interpreter interpreter; diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index a781cf3fcad..41bd44aea62 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -486,6 +486,9 @@ int TfLiteTypeGetSize(TfLiteType type) { case kTfLiteInt32: TF_LITE_ASSERT_EQ(sizeof(int32_t), 4); return 4; + case kTfLiteUInt32: + TF_LITE_ASSERT_EQ(sizeof(uint32_t), 4); + return 4; case kTfLiteInt64: TF_LITE_ASSERT_EQ(sizeof(int64_t), 8); return 8; diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 7cc986ab483..ec2d2484fcb 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -915,6 +915,7 @@ TensorType GetTensorType() { if (std::is_same::value) return TensorType_INT8; if (std::is_same::value) return TensorType_INT16; if (std::is_same::value) return TensorType_INT32; + if (std::is_same::value) return TensorType_UINT32; if (std::is_same::value) return TensorType_INT64; if (std::is_same::value) return TensorType_UINT8; if (std::is_same::value) return TensorType_STRING; @@ -955,6 +956,16 @@ struct TypeUnion { typedef int32_t ScalarType; }; +template <> +struct TypeUnion { + public: + // NOLINTNEXTLINE + static constexpr TensorType tensor_type = TensorType::TensorType_UINT32; + // NOLINTNEXTLINE + static constexpr TfLiteType tflite_type = TfLiteType::kTfLiteUInt32; + typedef uint32_t ScalarType; +}; + template <> struct TypeUnion { public: diff --git a/tensorflow/lite/micro/memory_helpers.cc b/tensorflow/lite/micro/memory_helpers.cc index 08bd9a893aa..2d8f7597a21 100644 --- a/tensorflow/lite/micro/memory_helpers.cc +++ b/tensorflow/lite/micro/memory_helpers.cc @@ -63,6 +63,9 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) { case kTfLiteInt32: *size = sizeof(int32_t); break; + case kTfLiteUInt32: + *size = sizeof(uint32_t); + break; case kTfLiteUInt8: *size = sizeof(uint8_t); break; diff --git a/tensorflow/lite/micro/memory_helpers_test.cc b/tensorflow/lite/micro/memory_helpers_test.cc index 5f28dea3750..230539c30db 100644 --- a/tensorflow/lite/micro/memory_helpers_test.cc +++ b/tensorflow/lite/micro/memory_helpers_test.cc @@ -136,6 +136,10 @@ TF_LITE_MICRO_TEST(TestTypeSizeOf) { tflite::TfLiteTypeSizeOf(kTfLiteInt32, &size)); TF_LITE_MICRO_EXPECT_EQ(sizeof(int32_t), size); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + tflite::TfLiteTypeSizeOf(kTfLiteUInt32, &size)); + TF_LITE_MICRO_EXPECT_EQ(sizeof(uint32_t), size); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, tflite::TfLiteTypeSizeOf(kTfLiteUInt8, &size)); TF_LITE_MICRO_EXPECT_EQ(sizeof(uint8_t), size); diff --git a/tensorflow/lite/objc/sources/TFLInterpreter.mm b/tensorflow/lite/objc/sources/TFLInterpreter.mm index 03a20f01d67..58b009dd79c 100644 --- a/tensorflow/lite/objc/sources/TFLInterpreter.mm +++ b/tensorflow/lite/objc/sources/TFLInterpreter.mm @@ -421,6 +421,7 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_ case kTfLiteString: case kTfLiteComplex64: case kTfLiteComplex128: + case kTfLiteUInt32: case kTfLiteUInt64: case kTfLiteResource: case kTfLiteVariant: diff --git a/tensorflow/lite/optional_debug_tools.cc b/tensorflow/lite/optional_debug_tools.cc index 1ec0d1ad523..d02d2d29921 100644 --- a/tensorflow/lite/optional_debug_tools.cc +++ b/tensorflow/lite/optional_debug_tools.cc @@ -51,6 +51,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteFloat32"; case kTfLiteInt32: return "kTfLiteInt32"; + case kTfLiteUInt32: + return "kTfLiteUInt32"; case kTfLiteUInt8: return "kTfLiteUInt8"; case kTfLiteInt8: diff --git a/tensorflow/lite/portable_type_to_tflitetype.h b/tensorflow/lite/portable_type_to_tflitetype.h index 9fbcfb8ed1e..83a0ac6c5ad 100644 --- a/tensorflow/lite/portable_type_to_tflitetype.h +++ b/tensorflow/lite/portable_type_to_tflitetype.h @@ -59,6 +59,7 @@ struct TfLiteTypeToType {}; // Specializations below // No string mapping is included here, since the TF Lite packed representation // doesn't correspond to a C++ type well. MATCH_TYPE_AND_TFLITE_TYPE(int32_t, kTfLiteInt32); +MATCH_TYPE_AND_TFLITE_TYPE(uint32_t, kTfLiteUInt32); MATCH_TYPE_AND_TFLITE_TYPE(int16_t, kTfLiteInt16); MATCH_TYPE_AND_TFLITE_TYPE(int64_t, kTfLiteInt64); MATCH_TYPE_AND_TFLITE_TYPE(float, kTfLiteFloat32); diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc index 1785aa02c4b..5fabf660e2e 100644 --- a/tensorflow/lite/python/interpreter_wrapper/numpy.cc +++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc @@ -42,6 +42,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_FLOAT64; case kTfLiteInt32: return NPY_INT32; + case kTfLiteUInt32: + return NPY_UINT32; case kTfLiteInt16: return NPY_INT16; case kTfLiteUInt8: @@ -80,6 +82,8 @@ TfLiteType TfLiteTypeFromPyType(int py_type) { return kTfLiteFloat64; case NPY_INT32: return kTfLiteInt32; + case NPY_UINT32: + return kTfLiteUInt32; case NPY_INT16: return kTfLiteInt16; case NPY_UINT8: diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index a70ea5bb932..2a744fba452 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -73,6 +73,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_FLOAT64; case kTfLiteInt32: return TensorType_INT32; + case kTfLiteUInt32: + return TensorType_UINT32; case kTfLiteUInt8: return TensorType_UINT8; case kTfLiteInt8: diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index f69bd326ff0..1ed390711ff 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -66,6 +66,7 @@ _MAP_TF_TO_TFLITE_TYPES = { dtypes.complex128: _types_pb2.COMPLEX128, dtypes.resource: _types_pb2.RESOURCE, dtypes.variant: _types_pb2.VARIANT, + dtypes.uint32: _types_pb2.UINT32, } _MAP_TFLITE_ENUM_TO_TF_TYPES = { @@ -81,6 +82,7 @@ _MAP_TFLITE_ENUM_TO_TF_TYPES = { 9: dtypes.int8, 10: dtypes.float64, 11: dtypes.complex128, + 16: dtypes.uint32, } _TFLITE_FILE_IDENTIFIER = b"TFL3" diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py index 7e0f2c9bf3c..0fd2be19362 100644 --- a/tensorflow/lite/python/util_test.py +++ b/tensorflow/lite/python/util_test.py @@ -47,6 +47,8 @@ class UtilTest(test_util.TensorFlowTestCase): util.convert_dtype_to_tflite_type(dtypes.float16), _types_pb2.FLOAT16) self.assertEqual( util.convert_dtype_to_tflite_type(dtypes.int32), _types_pb2.INT32) + self.assertEqual( + util.convert_dtype_to_tflite_type(dtypes.uint32), _types_pb2.UINT32) self.assertEqual( util.convert_dtype_to_tflite_type(dtypes.uint8), _types_pb2.QUANTIZED_UINT8) @@ -89,13 +91,16 @@ class UtilTest(test_util.TensorFlowTestCase): util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64) self.assertEqual( util._convert_tflite_enum_type_to_tf_type(11), dtypes.complex128) + self.assertEqual( + util._convert_tflite_enum_type_to_tf_type(16), dtypes.uint32) with self.assertRaises(ValueError) as error: util._convert_tflite_enum_type_to_tf_type(20) self.assertEqual( "Unsupported enum 20. The valid map of enum to tf types is : " "{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, " "5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, " - "10: tf.float64, 11: tf.complex128}", str(error.exception)) + "10: tf.float64, 11: tf.complex128, 16: tf.uint32}", + str(error.exception)) def testTensorName(self): with ops.Graph().as_default(): @@ -108,6 +113,30 @@ class UtilTest(test_util.TensorFlowTestCase): got_name = util.get_tensor_name(out_tensors[i]) self.assertEqual(got_name, expect_names[i]) + def testUint32PassThrough(self): + model = tf.keras.Sequential([ + tf.keras.layers.InputLayer(input_shape=(4,), dtype=tf.uint32), + tf.keras.layers.Reshape(target_shape=(2, 2)) + ]) + converter = tf.lite.TFLiteConverter.from_keras_model(model) + tflite_model = converter.convert() + interpreter = tf.lite.Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details()[0] + output_details = interpreter.get_output_details()[0] + + self.assertEqual(input_details["dtype"], np.uint32) + self.assertEqual(output_details["dtype"], np.uint32) + + in_array = np.array([[1, 1, 1, 1]], dtype="uint32") * ((1 << 32) - 1) + expected_out = np.reshape(in_array, (2, 2)) + + interpreter.set_tensor(input_details["index"], in_array) + interpreter.invoke() + + output_data = interpreter.get_tensor(output_details["index"])[0] + self.assertAllEqual(expected_out, output_data) + @test_util.enable_control_flow_v2 def testRemoveLowerUsingSwitchMerge(self): with ops.Graph().as_default(): diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 278afef007c..dae3a465fde 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -47,6 +47,7 @@ enum TensorType : byte { UINT64 = 12, RESOURCE = 13, VARIANT = 14, + UINT32 = 15, } // Custom quantization parameters for experimenting with new quantization diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 23e055b6935..c96049c297c 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -404,11 +404,12 @@ enum TensorType { TensorType_UINT64 = 12, TensorType_RESOURCE = 13, TensorType_VARIANT = 14, + TensorType_UINT32 = 15, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_VARIANT + TensorType_MAX = TensorType_UINT32 }; -inline const TensorType (&EnumValuesTensorType())[15] { +inline const TensorType (&EnumValuesTensorType())[16] { static const TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -424,13 +425,14 @@ inline const TensorType (&EnumValuesTensorType())[15] { TensorType_COMPLEX128, TensorType_UINT64, TensorType_RESOURCE, - TensorType_VARIANT + TensorType_VARIANT, + TensorType_UINT32 }; return values; } inline const char * const *EnumNamesTensorType() { - static const char * const names[16] = { + static const char * const names[17] = { "FLOAT32", "FLOAT16", "INT32", @@ -446,13 +448,14 @@ inline const char * const *EnumNamesTensorType() { "UINT64", "RESOURCE", "VARIANT", + "UINT32", nullptr }; return names; } inline const char *EnumNameTensorType(TensorType e) { - if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_VARIANT)) return ""; + if (flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT32)) return ""; const size_t index = static_cast(e); return EnumNamesTensorType()[index]; } diff --git a/tensorflow/lite/testing/split.h b/tensorflow/lite/testing/split.h index c23f6f90ce0..d70ed28a3c6 100644 --- a/tensorflow/lite/testing/split.h +++ b/tensorflow/lite/testing/split.h @@ -58,6 +58,16 @@ inline std::vector Split(const string& s, const string& delimiter) { return fields; } +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + // NOLINTNEXTLINE(runtime/deprecated_fn) + fields.push_back(strtol(s.data() + p.first, nullptr, 10)); + } + return fields; +} + template <> inline std::vector Split(const string& s, const string& delimiter) { std::vector fields; diff --git a/tensorflow/lite/testing/tf_driver.cc b/tensorflow/lite/testing/tf_driver.cc index b63aeccafbd..481030c596f 100644 --- a/tensorflow/lite/testing/tf_driver.cc +++ b/tensorflow/lite/testing/tf_driver.cc @@ -162,6 +162,10 @@ void TfDriver::SetInput(const string& values_as_string, num_values_available = FillTensorWithData(tensor, values_as_string); break; + case tensorflow::DT_UINT32: + num_values_available = + FillTensorWithData(tensor, values_as_string); + break; case tensorflow::DT_UINT8: num_values_available = FillTensorWithData(tensor, values_as_string); @@ -224,6 +228,8 @@ string TfDriver::ReadOutput(const tensorflow::Tensor& tensor) { return TensorDataToCsvString(tensor); case tensorflow::DT_INT32: return TensorDataToCsvString(tensor); + case tensorflow::DT_UINT32: + return TensorDataToCsvString(tensor); case tensorflow::DT_INT64: return TensorDataToCsvString(tensor); case tensorflow::DT_UINT8: diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc index c858bf051d1..7ac7e0747bd 100644 --- a/tensorflow/lite/testing/tflite_driver.cc +++ b/tensorflow/lite/testing/tflite_driver.cc @@ -333,6 +333,8 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose, return TypedCheck(verbose, tensor); case kTfLiteInt32: return TypedCheck(verbose, tensor); + case kTfLiteUInt32: + return TypedCheck(verbose, tensor); case kTfLiteInt64: return TypedCheck(verbose, tensor); case kTfLiteUInt64: @@ -485,6 +487,12 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) { SetTensorData(values, tensor->data.raw); break; } + case kTfLiteUInt32: { + const auto& values = testing::Split(csv_values, ","); + if (!CheckSizes(tensor->bytes, values.size())) return; + SetTensorData(values, tensor->data.raw); + break; + } case kTfLiteInt64: { const auto& values = testing::Split(csv_values, ","); if (!CheckSizes(tensor->bytes, values.size())) return; @@ -586,6 +594,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) { case kTfLiteInt32: expected_output_[id]->SetData(csv_values); break; + case kTfLiteUInt32: + expected_output_[id]->SetData(csv_values); + break; case kTfLiteInt64: expected_output_[id]->SetData(csv_values); break; @@ -692,6 +703,8 @@ string TfLiteDriver::ReadOutput(int id) { return JoinDefault(tensor->data.f, num_elements, ","); case kTfLiteInt32: return JoinDefault(tensor->data.i32, num_elements, ","); + case kTfLiteUInt32: + return JoinDefault(tensor->data.u32, num_elements, ","); case kTfLiteInt64: return JoinDefault(tensor->data.i64, num_elements, ","); case kTfLiteUInt64: diff --git a/tensorflow/lite/toco/export_tensorflow.cc b/tensorflow/lite/toco/export_tensorflow.cc index 7ecf6cc7d44..f3bdc9b2dc4 100644 --- a/tensorflow/lite/toco/export_tensorflow.cc +++ b/tensorflow/lite/toco/export_tensorflow.cc @@ -41,6 +41,7 @@ using tensorflow::DT_FLOAT; using tensorflow::DT_INT16; using tensorflow::DT_INT32; using tensorflow::DT_INT64; +using tensorflow::DT_UINT32; using tensorflow::DT_UINT8; using tensorflow::GraphDef; using tensorflow::TensorProto; @@ -59,6 +60,8 @@ tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type, return tensorflow::DT_UINT8; case ArrayDataType::kInt32: return tensorflow::DT_INT32; + case ArrayDataType::kUint32: + return tensorflow::DT_UINT32; case ArrayDataType::kInt64: return tensorflow::DT_INT64; case ArrayDataType::kString: @@ -2438,6 +2441,9 @@ void AddPlaceholder(const std::string& name, ArrayDataType type, case ArrayDataType::kInt32: (*placeholder->mutable_attr())["dtype"].set_type(DT_INT32); break; + case ArrayDataType::kUint32: + (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT32); + break; case ArrayDataType::kInt64: (*placeholder->mutable_attr())["dtype"].set_type(DT_INT64); break; diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index 2adfe838c3d..27e004751aa 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -57,6 +57,7 @@ using tensorflow::DT_INT32; using tensorflow::DT_INT64; using tensorflow::DT_QUINT8; using tensorflow::DT_STRING; +using tensorflow::DT_UINT32; using tensorflow::DT_UINT8; using tensorflow::GraphDef; using tensorflow::NodeDef; @@ -185,6 +186,8 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kBool; else if (dtype == DT_INT32) return ArrayDataType::kInt32; + else if (dtype == DT_UINT32) + return ArrayDataType::kUint32; else if (dtype == DT_INT64) return ArrayDataType::kInt64; else if (dtype == DT_STRING) @@ -295,6 +298,18 @@ struct TensorTraits { } }; +template <> +struct TensorTraits { + static int size(const TensorProto& p) { return p.uint32_val_size(); } + static int32 get(const TensorProto& p, int i) { return p.uint32_val(i); } + static std::string accessor_name() { return "uint32_val"; } + static std::string type_name() { return "uint32"; } + static void CopyFromContent(const TensorProto& p, std::vector* data) { + toco::port::CopyToBuffer(p.tensor_content(), + reinterpret_cast(data->data())); + } +}; + template <> struct TensorTraits { static int size(const TensorProto& p) { return p.int64_val_size(); } @@ -432,6 +447,23 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, &output_int_data); } +tensorflow::Status ImportUint32Array(const TensorProto& input_tensor, + Array* output_array) { + CHECK_EQ(input_tensor.dtype(), DT_UINT32); + const auto& input_shape = input_tensor.tensor_shape(); + CHECK_LE(input_shape.dim_size(), 6); + int input_flat_size; + auto status = ImportShape(input_shape.dim(), &input_flat_size, + output_array->mutable_shape()); + if (!status.ok()) return status; + + auto& output_int_data = + output_array->GetMutableBuffer().data; + output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); + return ImportTensorData(input_tensor, input_flat_size, + &output_int_data); +} + tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); @@ -757,6 +789,10 @@ tensorflow::Status ConvertConstOperator( array.data_type = ArrayDataType::kInt32; status = ImportInt32Array(tensor, &array); break; + case DT_UINT32: + array.data_type = ArrayDataType::kUint32; + status = ImportUint32Array(tensor, &array); + break; case DT_QUINT8: array.data_type = ArrayDataType::kUint8; status = ImportQuint8Array(tensor, &array); @@ -1473,7 +1509,6 @@ tensorflow::Status ConditionallyConvertConstOperator( model); } } - switch (GetDataTypeAttr(node, "dtype")) { case DT_FLOAT: case DT_INT32: diff --git a/tensorflow/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc index 98ce18bf38e..ef5a077b766 100644 --- a/tensorflow/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/lite/toco/import_tensorflow_test.cc @@ -37,6 +37,7 @@ using tensorflow::DT_INT64; using tensorflow::DT_INVALID; using tensorflow::DT_QUINT8; using tensorflow::DT_STRING; +using tensorflow::DT_UINT32; using tensorflow::NodeDef; using tensorflow::Status; using ::testing::ElementsAre; @@ -127,6 +128,11 @@ void BuildConstNode(std::initializer_list shape, t.add_int_val(i % std::numeric_limits::max() + 1); } break; + case DT_UINT32: + for (int64_t i = 0; i < num_elements; ++i) { + t.add_int_val(i % std::numeric_limits::max() + 1); + } + break; case DT_QUINT8: for (int64_t i = 0; i < num_elements; ++i) { t.add_int_val(i % std::numeric_limits::max() + 1); diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 077206dcd20..cf76e626849 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -48,6 +48,7 @@ namespace tflite { {ArrayDataType::kUint8, ::tflite::TensorType_UINT8}, {ArrayDataType::kInt16, ::tflite::TensorType_INT16}, {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, + {ArrayDataType::kUint32, ::tflite::TensorType_UINT32}, {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, {ArrayDataType::kUint64, ::tflite::TensorType_UINT64}, {ArrayDataType::kString, ::tflite::TensorType_STRING}, diff --git a/tensorflow/lite/toco/tflite/types.cc b/tensorflow/lite/toco/tflite/types.cc index 9d4ab8434d1..d241b560e19 100644 --- a/tensorflow/lite/toco/tflite/types.cc +++ b/tensorflow/lite/toco/tflite/types.cc @@ -92,6 +92,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { return ::tflite::TensorType_INT16; case ArrayDataType::kInt32: return ::tflite::TensorType_INT32; + case ArrayDataType::kUint32: + return ::tflite::TensorType_UINT32; case ArrayDataType::kInt64: return ::tflite::TensorType_INT64; case ArrayDataType::kUint8: @@ -117,6 +119,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { return ArrayDataType::kInt16; case ::tflite::TensorType_INT32: return ArrayDataType::kInt32; + case ::tflite::TensorType_UINT32: + return ArrayDataType::kUint32; case ::tflite::TensorType_INT64: return ArrayDataType::kInt64; case ::tflite::TensorType_STRING: @@ -143,6 +147,8 @@ flatbuffers::Offset> DataBuffer::Serialize( return CopyBuffer(array, builder); case ArrayDataType::kInt32: return CopyBuffer(array, builder); + case ArrayDataType::kUint32: + return CopyBuffer(array, builder); case ArrayDataType::kInt64: return CopyBuffer(array, builder); case ArrayDataType::kString: @@ -170,6 +176,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, return CopyBuffer(buffer, array); case ::tflite::TensorType_INT32: return CopyBuffer(buffer, array); + case ::tflite::TensorType_UINT32: + return CopyBuffer(buffer, array); case ::tflite::TensorType_INT64: return CopyBuffer(buffer, array); case ::tflite::TensorType_STRING: diff --git a/tensorflow/lite/toco/tflite/types_test.cc b/tensorflow/lite/toco/tflite/types_test.cc index efa2911b5b8..e1f4a65bc28 100644 --- a/tensorflow/lite/toco/tflite/types_test.cc +++ b/tensorflow/lite/toco/tflite/types_test.cc @@ -71,6 +71,7 @@ TEST(DataType, SupportedTypes) { std::vector> testdata = { {ArrayDataType::kUint8, ::tflite::TensorType_UINT8}, {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, + {ArrayDataType::kUint32, ::tflite::TensorType_UINT32}, {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}, {ArrayDataType::kBool, ::tflite::TensorType_BOOL}, @@ -154,6 +155,12 @@ TEST(DataBuffer, Int32) { ::testing::ElementsAre(1, 1 << 30)); } +TEST(DataBuffer, Uint32) { + Array recovered = ToFlatBufferAndBack({1, 1U << 31}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(1, 1U << 31)); +} + TEST(DataBuffer, Int16) { Array recovered = ToFlatBufferAndBack({1, 1 << 14}); EXPECT_THAT(recovered.GetBuffer().data, diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index b34f4922154..35a422948e9 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -2307,6 +2307,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { return ArrayDataType::kInt16; case INT32: return ArrayDataType::kInt32; + case UINT32: + return ArrayDataType::kUint32; case INT64: return ArrayDataType::kInt64; case UINT64: diff --git a/tensorflow/lite/toco/types.proto b/tensorflow/lite/toco/types.proto index 45489984f14..7e886b49be4 100644 --- a/tensorflow/lite/toco/types.proto +++ b/tensorflow/lite/toco/types.proto @@ -64,4 +64,7 @@ enum IODataType { // Variant type VARIANT = 15; + + // Uint32 + UINT32 = 16; } diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 036af73ea09..cb08bf38168 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -485,6 +485,12 @@ BenchmarkTfLiteModel::CreateRandomTensorData(const TfLiteTensor& t, return CreateInputTensorData( num_elements, std::uniform_int_distribution(low, high)); } + case kTfLiteUInt32: { + int low = has_value_range ? low_range : 0; + int high = has_value_range ? high_range : 99; + return CreateInputTensorData( + num_elements, std::uniform_int_distribution(low, high)); + } case kTfLiteInt16: { int low = has_value_range ? low_range : 0; int high = has_value_range ? high_range : 99; diff --git a/tensorflow/lite/tools/serialization/enum_mapping.h b/tensorflow/lite/tools/serialization/enum_mapping.h index 721ce3b3c32..a21271aa6c7 100644 --- a/tensorflow/lite/tools/serialization/enum_mapping.h +++ b/tensorflow/lite/tools/serialization/enum_mapping.h @@ -68,6 +68,8 @@ inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { return TensorType_FLOAT64; case kTfLiteInt32: return TensorType_INT32; + case kTfLiteUInt32: + return TensorType_UINT32; case kTfLiteUInt8: return TensorType_UINT8; case kTfLiteInt8: diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc index dcb154a7e0e..e23b36e4695 100644 --- a/tensorflow/lite/tools/verifier.cc +++ b/tensorflow/lite/tools/verifier.cc @@ -422,6 +422,9 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, case TensorType_INT32: bytes_required *= sizeof(int32_t); break; + case TensorType_UINT32: + bytes_required *= sizeof(uint32_t); + break; case TensorType_UINT8: bytes_required *= sizeof(uint8_t); break; diff --git a/tensorflow/lite/type_to_tflitetype_test.cc b/tensorflow/lite/type_to_tflitetype_test.cc index da6d7a63cc7..30bc2e5860f 100644 --- a/tensorflow/lite/type_to_tflitetype_test.cc +++ b/tensorflow/lite/type_to_tflitetype_test.cc @@ -30,6 +30,8 @@ TEST(TypeToTfLiteType, TypeMapsAreInverseOfEachOther) { typeToTfLiteType::Type>()); EXPECT_EQ(kTfLiteInt32, typeToTfLiteType::Type>()); + EXPECT_EQ(kTfLiteUInt32, + typeToTfLiteType::Type>()); EXPECT_EQ(kTfLiteFloat32, typeToTfLiteType::Type>()); EXPECT_EQ(kTfLiteUInt8, diff --git a/tensorflow/lite/util.cc b/tensorflow/lite/util.cc index 1395a4ea239..995d52bee9b 100644 --- a/tensorflow/lite/util.cc +++ b/tensorflow/lite/util.cc @@ -96,7 +96,10 @@ TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type, *bytes = sizeof(float); break; case kTfLiteInt32: - *bytes = sizeof(int); + *bytes = sizeof(int32_t); + break; + case kTfLiteUInt32: + *bytes = sizeof(uint32_t); break; case kTfLiteUInt8: *bytes = sizeof(uint8_t);