Merge pull request from ANSHUMAN87:unpack-quant-support

PiperOrigin-RevId: 247638716
This commit is contained in:
TensorFlower Gardener 2019-05-10 13:57:12 -07:00
commit d38eaa19b3
6 changed files with 153 additions and 7 deletions
tensorflow/lite

View File

@ -369,7 +369,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());

View File

@ -42,9 +42,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
axis += NumDimensions(input);
}
TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) {
context->ReportError(context,
"Currently pack only supports int32 and float32.");
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8) {
context->ReportError(context, "Type '%s' is not supported by unpack.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
@ -64,6 +65,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
TfLiteTensor* output = GetOutput(context, node, i);
TF_LITE_ENSURE_EQ(context, output->type, input->type);
// Guarantee input/output quantization params match as we do not support
// rescaling of unpacked quantized tensors.
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
output->params.zero_point);
TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output, copied_output_shape));
}
@ -98,9 +104,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
break;
}
case kTfLiteUInt8: {
UnpackImpl<uint8_t>(context, node, input, data->num, data->axis);
break;
}
case kTfLiteInt8: {
UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
break;
}
default: {
context->ReportError(context,
"Currently pack only supports int32 and float32.");
context->ReportError(context, "Type '%s' is not supported by unpack.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

View File

@ -159,6 +159,104 @@ TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
/*type=*/TensorType_INT32);
}
// uint8 tests.
TEST(UnpackOpTest, Uint8ThreeOutputs) {
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{2}, {2}, {2}},
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8ThreeOutputsAxisOne) {
Check<uint8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{3}, {3}},
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisOne) {
Check<uint8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{3}, {3}},
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisTwo) {
Check<uint8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{2}, {2}, {2}},
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8OneOutput) {
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{6}},
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
/*type=*/TensorType_UINT8);
}
TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
Check<uint8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
/*expected_output_shape=*/{{2, 2}, {2, 2}},
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
/*type=*/TensorType_UINT8);
}
// int8 tests.
TEST(UnpackOpTest, Int8ThreeOutputs) {
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{2}, {2}, {2}},
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8ThreeOutputsAxisOne) {
Check<int8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{3}, {3}},
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisOne) {
Check<int8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{3}, {3}},
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisTwo) {
Check<int8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{2}, {2}, {2}},
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8OneOutput) {
Check<int8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
/*input_data=*/{1, 2, 3, 4, 5, 6},
/*expected_output_shape=*/{{6}},
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
/*type=*/TensorType_INT8);
}
TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
Check<int8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
/*expected_output_shape=*/{{2, 2}, {2, 2}},
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
/*type=*/TensorType_INT8);
}
} // namespace
} // namespace tflite

View File

@ -64,7 +64,7 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
type == OperatorType::kLeakyRelu || type == OperatorType::kShape ||
type == OperatorType::kExpandDims || type == OperatorType::kPack ||
type == OperatorType::kTopK_V2 ||
type == OperatorType::kUnpack || type == OperatorType::kTopK_V2 ||
type == OperatorType::kRandomUniform ||
type == OperatorType::kResizeNearestNeighbor ||
type == OperatorType::kPRelu || type == OperatorType::kReduceMax ||

View File

@ -1806,6 +1806,13 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8/uint8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8 ||
input_array.data_type == ArrayDataType::kUint8) {
return 2;
}
return 1;
}
};

View File

@ -818,6 +818,31 @@ TEST_F(OperatorTest, VersioningPackTest) {
SimpleVersioningTest<PackOperator>();
}
TEST_F(OperatorTest, VersioningUnpackTest) {
UnpackOperator op;
op.inputs = {"input1"};
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
Model int32_model;
Array& int32_array = int32_model.GetOrCreateArray(op.inputs[0]);
int32_array.data_type = ArrayDataType::kInt32;
OperatorSignature int32_signature = {.op = &op, .model = &int32_model};
EXPECT_EQ(base_op->GetVersion(int32_signature), 1);
Model uint8_model;
Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
uint8_array.data_type = ArrayDataType::kUint8;
OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
EXPECT_EQ(base_op->GetVersion(uint8_signature), 2);
Model int8_model;
Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
int8_array.data_type = ArrayDataType::kInt8;
OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
}
TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
SimpleVersioningTest<BatchToSpaceNDOperator>();
}