Added transformation of global average pooling to mean.
PiperOrigin-RevId: 341696703 Change-Id: Iee80d46d4781952850510fd09e3775eb4024f226
This commit is contained in:
parent
8175ff32ab
commit
9f04e7773f
@ -408,6 +408,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common/task:tensor_linear_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:global_pooling_to_reduce_op",
|
||||
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
@ -727,6 +728,7 @@ absl::Status InferenceContext::GetOutputTensor(ValueId id,
|
||||
absl::Status RunGraphTransforms(GraphFloat32* graph) {
|
||||
auto merge_padding_transform = NewMergePaddingWithAdd();
|
||||
auto add_bias_transform = NewAddBias();
|
||||
auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
|
||||
ModelTransformer transformer(graph, /*reporter=*/nullptr);
|
||||
if (!transformer.Apply("add_bias", add_bias_transform.get())) {
|
||||
return absl::InternalError("Invalid add_bias transform");
|
||||
@ -734,6 +736,10 @@ absl::Status RunGraphTransforms(GraphFloat32* graph) {
|
||||
if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
|
||||
return absl::InternalError("Invalid merge_padding transform");
|
||||
}
|
||||
if (!transformer.Apply("global pooling to mean",
|
||||
pooling_to_reduce_op.get())) {
|
||||
return absl::InternalError("Invalid global pooling to mean transform");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -120,19 +120,34 @@ cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_transformations",
|
||||
srcs = ["model_transformations.cc"],
|
||||
hdrs = ["model_transformations.h"],
|
||||
name = "global_pooling_to_reduce_op",
|
||||
srcs = ["global_pooling_to_reduce_op.cc"],
|
||||
hdrs = ["global_pooling_to_reduce_op.h"],
|
||||
deps = [
|
||||
":add_quant_adjustments",
|
||||
":fuse_add_to_conv",
|
||||
":fuse_mul_to_conv",
|
||||
":make_fully_connected",
|
||||
":make_padding",
|
||||
":merge_padding_with",
|
||||
":remove_noop",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"),
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:any",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "global_pooling_to_reduce_op_test",
|
||||
srcs = ["global_pooling_to_reduce_op_test.cc"],
|
||||
deps = [
|
||||
":global_pooling_to_reduce_op",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/types:any",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -240,6 +255,22 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_transformations",
|
||||
srcs = ["model_transformations.cc"],
|
||||
hdrs = ["model_transformations.h"],
|
||||
deps = [
|
||||
":add_quant_adjustments",
|
||||
":fuse_add_to_conv",
|
||||
":fuse_mul_to_conv",
|
||||
":make_fully_connected",
|
||||
":make_padding",
|
||||
":merge_padding_with",
|
||||
":remove_noop",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "remove_noop",
|
||||
srcs = ["remove_noop.cc"],
|
||||
|
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/any.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
bool IsGlobalPooling(const Pooling2DAttributes& attr, const BHWC& src_shape) {
|
||||
return attr.strides.w == src_shape.w && attr.strides.h == src_shape.h &&
|
||||
attr.kernel.w == src_shape.w && attr.kernel.h == src_shape.h &&
|
||||
attr.padding.appended.w == 0 && attr.padding.appended.h == 0 &&
|
||||
attr.padding.prepended.w == 0 && attr.padding.prepended.h == 0;
|
||||
}
|
||||
|
||||
bool IsGlobalAveragePooling(const Pooling2DAttributes& attr,
|
||||
const BHWC& src_shape) {
|
||||
return attr.type == tflite::gpu::PoolingType::AVERAGE &&
|
||||
attr.output_indices == false && IsGlobalPooling(attr, src_shape);
|
||||
}
|
||||
|
||||
class GlobalPoolingToReduceOp : public NodeTransformation {
|
||||
public:
|
||||
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
|
||||
if (node->operation.type != ToString(OperationType::POOLING_2D)) {
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
|
||||
auto inputs = graph->FindInputs(node->id);
|
||||
const auto& pool_attr =
|
||||
absl::any_cast<const Pooling2DAttributes&>(node->operation.attributes);
|
||||
if (!IsGlobalAveragePooling(pool_attr, inputs[0]->tensor.shape)) {
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
|
||||
MeanAttributes mean_attr;
|
||||
mean_attr.dims = {Axis::WIDTH, Axis::HEIGHT};
|
||||
|
||||
node->operation.attributes = mean_attr;
|
||||
node->operation.type = ToString(OperationType::MEAN);
|
||||
return {TransformStatus::APPLIED,
|
||||
"Replaced global average pooling with mean."};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<NodeTransformation> NewGlobalPoolingToReduceOp() {
|
||||
return absl::make_unique<GlobalPoolingToReduceOp>();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
// Turns global pooling to reduce operation
|
||||
// currently can convert average pooling into mean.
|
||||
std::unique_ptr<NodeTransformation> NewGlobalPoolingToReduceOp();
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_
|
@ -0,0 +1,72 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/types/any.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
TEST(MakeMeanFromGlobalAveragePooling, Smoke) {
|
||||
GraphFloat32 graph;
|
||||
auto input = graph.NewValue();
|
||||
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||
|
||||
Pooling2DAttributes attr;
|
||||
attr.padding.prepended = tflite::gpu::HW(0, 0);
|
||||
attr.padding.appended = tflite::gpu::HW(0, 0);
|
||||
attr.strides = tflite::gpu::HW(4, 4);
|
||||
attr.kernel = tflite::gpu::HW(4, 4);
|
||||
attr.type = tflite::gpu::PoolingType::AVERAGE;
|
||||
attr.output_indices = false;
|
||||
|
||||
auto pool_node = graph.NewNode();
|
||||
pool_node->operation.type = ToString(OperationType::POOLING_2D);
|
||||
pool_node->operation.attributes = attr;
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(pool_node->id, input->id).ok());
|
||||
|
||||
Value* output = nullptr;
|
||||
ASSERT_TRUE(AddOutput(&graph, pool_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 1, 1, 8);
|
||||
|
||||
ASSERT_EQ(1, graph.nodes().size());
|
||||
ASSERT_EQ(2, graph.values().size());
|
||||
|
||||
auto transformation = NewGlobalPoolingToReduceOp();
|
||||
ModelTransformer transformer(&graph, nullptr);
|
||||
transformer.Apply("global_average_pooling_to_mean", transformation.get());
|
||||
|
||||
ASSERT_EQ(1, graph.nodes().size());
|
||||
ASSERT_EQ(2, graph.values().size());
|
||||
ASSERT_EQ(ToString(OperationType::MEAN), graph.nodes()[0]->operation.type);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
Loading…
x
Reference in New Issue
Block a user