Convert SB->Conv->BS subgraph into the dilated convolution.
PiperOrigin-RevId: 246045525
This commit is contained in:
parent
0a82623b59
commit
f43d0034ec
@ -77,6 +77,7 @@ cc_library(
|
||||
":fuse_mul_to_conv",
|
||||
":make_fully_connected",
|
||||
":make_padding",
|
||||
":match_dilated_convolution",
|
||||
":merge_padding_with",
|
||||
":remove_noop",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
@ -144,6 +145,34 @@ cc_library(
|
||||
deps = ["//tensorflow/lite/delegates/gpu/common:model"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "match_dilated_convolution",
|
||||
srcs = ["match_dilated_convolution.cc"],
|
||||
hdrs = ["match_dilated_convolution.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:any",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "match_dilated_convolution_test",
|
||||
srcs = ["match_dilated_convolution_test.cc"],
|
||||
deps = [
|
||||
":match_dilated_convolution",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"@com_google_absl//absl/types:any",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "merge_padding_with",
|
||||
srcs = ["merge_padding_with.cc"],
|
||||
|
@ -0,0 +1,97 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/any.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
class MatchDilatedConvolution : public SequenceTransformation {
|
||||
public:
|
||||
int ExpectedSequenceLength() const final { return 3; }
|
||||
|
||||
// TODO(eignasheva): use span instead of const reference b/131628066.
|
||||
TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
|
||||
GraphFloat32* graph) final {
|
||||
auto& sb_node = *sequence[0];
|
||||
auto& conv_node = *sequence[1];
|
||||
auto& bs_node = *sequence[2];
|
||||
if (sb_node.operation.type != ToString(OperationType::SPACE_TO_BATCH) &&
|
||||
bs_node.operation.type != ToString(OperationType::BATCH_TO_SPACE)) {
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
if (conv_node.operation.type !=
|
||||
ToString(OperationType::DEPTHWISE_CONVOLUTION) &&
|
||||
conv_node.operation.type != ToString(OperationType::CONVOLUTION_2D)) {
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
|
||||
auto sb_attr =
|
||||
absl::any_cast<SpaceToBatchAttributes>(sb_node.operation.attributes);
|
||||
|
||||
auto bs_attr =
|
||||
absl::any_cast<BatchToSpaceAttributes>(bs_node.operation.attributes);
|
||||
|
||||
if (sb_attr.block != bs_attr.block) {
|
||||
return {TransformStatus::INVALID, "Invalid block size"};
|
||||
}
|
||||
|
||||
if (conv_node.operation.type ==
|
||||
ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
|
||||
auto dw_attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
|
||||
conv_node.operation.attributes);
|
||||
dw_attr.padding = sb_attr.padding - bs_attr.crop;
|
||||
dw_attr.dilations = sb_attr.block;
|
||||
conv_node.operation.attributes = std::move(dw_attr);
|
||||
} else {
|
||||
auto conv2d_attr = absl::any_cast<Convolution2DAttributes>(
|
||||
conv_node.operation.attributes);
|
||||
conv2d_attr.padding = sb_attr.padding - bs_attr.crop;
|
||||
conv2d_attr.dilations = sb_attr.block;
|
||||
conv_node.operation.attributes = std::move(conv2d_attr);
|
||||
}
|
||||
|
||||
Status status = RemoveFollowingNode(graph, &bs_node, &conv_node);
|
||||
if (!status.ok()) {
|
||||
return {TransformStatus::INVALID,
|
||||
"Unable to remove batch_to_space node after convolution."};
|
||||
}
|
||||
status = RemovePrecedingNode(graph, &sb_node, &conv_node);
|
||||
if (!status.ok()) {
|
||||
return {TransformStatus::INVALID,
|
||||
"Unable to remove space_to_batch node before convolution."};
|
||||
}
|
||||
|
||||
return {TransformStatus::APPLIED, ""};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<SequenceTransformation> NewMatchDilatedConvolution() {
|
||||
return absl::make_unique<MatchDilatedConvolution>();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCH_DILATED_CONVOLUTION_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCH_DILATED_CONVOLUTION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
// TF->TFLite converter converts convolution with dilation into the chain of
|
||||
// SpaceToBatch->Convolution->BatchToSpace. Our GPU backend natively supports
|
||||
// dilation in convolutions, so we try to skip this inefficiency. For more
|
||||
// information see b/131436214.
|
||||
std::unique_ptr<SequenceTransformation> NewMatchDilatedConvolution();
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_MATCH_DILATED_CONVOLUTION_H_
|
@ -0,0 +1,98 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/transformations/match_dilated_convolution.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/types/any.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
TEST(MatchDilatedConvolutionTest, MakesDilatedConvolution) {
|
||||
GraphFloat32 graph;
|
||||
auto input = graph.NewValue();
|
||||
input->tensor.shape = BHWC(1, 95, 1, 17);
|
||||
|
||||
SpaceToBatchAttributes sb_attr;
|
||||
sb_attr.block = HW(128, 1);
|
||||
sb_attr.padding.prepended = HW(128, 0);
|
||||
sb_attr.padding.appended = HW(161, 0);
|
||||
|
||||
DepthwiseConvolution2DAttributes dw_attr;
|
||||
dw_attr.padding.prepended = HW(0, 0);
|
||||
dw_attr.padding.appended = HW(0, 0);
|
||||
dw_attr.strides = HW(1, 1);
|
||||
dw_attr.dilations = HW(1, 1);
|
||||
dw_attr.weights.shape = OHWI(1, 3, 1, 17);
|
||||
dw_attr.bias.shape = Linear(96);
|
||||
|
||||
BatchToSpaceAttributes bs_attr;
|
||||
bs_attr.block = HW(128, 1);
|
||||
bs_attr.crop.prepended = HW(0, 0);
|
||||
bs_attr.crop.appended = HW(33, 0);
|
||||
|
||||
auto sb_node = graph.NewNode();
|
||||
sb_node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
|
||||
sb_node->operation.attributes = sb_attr;
|
||||
auto dw_node = graph.NewNode();
|
||||
dw_node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
|
||||
dw_node->operation.attributes = dw_attr;
|
||||
auto bs_node = graph.NewNode();
|
||||
bs_node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
|
||||
bs_node->operation.attributes = bs_attr;
|
||||
|
||||
ASSERT_TRUE(graph.AddConsumer(sb_node->id, input->id).ok());
|
||||
|
||||
Value<TensorRefFloat32>* output;
|
||||
ASSERT_TRUE(AddOutput(&graph, bs_node, &output).ok());
|
||||
output->tensor.shape = BHWC(1, 95, 1, 17);
|
||||
|
||||
Value<TensorRefFloat32>* sb_link;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, sb_node, dw_node, &sb_link).ok());
|
||||
sb_link->tensor.shape = BHWC(21, 128, 1, 17);
|
||||
|
||||
Value<TensorRefFloat32>* bs_link;
|
||||
ASSERT_TRUE(ConnectTwoNodes(&graph, dw_node, bs_node, &bs_link).ok());
|
||||
bs_link->tensor.shape = BHWC(1, 95, 1, 17);
|
||||
|
||||
ASSERT_EQ(graph.nodes().size(), 3);
|
||||
ASSERT_EQ(graph.values().size(), 4);
|
||||
|
||||
auto transformation = NewMatchDilatedConvolution();
|
||||
ModelTransformer transformer(&graph, nullptr);
|
||||
transformer.Apply("match_dilated_convolution", transformation.get());
|
||||
|
||||
ASSERT_EQ(graph.nodes().size(), 1);
|
||||
ASSERT_EQ(graph.values().size(), 2);
|
||||
ASSERT_EQ(graph.nodes()[0]->operation.type,
|
||||
ToString(OperationType::DEPTHWISE_CONVOLUTION));
|
||||
|
||||
auto updated_dw_attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
|
||||
graph.nodes()[0]->operation.attributes);
|
||||
EXPECT_EQ(updated_dw_attr.padding.prepended, HW(128, 0));
|
||||
EXPECT_EQ(updated_dw_attr.padding.appended, HW(128, 0));
|
||||
EXPECT_EQ(updated_dw_attr.dilations, HW(128, 1));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user