Added transformation of global average pooling to mean.
PiperOrigin-RevId: 341696703 Change-Id: Iee80d46d4781952850510fd09e3775eb4024f226
This commit is contained in:
parent
8175ff32ab
commit
9f04e7773f
@ -408,6 +408,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/delegates/gpu/common/task:tensor_linear_desc",
|
"//tensorflow/lite/delegates/gpu/common/task:tensor_linear_desc",
|
||||||
"//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
|
"//tensorflow/lite/delegates/gpu/common/task:texture2d_desc",
|
||||||
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
|
"//tensorflow/lite/delegates/gpu/common/transformations:add_bias",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common/transformations:global_pooling_to_reduce_op",
|
||||||
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
|
"//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
@ -41,6 +41,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
|
#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
|
#include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
#include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
@ -727,6 +728,7 @@ absl::Status InferenceContext::GetOutputTensor(ValueId id,
|
|||||||
absl::Status RunGraphTransforms(GraphFloat32* graph) {
|
absl::Status RunGraphTransforms(GraphFloat32* graph) {
|
||||||
auto merge_padding_transform = NewMergePaddingWithAdd();
|
auto merge_padding_transform = NewMergePaddingWithAdd();
|
||||||
auto add_bias_transform = NewAddBias();
|
auto add_bias_transform = NewAddBias();
|
||||||
|
auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
|
||||||
ModelTransformer transformer(graph, /*reporter=*/nullptr);
|
ModelTransformer transformer(graph, /*reporter=*/nullptr);
|
||||||
if (!transformer.Apply("add_bias", add_bias_transform.get())) {
|
if (!transformer.Apply("add_bias", add_bias_transform.get())) {
|
||||||
return absl::InternalError("Invalid add_bias transform");
|
return absl::InternalError("Invalid add_bias transform");
|
||||||
@ -734,6 +736,10 @@ absl::Status RunGraphTransforms(GraphFloat32* graph) {
|
|||||||
if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
|
if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
|
||||||
return absl::InternalError("Invalid merge_padding transform");
|
return absl::InternalError("Invalid merge_padding transform");
|
||||||
}
|
}
|
||||||
|
if (!transformer.Apply("global pooling to mean",
|
||||||
|
pooling_to_reduce_op.get())) {
|
||||||
|
return absl::InternalError("Invalid global pooling to mean transform");
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,19 +120,34 @@ cc_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "model_transformations",
|
name = "global_pooling_to_reduce_op",
|
||||||
srcs = ["model_transformations.cc"],
|
srcs = ["global_pooling_to_reduce_op.cc"],
|
||||||
hdrs = ["model_transformations.h"],
|
hdrs = ["global_pooling_to_reduce_op.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":add_quant_adjustments",
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
":fuse_add_to_conv",
|
|
||||||
":fuse_mul_to_conv",
|
|
||||||
":make_fully_connected",
|
|
||||||
":make_padding",
|
|
||||||
":merge_padding_with",
|
|
||||||
":remove_noop",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"),
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "global_pooling_to_reduce_op_test",
|
||||||
|
srcs = ["global_pooling_to_reduce_op_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":global_pooling_to_reduce_op",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -240,6 +255,22 @@ cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "model_transformations",
|
||||||
|
srcs = ["model_transformations.cc"],
|
||||||
|
hdrs = ["model_transformations.h"],
|
||||||
|
deps = [
|
||||||
|
":add_quant_adjustments",
|
||||||
|
":fuse_add_to_conv",
|
||||||
|
":fuse_mul_to_conv",
|
||||||
|
":make_fully_connected",
|
||||||
|
":make_padding",
|
||||||
|
":merge_padding_with",
|
||||||
|
":remove_noop",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
|
] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"),
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "remove_noop",
|
name = "remove_noop",
|
||||||
srcs = ["remove_noop.cc"],
|
srcs = ["remove_noop.cc"],
|
||||||
|
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool IsGlobalPooling(const Pooling2DAttributes& attr, const BHWC& src_shape) {
|
||||||
|
return attr.strides.w == src_shape.w && attr.strides.h == src_shape.h &&
|
||||||
|
attr.kernel.w == src_shape.w && attr.kernel.h == src_shape.h &&
|
||||||
|
attr.padding.appended.w == 0 && attr.padding.appended.h == 0 &&
|
||||||
|
attr.padding.prepended.w == 0 && attr.padding.prepended.h == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsGlobalAveragePooling(const Pooling2DAttributes& attr,
|
||||||
|
const BHWC& src_shape) {
|
||||||
|
return attr.type == tflite::gpu::PoolingType::AVERAGE &&
|
||||||
|
attr.output_indices == false && IsGlobalPooling(attr, src_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
class GlobalPoolingToReduceOp : public NodeTransformation {
|
||||||
|
public:
|
||||||
|
TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
|
||||||
|
if (node->operation.type != ToString(OperationType::POOLING_2D)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputs = graph->FindInputs(node->id);
|
||||||
|
const auto& pool_attr =
|
||||||
|
absl::any_cast<const Pooling2DAttributes&>(node->operation.attributes);
|
||||||
|
if (!IsGlobalAveragePooling(pool_attr, inputs[0]->tensor.shape)) {
|
||||||
|
return {TransformStatus::SKIPPED, ""};
|
||||||
|
}
|
||||||
|
|
||||||
|
MeanAttributes mean_attr;
|
||||||
|
mean_attr.dims = {Axis::WIDTH, Axis::HEIGHT};
|
||||||
|
|
||||||
|
node->operation.attributes = mean_attr;
|
||||||
|
node->operation.type = ToString(OperationType::MEAN);
|
||||||
|
return {TransformStatus::APPLIED,
|
||||||
|
"Replaced global average pooling with mean."};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<NodeTransformation> NewGlobalPoolingToReduceOp() {
|
||||||
|
return absl::make_unique<GlobalPoolingToReduceOp>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Turns global pooling to reduce operation
|
||||||
|
// currently can convert average pooling into mean.
|
||||||
|
std::unique_ptr<NodeTransformation> NewGlobalPoolingToReduceOp();
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_
|
@ -0,0 +1,72 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(MakeMeanFromGlobalAveragePooling, Smoke) {
|
||||||
|
GraphFloat32 graph;
|
||||||
|
auto input = graph.NewValue();
|
||||||
|
input->tensor.shape = BHWC(1, 4, 4, 8);
|
||||||
|
|
||||||
|
Pooling2DAttributes attr;
|
||||||
|
attr.padding.prepended = tflite::gpu::HW(0, 0);
|
||||||
|
attr.padding.appended = tflite::gpu::HW(0, 0);
|
||||||
|
attr.strides = tflite::gpu::HW(4, 4);
|
||||||
|
attr.kernel = tflite::gpu::HW(4, 4);
|
||||||
|
attr.type = tflite::gpu::PoolingType::AVERAGE;
|
||||||
|
attr.output_indices = false;
|
||||||
|
|
||||||
|
auto pool_node = graph.NewNode();
|
||||||
|
pool_node->operation.type = ToString(OperationType::POOLING_2D);
|
||||||
|
pool_node->operation.attributes = attr;
|
||||||
|
|
||||||
|
ASSERT_TRUE(graph.AddConsumer(pool_node->id, input->id).ok());
|
||||||
|
|
||||||
|
Value* output = nullptr;
|
||||||
|
ASSERT_TRUE(AddOutput(&graph, pool_node, &output).ok());
|
||||||
|
output->tensor.shape = BHWC(1, 1, 1, 8);
|
||||||
|
|
||||||
|
ASSERT_EQ(1, graph.nodes().size());
|
||||||
|
ASSERT_EQ(2, graph.values().size());
|
||||||
|
|
||||||
|
auto transformation = NewGlobalPoolingToReduceOp();
|
||||||
|
ModelTransformer transformer(&graph, nullptr);
|
||||||
|
transformer.Apply("global_average_pooling_to_mean", transformation.get());
|
||||||
|
|
||||||
|
ASSERT_EQ(1, graph.nodes().size());
|
||||||
|
ASSERT_EQ(2, graph.values().size());
|
||||||
|
ASSERT_EQ(ToString(OperationType::MEAN), graph.nodes()[0]->operation.type);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
Loading…
x
Reference in New Issue
Block a user