Create int8 sum Op.
PiperOrigin-RevId: 254397376
This commit is contained in:
parent
a858d3ae90
commit
69ba31ace4
@ -3371,8 +3371,14 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point,
|
||||
const int num_axis_dimensions, bool keep_dims,
|
||||
int* temp_index, int* resolved_axis, U* temp_sum,
|
||||
bool compute_sum) {
|
||||
gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Uint8"
|
||||
: "Mean/Uint8");
|
||||
const bool uint8_case = std::is_same<T, int8_t>::value;
|
||||
if (uint8_case) {
|
||||
gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Uint8"
|
||||
: "Mean/Uint8");
|
||||
} else {
|
||||
gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Int8"
|
||||
: "Mean/Int8");
|
||||
}
|
||||
// Reset output data.
|
||||
size_t num_outputs = 1;
|
||||
for (int idx = 0; idx < output_num_dims; ++idx) {
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <string.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
@ -524,11 +525,13 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpContext op_context(context, node);
|
||||
const auto& input = op_context.input;
|
||||
const auto& output = op_context.output;
|
||||
if (input->type != kTfLiteUInt8 ||
|
||||
const bool same_scale =
|
||||
(input->params.scale == output->params.scale &&
|
||||
input->params.zero_point == output->params.zero_point)) {
|
||||
return EvalGeneric<kReference, kSum>(context, node);
|
||||
} else {
|
||||
input->params.zero_point == output->params.zero_point);
|
||||
const bool eight_bit_quantized =
|
||||
input->type == kTfLiteUInt8 || input->type == kTfLiteInt8;
|
||||
const bool need_rescale = (eight_bit_quantized && !same_scale);
|
||||
if (need_rescale) {
|
||||
// Rescaling 8bit reduce sum.
|
||||
int num_axis = static_cast<int>(NumElements(op_context.axis));
|
||||
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
|
||||
@ -541,20 +544,42 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
|
||||
TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::QuantizedMeanOrSum<>(
|
||||
GetTensorData<uint8_t>(op_context.input),
|
||||
op_context.input->params.zero_point, op_context.input->params.scale,
|
||||
op_context.input->dims->data, op_context.input->dims->size,
|
||||
GetTensorData<uint8_t>(op_context.output),
|
||||
op_context.output->params.zero_point,
|
||||
op_context.output->params.scale, op_context.output->dims->data,
|
||||
op_context.output->dims->size, GetTensorData<int>(op_context.axis),
|
||||
num_axis, op_context.params->keep_dims,
|
||||
GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
|
||||
GetTensorData<int32>(temp_sum), /*compute_sum=*/true));
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::QuantizedMeanOrSum<>(
|
||||
GetTensorData<uint8_t>(op_context.input),
|
||||
op_context.input->params.zero_point,
|
||||
op_context.input->params.scale, op_context.input->dims->data,
|
||||
op_context.input->dims->size,
|
||||
GetTensorData<uint8_t>(op_context.output),
|
||||
op_context.output->params.zero_point,
|
||||
op_context.output->params.scale, op_context.output->dims->data,
|
||||
op_context.output->dims->size,
|
||||
GetTensorData<int>(op_context.axis), num_axis,
|
||||
op_context.params->keep_dims, GetTensorData<int>(temp_index),
|
||||
GetTensorData<int>(resolved_axis), GetTensorData<int32>(temp_sum),
|
||||
/*compute_sum=*/true));
|
||||
}
|
||||
if (input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::QuantizedMeanOrSum<>(
|
||||
GetTensorData<int8_t>(op_context.input),
|
||||
op_context.input->params.zero_point,
|
||||
op_context.input->params.scale, op_context.input->dims->data,
|
||||
op_context.input->dims->size,
|
||||
GetTensorData<int8_t>(op_context.output),
|
||||
op_context.output->params.zero_point,
|
||||
op_context.output->params.scale, op_context.output->dims->data,
|
||||
op_context.output->dims->size,
|
||||
GetTensorData<int>(op_context.axis), num_axis,
|
||||
op_context.params->keep_dims, GetTensorData<int>(temp_index),
|
||||
GetTensorData<int>(resolved_axis), GetTensorData<int32>(temp_sum),
|
||||
/*compute_sum=*/true));
|
||||
}
|
||||
} else {
|
||||
return EvalGeneric<kReference, kSum>(context, node);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
|
@ -642,6 +642,21 @@ TEST(DynamicUint8SumOpTest, KeepDims) {
|
||||
ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance)));
|
||||
}
|
||||
|
||||
TEST(ConstInt8SumOpTest, Rescale) {
|
||||
const std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.3};
|
||||
SumOpConstModel m({TensorType_INT8, {1, 3, 2}, -1.0, 1.0},
|
||||
{TensorType_INT8, {2}, -5.0, 5.0}, {1}, {1}, false);
|
||||
// Expect the sum to be 0.4 + 0.3 + 0.5 = 1.2 and 0.2 + 0.4 + 0.3 = 0.9.
|
||||
const std::vector<float> expected_sum = {1.2, 0.9};
|
||||
const float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
|
||||
m.QuantizeAndPopulate<int8_t>(m.Input(), data);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
|
||||
EXPECT_THAT(
|
||||
m.GetDequantizedOutput<int8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear(expected_sum, kQuantizedTolerance)));
|
||||
}
|
||||
|
||||
// Tests for reduce_prod
|
||||
|
||||
TEST(ConstFloatProdOpTest, NotKeepDims) {
|
||||
|
@ -320,7 +320,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
|
||||
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
|
||||
AddBuiltin(BuiltinOperator_SUM, Register_SUM());
|
||||
AddBuiltin(BuiltinOperator_SUM, Register_SUM(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD());
|
||||
AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX(),
|
||||
/* min_version */ 1,
|
||||
|
@ -100,6 +100,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
{{OperatorType::kMean, 1}, "1.6.0"},
|
||||
{{OperatorType::kMean, 2}, "1.14.0"},
|
||||
{{OperatorType::kSum, 1}, "1.10.0"},
|
||||
{{OperatorType::kSum, 2}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kReduceMax, 1}, "1.11.0"},
|
||||
{{OperatorType::kReduceMax, 2}, "1.14.0"},
|
||||
{{OperatorType::kReduceMin, 1}, "1.11.0"},
|
||||
|
@ -1232,6 +1232,11 @@ class Sum
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
@ -895,6 +895,10 @@ TEST_F(OperatorTest, VersioningMeanTest) {
|
||||
SimpleVersioningTest<MeanOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningSumTest) {
|
||||
SimpleVersioningTest<TensorFlowSumOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
|
||||
|
||||
TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }
|
||||
|
Loading…
Reference in New Issue
Block a user