Move DequantizeInputs and QuantizeOutputs into common utility

PiperOrigin-RevId: 314093553
Change-Id: Ib15015f738c3e0bbaa1363ba7df9808940075d2c
This commit is contained in:
Taehee Jeong 2020-06-01 01:14:56 -07:00 committed by TensorFlower Gardener
parent 3ffb4ad2d4
commit 626bb2c4d0
6 changed files with 345 additions and 66 deletions

View File

@ -244,6 +244,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:model_builder", "//tensorflow/lite/delegates/gpu/common:model_builder",
"//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:model_transformer",
"//tensorflow/lite/delegates/gpu/common:quantization_util",
"//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/gl:api2", "//tensorflow/lite/delegates/gpu/gl:api2",
"//tensorflow/lite/kernels/internal:optimized_base", "//tensorflow/lite/kernels/internal:optimized_base",

View File

@ -203,6 +203,30 @@ cc_library(
], ],
) )
cc_library(
name = "quantization_util",
srcs = ["quantization_util.cc"],
hdrs = ["quantization_util.h"],
deps = [
":status",
"//tensorflow/lite:kernel_api",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels/internal:optimized_base",
"//tensorflow/lite/kernels/internal:types",
],
)
cc_test(
name = "quantization_util_test",
srcs = ["quantization_util_test.cc"],
deps = [
":quantization_util",
"//tensorflow/lite:util",
"//tensorflow/lite/micro/testing:micro_test",
"@com_google_googletest//:gtest_main",
],
)
# TODO(impjdi): Add unit test for operations. # TODO(impjdi): Add unit test for operations.
cc_library( cc_library(

View File

@ -0,0 +1,120 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace gpu {
namespace {
void DequantizeInput(TfLiteContext* context, int input_index,
const std::unordered_map<int, int>& quant_conversion_map) {
if (quant_conversion_map.find(input_index) == quant_conversion_map.end()) {
return;
}
int original_tensor_idx = quant_conversion_map.at(input_index);
const TfLiteTensor& dequantized_tflite_tensor = context->tensors[input_index];
const TfLiteTensor& original_tflite_tensor =
context->tensors[original_tensor_idx];
DequantizationParams op_params;
op_params.zero_point = original_tflite_tensor.params.zero_point;
op_params.scale = original_tflite_tensor.params.scale;
if (original_tflite_tensor.type == kTfLiteInt8) {
optimized_ops::Dequantize(op_params,
GetTensorShape(&original_tflite_tensor),
original_tflite_tensor.data.int8,
GetTensorShape(&original_tflite_tensor),
dequantized_tflite_tensor.data.f);
} else if (original_tflite_tensor.type == kTfLiteUInt8) {
optimized_ops::Dequantize(op_params,
GetTensorShape(&original_tflite_tensor),
original_tflite_tensor.data.uint8,
GetTensorShape(&original_tflite_tensor),
dequantized_tflite_tensor.data.f);
}
}
void QuantizeOutput(TfLiteContext* context, int output_index,
const std::unordered_map<int, int>& quant_conversion_map) {
if (quant_conversion_map.find(output_index) == quant_conversion_map.end()) {
return;
}
int original_tensor_idx = quant_conversion_map.at(output_index);
const TfLiteTensor& dequantized_tflite_tensor =
context->tensors[output_index];
const TfLiteTensor& original_tflite_tensor =
context->tensors[original_tensor_idx];
tflite::QuantizationParams op_params;
op_params.zero_point = original_tflite_tensor.params.zero_point;
op_params.scale = original_tflite_tensor.params.scale;
if (original_tflite_tensor.type == kTfLiteInt8) {
optimized_ops::AffineQuantize(op_params,
GetTensorShape(&original_tflite_tensor),
dequantized_tflite_tensor.data.f,
GetTensorShape(&original_tflite_tensor),
original_tflite_tensor.data.int8);
} else if (original_tflite_tensor.type == kTfLiteUInt8) {
optimized_ops::AffineQuantize(op_params,
GetTensorShape(&original_tflite_tensor),
dequantized_tflite_tensor.data.f,
GetTensorShape(&original_tflite_tensor),
original_tflite_tensor.data.uint8);
}
}
} // namespace
absl::Status DequantizeInputs(
TfLiteContext* context, const std::vector<uint32_t>& input_indices,
const std::unordered_map<int, int>& quant_conversion_map) {
for (auto index : input_indices) {
DequantizeInput(context, static_cast<int>(index), quant_conversion_map);
}
return absl::OkStatus();
}
absl::Status DequantizeInputs(
TfLiteContext* context, const std::vector<int64_t>& input_indices,
const std::unordered_map<int, int>& quant_conversion_map) {
for (auto index : input_indices) {
DequantizeInput(context, static_cast<int>(index), quant_conversion_map);
}
return absl::OkStatus();
}
absl::Status QuantizeOutputs(
TfLiteContext* context, const std::vector<uint32_t>& output_indices,
const std::unordered_map<int, int>& quant_conversion_map) {
for (auto index : output_indices) {
QuantizeOutput(context, static_cast<int>(index), quant_conversion_map);
}
return absl::OkStatus();
}
absl::Status QuantizeOutputs(
TfLiteContext* context, const std::vector<int64_t>& output_indices,
const std::unordered_map<int, int>& quant_conversion_map) {
for (auto index : output_indices) {
QuantizeOutput(context, static_cast<int>(index), quant_conversion_map);
}
return absl::OkStatus();
}
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,56 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_
#include <unordered_map>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
// Dequantizes input tensors pre-inference, leaving float tensors intact.
// input_indices contains dequantized (fp32) outputs, that are used as
// inputs to GPU delegate.
// quant_conversion_map contains bidirectional mapping between dequantized
// tensor and its original quantized one.
absl::Status DequantizeInputs(
TfLiteContext* context, const std::vector<uint32_t>& input_indices,
const std::unordered_map<int, int>& quant_conversion_map);
absl::Status DequantizeInputs(
TfLiteContext* context, const std::vector<int64_t>& input_indices,
const std::unordered_map<int, int>& quant_conversion_map);
// Quantizes output tensors post-inference, leaving float tensors intact.
// output_indices contains (fp32) inputs to be quantized, which are outputs of
// GPU delegate.
// quant_conversion_map contains bidirectional mapping between dequantized
// tensor and its original quantized one.
absl::Status QuantizeOutputs(
TfLiteContext* context, const std::vector<uint32_t>& output_indices,
const std::unordered_map<int, int>& quant_conversion_map);
absl::Status QuantizeOutputs(
TfLiteContext* context, const std::vector<int64_t>& output_indices,
const std::unordered_map<int, int>& quant_conversion_map);
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_

View File

@ -0,0 +1,139 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/micro/testing/test_utils.h"
#include "tensorflow/lite/util.h"
using ::testing::Eq;
using ::testing::FloatNear;
using ::testing::Pointwise;
namespace tflite {
namespace gpu {
namespace {
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
const std::vector<int>& data) {
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
TfLiteIntArrayCreate(data.size()));
std::copy(data.begin(), data.end(), result->data);
return result;
}
TEST(DequantizeInputs, Int8) {
TfLiteContext context;
auto input_dims = BuildTfLiteIntArray({1, 3, 2, 1});
std::vector<int8_t> data = {-3, -2, -1, 1, 2, 3};
std::vector<float> dequantized_data(data.size());
TfLiteTensor input = tflite::testing::CreateQuantizedTensor(
data.data(), input_dims.get(), "input",
/*min=*/-12.8f, /*max=*/12.7f, /*is_variable=*/false);
TfLiteTensor dequantized_input = tflite::testing::CreateFloatTensor(
dequantized_data.data(), input_dims.get(), "input_dequant",
/*is_variable=*/true);
std::vector<TfLiteTensor> tensors{input, dequantized_input};
tflite::testing::PopulateContext(tensors.data(), tensors.size(),
/*error_reporter=*/nullptr, &context);
std::vector<uint32_t> input_indices = {1};
std::unordered_map<int, int> quant_conversion_map = {{1, 0}};
auto status = DequantizeInputs(&context, input_indices, quant_conversion_map);
EXPECT_TRUE(status.ok());
EXPECT_THAT(dequantized_data,
Pointwise(FloatNear(1e-6), {-0.3, -0.2, -0.1, 0.1, 0.2, 0.3}));
}
TEST(DequantizeInputs, UInt8) {
TfLiteContext context;
auto input_dims = BuildTfLiteIntArray({1, 3, 2, 1});
std::vector<uint8_t> data = {0, 1, 2, 3, 4, 5};
std::vector<float> dequantized_data(data.size());
TfLiteTensor input = tflite::testing::CreateQuantizedTensor(
data.data(), input_dims.get(), "input",
/*min=*/0.0f, /*max=*/25.5f, /*is_variable=*/false);
TfLiteTensor dequantized_input = tflite::testing::CreateFloatTensor(
dequantized_data.data(), input_dims.get(), "input_dequant",
/*is_variable=*/true);
std::vector<TfLiteTensor> tensors{input, dequantized_input};
tflite::testing::PopulateContext(tensors.data(), tensors.size(),
/*error_reporter=*/nullptr, &context);
std::vector<int64_t> input_indices = {1};
std::unordered_map<int, int> quant_conversion_map = {{1, 0}};
auto status = DequantizeInputs(&context, input_indices, quant_conversion_map);
EXPECT_TRUE(status.ok());
EXPECT_THAT(dequantized_data,
Pointwise(FloatNear(1e-6), {0.0, 0.1, 0.2, 0.3, 0.4, 0.5}));
}
TEST(QuantizeOutputs, Int8) {
TfLiteContext context;
auto input_dims = BuildTfLiteIntArray({1, 3, 2, 1});
std::vector<float> data = {-0.3, -0.2, -0.1, 0.1, 0.2, 0.3};
std::vector<int8_t> quantized_data(data.size());
TfLiteTensor output = tflite::testing::CreateFloatTensor(
data.data(), input_dims.get(), "output", /*is_variable=*/false);
TfLiteTensor quantized_output = tflite::testing::CreateQuantizedTensor(
quantized_data.data(), input_dims.get(), "output_quant",
/*min=*/-12.8f, /*max=*/12.7f, /*is_variable=*/true);
std::vector<TfLiteTensor> tensors{output, quantized_output};
tflite::testing::PopulateContext(tensors.data(), tensors.size(),
/*error_reporter=*/nullptr, &context);
std::vector<uint32_t> output_indices = {0};
std::unordered_map<int, int> quant_conversion_map = {{0, 1}};
auto status = QuantizeOutputs(&context, output_indices, quant_conversion_map);
EXPECT_TRUE(status.ok());
EXPECT_THAT(quantized_data, Pointwise(Eq(), {-3, -2, -1, 1, 2, 3}));
}
TEST(QuantizeOutputs, UInt8) {
TfLiteContext context;
auto input_dims = BuildTfLiteIntArray({1, 3, 2, 1});
std::vector<float> data = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
std::vector<uint8_t> quantized_data(data.size());
TfLiteTensor output = tflite::testing::CreateFloatTensor(
data.data(), input_dims.get(), "output", /*is_variable=*/false);
TfLiteTensor quantized_output = tflite::testing::CreateQuantizedTensor(
quantized_data.data(), input_dims.get(), "output_quant",
/*min=*/0.0f, /*max=*/25.5f, /*is_variable=*/true);
std::vector<TfLiteTensor> tensors{output, quantized_output};
tflite::testing::PopulateContext(tensors.data(), tensors.size(),
/*error_reporter=*/nullptr, &context);
std::vector<int64_t> output_indices = {0};
std::unordered_map<int, int> quant_conversion_map = {{0, 1}};
auto status = QuantizeOutputs(&context, output_indices, quant_conversion_map);
EXPECT_TRUE(status.ok());
EXPECT_THAT(quantized_data, Pointwise(Eq(), {0, 1, 2, 3, 4, 5}));
}
} // namespace
} // namespace gpu
} // namespace tflite

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/model_builder.h" #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/gl/api2.h" #include "tensorflow/lite/delegates/gpu/gl/api2.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
@ -210,12 +211,14 @@ class DelegateKernel {
const bool is_dequant_required = !quant_conversion_map_.empty(); const bool is_dequant_required = !quant_conversion_map_.empty();
if (is_dequant_required) { if (is_dequant_required) {
RETURN_IF_ERROR(DequantizeInputs(context)); RETURN_IF_ERROR(
DequantizeInputs(context, input_indices_, quant_conversion_map_));
} }
RETURN_IF_ERROR(SetInputsAndOutputs(context)); RETURN_IF_ERROR(SetInputsAndOutputs(context));
RETURN_IF_ERROR(runner_->Run()); RETURN_IF_ERROR(runner_->Run());
if (is_dequant_required) { if (is_dequant_required) {
RETURN_IF_ERROR(QuantizeOutputs(context)); RETURN_IF_ERROR(
QuantizeOutputs(context, output_indices_, quant_conversion_map_));
} }
return absl::OkStatus(); return absl::OkStatus();
} }
@ -277,70 +280,6 @@ class DelegateKernel {
return absl::OkStatus(); return absl::OkStatus();
} }
// TODO(b/150798231): Refactor these two into common utils when generalizing
// to other backends.
// Dequantizes input tensors pre-inference, leaving float tensors intact.
absl::Status DequantizeInputs(TfLiteContext* context) {
for (auto index : input_indices_) {
if (quant_conversion_map_.find(index) == quant_conversion_map_.end()) {
continue;
}
int original_tensor_idx = quant_conversion_map_[index];
const TfLiteTensor& dequantized_tflite_tensor = context->tensors[index];
const TfLiteTensor& original_tflite_tensor =
context->tensors[original_tensor_idx];
DequantizationParams op_params;
op_params.zero_point = original_tflite_tensor.params.zero_point;
op_params.scale = original_tflite_tensor.params.scale;
if (original_tflite_tensor.type == kTfLiteInt8) {
optimized_ops::Dequantize(op_params,
GetTensorShape(&original_tflite_tensor),
original_tflite_tensor.data.int8,
GetTensorShape(&original_tflite_tensor),
dequantized_tflite_tensor.data.f);
} else if (original_tflite_tensor.type == kTfLiteUInt8) {
optimized_ops::Dequantize(op_params,
GetTensorShape(&original_tflite_tensor),
original_tflite_tensor.data.uint8,
GetTensorShape(&original_tflite_tensor),
dequantized_tflite_tensor.data.f);
}
}
return absl::OkStatus();
}
// Quantizes output tensors post-inference, leaving float tensors intact.
absl::Status QuantizeOutputs(TfLiteContext* context) {
for (auto index : output_indices_) {
if (quant_conversion_map_.find(index) == quant_conversion_map_.end()) {
continue;
}
int original_tensor_idx = quant_conversion_map_[index];
const TfLiteTensor& dequantized_tflite_tensor = context->tensors[index];
const TfLiteTensor& original_tflite_tensor =
context->tensors[original_tensor_idx];
tflite::QuantizationParams op_params;
op_params.zero_point = original_tflite_tensor.params.zero_point;
op_params.scale = original_tflite_tensor.params.scale;
if (original_tflite_tensor.type == kTfLiteInt8) {
optimized_ops::AffineQuantize(op_params,
GetTensorShape(&original_tflite_tensor),
dequantized_tflite_tensor.data.f,
GetTensorShape(&original_tflite_tensor),
original_tflite_tensor.data.int8);
} else if (original_tflite_tensor.type == kTfLiteUInt8) {
optimized_ops::AffineQuantize(op_params,
GetTensorShape(&original_tflite_tensor),
dequantized_tflite_tensor.data.f,
GetTensorShape(&original_tflite_tensor),
original_tflite_tensor.data.uint8);
}
}
return absl::OkStatus();
}
absl::Status InitializeOpenClApi(GraphFloat32* graph, absl::Status InitializeOpenClApi(GraphFloat32* graph,
std::unique_ptr<InferenceBuilder>* builder, std::unique_ptr<InferenceBuilder>* builder,
bool* graph_is_destroyed) { bool* graph_is_destroyed) {