Adds kernel impl for dequantizing per-channel quantized tensors to float

PiperOrigin-RevId: 299955689
Change-Id: I03c7b9693a2b86518b4b8dd9303b231fc4c025ee
This commit is contained in:
Sachin Joglekar 2020-03-09 15:52:47 -07:00 committed by TensorFlower Gardener
parent 0008a03e43
commit 185dd2376b
4 changed files with 169 additions and 0 deletions

View File

@ -1028,6 +1028,17 @@ cc_test(
],
)
cc_test(
name = "per_channel_dequantize_test",
srcs = ["per_channel_dequantize_test.cc"],
deps = [
":reference_base",
":types",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest_main",
],
)
exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"])
filegroup(

View File

@ -0,0 +1,121 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
TEST(PerChannelDequantize, TestInt8ToFloat_2D) {
const std::vector<float> scales = {0.5, 0.25};
const std::vector<int> zero_points = {-1, -1};
const int quantized_dimension = 0;
const RuntimeShape shape({2, 5});
const std::vector<int8_t> input = {-128, -127, -126, -125, -124,
123, 124, 125, 126, 127};
std::vector<float> output(10, -1);
PerChannelDequantizationParams op_params;
op_params.zero_point = zero_points.data();
op_params.scale = scales.data();
op_params.quantized_dimension = quantized_dimension;
reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
output.data());
EXPECT_THAT(output,
ElementsAreArray(ArrayFloatNear({-63.5, -63, -62.5, -62, -61.5,
31, 31.25, 31.5, 31.75, 32})));
}
TEST(PerChannelDequantize, TestInt8ToFloat_3D) {
const std::vector<float> scales = {0.5, 0.25, 0.5, 0.25, 1.0};
const std::vector<int> zero_points = {-1, 1, -1, 1, 0};
const int quantized_dimension = 2;
const RuntimeShape shape({1, 2, 5});
const std::vector<int8_t> input = {-128, -127, -126, -125, -124,
123, 124, 125, 126, 127};
std::vector<float> output(10, -1);
PerChannelDequantizationParams op_params;
op_params.zero_point = zero_points.data();
op_params.scale = scales.data();
op_params.quantized_dimension = quantized_dimension;
reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
output.data());
EXPECT_THAT(output,
ElementsAreArray(ArrayFloatNear({-63.5, -32, -62.5, -31.5, -124,
62, 30.75, 63, 31.25, 127})));
}
TEST(PerChannelDequantize, TestInt8ToFloat_4DDim0) {
const std::vector<float> scales = {0.5, 0.25};
const std::vector<int> zero_points = {-1, 1};
const int quantized_dimension = 0;
RuntimeShape shape({2, 2, 5, 1});
const std::vector<int8_t> input = {-128, -127, -126, -125, -124, 123, 124,
125, 126, 127, -128, -127, -126, -125,
-124, 123, 124, 125, 126, 127};
std::vector<float> output(20, -1);
PerChannelDequantizationParams op_params;
op_params.zero_point = zero_points.data();
op_params.scale = scales.data();
op_params.quantized_dimension = quantized_dimension;
reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
output.data());
EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
{-63.5, -63, -62.5, -62, -61.5, 62, 62.5,
63, 63.5, 64, -32.25, -32, -31.75, -31.5,
-31.25, 30.5, 30.75, 31, 31.25, 31.5})));
}
TEST(PerChannelDequantize, TestInt8ToFloat_4DDim3) {
const std::vector<float> scales = {0.5, 0.25, 0.5, 0.25, 1.0};
const std::vector<int> zero_points = {-1, 1, -1, 1, 0};
const int quantized_dimension = 3;
RuntimeShape shape({1, 2, 2, 5});
const std::vector<int8_t> input = {-128, -127, -126, -125, -124, 123, 124,
125, 126, 127, -128, -127, -126, -125,
-124, 123, 124, 125, 126, 127};
std::vector<float> output(20, -1);
PerChannelDequantizationParams op_params;
op_params.zero_point = zero_points.data();
op_params.scale = scales.data();
op_params.quantized_dimension = quantized_dimension;
reference_ops::PerChannelDequantize(op_params, shape, input.data(), shape,
output.data());
EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear(
{-63.5, -32, -62.5, -31.5, -124, 62, 30.75,
63, 31.25, 127, -63.5, -32, -62.5, -31.5,
-124, 62, 30.75, 63, 31.25, 127})));
}
} // namespace
} // namespace tflite

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <limits.h>
#include <vector>
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
@ -60,6 +62,35 @@ inline void DequantizeInteger(const tflite::DequantizationParams& op_params,
}
}
// Dequantizes per-channel quantized tensor to float.
template <typename T>
inline void PerChannelDequantize(
const tflite::PerChannelDequantizationParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, float* output_data) {
// Ensure flat size is same.
MatchingFlatSize(input_shape, output_shape);
const int32* zero_point = op_params.zero_point;
const float* scale = op_params.scale;
const int32 quantized_dimension = op_params.quantized_dimension;
const int32 num_dims = input_shape.DimensionsCount();
const int32* dims_data = input_shape.DimsData();
std::vector<int> current_dim(num_dims, 0);
do {
size_t offset =
ReducedOutputOffset(num_dims, reinterpret_cast<const int*>(dims_data),
current_dim.data(), 0, nullptr);
const int channel = current_dim[quantized_dimension];
const int32 val = input_data[offset];
const float result =
static_cast<float>(scale[channel] * (val - zero_point[channel]));
output_data[offset] = result;
} while (NextIndex(num_dims, reinterpret_cast<const int*>(dims_data),
current_dim.data()));
}
} // namespace reference_ops
} // namespace tflite

View File

@ -863,6 +863,12 @@ struct DequantizationParams {
int32 zero_point;
};
struct PerChannelDequantizationParams {
const float* scale;
const int32* zero_point;
int32 quantized_dimension;
};
struct FakeQuantParams {
MinMax minmax;
int32 num_bits;