Added transformation of global average pooling to mean.

PiperOrigin-RevId: 341696703
Change-Id: Iee80d46d4781952850510fd09e3775eb4024f226
This commit is contained in:
Raman Sarokin 2020-11-10 14:20:56 -08:00 committed by TensorFlower Gardener
parent 8175ff32ab
commit 9f04e7773f
6 changed files with 232 additions and 11 deletions

View File

@ -408,6 +408,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common/task:tensor_linear_desc",
"//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
"//tensorflow/lite/delegates/gpu/common/transformations:global_pooling_to_reduce_op",
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
@ -727,6 +728,7 @@ absl::Status InferenceContext::GetOutputTensor(ValueId id,
absl::Status RunGraphTransforms(GraphFloat32* graph) {
auto merge_padding_transform = NewMergePaddingWithAdd();
auto add_bias_transform = NewAddBias();
auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
ModelTransformer transformer(graph, /*reporter=*/nullptr);
if (!transformer.Apply("add_bias", add_bias_transform.get())) {
return absl::InternalError("Invalid add_bias transform");
@ -734,6 +736,10 @@ absl::Status RunGraphTransforms(GraphFloat32* graph) {
if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
return absl::InternalError("Invalid merge_padding transform");
}
if (!transformer.Apply("global pooling to mean",
pooling_to_reduce_op.get())) {
return absl::InternalError("Invalid global pooling to mean transform");
}
return absl::OkStatus();
}

View File

@ -120,19 +120,34 @@ cc_test(
)
cc_library(
name = "model_transformations",
srcs = ["model_transformations.cc"],
hdrs = ["model_transformations.h"],
name = "global_pooling_to_reduce_op",
srcs = ["global_pooling_to_reduce_op.cc"],
hdrs = ["global_pooling_to_reduce_op.h"],
deps = [
":add_quant_adjustments",
":fuse_add_to_conv",
":fuse_mul_to_conv",
":make_fully_connected",
":make_padding",
":merge_padding_with",
":remove_noop",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:model_transformer",
] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"),
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:tensor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:any",
],
)
cc_test(
name = "global_pooling_to_reduce_op_test",
srcs = ["global_pooling_to_reduce_op_test.cc"],
deps = [
":global_pooling_to_reduce_op",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:model_transformer",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:tensor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:any",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
@ -240,6 +255,22 @@ cc_test(
],
)
cc_library(
name = "model_transformations",
srcs = ["model_transformations.cc"],
hdrs = ["model_transformations.h"],
deps = [
":add_quant_adjustments",
":fuse_add_to_conv",
":fuse_mul_to_conv",
":make_fully_connected",
":make_padding",
":merge_padding_with",
":remove_noop",
"//tensorflow/lite/delegates/gpu/common:model_transformer",
] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"),
)
cc_library(
name = "remove_noop",
srcs = ["remove_noop.cc"],

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
#include <memory>
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/types/any.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
namespace tflite {
namespace gpu {
namespace {
bool IsGlobalPooling(const Pooling2DAttributes& attr, const BHWC& src_shape) {
return attr.strides.w == src_shape.w && attr.strides.h == src_shape.h &&
attr.kernel.w == src_shape.w && attr.kernel.h == src_shape.h &&
attr.padding.appended.w == 0 && attr.padding.appended.h == 0 &&
attr.padding.prepended.w == 0 && attr.padding.prepended.h == 0;
}
bool IsGlobalAveragePooling(const Pooling2DAttributes& attr,
const BHWC& src_shape) {
return attr.type == tflite::gpu::PoolingType::AVERAGE &&
attr.output_indices == false && IsGlobalPooling(attr, src_shape);
}
class GlobalPoolingToReduceOp : public NodeTransformation {
public:
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
if (node->operation.type != ToString(OperationType::POOLING_2D)) {
return {TransformStatus::SKIPPED, ""};
}
auto inputs = graph->FindInputs(node->id);
const auto& pool_attr =
absl::any_cast<const Pooling2DAttributes&>(node->operation.attributes);
if (!IsGlobalAveragePooling(pool_attr, inputs[0]->tensor.shape)) {
return {TransformStatus::SKIPPED, ""};
}
MeanAttributes mean_attr;
mean_attr.dims = {Axis::WIDTH, Axis::HEIGHT};
node->operation.attributes = mean_attr;
node->operation.type = ToString(OperationType::MEAN);
return {TransformStatus::APPLIED,
"Replaced global average pooling with mean."};
}
};
} // namespace
std::unique_ptr<NodeTransformation> NewGlobalPoolingToReduceOp() {
return absl::make_unique<GlobalPoolingToReduceOp>();
}
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_
#include <memory>
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
namespace tflite {
namespace gpu {
// Turns global pooling to reduce operation
// currently can convert average pooling into mean.
std::unique_ptr<NodeTransformation> NewGlobalPoolingToReduceOp();
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_

View File

@ -0,0 +1,72 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
#include <memory>
#include <string>
#include <vector>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/types/any.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
namespace tflite {
namespace gpu {
namespace {
TEST(MakeMeanFromGlobalAveragePooling, Smoke) {
GraphFloat32 graph;
auto input = graph.NewValue();
input->tensor.shape = BHWC(1, 4, 4, 8);
Pooling2DAttributes attr;
attr.padding.prepended = tflite::gpu::HW(0, 0);
attr.padding.appended = tflite::gpu::HW(0, 0);
attr.strides = tflite::gpu::HW(4, 4);
attr.kernel = tflite::gpu::HW(4, 4);
attr.type = tflite::gpu::PoolingType::AVERAGE;
attr.output_indices = false;
auto pool_node = graph.NewNode();
pool_node->operation.type = ToString(OperationType::POOLING_2D);
pool_node->operation.attributes = attr;
ASSERT_TRUE(graph.AddConsumer(pool_node->id, input->id).ok());
Value* output = nullptr;
ASSERT_TRUE(AddOutput(&graph, pool_node, &output).ok());
output->tensor.shape = BHWC(1, 1, 1, 8);
ASSERT_EQ(1, graph.nodes().size());
ASSERT_EQ(2, graph.values().size());
auto transformation = NewGlobalPoolingToReduceOp();
ModelTransformer transformer(&graph, nullptr);
transformer.Apply("global_average_pooling_to_mean", transformation.get());
ASSERT_EQ(1, graph.nodes().size());
ASSERT_EQ(2, graph.values().size());
ASSERT_EQ(ToString(OperationType::MEAN), graph.nodes()[0]->operation.type);
}
} // namespace
} // namespace gpu
} // namespace tflite