Graph transform to flatten atrous (dilated) convolutions (i.e., a sequence of SpaceToBatchND-Conv-BatchToSpaceND ops) to a regular Conv op with upsampled filters.
PiperOrigin-RevId: 168414124
This commit is contained in:
parent
3438981ca7
commit
86211d5545
@ -94,6 +94,7 @@ cc_library(
|
||||
"add_default_attributes.cc",
|
||||
"backports.cc",
|
||||
"fake_quantize_training.cc",
|
||||
"flatten_atrous.cc",
|
||||
"fold_batch_norms.cc",
|
||||
"fold_constants_lib.cc",
|
||||
"fold_old_batch_norms.cc",
|
||||
@ -145,6 +146,7 @@ tf_cc_test(
|
||||
"add_default_attributes_test.cc",
|
||||
"backports_test.cc",
|
||||
"fake_quantize_training_test.cc",
|
||||
"flatten_atrous_test.cc",
|
||||
"fold_batch_norms_test.cc",
|
||||
"fold_constants_test.cc",
|
||||
"fold_old_batch_norms_test.cc",
|
||||
|
@ -14,6 +14,7 @@
|
||||
* [Transform Reference](#transform-reference)
|
||||
* [add_default_attributes](#add_default_attributes)
|
||||
* [backport_concatv2](#backport_concatv2)
|
||||
* [flatten_atrous_conv](#flatten_atrous_conv)
|
||||
* [fold_batch_norms](#fold_batch_norms)
|
||||
* [fold_constants](#fold_constants)
|
||||
* [fold_old_batch_norms](#fold_old_batch_norms)
|
||||
@ -354,6 +355,20 @@ TensorFlow framework and includes ConcatV2, and you want to run it on an older
|
||||
version that only supports Concat, this transform will take care of converting
|
||||
those newer ops to the equivalent older form.
|
||||
|
||||
### flatten_atrous_conv
|
||||
|
||||
Args: None \
|
||||
Prerequisites: [fold_constants](#fold_constants)
|
||||
|
||||
This transform flattens atrous convolution, corresponding to a sequence of
|
||||
SpaceToBatchND-Conv2D-BatchToSpaceND operations, converting it to a regular
|
||||
Conv2D op with upsampled filters. This transforms should only be used in order
|
||||
to run graphs having atrous convolution on platforms that do not yet natively
|
||||
support SpaceToBatchND and BatchToSpaceND operations. You will need to make
|
||||
sure you run [fold_constants](#fold_constants) after this transform. If
|
||||
applicable, you should run this transform before
|
||||
[fold_batch_norms](#fold_batch_norms).
|
||||
|
||||
### fold_batch_norms
|
||||
|
||||
Args: None \
|
||||
|
141
tensorflow/tools/graph_transforms/flatten_atrous.cc
Normal file
141
tensorflow/tools/graph_transforms/flatten_atrous.cc
Normal file
@ -0,0 +1,141 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
Status FlattenAtrousConv(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
GraphDef replaced_graph_def;
|
||||
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||
input_graph_def, // clang-format off
|
||||
{"BatchToSpaceND",
|
||||
{
|
||||
{"Conv2D|DepthwiseConv2dNative",
|
||||
{
|
||||
{"SpaceToBatchND",
|
||||
{
|
||||
{"*"}, // Input to the flattened op.
|
||||
{"*"}, // block_shape
|
||||
{"*"} // paddings
|
||||
}
|
||||
},
|
||||
{"*"} // filter
|
||||
}
|
||||
},
|
||||
{"*"}, // block_shape
|
||||
{"*"} // crops
|
||||
}
|
||||
}, // clang-format on
|
||||
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||
const std::set<string>& output_nodes,
|
||||
std::vector<NodeDef>* new_nodes) {
|
||||
// Find all the nodes we expect in the subgraph.
|
||||
const NodeDef& batch_to_space_node = match.node;
|
||||
const NodeDef& conv_node = match.inputs[0].node;
|
||||
const NodeDef& filter_node = match.inputs[0].inputs[1].node;
|
||||
const NodeDef& input_node = match.inputs[0].inputs[0].inputs[0].node;
|
||||
const NodeDef& space_to_batch_block_shape_node =
|
||||
match.inputs[0].inputs[0].inputs[1].node;
|
||||
|
||||
// The atrous rate value is inferred from the block shape.
|
||||
Tensor block_shape =
|
||||
GetNodeTensorAttr(space_to_batch_block_shape_node, "value");
|
||||
const int32 block_height = block_shape.flat<int32>()(0);
|
||||
const int32 block_width = block_shape.flat<int32>()(1);
|
||||
|
||||
// Compute the upsampled filter.
|
||||
const Tensor& filter = GetNodeTensorAttr(filter_node, "value");
|
||||
const int32 filter_height = filter.dim_size(0);
|
||||
const int32 filter_width = filter.dim_size(1);
|
||||
const int32 in_channels = filter.dim_size(2);
|
||||
const int32 out_channels = filter.dim_size(3);
|
||||
|
||||
const int32 upsampled_filter_height =
|
||||
(filter_height - 1) * block_height + 1;
|
||||
const int32 upsampled_filter_width =
|
||||
(filter_width - 1) * block_width + 1;
|
||||
Tensor upsampled_filter(
|
||||
DT_FLOAT,
|
||||
TensorShape({upsampled_filter_height, upsampled_filter_width,
|
||||
in_channels, out_channels}));
|
||||
|
||||
auto filter_eigen = filter.tensor<float, 4>();
|
||||
auto upsampled_filter_eigen = upsampled_filter.tensor<float, 4>();
|
||||
|
||||
upsampled_filter_eigen.setZero();
|
||||
for (int h = 0; h < filter_height; ++h) {
|
||||
for (int w = 0; w < filter_width; ++w) {
|
||||
for (int c_in = 0; c_in < in_channels; ++c_in) {
|
||||
for (int c_out = 0; c_out < out_channels; ++c_out) {
|
||||
upsampled_filter_eigen(block_height * h, block_width * w, c_in,
|
||||
c_out) = filter_eigen(h, w, c_in, c_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NodeDef upsampled_filter_node;
|
||||
upsampled_filter_node.set_op("Const");
|
||||
upsampled_filter_node.set_name(filter_node.name());
|
||||
SetNodeAttr("dtype", DT_FLOAT, &upsampled_filter_node);
|
||||
SetNodeTensorAttr<float>("value", upsampled_filter,
|
||||
&upsampled_filter_node);
|
||||
|
||||
// Set up the new flattened version of the convolution op.
|
||||
NodeDef flattened_conv_node;
|
||||
|
||||
flattened_conv_node.set_name(batch_to_space_node.name());
|
||||
flattened_conv_node.set_op(conv_node.op());
|
||||
flattened_conv_node.set_device(conv_node.device());
|
||||
|
||||
AddNodeInput(input_node.name(), &flattened_conv_node);
|
||||
AddNodeInput(upsampled_filter_node.name(), &flattened_conv_node);
|
||||
|
||||
CopyNodeAttr(conv_node, "T", "T", &flattened_conv_node);
|
||||
CopyNodeAttr(conv_node, "strides", "strides", &flattened_conv_node);
|
||||
SetNodeAttr("padding", "SAME", &flattened_conv_node);
|
||||
CopyNodeAttr(conv_node, "data_format", "data_format",
|
||||
&flattened_conv_node);
|
||||
|
||||
if (conv_node.op() == "Conv2D") {
|
||||
CopyNodeAttr(conv_node, "use_cudnn_on_gpu", "use_cudnn_on_gpu",
|
||||
&flattened_conv_node);
|
||||
}
|
||||
|
||||
new_nodes->push_back(input_node);
|
||||
new_nodes->push_back(upsampled_filter_node);
|
||||
new_nodes->push_back(flattened_conv_node);
|
||||
|
||||
return Status::OK();
|
||||
},
|
||||
{}, &replaced_graph_def));
|
||||
*output_graph_def = replaced_graph_def;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("flatten_atrous_conv", FlattenAtrousConv);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
121
tensorflow/tools/graph_transforms/flatten_atrous_test.cc
Normal file
121
tensorflow/tools/graph_transforms/flatten_atrous_test.cc
Normal file
@ -0,0 +1,121 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Declare here, so we don't need a public header.
|
||||
Status FlattenAtrousConv(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
class FlattenAtrousConvTest : public ::testing::Test {
|
||||
protected:
|
||||
template <class TConvOp>
|
||||
void TestFlattenAtrousConv() {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
Tensor input_data(DT_FLOAT, TensorShape({1, 3, 3, 2}));
|
||||
test::FillValues<float>(
|
||||
&input_data, {.1f, .4f, .2f, .5f, .3f, .6f, -1.0f, -.4f, -.2f, -.5f,
|
||||
-.3f, -.6f, .1f, .4f, .2f, .5f, .3f, .6f});
|
||||
Output input_op =
|
||||
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
|
||||
|
||||
Tensor block_shape_data(DT_INT32, TensorShape({2}));
|
||||
test::FillValues<int>(&block_shape_data, {2, 2});
|
||||
Output block_shape_op = Const(root.WithOpName("block_shape_op"),
|
||||
Input::Initializer(block_shape_data));
|
||||
|
||||
Tensor paddings_data(DT_INT32, TensorShape({2, 2}));
|
||||
test::FillValues<int>(&paddings_data, {1, 2, 1, 2});
|
||||
Output paddings_op = Const(root.WithOpName("paddings_op"),
|
||||
Input::Initializer(paddings_data));
|
||||
|
||||
Output space_to_batch_op =
|
||||
SpaceToBatchND(root.WithOpName("space_to_batch_op"), input_op,
|
||||
block_shape_op, paddings_op);
|
||||
|
||||
Tensor weights_data(DT_FLOAT, TensorShape({2, 2, 2, 1}));
|
||||
test::FillValues<float>(&weights_data,
|
||||
{.1f, .2f, .3f, .4f, .1f, .2f, .3f, .4f});
|
||||
Output weights_op =
|
||||
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
|
||||
|
||||
Output conv_op = TConvOp(root.WithOpName("conv_op"), space_to_batch_op,
|
||||
weights_op, {1, 1, 1, 1}, "VALID");
|
||||
|
||||
Tensor crops_data(DT_INT32, TensorShape({2, 2}));
|
||||
test::FillValues<int>(&crops_data, {0, 1, 0, 1});
|
||||
Output crops_op =
|
||||
Const(root.WithOpName("crops_op"), Input::Initializer(crops_data));
|
||||
|
||||
Output batch_to_space_op = BatchToSpaceND(
|
||||
root.WithOpName("output"), conv_op, block_shape_op, crops_op);
|
||||
|
||||
GraphDef original_graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
|
||||
|
||||
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||
TF_ASSERT_OK(original_session->Create(original_graph_def));
|
||||
std::vector<Tensor> original_outputs;
|
||||
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
|
||||
|
||||
GraphDef modified_graph_def;
|
||||
TF_ASSERT_OK(FlattenAtrousConv(original_graph_def, {{}, {"output"}},
|
||||
&modified_graph_def));
|
||||
|
||||
std::unique_ptr<Session> modified_session(NewSession(SessionOptions()));
|
||||
TF_ASSERT_OK(modified_session->Create(modified_graph_def));
|
||||
std::vector<Tensor> modified_outputs;
|
||||
TF_ASSERT_OK(modified_session->Run({}, {"output"}, {}, &modified_outputs));
|
||||
|
||||
EXPECT_EQ(3, modified_graph_def.node_size());
|
||||
|
||||
EXPECT_EQ("input_op", modified_graph_def.node(0).name());
|
||||
EXPECT_EQ("weights_op", modified_graph_def.node(1).name());
|
||||
EXPECT_EQ("output", modified_graph_def.node(2).name());
|
||||
|
||||
EXPECT_EQ("Const", modified_graph_def.node(0).op());
|
||||
EXPECT_EQ("Const", modified_graph_def.node(1).op());
|
||||
EXPECT_EQ(conv_op.node()->type_string(), modified_graph_def.node(2).op());
|
||||
|
||||
test::ExpectTensorNear<float>(original_outputs[0], modified_outputs[0],
|
||||
1e-6);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(FlattenAtrousConvTest, TestFlattenAtrousConv2D) {
|
||||
TestFlattenAtrousConv<::tensorflow::ops::Conv2D>();
|
||||
}
|
||||
TEST_F(FlattenAtrousConvTest, TestFlattenAtrousDepthwiseConv2dNative) {
|
||||
TestFlattenAtrousConv<::tensorflow::ops::DepthwiseConv2dNative>();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user