Convert SB->Conv->BS subgraph into the dilated convolution.
PiperOrigin-RevId: 305154803 Change-Id: If7e127b4f55e6b0fa23dde7fdd9560e2b65fb586
This commit is contained in:
parent
ee27301273
commit
cfc7eab980
@ -109,7 +109,6 @@ cc_library(
|
|||||||
":fuse_mul_to_conv",
|
":fuse_mul_to_conv",
|
||||||
":make_fully_connected",
|
":make_fully_connected",
|
||||||
":make_padding",
|
":make_padding",
|
||||||
":match_dilated_convolution",
|
|
||||||
":merge_padding_with",
|
":merge_padding_with",
|
||||||
":remove_noop",
|
":remove_noop",
|
||||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||||
@ -177,34 +176,6 @@ cc_library(
|
|||||||
deps = ["//tensorflow/lite/delegates/gpu/common:model"],
|
deps = ["//tensorflow/lite/delegates/gpu/common:model"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "match_dilated_convolution",
|
|
||||||
srcs = ["match_dilated_convolution.cc"],
|
|
||||||
hdrs = ["match_dilated_convolution.h"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:model",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/types:any",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "match_dilated_convolution_test",
|
|
||||||
srcs = ["match_dilated_convolution_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":match_dilated_convolution",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:model",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
|
||||||
"@com_google_absl//absl/types:any",
|
|
||||||
"@com_google_googletest//:gtest_main",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "merge_padding_with",
|
name = "merge_padding_with",
|
||||||
srcs = ["merge_padding_with.cc"],
|
srcs = ["merge_padding_with.cc"],
|
||||||
|
@ -1,97 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.h"
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "absl/types/any.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace gpu {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
class MatchDilatedConvolution : public SequenceTransformation {
|
|
||||||
public:
|
|
||||||
int ExpectedSequenceLength() const final { return 3; }
|
|
||||||
|
|
||||||
// TODO(eignasheva): use span instead of const reference b/131628066.
|
|
||||||
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
|
||||||
GraphFloat32* graph) final {
|
|
||||||
auto& sb_node = *sequence[0];
|
|
||||||
auto& conv_node = *sequence[1];
|
|
||||||
auto& bs_node = *sequence[2];
|
|
||||||
if (sb_node.operation.type != ToString(OperationType::SPACE_TO_BATCH) &&
|
|
||||||
bs_node.operation.type != ToString(OperationType::BATCH_TO_SPACE)) {
|
|
||||||
return {TransformStatus::SKIPPED, ""};
|
|
||||||
}
|
|
||||||
if (conv_node.operation.type !=
|
|
||||||
ToString(OperationType::DEPTHWISE_CONVOLUTION) &&
|
|
||||||
conv_node.operation.type != ToString(OperationType::CONVOLUTION_2D)) {
|
|
||||||
return {TransformStatus::SKIPPED, ""};
|
|
||||||
}
|
|
||||||
|
|
||||||
auto sb_attr =
|
|
||||||
absl::any_cast<SpaceToBatchAttributes>(sb_node.operation.attributes);
|
|
||||||
|
|
||||||
auto bs_attr =
|
|
||||||
absl::any_cast<BatchToSpaceAttributes>(bs_node.operation.attributes);
|
|
||||||
|
|
||||||
if (sb_attr.block != bs_attr.block) {
|
|
||||||
return {TransformStatus::INVALID, "Invalid block size"};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (conv_node.operation.type ==
|
|
||||||
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
|
||||||
auto dw_attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
|
|
||||||
conv_node.operation.attributes);
|
|
||||||
dw_attr.padding = sb_attr.padding - bs_attr.crop;
|
|
||||||
dw_attr.dilations = sb_attr.block;
|
|
||||||
conv_node.operation.attributes = std::move(dw_attr);
|
|
||||||
} else {
|
|
||||||
auto conv2d_attr = absl::any_cast<Convolution2DAttributes>(
|
|
||||||
conv_node.operation.attributes);
|
|
||||||
conv2d_attr.padding = sb_attr.padding - bs_attr.crop;
|
|
||||||
conv2d_attr.dilations = sb_attr.block;
|
|
||||||
conv_node.operation.attributes = std::move(conv2d_attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status status = RemoveFollowingNode(graph, &bs_node, &conv_node);
|
|
||||||
if (!status.ok()) {
|
|
||||||
return {TransformStatus::INVALID,
|
|
||||||
"Unable to remove batch_to_space node after convolution."};
|
|
||||||
}
|
|
||||||
status = RemovePrecedingNode(graph, &sb_node, &conv_node);
|
|
||||||
if (!status.ok()) {
|
|
||||||
return {TransformStatus::INVALID,
|
|
||||||
"Unable to remove space_to_batch node before convolution."};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {TransformStatus::APPLIED, ""};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<SequenceTransformation> NewMatchDilatedConvolution() {
|
|
||||||
return absl::make_unique<MatchDilatedConvolution>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gpu
|
|
||||||
} // namespace tflite
|
|
@ -1,35 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCH_DILATED_CONVOLUTION_H_
|
|
||||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCH_DILATED_CONVOLUTION_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace gpu {
|
|
||||||
|
|
||||||
// TF->TFLite converter converts convolution with dilation into the chain of
|
|
||||||
// SpaceToBatch->Convolution->BatchToSpace. Our GPU backend natively supports
|
|
||||||
// dilation in convolutions, so we try to skip this inefficiency. For more
|
|
||||||
// information see b/131436214.
|
|
||||||
std::unique_ptr<SequenceTransformation> NewMatchDilatedConvolution();
|
|
||||||
|
|
||||||
} // namespace gpu
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCH_DILATED_CONVOLUTION_H_
|
|
@ -1,98 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.h"
|
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include "absl/types/any.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace gpu {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
TEST(MatchDilatedConvolutionTest, MakesDilatedConvolution) {
|
|
||||||
GraphFloat32 graph;
|
|
||||||
auto input = graph.NewValue();
|
|
||||||
input->tensor.shape = BHWC(1, 95, 1, 17);
|
|
||||||
|
|
||||||
SpaceToBatchAttributes sb_attr;
|
|
||||||
sb_attr.block = HW(128, 1);
|
|
||||||
sb_attr.padding.prepended = HW(128, 0);
|
|
||||||
sb_attr.padding.appended = HW(161, 0);
|
|
||||||
|
|
||||||
DepthwiseConvolution2DAttributes dw_attr;
|
|
||||||
dw_attr.padding.prepended = HW(0, 0);
|
|
||||||
dw_attr.padding.appended = HW(0, 0);
|
|
||||||
dw_attr.strides = HW(1, 1);
|
|
||||||
dw_attr.dilations = HW(1, 1);
|
|
||||||
dw_attr.weights.shape = OHWI(1, 3, 1, 17);
|
|
||||||
dw_attr.bias.shape = Linear(96);
|
|
||||||
|
|
||||||
BatchToSpaceAttributes bs_attr;
|
|
||||||
bs_attr.block = HW(128, 1);
|
|
||||||
bs_attr.crop.prepended = HW(0, 0);
|
|
||||||
bs_attr.crop.appended = HW(33, 0);
|
|
||||||
|
|
||||||
auto sb_node = graph.NewNode();
|
|
||||||
sb_node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
|
|
||||||
sb_node->operation.attributes = sb_attr;
|
|
||||||
auto dw_node = graph.NewNode();
|
|
||||||
dw_node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
|
|
||||||
dw_node->operation.attributes = dw_attr;
|
|
||||||
auto bs_node = graph.NewNode();
|
|
||||||
bs_node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
|
|
||||||
bs_node->operation.attributes = bs_attr;
|
|
||||||
|
|
||||||
ASSERT_TRUE(graph.AddConsumer(sb_node->id, input->id).ok());
|
|
||||||
|
|
||||||
Value<TensorRef<BHWC>>* output;
|
|
||||||
ASSERT_TRUE(AddOutput(&graph, bs_node, &output).ok());
|
|
||||||
output->tensor.shape = BHWC(1, 95, 1, 17);
|
|
||||||
|
|
||||||
Value<TensorRef<BHWC>>* sb_link;
|
|
||||||
ASSERT_TRUE(ConnectTwoNodes(&graph, sb_node, dw_node, &sb_link).ok());
|
|
||||||
sb_link->tensor.shape = BHWC(21, 128, 1, 17);
|
|
||||||
|
|
||||||
Value<TensorRef<BHWC>>* bs_link;
|
|
||||||
ASSERT_TRUE(ConnectTwoNodes(&graph, dw_node, bs_node, &bs_link).ok());
|
|
||||||
bs_link->tensor.shape = BHWC(1, 95, 1, 17);
|
|
||||||
|
|
||||||
ASSERT_EQ(graph.nodes().size(), 3);
|
|
||||||
ASSERT_EQ(graph.values().size(), 4);
|
|
||||||
|
|
||||||
auto transformation = NewMatchDilatedConvolution();
|
|
||||||
ModelTransformer transformer(&graph, nullptr);
|
|
||||||
transformer.Apply("match_dilated_convolution", transformation.get());
|
|
||||||
|
|
||||||
ASSERT_EQ(graph.nodes().size(), 1);
|
|
||||||
ASSERT_EQ(graph.values().size(), 2);
|
|
||||||
ASSERT_EQ(graph.nodes()[0]->operation.type,
|
|
||||||
ToString(OperationType::DEPTHWISE_CONVOLUTION));
|
|
||||||
|
|
||||||
auto updated_dw_attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
|
|
||||||
graph.nodes()[0]->operation.attributes);
|
|
||||||
EXPECT_EQ(updated_dw_attr.padding.prepended, HW(128, 0));
|
|
||||||
EXPECT_EQ(updated_dw_attr.padding.appended, HW(128, 0));
|
|
||||||
EXPECT_EQ(updated_dw_attr.dilations, HW(128, 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace gpu
|
|
||||||
} // namespace tflite
|
|
Loading…
Reference in New Issue
Block a user