Merge pull request #27881 from ANSHUMAN87:unpack-quant-support
PiperOrigin-RevId: 247638716
This commit is contained in:
commit
d38eaa19b3
tensorflow/lite
kernels
toco
@ -369,7 +369,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
||||
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
|
||||
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
|
||||
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
|
||||
AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
|
||||
|
@ -42,9 +42,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
axis += NumDimensions(input);
|
||||
}
|
||||
TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
|
||||
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) {
|
||||
context->ReportError(context,
|
||||
"Currently pack only supports int32 and float32.");
|
||||
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
|
||||
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8) {
|
||||
context->ReportError(context, "Type '%s' is not supported by unpack.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
@ -64,6 +65,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TF_LITE_ENSURE_EQ(context, output->type, input->type);
|
||||
// Guarantee input/output quantization params match as we do not support
|
||||
// rescaling of unpacked quantized tensors.
|
||||
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
|
||||
output->params.zero_point);
|
||||
TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, output, copied_output_shape));
|
||||
}
|
||||
@ -98,9 +104,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
|
||||
break;
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
UnpackImpl<uint8_t>(context, node, input, data->num, data->axis);
|
||||
break;
|
||||
}
|
||||
case kTfLiteInt8: {
|
||||
UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
context->ReportError(context,
|
||||
"Currently pack only supports int32 and float32.");
|
||||
context->ReportError(context, "Type '%s' is not supported by unpack.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
@ -159,6 +159,104 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
|
||||
/*type=*/TensorType_INT32);
|
||||
}
|
||||
|
||||
// uint8 tests.
|
||||
TEST(UnpackOpTest, Uint8ThreeOutputs) {
|
||||
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Uint8ThreeOutputsAxisOne) {
|
||||
Check<uint8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{3}, {3}},
|
||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisOne) {
|
||||
Check<uint8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{3}, {3}},
|
||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisTwo) {
|
||||
Check<uint8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Uint8OneOutput) {
|
||||
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{6}},
|
||||
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
|
||||
Check<uint8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
/*expected_output_shape=*/{{2, 2}, {2, 2}},
|
||||
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
||||
/*type=*/TensorType_UINT8);
|
||||
}
|
||||
|
||||
// int8 tests.
|
||||
TEST(UnpackOpTest, Int8ThreeOutputs) {
|
||||
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Int8ThreeOutputsAxisOne) {
|
||||
Check<int8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{3}, {3}},
|
||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisOne) {
|
||||
Check<int8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{3}, {3}},
|
||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisTwo) {
|
||||
Check<int8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Int8OneOutput) {
|
||||
Check<int8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||
/*expected_output_shape=*/{{6}},
|
||||
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
|
||||
Check<int8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
/*expected_output_shape=*/{{2, 2}, {2, 2}},
|
||||
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
||||
/*type=*/TensorType_INT8);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -64,7 +64,7 @@ bool SupportsQuantization(const Operator& op) {
|
||||
type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
|
||||
type == OperatorType::kLeakyRelu || type == OperatorType::kShape ||
|
||||
type == OperatorType::kExpandDims || type == OperatorType::kPack ||
|
||||
type == OperatorType::kTopK_V2 ||
|
||||
type == OperatorType::kUnpack || type == OperatorType::kTopK_V2 ||
|
||||
type == OperatorType::kRandomUniform ||
|
||||
type == OperatorType::kResizeNearestNeighbor ||
|
||||
type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||
|
||||
|
@ -1806,6 +1806,13 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// If the op take int8/uint8 input, it is version 2.
|
||||
if (input_array.data_type == ArrayDataType::kInt8 ||
|
||||
input_array.data_type == ArrayDataType::kUint8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
@ -818,6 +818,31 @@ TEST_F(OperatorTest, VersioningPackTest) {
|
||||
SimpleVersioningTest<PackOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningUnpackTest) {
|
||||
UnpackOperator op;
|
||||
op.inputs = {"input1"};
|
||||
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
||||
|
||||
Model int32_model;
|
||||
Array& int32_array = int32_model.GetOrCreateArray(op.inputs[0]);
|
||||
int32_array.data_type = ArrayDataType::kInt32;
|
||||
OperatorSignature int32_signature = {.op = &op, .model = &int32_model};
|
||||
EXPECT_EQ(base_op->GetVersion(int32_signature), 1);
|
||||
|
||||
Model uint8_model;
|
||||
Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
|
||||
uint8_array.data_type = ArrayDataType::kUint8;
|
||||
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
|
||||
EXPECT_EQ(base_op->GetVersion(uint8_signature), 2);
|
||||
|
||||
Model int8_model;
|
||||
Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
|
||||
int8_array.data_type = ArrayDataType::kInt8;
|
||||
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
|
||||
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
|
||||
SimpleVersioningTest<BatchToSpaceNDOperator>();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user