Create quantized mean.

PiperOrigin-RevId: 238455331
This commit is contained in:
Jian Li 2019-03-14 09:14:53 -07:00 committed by TensorFlower Gardener
parent aaa0ea6191
commit 872b9bd9fc
4 changed files with 148 additions and 8 deletions

View File

@ -322,6 +322,7 @@ cc_library(
"reference/integer_ops/l2normalization.h",
"reference/integer_ops/log_softmax.h",
"reference/integer_ops/logistic.h",
"reference/integer_ops/mean.h",
"reference/integer_ops/mul.h",
"reference/integer_ops/pooling.h",
"reference/integer_ops/softmax.h",

View File

@ -0,0 +1,73 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace reference_integer_ops {
inline void Mean(const tflite::MeanParams& op_params, int32_t multiplier,
int32_t shift, const RuntimeShape& unextended_input_shape,
const int8_t* input_data, int32 input_zero_point,
const RuntimeShape& unextended_output_shape,
int8_t* output_data, int32 output_zero_point) {
// Current implementation only supports dimension equals 4 and simultaneous
// reduction over width and height.
TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_batch = output_shape.Dims(0);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_depth = output_shape.Dims(3);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int num_elements_in_axis = input_width * input_height;
TFLITE_DCHECK_EQ(op_params.axis_count, 2);
TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
TFLITE_DCHECK_EQ(output_height, 1);
TFLITE_DCHECK_EQ(output_width, 1);
for (int out_b = 0; out_b < output_batch; ++out_b) {
for (int out_d = 0; out_d < output_depth; ++out_d) {
int32 acc = 0;
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)] -
input_zero_point;
}
}
acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
acc = acc > 0 ? (acc + num_elements_in_axis / 2) / num_elements_in_axis
: (acc - num_elements_in_axis / 2) / num_elements_in_axis;
acc += output_zero_point;
acc = std::min(std::max(acc, -128), 127);
output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
static_cast<int8_t>(acc);
}
}
}
} // namespace reference_integer_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/gemm_support.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
@ -37,6 +38,13 @@ enum KernelType {
kReference,
};
struct OpData {
int32_t multiplier;
int shift;
// The index of the temporary tensor where the quantized inputs are cached.
int scratch_tensor_index;
};
struct OpContext {
OpContext(TfLiteContext* context, TfLiteNode* node) {
params = reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
@ -54,14 +62,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
gemm_support::IncrementUsageCounter(context);
// Creates two temp tensors to store index and axis for internal
// implementation only.
auto* scratch_tensor_index = new int;
context->AddTensors(context, 3, scratch_tensor_index);
return scratch_tensor_index;
auto* op_data = new OpData();
context->AddTensors(context, 3, &op_data->scratch_tensor_index);
return op_data;
}
void Free(TfLiteContext* context, void* buffer) {
gemm_support::DecrementUsageCounter(context);
delete reinterpret_cast<int*>(buffer);
delete reinterpret_cast<OpData*>(buffer);
}
// Resizes the temp tensor that stores resolved axis.
@ -152,10 +160,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) {
TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
OpContext* op_context) {
// Creates a temp index to iterate through input data.
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = *scratch_tensor_index;
node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0);
scratch_tensor->type = kTfLiteInt32;
scratch_tensor->allocation_type = kTfLiteArenaRw;
@ -165,11 +173,11 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
context->ResizeTensor(context, scratch_tensor, index_size));
// Creates a temp tensor to store resolved axis given input data.
node->temporaries->data[1] = *scratch_tensor_index + 1;
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
resolved_axis->type = kTfLiteInt32;
// Creates a temp tensor to store temp sums when calculating mean.
node->temporaries->data[2] = *scratch_tensor_index + 2;
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
switch (op_context->input->type) {
case kTfLiteFloat32:
@ -226,9 +234,18 @@ TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
OpData* data = reinterpret_cast<OpData*>(node->user_data);
// reduce_mean requires a buffer to store intermediate sum result.
OpContext op_context(context, node);
if (op_context.input->type == kTfLiteInt8) {
const double real_multiplier =
static_cast<double>(op_context.input->params.scale) /
static_cast<double>(op_context.output->params.scale);
int exponent;
QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent);
data->shift = exponent;
}
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
if (!IsConstantTensor(op_context.axis)) {
SetTensorToDynamic(temp_sum);
@ -252,6 +269,7 @@ void ResolveAxis(const int* axis_data, int axis_count,
template <KernelType kernel_type>
TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
@ -296,6 +314,20 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
}
}
if (op_context.input->type == kTfLiteInt8) {
tflite::MeanParams op_params;
op_params.axis_count = num_axis;
ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
const TfLiteTensor* input = op_context.input;
reference_integer_ops::Mean(
op_params, data->multiplier, data->shift, GetTensorShape(input),
GetTensorData<int8_t>(input), op_context.input->params.zero_point,
GetTensorShape(op_context.output),
GetTensorData<int8_t>(op_context.output),
op_context.output->params.zero_point);
return kTfLiteOk;
}
#define TF_LITE_MEAN(kernel_type, data_type, temp_data_type) \
kernel_type::Mean<>( \
GetTensorData<data_type>(op_context.input), \

View File

@ -397,6 +397,40 @@ TEST(ConstUint8MeanOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
}
TEST(ConstInt8MeanOpTest, QuantizedSameScale) {
float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1,
0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9,
0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3,
0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 9}, -1.0, 1.0},
{TensorType_INT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true);
m.QuantizeAndPopulate<int8_t>(m.Input(), data);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9}));
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear(
{0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425},
kQuantizedTolerance)));
}
TEST(ConstInt8MeanOpTest, QuantizedDifferentScale) {
float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1,
0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9,
0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3,
0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 9}, -1.0, 1.0},
{TensorType_INT8, {2}, -4.0, 4.0}, {2}, {1, 2}, true);
m.QuantizeAndPopulate<int8_t>(m.Input(), data);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9}));
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear(
{0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425},
kQuantizedTolerance)));
}
TEST(DynamicUint8MeanOpTest, NotKeepDims) {
float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
std::vector<float> data = {1.3, -4.8, -3.6, 0.24};