Move DequantizeInputs and QuantizeOutputs into common utility
PiperOrigin-RevId: 314093553 Change-Id: Ib15015f738c3e0bbaa1363ba7df9808940075d2c
This commit is contained in:
parent
3ffb4ad2d4
commit
626bb2c4d0
@ -244,6 +244,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_builder",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:quantization_util",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/gl:api2",
|
||||
"//tensorflow/lite/kernels/internal:optimized_base",
|
||||
|
@ -203,6 +203,30 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantization_util",
|
||||
srcs = ["quantization_util.cc"],
|
||||
hdrs = ["quantization_util.h"],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels/internal:optimized_base",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "quantization_util_test",
|
||||
srcs = ["quantization_util_test.cc"],
|
||||
deps = [
|
||||
":quantization_util",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(impjdi): Add unit test for operations.
|
||||
|
||||
cc_library(
|
||||
|
120
tensorflow/lite/delegates/gpu/common/quantization_util.cc
Normal file
120
tensorflow/lite/delegates/gpu/common/quantization_util.cc
Normal file
@ -0,0 +1,120 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
|
||||
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
void DequantizeInput(TfLiteContext* context, int input_index,
|
||||
const std::unordered_map<int, int>& quant_conversion_map) {
|
||||
if (quant_conversion_map.find(input_index) == quant_conversion_map.end()) {
|
||||
return;
|
||||
}
|
||||
int original_tensor_idx = quant_conversion_map.at(input_index);
|
||||
const TfLiteTensor& dequantized_tflite_tensor = context->tensors[input_index];
|
||||
const TfLiteTensor& original_tflite_tensor =
|
||||
context->tensors[original_tensor_idx];
|
||||
DequantizationParams op_params;
|
||||
op_params.zero_point = original_tflite_tensor.params.zero_point;
|
||||
op_params.scale = original_tflite_tensor.params.scale;
|
||||
if (original_tflite_tensor.type == kTfLiteInt8) {
|
||||
optimized_ops::Dequantize(op_params,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
original_tflite_tensor.data.int8,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
dequantized_tflite_tensor.data.f);
|
||||
} else if (original_tflite_tensor.type == kTfLiteUInt8) {
|
||||
optimized_ops::Dequantize(op_params,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
original_tflite_tensor.data.uint8,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
dequantized_tflite_tensor.data.f);
|
||||
}
|
||||
}
|
||||
|
||||
void QuantizeOutput(TfLiteContext* context, int output_index,
|
||||
const std::unordered_map<int, int>& quant_conversion_map) {
|
||||
if (quant_conversion_map.find(output_index) == quant_conversion_map.end()) {
|
||||
return;
|
||||
}
|
||||
int original_tensor_idx = quant_conversion_map.at(output_index);
|
||||
const TfLiteTensor& dequantized_tflite_tensor =
|
||||
context->tensors[output_index];
|
||||
const TfLiteTensor& original_tflite_tensor =
|
||||
context->tensors[original_tensor_idx];
|
||||
tflite::QuantizationParams op_params;
|
||||
op_params.zero_point = original_tflite_tensor.params.zero_point;
|
||||
op_params.scale = original_tflite_tensor.params.scale;
|
||||
if (original_tflite_tensor.type == kTfLiteInt8) {
|
||||
optimized_ops::AffineQuantize(op_params,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
dequantized_tflite_tensor.data.f,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
original_tflite_tensor.data.int8);
|
||||
} else if (original_tflite_tensor.type == kTfLiteUInt8) {
|
||||
optimized_ops::AffineQuantize(op_params,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
dequantized_tflite_tensor.data.f,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
original_tflite_tensor.data.uint8);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status DequantizeInputs(
|
||||
TfLiteContext* context, const std::vector<uint32_t>& input_indices,
|
||||
const std::unordered_map<int, int>& quant_conversion_map) {
|
||||
for (auto index : input_indices) {
|
||||
DequantizeInput(context, static_cast<int>(index), quant_conversion_map);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status DequantizeInputs(
|
||||
TfLiteContext* context, const std::vector<int64_t>& input_indices,
|
||||
const std::unordered_map<int, int>& quant_conversion_map) {
|
||||
for (auto index : input_indices) {
|
||||
DequantizeInput(context, static_cast<int>(index), quant_conversion_map);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status QuantizeOutputs(
|
||||
TfLiteContext* context, const std::vector<uint32_t>& output_indices,
|
||||
const std::unordered_map<int, int>& quant_conversion_map) {
|
||||
for (auto index : output_indices) {
|
||||
QuantizeOutput(context, static_cast<int>(index), quant_conversion_map);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status QuantizeOutputs(
|
||||
TfLiteContext* context, const std::vector<int64_t>& output_indices,
|
||||
const std::unordered_map<int, int>& quant_conversion_map) {
|
||||
for (auto index : output_indices) {
|
||||
QuantizeOutput(context, static_cast<int>(index), quant_conversion_map);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
56
tensorflow/lite/delegates/gpu/common/quantization_util.h
Normal file
56
tensorflow/lite/delegates/gpu/common/quantization_util.h
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
// Dequantizes input tensors pre-inference, leaving float tensors intact.
|
||||
// input_indices contains dequantized (fp32) outputs, that are used as
|
||||
// inputs to GPU delegate.
|
||||
// quant_conversion_map contains bidirectional mapping between dequantized
|
||||
// tensor and its original quantized one.
|
||||
absl::Status DequantizeInputs(
|
||||
TfLiteContext* context, const std::vector<uint32_t>& input_indices,
|
||||
const std::unordered_map<int, int>& quant_conversion_map);
|
||||
|
||||
absl::Status DequantizeInputs(
|
||||
TfLiteContext* context, const std::vector<int64_t>& input_indices,
|
||||
const std::unordered_map<int, int>& quant_conversion_map);
|
||||
|
||||
// Quantizes output tensors post-inference, leaving float tensors intact.
|
||||
// output_indices contains (fp32) inputs to be quantized, which are outputs of
|
||||
// GPU delegate.
|
||||
// quant_conversion_map contains bidirectional mapping between dequantized
|
||||
// tensor and its original quantized one.
|
||||
absl::Status QuantizeOutputs(
|
||||
TfLiteContext* context, const std::vector<uint32_t>& output_indices,
|
||||
const std::unordered_map<int, int>& quant_conversion_map);
|
||||
|
||||
absl::Status QuantizeOutputs(
|
||||
TfLiteContext* context, const std::vector<int64_t>& output_indices,
|
||||
const std::unordered_map<int, int>& quant_conversion_map);
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_QUANTIZATION_UTIL_H_
|
139
tensorflow/lite/delegates/gpu/common/quantization_util_test.cc
Normal file
139
tensorflow/lite/delegates/gpu/common/quantization_util_test.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
using ::testing::Eq;
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Pointwise;
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
|
||||
const std::vector<int>& data) {
|
||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
|
||||
TfLiteIntArrayCreate(data.size()));
|
||||
std::copy(data.begin(), data.end(), result->data);
|
||||
return result;
|
||||
}
|
||||
|
||||
TEST(DequantizeInputs, Int8) {
|
||||
TfLiteContext context;
|
||||
auto input_dims = BuildTfLiteIntArray({1, 3, 2, 1});
|
||||
std::vector<int8_t> data = {-3, -2, -1, 1, 2, 3};
|
||||
std::vector<float> dequantized_data(data.size());
|
||||
|
||||
TfLiteTensor input = tflite::testing::CreateQuantizedTensor(
|
||||
data.data(), input_dims.get(), "input",
|
||||
/*min=*/-12.8f, /*max=*/12.7f, /*is_variable=*/false);
|
||||
TfLiteTensor dequantized_input = tflite::testing::CreateFloatTensor(
|
||||
dequantized_data.data(), input_dims.get(), "input_dequant",
|
||||
/*is_variable=*/true);
|
||||
|
||||
std::vector<TfLiteTensor> tensors{input, dequantized_input};
|
||||
tflite::testing::PopulateContext(tensors.data(), tensors.size(),
|
||||
/*error_reporter=*/nullptr, &context);
|
||||
|
||||
std::vector<uint32_t> input_indices = {1};
|
||||
std::unordered_map<int, int> quant_conversion_map = {{1, 0}};
|
||||
|
||||
auto status = DequantizeInputs(&context, input_indices, quant_conversion_map);
|
||||
EXPECT_TRUE(status.ok());
|
||||
EXPECT_THAT(dequantized_data,
|
||||
Pointwise(FloatNear(1e-6), {-0.3, -0.2, -0.1, 0.1, 0.2, 0.3}));
|
||||
}
|
||||
|
||||
TEST(DequantizeInputs, UInt8) {
|
||||
TfLiteContext context;
|
||||
auto input_dims = BuildTfLiteIntArray({1, 3, 2, 1});
|
||||
std::vector<uint8_t> data = {0, 1, 2, 3, 4, 5};
|
||||
std::vector<float> dequantized_data(data.size());
|
||||
|
||||
TfLiteTensor input = tflite::testing::CreateQuantizedTensor(
|
||||
data.data(), input_dims.get(), "input",
|
||||
/*min=*/0.0f, /*max=*/25.5f, /*is_variable=*/false);
|
||||
TfLiteTensor dequantized_input = tflite::testing::CreateFloatTensor(
|
||||
dequantized_data.data(), input_dims.get(), "input_dequant",
|
||||
/*is_variable=*/true);
|
||||
|
||||
std::vector<TfLiteTensor> tensors{input, dequantized_input};
|
||||
tflite::testing::PopulateContext(tensors.data(), tensors.size(),
|
||||
/*error_reporter=*/nullptr, &context);
|
||||
|
||||
std::vector<int64_t> input_indices = {1};
|
||||
std::unordered_map<int, int> quant_conversion_map = {{1, 0}};
|
||||
|
||||
auto status = DequantizeInputs(&context, input_indices, quant_conversion_map);
|
||||
EXPECT_TRUE(status.ok());
|
||||
EXPECT_THAT(dequantized_data,
|
||||
Pointwise(FloatNear(1e-6), {0.0, 0.1, 0.2, 0.3, 0.4, 0.5}));
|
||||
}
|
||||
|
||||
TEST(QuantizeOutputs, Int8) {
|
||||
TfLiteContext context;
|
||||
auto input_dims = BuildTfLiteIntArray({1, 3, 2, 1});
|
||||
std::vector<float> data = {-0.3, -0.2, -0.1, 0.1, 0.2, 0.3};
|
||||
std::vector<int8_t> quantized_data(data.size());
|
||||
TfLiteTensor output = tflite::testing::CreateFloatTensor(
|
||||
data.data(), input_dims.get(), "output", /*is_variable=*/false);
|
||||
TfLiteTensor quantized_output = tflite::testing::CreateQuantizedTensor(
|
||||
quantized_data.data(), input_dims.get(), "output_quant",
|
||||
/*min=*/-12.8f, /*max=*/12.7f, /*is_variable=*/true);
|
||||
|
||||
std::vector<TfLiteTensor> tensors{output, quantized_output};
|
||||
tflite::testing::PopulateContext(tensors.data(), tensors.size(),
|
||||
/*error_reporter=*/nullptr, &context);
|
||||
|
||||
std::vector<uint32_t> output_indices = {0};
|
||||
std::unordered_map<int, int> quant_conversion_map = {{0, 1}};
|
||||
|
||||
auto status = QuantizeOutputs(&context, output_indices, quant_conversion_map);
|
||||
EXPECT_TRUE(status.ok());
|
||||
EXPECT_THAT(quantized_data, Pointwise(Eq(), {-3, -2, -1, 1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST(QuantizeOutputs, UInt8) {
|
||||
TfLiteContext context;
|
||||
auto input_dims = BuildTfLiteIntArray({1, 3, 2, 1});
|
||||
std::vector<float> data = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
|
||||
std::vector<uint8_t> quantized_data(data.size());
|
||||
TfLiteTensor output = tflite::testing::CreateFloatTensor(
|
||||
data.data(), input_dims.get(), "output", /*is_variable=*/false);
|
||||
TfLiteTensor quantized_output = tflite::testing::CreateQuantizedTensor(
|
||||
quantized_data.data(), input_dims.get(), "output_quant",
|
||||
/*min=*/0.0f, /*max=*/25.5f, /*is_variable=*/true);
|
||||
|
||||
std::vector<TfLiteTensor> tensors{output, quantized_output};
|
||||
tflite::testing::PopulateContext(tensors.data(), tensors.size(),
|
||||
/*error_reporter=*/nullptr, &context);
|
||||
|
||||
std::vector<int64_t> output_indices = {0};
|
||||
std::unordered_map<int, int> quant_conversion_map = {{0, 1}};
|
||||
|
||||
auto status = QuantizeOutputs(&context, output_indices, quant_conversion_map);
|
||||
EXPECT_TRUE(status.ok());
|
||||
EXPECT_THAT(quantized_data, Pointwise(Eq(), {0, 1, 2, 3, 4, 5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/api2.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
@ -210,12 +211,14 @@ class DelegateKernel {
|
||||
|
||||
const bool is_dequant_required = !quant_conversion_map_.empty();
|
||||
if (is_dequant_required) {
|
||||
RETURN_IF_ERROR(DequantizeInputs(context));
|
||||
RETURN_IF_ERROR(
|
||||
DequantizeInputs(context, input_indices_, quant_conversion_map_));
|
||||
}
|
||||
RETURN_IF_ERROR(SetInputsAndOutputs(context));
|
||||
RETURN_IF_ERROR(runner_->Run());
|
||||
if (is_dequant_required) {
|
||||
RETURN_IF_ERROR(QuantizeOutputs(context));
|
||||
RETURN_IF_ERROR(
|
||||
QuantizeOutputs(context, output_indices_, quant_conversion_map_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
@ -277,70 +280,6 @@ class DelegateKernel {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// TODO(b/150798231): Refactor these two into common utils when generalizing
|
||||
// to other backends.
|
||||
|
||||
// Dequantizes input tensors pre-inference, leaving float tensors intact.
|
||||
absl::Status DequantizeInputs(TfLiteContext* context) {
|
||||
for (auto index : input_indices_) {
|
||||
if (quant_conversion_map_.find(index) == quant_conversion_map_.end()) {
|
||||
continue;
|
||||
}
|
||||
int original_tensor_idx = quant_conversion_map_[index];
|
||||
const TfLiteTensor& dequantized_tflite_tensor = context->tensors[index];
|
||||
const TfLiteTensor& original_tflite_tensor =
|
||||
context->tensors[original_tensor_idx];
|
||||
DequantizationParams op_params;
|
||||
op_params.zero_point = original_tflite_tensor.params.zero_point;
|
||||
op_params.scale = original_tflite_tensor.params.scale;
|
||||
if (original_tflite_tensor.type == kTfLiteInt8) {
|
||||
optimized_ops::Dequantize(op_params,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
original_tflite_tensor.data.int8,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
dequantized_tflite_tensor.data.f);
|
||||
} else if (original_tflite_tensor.type == kTfLiteUInt8) {
|
||||
optimized_ops::Dequantize(op_params,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
original_tflite_tensor.data.uint8,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
dequantized_tflite_tensor.data.f);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Quantizes output tensors post-inference, leaving float tensors intact.
|
||||
absl::Status QuantizeOutputs(TfLiteContext* context) {
|
||||
for (auto index : output_indices_) {
|
||||
if (quant_conversion_map_.find(index) == quant_conversion_map_.end()) {
|
||||
continue;
|
||||
}
|
||||
int original_tensor_idx = quant_conversion_map_[index];
|
||||
const TfLiteTensor& dequantized_tflite_tensor = context->tensors[index];
|
||||
const TfLiteTensor& original_tflite_tensor =
|
||||
context->tensors[original_tensor_idx];
|
||||
tflite::QuantizationParams op_params;
|
||||
op_params.zero_point = original_tflite_tensor.params.zero_point;
|
||||
op_params.scale = original_tflite_tensor.params.scale;
|
||||
if (original_tflite_tensor.type == kTfLiteInt8) {
|
||||
optimized_ops::AffineQuantize(op_params,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
dequantized_tflite_tensor.data.f,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
original_tflite_tensor.data.int8);
|
||||
} else if (original_tflite_tensor.type == kTfLiteUInt8) {
|
||||
optimized_ops::AffineQuantize(op_params,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
dequantized_tflite_tensor.data.f,
|
||||
GetTensorShape(&original_tflite_tensor),
|
||||
original_tflite_tensor.data.uint8);
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InitializeOpenClApi(GraphFloat32* graph,
|
||||
std::unique_ptr<InferenceBuilder>* builder,
|
||||
bool* graph_is_destroyed) {
|
||||
|
Loading…
Reference in New Issue
Block a user