Merge pull request #27881 from ANSHUMAN87:unpack-quant-support
PiperOrigin-RevId: 247638716
This commit is contained in:
commit
d38eaa19b3
@ -369,7 +369,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
|
AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
|
||||||
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
|
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
|
||||||
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
||||||
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
|
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(),
|
||||||
|
/* min_version */ 1,
|
||||||
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
|
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
|
||||||
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
|
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
|
||||||
AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
|
AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
|
||||||
|
@ -42,9 +42,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
axis += NumDimensions(input);
|
axis += NumDimensions(input);
|
||||||
}
|
}
|
||||||
TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
|
TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
|
||||||
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) {
|
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
|
||||||
context->ReportError(context,
|
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8) {
|
||||||
"Currently pack only supports int32 and float32.");
|
context->ReportError(context, "Type '%s' is not supported by unpack.",
|
||||||
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,6 +65,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
|
TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
|
||||||
TfLiteTensor* output = GetOutput(context, node, i);
|
TfLiteTensor* output = GetOutput(context, node, i);
|
||||||
TF_LITE_ENSURE_EQ(context, output->type, input->type);
|
TF_LITE_ENSURE_EQ(context, output->type, input->type);
|
||||||
|
// Guarantee input/output quantization params match as we do not support
|
||||||
|
// rescaling of unpacked quantized tensors.
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
|
||||||
|
output->params.zero_point);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, context->ResizeTensor(context, output, copied_output_shape));
|
context, context->ResizeTensor(context, output, copied_output_shape));
|
||||||
}
|
}
|
||||||
@ -98,9 +104,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
|
UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case kTfLiteUInt8: {
|
||||||
|
UnpackImpl<uint8_t>(context, node, input, data->num, data->axis);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case kTfLiteInt8: {
|
||||||
|
UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
context->ReportError(context,
|
context->ReportError(context, "Type '%s' is not supported by unpack.",
|
||||||
"Currently pack only supports int32 and float32.");
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -159,6 +159,104 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
|
|||||||
/*type=*/TensorType_INT32);
|
/*type=*/TensorType_INT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// uint8 tests.
|
||||||
|
TEST(UnpackOpTest, Uint8ThreeOutputs) {
|
||||||
|
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{2}, {2}, {2}},
|
||||||
|
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||||
|
/*type=*/TensorType_UINT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Uint8ThreeOutputsAxisOne) {
|
||||||
|
Check<uint8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{3}, {3}},
|
||||||
|
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||||
|
/*type=*/TensorType_UINT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisOne) {
|
||||||
|
Check<uint8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{3}, {3}},
|
||||||
|
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||||
|
/*type=*/TensorType_UINT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisTwo) {
|
||||||
|
Check<uint8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{2}, {2}, {2}},
|
||||||
|
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||||
|
/*type=*/TensorType_UINT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Uint8OneOutput) {
|
||||||
|
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{6}},
|
||||||
|
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
||||||
|
/*type=*/TensorType_UINT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
|
||||||
|
Check<uint8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
|
/*expected_output_shape=*/{{2, 2}, {2, 2}},
|
||||||
|
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
||||||
|
/*type=*/TensorType_UINT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// int8 tests.
|
||||||
|
TEST(UnpackOpTest, Int8ThreeOutputs) {
|
||||||
|
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{2}, {2}, {2}},
|
||||||
|
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||||
|
/*type=*/TensorType_INT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Int8ThreeOutputsAxisOne) {
|
||||||
|
Check<int8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{3}, {3}},
|
||||||
|
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||||
|
/*type=*/TensorType_INT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisOne) {
|
||||||
|
Check<int8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{3}, {3}},
|
||||||
|
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||||
|
/*type=*/TensorType_INT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisTwo) {
|
||||||
|
Check<int8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{2}, {2}, {2}},
|
||||||
|
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||||
|
/*type=*/TensorType_INT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Int8OneOutput) {
|
||||||
|
Check<int8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
|
/*expected_output_shape=*/{{6}},
|
||||||
|
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
||||||
|
/*type=*/TensorType_INT8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
|
||||||
|
Check<int8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
||||||
|
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
|
/*expected_output_shape=*/{{2, 2}, {2, 2}},
|
||||||
|
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
||||||
|
/*type=*/TensorType_INT8);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ bool SupportsQuantization(const Operator& op) {
|
|||||||
type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
|
type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
|
||||||
type == OperatorType::kLeakyRelu || type == OperatorType::kShape ||
|
type == OperatorType::kLeakyRelu || type == OperatorType::kShape ||
|
||||||
type == OperatorType::kExpandDims || type == OperatorType::kPack ||
|
type == OperatorType::kExpandDims || type == OperatorType::kPack ||
|
||||||
type == OperatorType::kTopK_V2 ||
|
type == OperatorType::kUnpack || type == OperatorType::kTopK_V2 ||
|
||||||
type == OperatorType::kRandomUniform ||
|
type == OperatorType::kRandomUniform ||
|
||||||
type == OperatorType::kResizeNearestNeighbor ||
|
type == OperatorType::kResizeNearestNeighbor ||
|
||||||
type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||
|
type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||
|
||||||
|
@ -1806,6 +1806,13 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
|
|||||||
}
|
}
|
||||||
|
|
||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
|
// If the op take int8/uint8 input, it is version 2.
|
||||||
|
if (input_array.data_type == ArrayDataType::kInt8 ||
|
||||||
|
input_array.data_type == ArrayDataType::kUint8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -818,6 +818,31 @@ TEST_F(OperatorTest, VersioningPackTest) {
|
|||||||
SimpleVersioningTest<PackOperator>();
|
SimpleVersioningTest<PackOperator>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, VersioningUnpackTest) {
|
||||||
|
UnpackOperator op;
|
||||||
|
op.inputs = {"input1"};
|
||||||
|
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
|
||||||
|
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
|
||||||
|
|
||||||
|
Model int32_model;
|
||||||
|
Array& int32_array = int32_model.GetOrCreateArray(op.inputs[0]);
|
||||||
|
int32_array.data_type = ArrayDataType::kInt32;
|
||||||
|
OperatorSignature int32_signature = {.op = &op, .model = &int32_model};
|
||||||
|
EXPECT_EQ(base_op->GetVersion(int32_signature), 1);
|
||||||
|
|
||||||
|
Model uint8_model;
|
||||||
|
Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
|
||||||
|
uint8_array.data_type = ArrayDataType::kUint8;
|
||||||
|
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
|
||||||
|
EXPECT_EQ(base_op->GetVersion(uint8_signature), 2);
|
||||||
|
|
||||||
|
Model int8_model;
|
||||||
|
Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
|
||||||
|
int8_array.data_type = ArrayDataType::kInt8;
|
||||||
|
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
|
||||||
|
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
|
TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
|
||||||
SimpleVersioningTest<BatchToSpaceNDOperator>();
|
SimpleVersioningTest<BatchToSpaceNDOperator>();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user