Adds GraphTransformation to add QuantizeAndDequantize nodes in GPU graph
PiperOrigin-RevId: 302038856 Change-Id: I009684ea5b611a3bfc05c88b4fd8a40c570cfd86
This commit is contained in:
parent
bb97495f77
commit
d46aa971be
@ -92,6 +92,7 @@ cc_library(
|
|||||||
":tensor",
|
":tensor",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/types:any",
|
"@com_google_absl//absl/types:any",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/types/any.h"
|
#include "absl/types/any.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
@ -39,6 +40,13 @@ using ValueId = uint32_t;
|
|||||||
|
|
||||||
using NodeId = uint32_t;
|
using NodeId = uint32_t;
|
||||||
|
|
||||||
|
// Used to emulate quantized behavior.
|
||||||
|
struct QuantizationParams {
|
||||||
|
float min = 0;
|
||||||
|
float max = 0;
|
||||||
|
float scale = 0;
|
||||||
|
};
|
||||||
|
|
||||||
// Connects tensor's producer and operation that depends on this tensor.
|
// Connects tensor's producer and operation that depends on this tensor.
|
||||||
template <typename TensorT>
|
template <typename TensorT>
|
||||||
struct Value {
|
struct Value {
|
||||||
@ -47,6 +55,8 @@ struct Value {
|
|||||||
const ValueId id;
|
const ValueId id;
|
||||||
|
|
||||||
TensorType tensor;
|
TensorType tensor;
|
||||||
|
|
||||||
|
absl::optional<QuantizationParams> quant_params;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Operation {
|
struct Operation {
|
||||||
|
@ -19,6 +19,37 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "add_quant_adjustments",
|
||||||
|
srcs = ["add_quant_adjustments.cc"],
|
||||||
|
hdrs = ["add_quant_adjustments.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "add_quant_adjustments_test",
|
||||||
|
srcs = ["add_quant_adjustments_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":add_quant_adjustments",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "fuse_add_to_conv",
|
name = "fuse_add_to_conv",
|
||||||
srcs = ["fuse_add_to_conv.cc"],
|
srcs = ["fuse_add_to_conv.cc"],
|
||||||
|
@ -0,0 +1,110 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
class AddQuantAdjustments : public NodeTransformation {
|
||||||
|
public:
|
||||||
|
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
|
||||||
|
if (node->operation.type ==
|
||||||
|
ToString(OperationType::QUANTIZE_AND_DEQUANTIZE)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool transform_applied = false;
|
||||||
|
auto node_outputs = graph->FindOutputs(node->id);
|
||||||
|
for (auto output_value : node_outputs) {
|
||||||
|
// Skip if quantization doesn't apply.
|
||||||
|
if (!output_value->quant_params) continue;
|
||||||
|
auto consumers = graph->FindConsumers(output_value->id);
|
||||||
|
// No need to do anything if this isn't consumed by another node.
|
||||||
|
if (consumers.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a new QuantizeAndDequantize node.
|
||||||
|
auto* quant_and_dequant_node = graph->NewNode();
|
||||||
|
quant_and_dequant_node->operation.type =
|
||||||
|
ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
|
||||||
|
QuantizeAndDequantizeAttributes attr;
|
||||||
|
attr.min = output_value->quant_params.value().min;
|
||||||
|
attr.max = output_value->quant_params.value().max;
|
||||||
|
attr.scale = output_value->quant_params.value().scale;
|
||||||
|
quant_and_dequant_node->operation.attributes = attr;
|
||||||
|
|
||||||
|
// Add one output Value for the new node.
|
||||||
|
// The tensor information should rename the same.
|
||||||
|
Value<TensorRef<BHWC>>* adjusted_value = graph->NewValue();
|
||||||
|
adjusted_value->tensor = output_value->tensor;
|
||||||
|
Status status =
|
||||||
|
graph->SetProducer(quant_and_dequant_node->id, adjusted_value->id);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
"Could not create QuantizeAndDequantize node."};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace output_value with adjusted_value on all consumers.
|
||||||
|
for (auto& consumer : consumers) {
|
||||||
|
status = graph->ReplaceInput(consumer->id, output_value->id,
|
||||||
|
adjusted_value->id);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
absl::StrCat(
|
||||||
|
"Failed to associate quant-adjusted value for consumer: ",
|
||||||
|
status.message())};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add QuantizeAndDequantize node as a consumer of output_value.
|
||||||
|
status = graph->AddConsumer(quant_and_dequant_node->id, output_value->id);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return {TransformStatus::INVALID,
|
||||||
|
absl::StrCat(
|
||||||
|
"Could not associate output to QuantizeAndDequantize: ",
|
||||||
|
status.message())};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove quant params on output_value, to make the transformation
|
||||||
|
// idempotent.
|
||||||
|
output_value->quant_params.reset();
|
||||||
|
transform_applied = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (transform_applied) {
|
||||||
|
return {TransformStatus::APPLIED, ""};
|
||||||
|
}
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<NodeTransformation> NewAddQuantAdjustments() {
|
||||||
|
return absl::make_unique<AddQuantAdjustments>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_QUANT_ADJUSTMENTS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_QUANT_ADJUSTMENTS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// This pass is used to support inference on quantized models with the GPU
|
||||||
|
// delegate.
|
||||||
|
//
|
||||||
|
// When delegating quantized models, we still run float-point inference on GPU
|
||||||
|
// under-the-hood. This is done by dequantizing inputs (at runtime) & constants
|
||||||
|
// (during delegation).
|
||||||
|
// However, intermediate tensors can still deviate from the original quantized
|
||||||
|
// inference, since activations may not follow the attributes set by the
|
||||||
|
// original quantizion parameters.
|
||||||
|
// To prevent this, we add "QuantizeAndDequantize" nodes for each node-output
|
||||||
|
// that was originally fixed-point:
|
||||||
|
// op1 -> op2
|
||||||
|
// becomes
|
||||||
|
// op1 -> QuantizeAndDequantize -> op2
|
||||||
|
std::unique_ptr<NodeTransformation> NewAddQuantAdjustments();
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_ADD_QUANT_ADJUSTMENTS_H_
|
@ -0,0 +1,166 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void AddQuantParams(absl::optional<QuantizationParams>* params, float min,
|
||||||
|
float max, float scale) {
|
||||||
|
params->emplace();
|
||||||
|
params->value().min = min;
|
||||||
|
params->value().max = max;
|
||||||
|
params->value().scale = scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scenario:
|
||||||
|
// -> Add ->
|
||||||
|
//
|
||||||
|
// Since there is only one node output with no consumers, no new node should be
|
||||||
|
// added.
|
||||||
|
TEST(AddQuantAdjustments, OneNode) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/1.0,
|
||||||
|
/*scale=*/0.004);
|
||||||
|
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(8);
|
||||||
|
add_tensor.data.resize(8);
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
auto add_node = graph.NewNode();
|
||||||
|
add_node->operation.type = ToString(OperationType::ADD);
|
||||||
|
add_node->operation.attributes = add_attr;
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
|
||||||
|
|
||||||
|
Value<TensorRef<BHWC>>* output;
|
||||||
|
AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/2.0,
|
||||||
|
/*scale=*/0.008);
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
|
||||||
|
output->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
|
||||||
|
ASSERT_EQ(1, graph.nodes().size());
|
||||||
|
ASSERT_EQ(2, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewAddQuantAdjustments();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("add_quant_adjustments", transformation.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, graph.nodes().size());
|
||||||
|
EXPECT_EQ(2, graph.values().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scenario:
|
||||||
|
// -> Add -> QuantizeAndDequantize -> Add ->
|
||||||
|
// | ^
|
||||||
|
// | |
|
||||||
|
// ------------------------------
|
||||||
|
//
|
||||||
|
// A new QuantizeAndDequantize should only be added after the left/first 'Add'
|
||||||
|
// op, and it should connect to both its consumers.
|
||||||
|
TEST(AddQuantAdjustments, GeneralCase) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/1.0,
|
||||||
|
/*scale=*/0.004);
|
||||||
|
|
||||||
|
// First Add.
|
||||||
|
Tensor<Linear, DataType::FLOAT32> add_tensor;
|
||||||
|
add_tensor.shape = Linear(8);
|
||||||
|
add_tensor.data.resize(8);
|
||||||
|
AddAttributes add_attr;
|
||||||
|
add_attr.param = add_tensor;
|
||||||
|
auto add1_node = graph.NewNode();
|
||||||
|
add1_node->operation.type = ToString(OperationType::ADD);
|
||||||
|
add1_node->operation.attributes = add_attr;
|
||||||
|
// QuantizeAndDequantize.
|
||||||
|
QuantizeAndDequantizeAttributes quant_attr;
|
||||||
|
quant_attr.min = -1.0;
|
||||||
|
quant_attr.max = 1.0;
|
||||||
|
quant_attr.scale = 0.008;
|
||||||
|
auto quant_node = graph.NewNode();
|
||||||
|
quant_node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
|
||||||
|
quant_node->operation.attributes = quant_attr;
|
||||||
|
// Second Add.
|
||||||
|
auto add2_node = graph.NewNode();
|
||||||
|
add2_node->operation.type = ToString(OperationType::ADD);
|
||||||
|
|
||||||
|
// Connections.
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(add1_node->id, input->id).ok());
|
||||||
|
Value<TensorRef<BHWC>>* link1;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, add1_node, quant_node, &link1).ok());
|
||||||
|
AddQuantParams(&link1->quant_params, /*min=*/0.0, /*max=*/2.0,
|
||||||
|
/*scale=*/0.008);
|
||||||
|
link1->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(add2_node->id, link1->id).ok());
|
||||||
|
Value<TensorRef<BHWC>>* link2;
|
||||||
|
ASSERT_TRUE(ConnectTwoNodes(&graph, quant_node, add2_node, &link2).ok());
|
||||||
|
AddQuantParams(&link2->quant_params, /*min=*/-1.0, /*max=*/1.0,
|
||||||
|
/*scale=*/0.008);
|
||||||
|
link2->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
Value<TensorRef<BHWC>>* output;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, add2_node, &output).ok());
|
||||||
|
AddQuantParams(&output->quant_params, /*min=*/-1.0, /*max=*/1.0,
|
||||||
|
/*scale=*/0.008);
|
||||||
|
output->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
|
||||||
|
ASSERT_EQ(3, graph.nodes().size());
|
||||||
|
ASSERT_EQ(4, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewAddQuantAdjustments();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("add_quant_adjustments", transformation.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(4, graph.nodes().size());
|
||||||
|
EXPECT_EQ(5, graph.values().size());
|
||||||
|
EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[0]->operation.type);
|
||||||
|
EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE),
|
||||||
|
graph.nodes()[1]->operation.type);
|
||||||
|
EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[2]->operation.type);
|
||||||
|
EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE),
|
||||||
|
graph.nodes()[3]->operation.type);
|
||||||
|
auto new_quant_attr = absl::any_cast<QuantizeAndDequantizeAttributes>(
|
||||||
|
graph.nodes()[3]->operation.attributes);
|
||||||
|
EXPECT_EQ(0.0, new_quant_attr.min);
|
||||||
|
EXPECT_EQ(2.0, new_quant_attr.max);
|
||||||
|
const auto& new_quant_consumers = graph.FindConsumers(graph.values()[4]->id);
|
||||||
|
EXPECT_EQ(2, new_quant_consumers.size());
|
||||||
|
EXPECT_EQ(quant_node, new_quant_consumers[0]);
|
||||||
|
EXPECT_EQ(add2_node, new_quant_consumers[1]);
|
||||||
|
|
||||||
|
// Transformation should be idempotent.
|
||||||
|
transformer.Apply("add_quant_adjustments", transformation.get());
|
||||||
|
EXPECT_EQ(4, graph.nodes().size());
|
||||||
|
EXPECT_EQ(5, graph.values().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user