diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index fb9807b7fa9..d8a6d5d3051 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -369,7 +369,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); - AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK()); + AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV()); AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE()); AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE()); diff --git a/tensorflow/lite/kernels/unpack.cc b/tensorflow/lite/kernels/unpack.cc index eed69ee7e53..3af2e969a7b 100644 --- a/tensorflow/lite/kernels/unpack.cc +++ b/tensorflow/lite/kernels/unpack.cc @@ -42,9 +42,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { axis += NumDimensions(input); } TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input)); - if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) { - context->ReportError(context, - "Currently pack only supports int32 and float32."); + if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 && + input->type != kTfLiteUInt8 && input->type != kTfLiteInt8) { + context->ReportError(context, "Type '%s' is not supported by unpack.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } @@ -64,6 +65,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape); TfLiteTensor* output = GetOutput(context, node, i); TF_LITE_ENSURE_EQ(context, output->type, input->type); + // Guarantee input/output quantization params match as we do not support + // rescaling of unpacked quantized tensors. + TF_LITE_ENSURE_EQ(context, input->params.zero_point, + output->params.zero_point); + TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, output, copied_output_shape)); } @@ -98,9 +104,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { UnpackImpl(context, node, input, data->num, data->axis); break; } + case kTfLiteUInt8: { + UnpackImpl(context, node, input, data->num, data->axis); + break; + } + case kTfLiteInt8: { + UnpackImpl(context, node, input, data->num, data->axis); + break; + } default: { - context->ReportError(context, - "Currently pack only supports int32 and float32."); + context->ReportError(context, "Type '%s' is not supported by unpack.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } } diff --git a/tensorflow/lite/kernels/unpack_test.cc b/tensorflow/lite/kernels/unpack_test.cc index 365970d683e..487fc95ea88 100644 --- a/tensorflow/lite/kernels/unpack_test.cc +++ b/tensorflow/lite/kernels/unpack_test.cc @@ -159,6 +159,104 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) { /*type=*/TensorType_INT32); } +// uint8 tests. +TEST(UnpackOpTest, Uint8ThreeOutputs) { + Check(/*axis=*/0, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{2}, {2}, {2}}, + /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, + /*type=*/TensorType_UINT8); +} + +TEST(UnpackOpTest, Uint8ThreeOutputsAxisOne) { + Check(/*axis=*/1, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{3}, {3}}, + /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}}, + /*type=*/TensorType_UINT8); +} + +TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisOne) { + Check(/*axis=*/-1, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{3}, {3}}, + /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}}, + /*type=*/TensorType_UINT8); +} + +TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisTwo) { + Check(/*axis=*/-2, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{2}, {2}, {2}}, + /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, + /*type=*/TensorType_UINT8); +} + +TEST(UnpackOpTest, Uint8OneOutput) { + Check(/*axis=*/0, /*input_shape=*/{1, 6}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{6}}, + /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}}, + /*type=*/TensorType_UINT8); +} + +TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) { + Check(/*axis=*/2, /*input_shape=*/{2, 2, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8}, + /*expected_output_shape=*/{{2, 2}, {2, 2}}, + /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}}, + /*type=*/TensorType_UINT8); +} + +// int8 tests. +TEST(UnpackOpTest, Int8ThreeOutputs) { + Check(/*axis=*/0, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{2}, {2}, {2}}, + /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, + /*type=*/TensorType_INT8); +} + +TEST(UnpackOpTest, Int8ThreeOutputsAxisOne) { + Check(/*axis=*/1, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{3}, {3}}, + /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}}, + /*type=*/TensorType_INT8); +} + +TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisOne) { + Check(/*axis=*/-1, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{3}, {3}}, + /*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}}, + /*type=*/TensorType_INT8); +} + +TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisTwo) { + Check(/*axis=*/-2, /*input_shape=*/{3, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{2}, {2}, {2}}, + /*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}}, + /*type=*/TensorType_INT8); +} + +TEST(UnpackOpTest, Int8OneOutput) { + Check(/*axis=*/0, /*input_shape=*/{1, 6}, + /*input_data=*/{1, 2, 3, 4, 5, 6}, + /*expected_output_shape=*/{{6}}, + /*expected_output_data=*/{{1, 2, 3, 4, 5, 6}}, + /*type=*/TensorType_INT8); +} + +TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) { + Check(/*axis=*/2, /*input_shape=*/{2, 2, 2}, + /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8}, + /*expected_output_shape=*/{{2, 2}, {2, 2}}, + /*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}}, + /*type=*/TensorType_INT8); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc index 03cb6597738..62eaba0d756 100644 --- a/tensorflow/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/lite/toco/graph_transformations/quantize.cc @@ -64,7 +64,7 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kRelu1 || type == OperatorType::kRelu6 || type == OperatorType::kLeakyRelu || type == OperatorType::kShape || type == OperatorType::kExpandDims || type == OperatorType::kPack || - type == OperatorType::kTopK_V2 || + type == OperatorType::kUnpack || type == OperatorType::kTopK_V2 || type == OperatorType::kRandomUniform || type == OperatorType::kResizeNearestNeighbor || type == OperatorType::kPRelu || type == OperatorType::kReduceMax || diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 15c4d7457b1..09fe72f1ae0 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -1806,6 +1806,13 @@ class Unpack : public BuiltinOperatorinputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8/uint8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8 || + input_array.data_type == ArrayDataType::kUint8) { + return 2; + } return 1; } }; diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index 937b69e331e..eece77327cb 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -818,6 +818,31 @@ TEST_F(OperatorTest, VersioningPackTest) { SimpleVersioningTest(); } +TEST_F(OperatorTest, VersioningUnpackTest) { + UnpackOperator op; + op.inputs = {"input1"}; + auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/); + const BaseOperator* base_op = operator_by_type_map.at(op.type).get(); + + Model int32_model; + Array& int32_array = int32_model.GetOrCreateArray(op.inputs[0]); + int32_array.data_type = ArrayDataType::kInt32; + OperatorSignature int32_signature = {.op = &op, .model = &int32_model}; + EXPECT_EQ(base_op->GetVersion(int32_signature), 1); + + Model uint8_model; + Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]); + uint8_array.data_type = ArrayDataType::kUint8; + OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model}; + EXPECT_EQ(base_op->GetVersion(uint8_signature), 2); + + Model int8_model; + Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]); + int8_array.data_type = ArrayDataType::kInt8; + OperatorSignature int8_signature = {.op = &op, .model = &int8_model}; + EXPECT_EQ(base_op->GetVersion(int8_signature), 2); +} + TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) { SimpleVersioningTest(); }