Create int8 sum Op.

PiperOrigin-RevId: 254397376
This commit is contained in:
Jian Li 2019-06-21 07:54:28 -07:00 committed by TensorFlower Gardener
parent a858d3ae90
commit 69ba31ace4
7 changed files with 79 additions and 21 deletions

View File

@ -3371,8 +3371,14 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point,
const int num_axis_dimensions, bool keep_dims,
int* temp_index, int* resolved_axis, U* temp_sum,
bool compute_sum) {
gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Uint8"
: "Mean/Uint8");
const bool uint8_case = std::is_same<T, int8_t>::value;
if (uint8_case) {
gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Uint8"
: "Mean/Uint8");
} else {
gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Int8"
: "Mean/Int8");
}
// Reset output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <cstdint>
#include <limits>
#include <vector>
@ -524,11 +525,13 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
const auto& input = op_context.input;
const auto& output = op_context.output;
if (input->type != kTfLiteUInt8 ||
const bool same_scale =
(input->params.scale == output->params.scale &&
input->params.zero_point == output->params.zero_point)) {
return EvalGeneric<kReference, kSum>(context, node);
} else {
input->params.zero_point == output->params.zero_point);
const bool eight_bit_quantized =
input->type == kTfLiteUInt8 || input->type == kTfLiteInt8;
const bool need_rescale = (eight_bit_quantized && !same_scale);
if (need_rescale) {
// Rescaling 8bit reduce sum.
int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
@ -541,20 +544,42 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
}
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum<>(
GetTensorData<uint8_t>(op_context.input),
op_context.input->params.zero_point, op_context.input->params.scale,
op_context.input->dims->data, op_context.input->dims->size,
GetTensorData<uint8_t>(op_context.output),
op_context.output->params.zero_point,
op_context.output->params.scale, op_context.output->dims->data,
op_context.output->dims->size, GetTensorData<int>(op_context.axis),
num_axis, op_context.params->keep_dims,
GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
GetTensorData<int32>(temp_sum), /*compute_sum=*/true));
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum<>(
GetTensorData<uint8_t>(op_context.input),
op_context.input->params.zero_point,
op_context.input->params.scale, op_context.input->dims->data,
op_context.input->dims->size,
GetTensorData<uint8_t>(op_context.output),
op_context.output->params.zero_point,
op_context.output->params.scale, op_context.output->dims->data,
op_context.output->dims->size,
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis), GetTensorData<int32>(temp_sum),
/*compute_sum=*/true));
}
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum<>(
GetTensorData<int8_t>(op_context.input),
op_context.input->params.zero_point,
op_context.input->params.scale, op_context.input->dims->data,
op_context.input->dims->size,
GetTensorData<int8_t>(op_context.output),
op_context.output->params.zero_point,
op_context.output->params.scale, op_context.output->dims->data,
op_context.output->dims->size,
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis), GetTensorData<int32>(temp_sum),
/*compute_sum=*/true));
}
} else {
return EvalGeneric<kReference, kSum>(context, node);
}
return kTfLiteOk;

View File

@ -642,6 +642,21 @@ TEST(DynamicUint8SumOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance)));
}
TEST(ConstInt8SumOpTest, Rescale) {
const std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.3};
SumOpConstModel m({TensorType_INT8, {1, 3, 2}, -1.0, 1.0},
{TensorType_INT8, {2}, -5.0, 5.0}, {1}, {1}, false);
// Expect the sum to be 0.4 + 0.3 + 0.5 = 1.2 and 0.2 + 0.4 + 0.3 = 0.9.
const std::vector<float> expected_sum = {1.2, 0.9};
const float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
m.QuantizeAndPopulate<int8_t>(m.Input(), data);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
EXPECT_THAT(
m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear(expected_sum, kQuantizedTolerance)));
}
// Tests for reduce_prod
TEST(ConstFloatProdOpTest, NotKeepDims) {

View File

@ -320,7 +320,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
AddBuiltin(BuiltinOperator_SUM, Register_SUM());
AddBuiltin(BuiltinOperator_SUM, Register_SUM(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD());
AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX(),
/* min_version */ 1,

View File

@ -100,6 +100,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kMean, 1}, "1.6.0"},
{{OperatorType::kMean, 2}, "1.14.0"},
{{OperatorType::kSum, 1}, "1.10.0"},
{{OperatorType::kSum, 2}, kPendingReleaseOpVersion},
{{OperatorType::kReduceMax, 1}, "1.11.0"},
{{OperatorType::kReduceMax, 2}, "1.14.0"},
{{OperatorType::kReduceMin, 1}, "1.11.0"},

View File

@ -1232,6 +1232,11 @@ class Sum
}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};

View File

@ -895,6 +895,10 @@ TEST_F(OperatorTest, VersioningMeanTest) {
SimpleVersioningTest<MeanOperator>();
}
TEST_F(OperatorTest, VersioningSumTest) {
SimpleVersioningTest<TensorFlowSumOperator>();
}
TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }