diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 93006fad9ec..4a18ee3c097 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -322,6 +322,7 @@ cc_library( "reference/integer_ops/l2normalization.h", "reference/integer_ops/log_softmax.h", "reference/integer_ops/logistic.h", + "reference/integer_ops/mean.h", "reference/integer_ops/mul.h", "reference/integer_ops/pooling.h", "reference/integer_ops/softmax.h", diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h b/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h new file mode 100644 index 00000000000..1c3da93b703 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/mean.h @@ -0,0 +1,73 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_ + +#include "tensorflow/lite/kernels/internal/common.h" + +namespace tflite { +namespace reference_integer_ops { + +inline void Mean(const tflite::MeanParams& op_params, int32_t multiplier, + int32_t shift, const RuntimeShape& unextended_input_shape, + const int8_t* input_data, int32 input_zero_point, + const RuntimeShape& unextended_output_shape, + int8_t* output_data, int32 output_zero_point) { + // Current implementation only supports dimension equals 4 and simultaneous + // reduction over width and height. + TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4); + TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4); + const RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + const int output_batch = output_shape.Dims(0); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_depth = output_shape.Dims(3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int num_elements_in_axis = input_width * input_height; + + TFLITE_DCHECK_EQ(op_params.axis_count, 2); + TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + int32 acc = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)] - + input_zero_point; + } + } + acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift); + acc = acc > 0 ? (acc + num_elements_in_axis / 2) / num_elements_in_axis + : (acc - num_elements_in_axis / 2) / num_elements_in_axis; + acc += output_zero_point; + acc = std::min(std::max(acc, -128), 127); + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = + static_cast(acc); + } + } +} + +} // namespace reference_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_ diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc index 95d2370aaa8..a0f1126048e 100644 --- a/tensorflow/lite/kernels/reduce.cc +++ b/tensorflow/lite/kernels/reduce.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/lite/kernels/gemm_support.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -37,6 +38,13 @@ enum KernelType { kReference, }; +struct OpData { + int32_t multiplier; + int shift; + // The index of the temporary tensor where the quantized inputs are cached. + int scratch_tensor_index; +}; + struct OpContext { OpContext(TfLiteContext* context, TfLiteNode* node) { params = reinterpret_cast(node->builtin_data); @@ -54,14 +62,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { gemm_support::IncrementUsageCounter(context); // Creates two temp tensors to store index and axis for internal // implementation only. - auto* scratch_tensor_index = new int; - context->AddTensors(context, 3, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, 3, &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { gemm_support::DecrementUsageCounter(context); - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } // Resizes the temp tensor that stores resolved axis. @@ -152,10 +160,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) { TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, OpContext* op_context) { // Creates a temp index to iterate through input data. - int* scratch_tensor_index = reinterpret_cast(node->user_data); + OpData* op_data = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); node->temporaries = TfLiteIntArrayCreate(3); - node->temporaries->data[0] = *scratch_tensor_index; + node->temporaries->data[0] = op_data->scratch_tensor_index; TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0); scratch_tensor->type = kTfLiteInt32; scratch_tensor->allocation_type = kTfLiteArenaRw; @@ -165,11 +173,11 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, context->ResizeTensor(context, scratch_tensor, index_size)); // Creates a temp tensor to store resolved axis given input data. - node->temporaries->data[1] = *scratch_tensor_index + 1; + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); resolved_axis->type = kTfLiteInt32; // Creates a temp tensor to store temp sums when calculating mean. - node->temporaries->data[2] = *scratch_tensor_index + 2; + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); switch (op_context->input->type) { case kTfLiteFloat32: @@ -226,9 +234,18 @@ TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); + OpData* data = reinterpret_cast(node->user_data); // reduce_mean requires a buffer to store intermediate sum result. OpContext op_context(context, node); + if (op_context.input->type == kTfLiteInt8) { + const double real_multiplier = + static_cast(op_context.input->params.scale) / + static_cast(op_context.output->params.scale); + int exponent; + QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent); + data->shift = exponent; + } TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); if (!IsConstantTensor(op_context.axis)) { SetTensorToDynamic(temp_sum); @@ -252,6 +269,7 @@ void ResolveAxis(const int* axis_data, int axis_count, template TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); + OpData* data = reinterpret_cast(node->user_data); int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); @@ -296,6 +314,20 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { } } + if (op_context.input->type == kTfLiteInt8) { + tflite::MeanParams op_params; + op_params.axis_count = num_axis; + ResolveAxis(GetTensorData(op_context.axis), num_axis, &op_params); + const TfLiteTensor* input = op_context.input; + reference_integer_ops::Mean( + op_params, data->multiplier, data->shift, GetTensorShape(input), + GetTensorData(input), op_context.input->params.zero_point, + GetTensorShape(op_context.output), + GetTensorData(op_context.output), + op_context.output->params.zero_point); + return kTfLiteOk; + } + #define TF_LITE_MEAN(kernel_type, data_type, temp_data_type) \ kernel_type::Mean<>( \ GetTensorData(op_context.input), \ diff --git a/tensorflow/lite/kernels/reduce_test.cc b/tensorflow/lite/kernels/reduce_test.cc index f9f3cdecbdf..dd852b90aef 100644 --- a/tensorflow/lite/kernels/reduce_test.cc +++ b/tensorflow/lite/kernels/reduce_test.cc @@ -397,6 +397,40 @@ TEST(ConstUint8MeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance))); } +TEST(ConstInt8MeanOpTest, QuantizedSameScale) { + float kQuantizedTolerance = GetTolerance(-5.0, 5.0); + std::vector data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, + 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, + 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, + 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; + MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 9}, -1.0, 1.0}, + {TensorType_INT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425}, + kQuantizedTolerance))); +} + +TEST(ConstInt8MeanOpTest, QuantizedDifferentScale) { + float kQuantizedTolerance = GetTolerance(-5.0, 5.0); + std::vector data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1, + 0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9, + 0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3, + 0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4}; + MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 9}, -1.0, 1.0}, + {TensorType_INT8, {2}, -4.0, 4.0}, {2}, {1, 2}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425}, + kQuantizedTolerance))); +} + TEST(DynamicUint8MeanOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-5.0, 2.0); std::vector data = {1.3, -4.8, -3.6, 0.24};