diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 96a28e3d484..a8fcbf1570c 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -408,6 +408,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common/task:tensor_linear_desc", "//tensorflow/lite/delegates/gpu/common/task:texture2d_desc", "//tensorflow/lite/delegates/gpu/common/transformations:add_bias", + "//tensorflow/lite/delegates/gpu/common/transformations:global_pooling_to_reduce_op", "//tensorflow/lite/delegates/gpu/common/transformations:merge_padding_with", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index 0b16ff247c8..cb26dc24426 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h" +#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h" #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/common/util.h" @@ -727,6 +728,7 @@ absl::Status InferenceContext::GetOutputTensor(ValueId id, absl::Status RunGraphTransforms(GraphFloat32* graph) { auto merge_padding_transform = NewMergePaddingWithAdd(); auto add_bias_transform = NewAddBias(); + auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp(); ModelTransformer transformer(graph, /*reporter=*/nullptr); if (!transformer.Apply("add_bias", add_bias_transform.get())) { return absl::InternalError("Invalid add_bias transform"); @@ -734,6 +736,10 @@ absl::Status RunGraphTransforms(GraphFloat32* graph) { if (!transformer.Apply("merge_padding", merge_padding_transform.get())) { return absl::InternalError("Invalid merge_padding transform"); } + if (!transformer.Apply("global pooling to mean", + pooling_to_reduce_op.get())) { + return absl::InternalError("Invalid global pooling to mean transform"); + } return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/common/transformations/BUILD b/tensorflow/lite/delegates/gpu/common/transformations/BUILD index 6cb358bcc93..b9e332c5476 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/BUILD +++ b/tensorflow/lite/delegates/gpu/common/transformations/BUILD @@ -120,19 +120,34 @@ cc_test( ) cc_library( - name = "model_transformations", - srcs = ["model_transformations.cc"], - hdrs = ["model_transformations.h"], + name = "global_pooling_to_reduce_op", + srcs = ["global_pooling_to_reduce_op.cc"], + hdrs = ["global_pooling_to_reduce_op.h"], deps = [ - ":add_quant_adjustments", - ":fuse_add_to_conv", - ":fuse_mul_to_conv", - ":make_fully_connected", - ":make_padding", - ":merge_padding_with", - ":remove_noop", + "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_transformer", - ] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"), + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:any", + ], +) + +cc_test( + name = "global_pooling_to_reduce_op_test", + srcs = ["global_pooling_to_reduce_op_test.cc"], + deps = [ + ":global_pooling_to_reduce_op", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:any", + "@com_google_googletest//:gtest_main", + ], ) cc_library( @@ -240,6 +255,22 @@ cc_test( ], ) +cc_library( + name = "model_transformations", + srcs = ["model_transformations.cc"], + hdrs = ["model_transformations.h"], + deps = [ + ":add_quant_adjustments", + ":fuse_add_to_conv", + ":fuse_mul_to_conv", + ":make_fully_connected", + ":make_padding", + ":merge_padding_with", + ":remove_noop", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + ] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"), +) + cc_library( name = "remove_noop", srcs = ["remove_noop.cc"], diff --git a/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc new file mode 100644 index 00000000000..377fe752001 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" + +namespace tflite { +namespace gpu { +namespace { + +bool IsGlobalPooling(const Pooling2DAttributes& attr, const BHWC& src_shape) { + return attr.strides.w == src_shape.w && attr.strides.h == src_shape.h && + attr.kernel.w == src_shape.w && attr.kernel.h == src_shape.h && + attr.padding.appended.w == 0 && attr.padding.appended.h == 0 && + attr.padding.prepended.w == 0 && attr.padding.prepended.h == 0; +} + +bool IsGlobalAveragePooling(const Pooling2DAttributes& attr, + const BHWC& src_shape) { + return attr.type == tflite::gpu::PoolingType::AVERAGE && + attr.output_indices == false && IsGlobalPooling(attr, src_shape); +} + +class GlobalPoolingToReduceOp : public NodeTransformation { + public: + TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final { + if (node->operation.type != ToString(OperationType::POOLING_2D)) { + return {TransformStatus::SKIPPED, ""}; + } + + auto inputs = graph->FindInputs(node->id); + const auto& pool_attr = + absl::any_cast(node->operation.attributes); + if (!IsGlobalAveragePooling(pool_attr, inputs[0]->tensor.shape)) { + return {TransformStatus::SKIPPED, ""}; + } + + MeanAttributes mean_attr; + mean_attr.dims = {Axis::WIDTH, Axis::HEIGHT}; + + node->operation.attributes = mean_attr; + node->operation.type = ToString(OperationType::MEAN); + return {TransformStatus::APPLIED, + "Replaced global average pooling with mean."}; + } +}; + +} // namespace + +std::unique_ptr NewGlobalPoolingToReduceOp() { + return absl::make_unique(); +} + +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h new file mode 100644 index 00000000000..d2eba5d9fe9 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_ + +#include + +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" + +namespace tflite { +namespace gpu { + +// Turns global pooling to reduce operation +// currently can convert average pooling into mean. +std::unique_ptr NewGlobalPoolingToReduceOp(); + +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_GLOBAL_POOLING_TO_REDUCE_OP_H_ diff --git a/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op_test.cc new file mode 100644 index 00000000000..4751c84ed98 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h" + +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" + +namespace tflite { +namespace gpu { +namespace { + +TEST(MakeMeanFromGlobalAveragePooling, Smoke) { + GraphFloat32 graph; + auto input = graph.NewValue(); + input->tensor.shape = BHWC(1, 4, 4, 8); + + Pooling2DAttributes attr; + attr.padding.prepended = tflite::gpu::HW(0, 0); + attr.padding.appended = tflite::gpu::HW(0, 0); + attr.strides = tflite::gpu::HW(4, 4); + attr.kernel = tflite::gpu::HW(4, 4); + attr.type = tflite::gpu::PoolingType::AVERAGE; + attr.output_indices = false; + + auto pool_node = graph.NewNode(); + pool_node->operation.type = ToString(OperationType::POOLING_2D); + pool_node->operation.attributes = attr; + + ASSERT_TRUE(graph.AddConsumer(pool_node->id, input->id).ok()); + + Value* output = nullptr; + ASSERT_TRUE(AddOutput(&graph, pool_node, &output).ok()); + output->tensor.shape = BHWC(1, 1, 1, 8); + + ASSERT_EQ(1, graph.nodes().size()); + ASSERT_EQ(2, graph.values().size()); + + auto transformation = NewGlobalPoolingToReduceOp(); + ModelTransformer transformer(&graph, nullptr); + transformer.Apply("global_average_pooling_to_mean", transformation.get()); + + ASSERT_EQ(1, graph.nodes().size()); + ASSERT_EQ(2, graph.values().size()); + ASSERT_EQ(ToString(OperationType::MEAN), graph.nodes()[0]->operation.type); +} + +} // namespace +} // namespace gpu +} // namespace tflite