Create quantized mean.
PiperOrigin-RevId: 238455331
This commit is contained in:
parent
aaa0ea6191
commit
872b9bd9fc
@ -322,6 +322,7 @@ cc_library(
|
|||||||
"reference/integer_ops/l2normalization.h",
|
"reference/integer_ops/l2normalization.h",
|
||||||
"reference/integer_ops/log_softmax.h",
|
"reference/integer_ops/log_softmax.h",
|
||||||
"reference/integer_ops/logistic.h",
|
"reference/integer_ops/logistic.h",
|
||||||
|
"reference/integer_ops/mean.h",
|
||||||
"reference/integer_ops/mul.h",
|
"reference/integer_ops/mul.h",
|
||||||
"reference/integer_ops/pooling.h",
|
"reference/integer_ops/pooling.h",
|
||||||
"reference/integer_ops/softmax.h",
|
"reference/integer_ops/softmax.h",
|
||||||
|
@ -0,0 +1,73 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace reference_integer_ops {
|
||||||
|
|
||||||
|
inline void Mean(const tflite::MeanParams& op_params, int32_t multiplier,
|
||||||
|
int32_t shift, const RuntimeShape& unextended_input_shape,
|
||||||
|
const int8_t* input_data, int32 input_zero_point,
|
||||||
|
const RuntimeShape& unextended_output_shape,
|
||||||
|
int8_t* output_data, int32 output_zero_point) {
|
||||||
|
// Current implementation only supports dimension equals 4 and simultaneous
|
||||||
|
// reduction over width and height.
|
||||||
|
TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
|
||||||
|
TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||||
|
const RuntimeShape input_shape =
|
||||||
|
RuntimeShape::ExtendedShape(4, unextended_input_shape);
|
||||||
|
const RuntimeShape output_shape =
|
||||||
|
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||||
|
const int output_batch = output_shape.Dims(0);
|
||||||
|
const int output_height = output_shape.Dims(1);
|
||||||
|
const int output_width = output_shape.Dims(2);
|
||||||
|
const int output_depth = output_shape.Dims(3);
|
||||||
|
const int input_height = input_shape.Dims(1);
|
||||||
|
const int input_width = input_shape.Dims(2);
|
||||||
|
const int num_elements_in_axis = input_width * input_height;
|
||||||
|
|
||||||
|
TFLITE_DCHECK_EQ(op_params.axis_count, 2);
|
||||||
|
TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
|
||||||
|
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
|
||||||
|
TFLITE_DCHECK_EQ(output_height, 1);
|
||||||
|
TFLITE_DCHECK_EQ(output_width, 1);
|
||||||
|
|
||||||
|
for (int out_b = 0; out_b < output_batch; ++out_b) {
|
||||||
|
for (int out_d = 0; out_d < output_depth; ++out_d) {
|
||||||
|
int32 acc = 0;
|
||||||
|
for (int in_h = 0; in_h < input_height; ++in_h) {
|
||||||
|
for (int in_w = 0; in_w < input_width; ++in_w) {
|
||||||
|
acc += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)] -
|
||||||
|
input_zero_point;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
|
||||||
|
acc = acc > 0 ? (acc + num_elements_in_axis / 2) / num_elements_in_axis
|
||||||
|
: (acc - num_elements_in_axis / 2) / num_elements_in_axis;
|
||||||
|
acc += output_zero_point;
|
||||||
|
acc = std::min(std::max(acc, -128), 127);
|
||||||
|
output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
|
||||||
|
static_cast<int8_t>(acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference_integer_ops
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_MEAN_H_
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/gemm_support.h"
|
#include "tensorflow/lite/kernels/gemm_support.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
@ -37,6 +38,13 @@ enum KernelType {
|
|||||||
kReference,
|
kReference,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct OpData {
|
||||||
|
int32_t multiplier;
|
||||||
|
int shift;
|
||||||
|
// The index of the temporary tensor where the quantized inputs are cached.
|
||||||
|
int scratch_tensor_index;
|
||||||
|
};
|
||||||
|
|
||||||
struct OpContext {
|
struct OpContext {
|
||||||
OpContext(TfLiteContext* context, TfLiteNode* node) {
|
OpContext(TfLiteContext* context, TfLiteNode* node) {
|
||||||
params = reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
|
params = reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
|
||||||
@ -54,14 +62,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
|||||||
gemm_support::IncrementUsageCounter(context);
|
gemm_support::IncrementUsageCounter(context);
|
||||||
// Creates two temp tensors to store index and axis for internal
|
// Creates two temp tensors to store index and axis for internal
|
||||||
// implementation only.
|
// implementation only.
|
||||||
auto* scratch_tensor_index = new int;
|
auto* op_data = new OpData();
|
||||||
context->AddTensors(context, 3, scratch_tensor_index);
|
context->AddTensors(context, 3, &op_data->scratch_tensor_index);
|
||||||
return scratch_tensor_index;
|
return op_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
gemm_support::DecrementUsageCounter(context);
|
gemm_support::DecrementUsageCounter(context);
|
||||||
delete reinterpret_cast<int*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resizes the temp tensor that stores resolved axis.
|
// Resizes the temp tensor that stores resolved axis.
|
||||||
@ -152,10 +160,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) {
|
|||||||
TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
||||||
OpContext* op_context) {
|
OpContext* op_context) {
|
||||||
// Creates a temp index to iterate through input data.
|
// Creates a temp index to iterate through input data.
|
||||||
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
|
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
TfLiteIntArrayFree(node->temporaries);
|
TfLiteIntArrayFree(node->temporaries);
|
||||||
node->temporaries = TfLiteIntArrayCreate(3);
|
node->temporaries = TfLiteIntArrayCreate(3);
|
||||||
node->temporaries->data[0] = *scratch_tensor_index;
|
node->temporaries->data[0] = op_data->scratch_tensor_index;
|
||||||
TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0);
|
||||||
scratch_tensor->type = kTfLiteInt32;
|
scratch_tensor->type = kTfLiteInt32;
|
||||||
scratch_tensor->allocation_type = kTfLiteArenaRw;
|
scratch_tensor->allocation_type = kTfLiteArenaRw;
|
||||||
@ -165,11 +173,11 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
|
|||||||
context->ResizeTensor(context, scratch_tensor, index_size));
|
context->ResizeTensor(context, scratch_tensor, index_size));
|
||||||
|
|
||||||
// Creates a temp tensor to store resolved axis given input data.
|
// Creates a temp tensor to store resolved axis given input data.
|
||||||
node->temporaries->data[1] = *scratch_tensor_index + 1;
|
node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
|
||||||
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
|
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
|
||||||
resolved_axis->type = kTfLiteInt32;
|
resolved_axis->type = kTfLiteInt32;
|
||||||
// Creates a temp tensor to store temp sums when calculating mean.
|
// Creates a temp tensor to store temp sums when calculating mean.
|
||||||
node->temporaries->data[2] = *scratch_tensor_index + 2;
|
node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
|
||||||
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
|
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
|
||||||
switch (op_context->input->type) {
|
switch (op_context->input->type) {
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
@ -226,9 +234,18 @@ TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
|
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
|
||||||
|
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
// reduce_mean requires a buffer to store intermediate sum result.
|
// reduce_mean requires a buffer to store intermediate sum result.
|
||||||
OpContext op_context(context, node);
|
OpContext op_context(context, node);
|
||||||
|
if (op_context.input->type == kTfLiteInt8) {
|
||||||
|
const double real_multiplier =
|
||||||
|
static_cast<double>(op_context.input->params.scale) /
|
||||||
|
static_cast<double>(op_context.output->params.scale);
|
||||||
|
int exponent;
|
||||||
|
QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent);
|
||||||
|
data->shift = exponent;
|
||||||
|
}
|
||||||
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
|
TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
|
||||||
if (!IsConstantTensor(op_context.axis)) {
|
if (!IsConstantTensor(op_context.axis)) {
|
||||||
SetTensorToDynamic(temp_sum);
|
SetTensorToDynamic(temp_sum);
|
||||||
@ -252,6 +269,7 @@ void ResolveAxis(const int* axis_data, int axis_count,
|
|||||||
template <KernelType kernel_type>
|
template <KernelType kernel_type>
|
||||||
TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
|
||||||
OpContext op_context(context, node);
|
OpContext op_context(context, node);
|
||||||
|
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
int num_axis = static_cast<int>(NumElements(op_context.axis));
|
int num_axis = static_cast<int>(NumElements(op_context.axis));
|
||||||
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
|
||||||
@ -296,6 +314,20 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (op_context.input->type == kTfLiteInt8) {
|
||||||
|
tflite::MeanParams op_params;
|
||||||
|
op_params.axis_count = num_axis;
|
||||||
|
ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
|
||||||
|
const TfLiteTensor* input = op_context.input;
|
||||||
|
reference_integer_ops::Mean(
|
||||||
|
op_params, data->multiplier, data->shift, GetTensorShape(input),
|
||||||
|
GetTensorData<int8_t>(input), op_context.input->params.zero_point,
|
||||||
|
GetTensorShape(op_context.output),
|
||||||
|
GetTensorData<int8_t>(op_context.output),
|
||||||
|
op_context.output->params.zero_point);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
#define TF_LITE_MEAN(kernel_type, data_type, temp_data_type) \
|
#define TF_LITE_MEAN(kernel_type, data_type, temp_data_type) \
|
||||||
kernel_type::Mean<>( \
|
kernel_type::Mean<>( \
|
||||||
GetTensorData<data_type>(op_context.input), \
|
GetTensorData<data_type>(op_context.input), \
|
||||||
|
@ -397,6 +397,40 @@ TEST(ConstUint8MeanOpTest, KeepDims) {
|
|||||||
ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
|
ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ConstInt8MeanOpTest, QuantizedSameScale) {
|
||||||
|
float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
|
||||||
|
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1,
|
||||||
|
0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9,
|
||||||
|
0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3,
|
||||||
|
0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
|
||||||
|
MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 9}, -1.0, 1.0},
|
||||||
|
{TensorType_INT8, {2}, -1.0, 1.0}, {2}, {1, 2}, true);
|
||||||
|
m.QuantizeAndPopulate<int8_t>(m.Input(), data);
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9}));
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ConstInt8MeanOpTest, QuantizedDifferentScale) {
|
||||||
|
float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
|
||||||
|
std::vector<float> data = {0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.4, 0.5, 0.1,
|
||||||
|
0.1, 0.1, 0.1, 0.4, 0.2, 0.2, 0.2, 0.9, 0.9,
|
||||||
|
0.9, 0.9, 0.2, 0.3, 0.7, 0.7, 0.1, 0.1, 0.3,
|
||||||
|
0.3, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4};
|
||||||
|
MeanOpConstModel m({TensorType_INT8, {1, 2, 2, 9}, -1.0, 1.0},
|
||||||
|
{TensorType_INT8, {2}, -4.0, 4.0}, {2}, {1, 2}, true);
|
||||||
|
m.QuantizeAndPopulate<int8_t>(m.Input(), data);
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 9}));
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(
|
||||||
|
{0.35, 0.325, 0.2, 0.35, 0.375, 0.325, 0.225, 0.45, 0.425},
|
||||||
|
kQuantizedTolerance)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(DynamicUint8MeanOpTest, NotKeepDims) {
|
TEST(DynamicUint8MeanOpTest, NotKeepDims) {
|
||||||
float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
|
float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
|
||||||
std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
|
std::vector<float> data = {1.3, -4.8, -3.6, 0.24};
|
||||||
|
Loading…
Reference in New Issue
Block a user