From 69ba31ace44af28399393b7e630ac8562f707e60 Mon Sep 17 00:00:00 2001 From: Jian Li Date: Fri, 21 Jun 2019 07:54:28 -0700 Subject: [PATCH] Create int8 sum Op. PiperOrigin-RevId: 254397376 --- .../internal/reference/reference_ops.h | 10 ++- tensorflow/lite/kernels/reduce.cc | 61 +++++++++++++------ tensorflow/lite/kernels/reduce_test.cc | 15 +++++ tensorflow/lite/kernels/register.cc | 4 +- tensorflow/lite/toco/tflite/op_version.cc | 1 + tensorflow/lite/toco/tflite/operator.cc | 5 ++ tensorflow/lite/toco/tflite/operator_test.cc | 4 ++ 7 files changed, 79 insertions(+), 21 deletions(-) diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 25e98066562..4e8211d178c 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -3371,8 +3371,14 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point, const int num_axis_dimensions, bool keep_dims, int* temp_index, int* resolved_axis, U* temp_sum, bool compute_sum) { - gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Uint8" - : "Mean/Uint8"); + const bool uint8_case = std::is_same::value; + if (uint8_case) { + gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Uint8" + : "Mean/Uint8"); + } else { + gemmlowp::ScopedProfilingLabel label(compute_sum ? "Sum/Int8" + : "Mean/Int8"); + } // Reset output data. size_t num_outputs = 1; for (int idx = 0; idx < output_num_dims; ++idx) { diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc index 4e222915fd7..c5e2c673497 100644 --- a/tensorflow/lite/kernels/reduce.cc +++ b/tensorflow/lite/kernels/reduce.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include @@ -524,11 +525,13 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); const auto& input = op_context.input; const auto& output = op_context.output; - if (input->type != kTfLiteUInt8 || + const bool same_scale = (input->params.scale == output->params.scale && - input->params.zero_point == output->params.zero_point)) { - return EvalGeneric(context, node); - } else { + input->params.zero_point == output->params.zero_point); + const bool eight_bit_quantized = + input->type == kTfLiteUInt8 || input->type == kTfLiteInt8; + const bool need_rescale = (eight_bit_quantized && !same_scale); + if (need_rescale) { // Rescaling 8bit reduce sum. int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); @@ -541,20 +544,42 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum)); } - - TF_LITE_ENSURE( - context, - reference_ops::QuantizedMeanOrSum<>( - GetTensorData(op_context.input), - op_context.input->params.zero_point, op_context.input->params.scale, - op_context.input->dims->data, op_context.input->dims->size, - GetTensorData(op_context.output), - op_context.output->params.zero_point, - op_context.output->params.scale, op_context.output->dims->data, - op_context.output->dims->size, GetTensorData(op_context.axis), - num_axis, op_context.params->keep_dims, - GetTensorData(temp_index), GetTensorData(resolved_axis), - GetTensorData(temp_sum), /*compute_sum=*/true)); + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE( + context, + reference_ops::QuantizedMeanOrSum<>( + GetTensorData(op_context.input), + op_context.input->params.zero_point, + op_context.input->params.scale, op_context.input->dims->data, + op_context.input->dims->size, + GetTensorData(op_context.output), + op_context.output->params.zero_point, + op_context.output->params.scale, op_context.output->dims->data, + op_context.output->dims->size, + GetTensorData(op_context.axis), num_axis, + op_context.params->keep_dims, GetTensorData(temp_index), + GetTensorData(resolved_axis), GetTensorData(temp_sum), + /*compute_sum=*/true)); + } + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE( + context, + reference_ops::QuantizedMeanOrSum<>( + GetTensorData(op_context.input), + op_context.input->params.zero_point, + op_context.input->params.scale, op_context.input->dims->data, + op_context.input->dims->size, + GetTensorData(op_context.output), + op_context.output->params.zero_point, + op_context.output->params.scale, op_context.output->dims->data, + op_context.output->dims->size, + GetTensorData(op_context.axis), num_axis, + op_context.params->keep_dims, GetTensorData(temp_index), + GetTensorData(resolved_axis), GetTensorData(temp_sum), + /*compute_sum=*/true)); + } + } else { + return EvalGeneric(context, node); } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc index a7fa3fcd054..e3518cbe22e 100644 --- a/tensorflow/lite/kernels/reduce_test.cc +++ b/tensorflow/lite/kernels/reduce_test.cc @@ -642,6 +642,21 @@ TEST(DynamicUint8SumOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance))); } +TEST(ConstInt8SumOpTest, Rescale) { + const std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.3}; + SumOpConstModel m({TensorType_INT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_INT8, {2}, -5.0, 5.0}, {1}, {1}, false); + // Expect the sum to be 0.4 + 0.3 + 0.5 = 1.2 and 0.2 + 0.4 + 0.3 = 0.9. + const std::vector expected_sum = {1.2, 0.9}; + const float kQuantizedTolerance = GetTolerance(-5.0, 5.0); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(expected_sum, kQuantizedTolerance))); +} + // Tests for reduce_prod TEST(ConstFloatProdOpTest, NotKeepDims) { diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 0ba29ec95e6..226a3714d33 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -320,7 +320,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_COS, Register_COS()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); AddBuiltin(BuiltinOperator_TILE, Register_TILE()); - AddBuiltin(BuiltinOperator_SUM, Register_SUM()); + AddBuiltin(BuiltinOperator_SUM, Register_SUM(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD()); AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX(), /* min_version */ 1, diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index 6c8a507cada..fb48541721c 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -100,6 +100,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kMean, 1}, "1.6.0"}, {{OperatorType::kMean, 2}, "1.14.0"}, {{OperatorType::kSum, 1}, "1.10.0"}, + {{OperatorType::kSum, 2}, kPendingReleaseOpVersion}, {{OperatorType::kReduceMax, 1}, "1.11.0"}, {{OperatorType::kReduceMax, 2}, "1.14.0"}, {{OperatorType::kReduceMin, 1}, "1.11.0"}, diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 0a8449cd710..1cf3f31089b 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -1232,6 +1232,11 @@ class Sum } int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index ae3ec2ef2f6..28d342382f8 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -895,6 +895,10 @@ TEST_F(OperatorTest, VersioningMeanTest) { SimpleVersioningTest(); } +TEST_F(OperatorTest, VersioningSumTest) { + SimpleVersioningTest(); +} + TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest(); } TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest(); }