Adds QuantizeAndDequantize kernel to OpenGL & OpenCL backends. This is not a TFLite op, but will be used to support inference on quantized models with future CLs.

PiperOrigin-RevId: 301229478
Change-Id: I7379a801ba355616a6730578a01c077253494670
This commit is contained in:
Sachin Joglekar 2020-03-16 13:43:27 -07:00 committed by TensorFlower Gardener
parent eb6b2831f8
commit e61ff10d8b
16 changed files with 803 additions and 0 deletions

View File

@ -991,6 +991,45 @@ cc_test(
],
)
cc_library(
name = "quantize_and_dequantize",
srcs = ["quantize_and_dequantize.cc"],
hdrs = ["quantize_and_dequantize.h"],
deps = [
":flt_type",
":gpu_operation",
":util",
"//tensorflow/lite/delegates/gpu/cl:cl_context",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
"//tensorflow/lite/delegates/gpu/cl:linear_storage",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:tensor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
],
)
cc_test(
name = "quantize_and_dequantize_test",
srcs = ["quantize_and_dequantize_test.cc"],
linkstatic = True,
tags = tf_gpu_tests_tags() + [
"linux",
"local",
],
deps = [
":cl_test",
":quantize_and_dequantize",
"//tensorflow/lite/delegates/gpu/cl:tensor",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/kernels/internal:quantization_util",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "relu",
srcs = ["relu.cc"],

View File

@ -0,0 +1,128 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h"
#include <string>
#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
namespace tflite {
namespace gpu {
namespace cl {
QuantizeAndDequantize::QuantizeAndDequantize(
const OperationDef& definition, const QuantizeAndDequantizeAttributes& attr,
CalculationsPrecision scalar_precision)
: ElementwiseOperation(definition) {
min_ = FLT(scalar_precision, attr.min);
max_ = FLT(scalar_precision, attr.max);
scale_ = FLT(scalar_precision, attr.scale);
}
QuantizeAndDequantize::QuantizeAndDequantize(QuantizeAndDequantize&& operation)
: ElementwiseOperation(std::move(operation)),
min_(std::move(operation.min_)),
max_(std::move(operation.max_)),
scale_(std::move(operation.scale_)) {}
QuantizeAndDequantize& QuantizeAndDequantize::operator=(
QuantizeAndDequantize&& operation) {
if (this != &operation) {
min_ = std::move(operation.min_);
max_ = std::move(operation.max_);
scale_ = std::move(operation.scale_);
ElementwiseOperation::operator=(std::move(operation));
}
return *this;
}
void QuantizeAndDequantize::SetLinkIndex(int index) {
min_.SetName(absl::StrCat("quantize_and_dequantize_min_", index));
max_.SetName(absl::StrCat("quantize_and_dequantize_max_", index));
scale_.SetName(absl::StrCat("quantize_and_dequantize_scale_", index));
}
std::string QuantizeAndDequantize::GetCoreCode(
const LinkingContext& context) const {
std::string scale_string, max_string, min_string;
if (!scale_.Active()) {
scale_string = "(FLT4)(1.0f)";
} else {
scale_string = absl::StrCat("(FLT4)(", scale_.GetName(), ")");
}
if (!max_.Active()) {
max_string = "(FLT4)(0.0f)";
} else {
max_string = absl::StrCat("(FLT4)(", max_.GetName(), ")");
}
if (!min_.Active()) {
min_string = "(FLT4)(0.0f)";
} else {
min_string = absl::StrCat("(FLT4)(", min_.GetName(), ")");
}
std::string clamped_value = absl::StrCat(
"min(", max_string, ", max(", min_string, ", ", context.var_name, "))");
std::string quantized_value = absl::StrCat(
"round((", clamped_value, " - ", min_string, ") / ", scale_string, ")");
std::string dequantized_value =
absl::StrCat(quantized_value, " * ", scale_string, " + ", min_string);
return absl::StrCat(context.var_name, " = ", dequantized_value, ";\n");
}
std::string QuantizeAndDequantize::GetArgsDeclaration() const {
return absl::StrCat(",\n ", min_.GetDeclaration(), ",\n ",
max_.GetDeclaration(), ",\n ",
scale_.GetDeclaration());
}
Status QuantizeAndDequantize::BindArguments(CLKernel* kernel) {
RETURN_IF_ERROR(kernel->SetBytesAuto(min_));
RETURN_IF_ERROR(kernel->SetBytesAuto(max_));
RETURN_IF_ERROR(kernel->SetBytesAuto(scale_));
return OkStatus();
}
Status CreateQuantizeAndDequantize(const CreationContext& creation_context,
const OperationDef& definition,
const QuantizeAndDequantizeAttributes& attr,
QuantizeAndDequantize* result) {
const auto scalar_precision = creation_context.device->IsPowerVR()
? CalculationsPrecision::F32
: definition.precision;
const bool is_fp16 = definition.precision == CalculationsPrecision::F16 ||
definition.precision == CalculationsPrecision::F32_F16;
if (is_fp16 && attr.scale < 0.000062f) {
// The smallest positive normal number for Half-precision floating-point
// format is 2^-14 ~ 0.000062f. Therefore, if the scale is lesser than this
// number, we just reset it accordingly.
QuantizeAndDequantizeAttributes adjusted_attr = attr;
adjusted_attr.scale = 0.000062f;
*result =
QuantizeAndDequantize(definition, adjusted_attr, scalar_precision);
} else {
*result = QuantizeAndDequantize(definition, attr, scalar_precision);
}
result->SetLinkIndex(0);
return OkStatus();
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,100 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_QUANTIZE_AND_DEQUANTIZE_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_QUANTIZE_AND_DEQUANTIZE_H_
#include <string>
#include "tensorflow/lite/delegates/gpu/cl/cl_context.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/flt_type.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
namespace tflite {
namespace gpu {
namespace cl {
// Performs the operation: {Quantize, Dequantize} on floating-point data.
// We need this operation to emulate the error introduced by quantization
// on the GPU, which cannot represent int8 tensors.
//
// Implemented as:
// qvalue = round((min(qmax, max(qmin, src_val)) - qmin) * (1/qscale) + 0.5)
// dq_value = qvalue * qscale + qmin
// Here, qmin, qmax & qscale refer to the quantization values as implemented in
// TensorFlow Lite's 'FakeQuant' kernel. round(x + 0.5) ensures we round away
// from zero.
//
// NOTE: We do not need to nudge min/max values in this op, since they would
// already be adjusted while generating the quantized model.
class QuantizeAndDequantize : public ElementwiseOperation {
public:
QuantizeAndDequantize() = default;
// Move only
QuantizeAndDequantize(QuantizeAndDequantize&& operation);
QuantizeAndDequantize& operator=(QuantizeAndDequantize&& operation);
QuantizeAndDequantize(const QuantizeAndDequantize&) = delete;
QuantizeAndDequantize& operator=(const QuantizeAndDequantize&) = delete;
void SetLinkIndex(int index) override;
std::string GetCoreCode(const LinkingContext& context) const override;
std::string GetArgsDeclaration() const override;
Status BindArguments(CLKernel* kernel) override;
friend Status CreateQuantizeAndDequantize(
const CreationContext& creation_context, const OperationDef& definition,
const QuantizeAndDequantizeAttributes& attr,
QuantizeAndDequantize* result);
private:
QuantizeAndDequantize(const OperationDef& definition,
const QuantizeAndDequantizeAttributes& attr,
CalculationsPrecision scalar_precision);
template <DataType T>
Status UploadParameters(const ::tflite::gpu::Tensor<Linear, T>& parameters,
CLContext* context);
FLT min_;
FLT max_;
FLT scale_;
};
Status CreateQuantizeAndDequantize(const CreationContext& creation_context,
const OperationDef& definition,
const QuantizeAndDequantizeAttributes& attr,
QuantizeAndDequantize* result);
template <DataType T>
Status QuantizeAndDequantize::UploadParameters(
const ::tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
LinearStorageCreateInfo create_info;
create_info.storage_type =
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
create_info.data_type = definition_.GetPrimaryDataType();
return OkStatus();
}
} // namespace cl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_QUANTIZE_AND_DEQUANTIZE_H_

View File

@ -0,0 +1,182 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h"
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
using ::testing::FloatNear;
using ::testing::Pointwise;
namespace tflite {
namespace gpu {
namespace cl {
namespace {
TEST_F(OpenCLOperationTest, QuantAndDequant_Dim2Bits8) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 3, 2, 1);
src_tensor.data = {0.0f, 1.0f, 0.25f, 0.50f, 0.4444444f, 0.00001f};
// Unlike TFLite's FakeQuant kernel, we assume that the incoming values are
// pre-nudged, since this should be done during model conversion.
const int num_bits = 8;
const int quant_min = 0;
const int quant_max = (1 << num_bits) - 1;
QuantizeAndDequantizeAttributes attr;
NudgeQuantizationRange(/**original_min**/ 0.0, /**original_max**/ 1.0,
quant_min, quant_max, &attr.min, &attr.max,
&attr.scale);
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
QuantizeAndDequantize operation;
ASSERT_OK(CreateQuantizeAndDequantize(creation_context_, op_def, attr,
&operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 3, 2, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, 1.0f, 0.25098f, 0.498039f,
0.443137f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, QuantAndDequant_Dim3Bits8_NegativeRange) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 3, 1, 2);
src_tensor.data = {0.0f, -0.9f, 0.25f, 0.50f, 0.4444444f, -0.00001f};
// Unlike TFLite's FakeQuant kernel, we assume that the incoming values are
// pre-nudged, since this should be done during model conversion.
const int num_bits = 8;
const int quant_min = 0;
const int quant_max = (1 << num_bits) - 1;
QuantizeAndDequantizeAttributes attr;
NudgeQuantizationRange(/**original_min**/ -0.9, /**original_max**/ 0.9,
quant_min, quant_max, &attr.min, &attr.max,
&attr.scale);
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
QuantizeAndDequantize operation;
ASSERT_OK(CreateQuantizeAndDequantize(creation_context_, op_def, attr,
&operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 3, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, -0.896471f, 0.247059f,
0.501176f, 0.444706f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, QuantAndDequant_Dim3Bits16) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 3, 1, 2);
src_tensor.data = {0.0f, 1.0f, 0.25f, 0.50f, 0.4444444f, 0.00001f};
// Unlike TFLite's FakeQuant kernel, we assume that the incoming values are
// pre-nudged, since this should be done during model conversion.
const int num_bits = 16;
const int quant_min = 0;
const int quant_max = (1 << num_bits) - 1;
QuantizeAndDequantizeAttributes attr;
NudgeQuantizationRange(/**original_min**/ 0.0, /**original_max**/ 1.0,
quant_min, quant_max, &attr.min, &attr.max,
&attr.scale);
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
QuantizeAndDequantize operation;
ASSERT_OK(CreateQuantizeAndDequantize(creation_context_, op_def, attr,
&operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 3, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, 1.0f, 0.250004f, 0.500008f,
0.44445f, 1.5259e-05f}));
}
}
}
TEST_F(OpenCLOperationTest, QuantAndDequant_Dim2Bits16_NegativeRange) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 3, 2, 1);
src_tensor.data = {0.0f, -0.9f, 0.25f, 0.50f, 0.4444444f, -0.00001f};
// Unlike TFLite's FakeQuant kernel, we assume that the incoming values are
// pre-nudged, since this should be done during model conversion.
const int num_bits = 16;
const int quant_min = 0;
const int quant_max = (1 << num_bits) - 1;
QuantizeAndDequantizeAttributes attr;
NudgeQuantizationRange(/**original_min**/ -0.9, /**original_max**/ 0.9,
quant_min, quant_max, &attr.min, &attr.max,
&attr.scale);
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
QuantizeAndDequantize operation;
ASSERT_OK(CreateQuantizeAndDequantize(creation_context_, op_def, attr,
&operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 3, 2, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, -0.900014f, 0.249998f,
0.499995f, 0.444431f, 0.0f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -132,6 +132,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/cl/kernels:padding",
"//tensorflow/lite/delegates/gpu/cl/kernels:pooling",
"//tensorflow/lite/delegates/gpu/cl/kernels:prelu",
"//tensorflow/lite/delegates/gpu/cl/kernels:quantize_and_dequantize",
"//tensorflow/lite/delegates/gpu/cl/kernels:relu",
"//tensorflow/lite/delegates/gpu/cl/kernels:reshape",
"//tensorflow/lite/delegates/gpu/cl/kernels:reshapex4",

View File

@ -279,6 +279,12 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
auto attr = absl::any_cast<PReLUAttributes>(node.operation.attributes);
return SelectPReLU(attr, creation_context, op_def, gpu_op);
}
case OperationType::QUANTIZE_AND_DEQUANTIZE: {
auto attr = absl::any_cast<QuantizeAndDequantizeAttributes>(
node.operation.attributes);
return SelectQuantizeAndDequantize(attr, creation_context, op_def,
gpu_op);
}
case OperationType::RELU: {
auto attr = absl::any_cast<ReLUAttributes>(node.operation.attributes);
SelectReLU(creation_context, attr, op_def, gpu_op);

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/kernels/padding.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/pooling.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/prelu.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/quantize_and_dequantize.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/relu.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/reshape.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/reshapex4.h"
@ -218,6 +219,17 @@ Status SelectWinograd36To4x4(
return OkStatus();
}
Status SelectQuantizeAndDequantize(const QuantizeAndDequantizeAttributes& attr,
const CreationContext& creation_context,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
QuantizeAndDequantize operation;
RETURN_IF_ERROR(
CreateQuantizeAndDequantize(creation_context, op_def, attr, &operation));
*ptr = absl::make_unique<QuantizeAndDequantize>(std::move(operation));
return OkStatus();
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -100,6 +100,11 @@ Status SelectWinograd36To4x4(
const ::tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
std::unique_ptr<GPUOperation>* ptr);
Status SelectQuantizeAndDequantize(const QuantizeAndDequantizeAttributes& attr,
const CreationContext& creation_context,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -118,6 +118,8 @@ std::string ToString(enum OperationType op) {
return "pow";
case OperationType::PRELU:
return "prelu";
case OperationType::QUANTIZE_AND_DEQUANTIZE:
return "quantize_and_dequantize";
case OperationType::RELU:
return "relu";
case OperationType::RESHAPE:
@ -183,6 +185,7 @@ OperationType OperationTypeFromString(const std::string& name) {
{"pooling_2d", OperationType::POOLING_2D},
{"pow", OperationType::POW},
{"prelu", OperationType::PRELU},
{"quantize_and_dequantize", OperationType::QUANTIZE_AND_DEQUANTIZE},
{"relu", OperationType::RELU},
{"resize", OperationType::RESIZE},
{"reshape", OperationType::RESHAPE},

View File

@ -57,6 +57,8 @@ enum class OperationType {
POOLING_2D,
POW,
PRELU,
// Used to accurately run inference on quantized models.
QUANTIZE_AND_DEQUANTIZE,
RELU,
RESHAPE,
RESIZE,
@ -478,6 +480,14 @@ struct SpaceToDepthAttributes {
int block_size;
};
// These help perform a combination of Quantize & Dequantize to adjust float
// values like quantized inference would.
struct QuantizeAndDequantizeAttributes {
float min = 0;
float max = 0;
float scale = 0;
};
} // namespace gpu
} // namespace tflite

View File

@ -451,6 +451,38 @@ cc_test(
],
)
cc_library(
name = "quantize_and_dequantize",
srcs = ["quantize_and_dequantize.cc"],
hdrs = ["quantize_and_dequantize.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:convert",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/gl:node_shader",
"@com_google_absl//absl/memory",
],
)
cc_test(
name = "quantize_and_dequantize_test",
srcs = ["quantize_and_dequantize_test.cc"],
tags = tf_gpu_tests_tags() + [
"notap",
"tflite_not_portable_ios",
],
deps = [
":quantize_and_dequantize",
":test_util",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/kernels/internal:quantization_util",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "relu",
srcs = ["relu.cc"],
@ -699,6 +731,7 @@ TFLITE_GPU_BINARY_RELEASE_OPERATORS = [
"pad",
"pooling",
"prelu",
"quantize_and_dequantize",
"relu",
"mean",
"reshape",

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.h"
#include <memory>
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {
namespace gpu {
namespace gl {
namespace {
class QuantizeAndDequantize : public NodeShader {
public:
Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
std::string code;
// Constants
code += "vec4 scale = vec4($quant_scale$);";
code += "vec4 min_bound = vec4($quant_min$);";
code += "vec4 max_bound = vec4($quant_max$);";
// Quantize
code += "value_0 = clamp(value_0, min_bound, max_bound);";
code += "value_0 = (value_0 - min_bound) / scale;";
code += "value_0 = floor(value_0 + vec4(0.5));";
// Dequantize
code += "value_0 = value_0 * scale + min_bound;";
auto attr = absl::any_cast<const QuantizeAndDequantizeAttributes&>(
ctx.node->operation.attributes);
*generated_code = {
/*parameters=*/{{"quant_min", attr.min},
{"quant_max", attr.max},
{"quant_scale", attr.scale}},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/code,
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
};
} // namespace
std::unique_ptr<NodeShader> NewQuantizeAndDequantizeNodeShader() {
return absl::make_unique<QuantizeAndDequantize>();
}
} // namespace gl
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_QUANTIZE_AND_DEQUANTIZE_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_QUANTIZE_AND_DEQUANTIZE_H_
#include <memory>
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
namespace tflite {
namespace gpu {
namespace gl {
// Performs the operation: {Quantize, Dequantize} on floating-point data.
// We need this operation to emulate the error introduced by quantization
// on the GPU, which cannot represent int8 tensors.
//
// Implemented as:
// qvalue = round((min(qmax, max(qmin, src_val)) - qmin) * (1/qscale) + 0.5)
// dq_value = qvalue * qscale + qmin
// Here, qmin, qmax & qscale refer to the quantization values as implemented in
// TensorFlow Lite's 'FakeQuant' kernel. round(x + 0.5) ensures we round away
// from zero.
//
// NOTE: We do not need to nudge min/max values in this op, since they would
// already be adjusted while generating the quantized model.
std::unique_ptr<NodeShader> NewQuantizeAndDequantizeNodeShader();
} // namespace gl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_QUANTIZE_AND_DEQUANTIZE_H_

View File

@ -0,0 +1,159 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/test_util.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
using ::testing::FloatNear;
using ::testing::Pointwise;
namespace tflite {
namespace gpu {
namespace gl {
namespace {
TEST(QuantizeAndDequantizeTest, Dim2Bits8) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 2, 1);
// Unlike TFLite's FakeQuant kernel, we assume that the incoming values are
// pre-nudged, since this should be done during model conversion.
const int num_bits = 8;
const int quant_min = 0;
const int quant_max = (1 << num_bits) - 1;
QuantizeAndDequantizeAttributes attr;
NudgeQuantizationRange(/**original_min**/ 0.0, /**original_max**/ 1.0,
quant_min, quant_max, &attr.min, &attr.max,
&attr.scale);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 3, 2, 1);
SingleOpModel model({ToString(OperationType::QUANTIZE_AND_DEQUANTIZE), attr},
{input}, {output});
ASSERT_TRUE(
model.PopulateTensor(0, {0.0, 1.0, 0.25, 0.50, 0.4444444, 0.00001}));
ASSERT_OK(model.Invoke(*NewQuantizeAndDequantizeNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6),
{0.0f, 1.0f, 0.25098f, 0.498039f, 0.443137f, 0.0f}));
}
TEST(QuantizeAndDequantizeTest, Dim3Bits8_NegativeRange) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 1, 2);
// Unlike TFLite's FakeQuant kernel, we assume that the incoming values are
// pre-nudged, since this should be done during model conversion.
const int num_bits = 8;
const int quant_min = 0;
const int quant_max = (1 << num_bits) - 1;
QuantizeAndDequantizeAttributes attr;
NudgeQuantizationRange(/**original_min**/ -0.9, /**original_max**/ 0.9,
quant_min, quant_max, &attr.min, &attr.max,
&attr.scale);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 3, 1, 2);
SingleOpModel model({ToString(OperationType::QUANTIZE_AND_DEQUANTIZE), attr},
{input}, {output});
ASSERT_TRUE(
model.PopulateTensor(0, {0.0, -0.9, 0.25, 0.50, 0.4444444, -0.00001}));
ASSERT_OK(model.Invoke(*NewQuantizeAndDequantizeNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0f, -0.896471f, 0.247059f,
0.501176f, 0.444706f, 0.0f}));
}
TEST(QuantizeAndDequantizeTest, Dim3Bits16) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 1, 2);
// Unlike TFLite's FakeQuant kernel, we assume that the incoming values are
// pre-nudged, since this should be done during model conversion.
const int num_bits = 16;
const int quant_min = 0;
const int quant_max = (1 << num_bits) - 1;
QuantizeAndDequantizeAttributes attr;
NudgeQuantizationRange(/**original_min**/ 0.0, /**original_max**/ 1.0,
quant_min, quant_max, &attr.min, &attr.max,
&attr.scale);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 3, 1, 2);
SingleOpModel model({ToString(OperationType::QUANTIZE_AND_DEQUANTIZE), attr},
{input}, {output});
ASSERT_TRUE(
model.PopulateTensor(0, {0.0, 1.0, 0.25, 0.50, 0.4444444, 0.00001}));
ASSERT_OK(model.Invoke(*NewQuantizeAndDequantizeNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0f, 1.0f, 0.250004f, 0.500008f,
0.44445f, 1.5259e-05f}));
}
TEST(QuantizeAndDequantizeTest, Dim2Bits16_NegativeRange) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 3, 2, 1);
// Unlike TFLite's FakeQuant kernel, we assume that the incoming values are
// pre-nudged, since this should be done during model conversion.
const int num_bits = 16;
const int quant_min = 0;
const int quant_max = (1 << num_bits) - 1;
QuantizeAndDequantizeAttributes attr;
NudgeQuantizationRange(/**original_min**/ -0.9, /**original_max**/ 0.9,
quant_min, quant_max, &attr.min, &attr.max,
&attr.scale);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 3, 2, 1);
SingleOpModel model({ToString(OperationType::QUANTIZE_AND_DEQUANTIZE), attr},
{input}, {output});
ASSERT_TRUE(
model.PopulateTensor(0, {0.0, -0.9, 0.25, 0.50, 0.4444444, -0.00001}));
ASSERT_OK(model.Invoke(*NewQuantizeAndDequantizeNodeShader()));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0f, -0.900014f, 0.249998f,
0.499995f, 0.444431f, 0.0f}));
}
} // namespace
} // namespace gl
} // namespace gpu
} // namespace tflite

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/gl/kernels/pad.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/prelu.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/relu.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/reshape.h"
#include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
@ -85,6 +86,8 @@ class Registry : public NodeShader {
insert_op(Type::PAD, NewPadNodeShader);
insert_op(Type::POOLING_2D, NewPoolingNodeShader);
insert_op(Type::PRELU, NewPReLUNodeShader);
insert_op(Type::QUANTIZE_AND_DEQUANTIZE,
NewQuantizeAndDequantizeNodeShader);
insert_op(Type::RELU, NewReLUNodeShader);
insert_op(Type::RESIZE, NewResizeNodeShader);
insert_op(Type::RESHAPE, NewReshapeNodeShader);

View File

@ -305,6 +305,7 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::BATCH_TO_SPACE:
case OperationType::CONST:
case OperationType::LSTM:
case OperationType::QUANTIZE_AND_DEQUANTIZE:
case OperationType::SPACE_TO_BATCH:
case OperationType::TRANSPOSE:
case OperationType::UNKNOWN: