Graph transform to flatten atrous (dilated) convolutions (i.e., a sequence of SpaceToBatchND-Conv-BatchToSpaceND ops) to a regular Conv op with upsampled filters.
PiperOrigin-RevId: 168414124
This commit is contained in:
parent
3438981ca7
commit
86211d5545
@ -94,6 +94,7 @@ cc_library(
|
|||||||
"add_default_attributes.cc",
|
"add_default_attributes.cc",
|
||||||
"backports.cc",
|
"backports.cc",
|
||||||
"fake_quantize_training.cc",
|
"fake_quantize_training.cc",
|
||||||
|
"flatten_atrous.cc",
|
||||||
"fold_batch_norms.cc",
|
"fold_batch_norms.cc",
|
||||||
"fold_constants_lib.cc",
|
"fold_constants_lib.cc",
|
||||||
"fold_old_batch_norms.cc",
|
"fold_old_batch_norms.cc",
|
||||||
@ -145,6 +146,7 @@ tf_cc_test(
|
|||||||
"add_default_attributes_test.cc",
|
"add_default_attributes_test.cc",
|
||||||
"backports_test.cc",
|
"backports_test.cc",
|
||||||
"fake_quantize_training_test.cc",
|
"fake_quantize_training_test.cc",
|
||||||
|
"flatten_atrous_test.cc",
|
||||||
"fold_batch_norms_test.cc",
|
"fold_batch_norms_test.cc",
|
||||||
"fold_constants_test.cc",
|
"fold_constants_test.cc",
|
||||||
"fold_old_batch_norms_test.cc",
|
"fold_old_batch_norms_test.cc",
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
* [Transform Reference](#transform-reference)
|
* [Transform Reference](#transform-reference)
|
||||||
* [add_default_attributes](#add_default_attributes)
|
* [add_default_attributes](#add_default_attributes)
|
||||||
* [backport_concatv2](#backport_concatv2)
|
* [backport_concatv2](#backport_concatv2)
|
||||||
|
* [flatten_atrous_conv](#flatten_atrous_conv)
|
||||||
* [fold_batch_norms](#fold_batch_norms)
|
* [fold_batch_norms](#fold_batch_norms)
|
||||||
* [fold_constants](#fold_constants)
|
* [fold_constants](#fold_constants)
|
||||||
* [fold_old_batch_norms](#fold_old_batch_norms)
|
* [fold_old_batch_norms](#fold_old_batch_norms)
|
||||||
@ -354,6 +355,20 @@ TensorFlow framework and includes ConcatV2, and you want to run it on an older
|
|||||||
version that only supports Concat, this transform will take care of converting
|
version that only supports Concat, this transform will take care of converting
|
||||||
those newer ops to the equivalent older form.
|
those newer ops to the equivalent older form.
|
||||||
|
|
||||||
|
### flatten_atrous_conv
|
||||||
|
|
||||||
|
Args: None \
|
||||||
|
Prerequisites: [fold_constants](#fold_constants)
|
||||||
|
|
||||||
|
This transform flattens atrous convolution, corresponding to a sequence of
|
||||||
|
SpaceToBatchND-Conv2D-BatchToSpaceND operations, converting it to a regular
|
||||||
|
Conv2D op with upsampled filters. This transforms should only be used in order
|
||||||
|
to run graphs having atrous convolution on platforms that do not yet natively
|
||||||
|
support SpaceToBatchND and BatchToSpaceND operations. You will need to make
|
||||||
|
sure you run [fold_constants](#fold_constants) after this transform. If
|
||||||
|
applicable, you should run this transform before
|
||||||
|
[fold_batch_norms](#fold_batch_norms).
|
||||||
|
|
||||||
### fold_batch_norms
|
### fold_batch_norms
|
||||||
|
|
||||||
Args: None \
|
Args: None \
|
||||||
|
141
tensorflow/tools/graph_transforms/flatten_atrous.cc
Normal file
141
tensorflow/tools/graph_transforms/flatten_atrous.cc
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
Status FlattenAtrousConv(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
GraphDef replaced_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, // clang-format off
|
||||||
|
{"BatchToSpaceND",
|
||||||
|
{
|
||||||
|
{"Conv2D|DepthwiseConv2dNative",
|
||||||
|
{
|
||||||
|
{"SpaceToBatchND",
|
||||||
|
{
|
||||||
|
{"*"}, // Input to the flattened op.
|
||||||
|
{"*"}, // block_shape
|
||||||
|
{"*"} // paddings
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"*"} // filter
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"*"}, // block_shape
|
||||||
|
{"*"} // crops
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
// Find all the nodes we expect in the subgraph.
|
||||||
|
const NodeDef& batch_to_space_node = match.node;
|
||||||
|
const NodeDef& conv_node = match.inputs[0].node;
|
||||||
|
const NodeDef& filter_node = match.inputs[0].inputs[1].node;
|
||||||
|
const NodeDef& input_node = match.inputs[0].inputs[0].inputs[0].node;
|
||||||
|
const NodeDef& space_to_batch_block_shape_node =
|
||||||
|
match.inputs[0].inputs[0].inputs[1].node;
|
||||||
|
|
||||||
|
// The atrous rate value is inferred from the block shape.
|
||||||
|
Tensor block_shape =
|
||||||
|
GetNodeTensorAttr(space_to_batch_block_shape_node, "value");
|
||||||
|
const int32 block_height = block_shape.flat<int32>()(0);
|
||||||
|
const int32 block_width = block_shape.flat<int32>()(1);
|
||||||
|
|
||||||
|
// Compute the upsampled filter.
|
||||||
|
const Tensor& filter = GetNodeTensorAttr(filter_node, "value");
|
||||||
|
const int32 filter_height = filter.dim_size(0);
|
||||||
|
const int32 filter_width = filter.dim_size(1);
|
||||||
|
const int32 in_channels = filter.dim_size(2);
|
||||||
|
const int32 out_channels = filter.dim_size(3);
|
||||||
|
|
||||||
|
const int32 upsampled_filter_height =
|
||||||
|
(filter_height - 1) * block_height + 1;
|
||||||
|
const int32 upsampled_filter_width =
|
||||||
|
(filter_width - 1) * block_width + 1;
|
||||||
|
Tensor upsampled_filter(
|
||||||
|
DT_FLOAT,
|
||||||
|
TensorShape({upsampled_filter_height, upsampled_filter_width,
|
||||||
|
in_channels, out_channels}));
|
||||||
|
|
||||||
|
auto filter_eigen = filter.tensor<float, 4>();
|
||||||
|
auto upsampled_filter_eigen = upsampled_filter.tensor<float, 4>();
|
||||||
|
|
||||||
|
upsampled_filter_eigen.setZero();
|
||||||
|
for (int h = 0; h < filter_height; ++h) {
|
||||||
|
for (int w = 0; w < filter_width; ++w) {
|
||||||
|
for (int c_in = 0; c_in < in_channels; ++c_in) {
|
||||||
|
for (int c_out = 0; c_out < out_channels; ++c_out) {
|
||||||
|
upsampled_filter_eigen(block_height * h, block_width * w, c_in,
|
||||||
|
c_out) = filter_eigen(h, w, c_in, c_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef upsampled_filter_node;
|
||||||
|
upsampled_filter_node.set_op("Const");
|
||||||
|
upsampled_filter_node.set_name(filter_node.name());
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &upsampled_filter_node);
|
||||||
|
SetNodeTensorAttr<float>("value", upsampled_filter,
|
||||||
|
&upsampled_filter_node);
|
||||||
|
|
||||||
|
// Set up the new flattened version of the convolution op.
|
||||||
|
NodeDef flattened_conv_node;
|
||||||
|
|
||||||
|
flattened_conv_node.set_name(batch_to_space_node.name());
|
||||||
|
flattened_conv_node.set_op(conv_node.op());
|
||||||
|
flattened_conv_node.set_device(conv_node.device());
|
||||||
|
|
||||||
|
AddNodeInput(input_node.name(), &flattened_conv_node);
|
||||||
|
AddNodeInput(upsampled_filter_node.name(), &flattened_conv_node);
|
||||||
|
|
||||||
|
CopyNodeAttr(conv_node, "T", "T", &flattened_conv_node);
|
||||||
|
CopyNodeAttr(conv_node, "strides", "strides", &flattened_conv_node);
|
||||||
|
SetNodeAttr("padding", "SAME", &flattened_conv_node);
|
||||||
|
CopyNodeAttr(conv_node, "data_format", "data_format",
|
||||||
|
&flattened_conv_node);
|
||||||
|
|
||||||
|
if (conv_node.op() == "Conv2D") {
|
||||||
|
CopyNodeAttr(conv_node, "use_cudnn_on_gpu", "use_cudnn_on_gpu",
|
||||||
|
&flattened_conv_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
new_nodes->push_back(input_node);
|
||||||
|
new_nodes->push_back(upsampled_filter_node);
|
||||||
|
new_nodes->push_back(flattened_conv_node);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &replaced_graph_def));
|
||||||
|
*output_graph_def = replaced_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("flatten_atrous_conv", FlattenAtrousConv);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
121
tensorflow/tools/graph_transforms/flatten_atrous_test.cc
Normal file
121
tensorflow/tools/graph_transforms/flatten_atrous_test.cc
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/array_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status FlattenAtrousConv(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class FlattenAtrousConvTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
template <class TConvOp>
|
||||||
|
void TestFlattenAtrousConv() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
Tensor input_data(DT_FLOAT, TensorShape({1, 3, 3, 2}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&input_data, {.1f, .4f, .2f, .5f, .3f, .6f, -1.0f, -.4f, -.2f, -.5f,
|
||||||
|
-.3f, -.6f, .1f, .4f, .2f, .5f, .3f, .6f});
|
||||||
|
Output input_op =
|
||||||
|
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
|
||||||
|
|
||||||
|
Tensor block_shape_data(DT_INT32, TensorShape({2}));
|
||||||
|
test::FillValues<int>(&block_shape_data, {2, 2});
|
||||||
|
Output block_shape_op = Const(root.WithOpName("block_shape_op"),
|
||||||
|
Input::Initializer(block_shape_data));
|
||||||
|
|
||||||
|
Tensor paddings_data(DT_INT32, TensorShape({2, 2}));
|
||||||
|
test::FillValues<int>(&paddings_data, {1, 2, 1, 2});
|
||||||
|
Output paddings_op = Const(root.WithOpName("paddings_op"),
|
||||||
|
Input::Initializer(paddings_data));
|
||||||
|
|
||||||
|
Output space_to_batch_op =
|
||||||
|
SpaceToBatchND(root.WithOpName("space_to_batch_op"), input_op,
|
||||||
|
block_shape_op, paddings_op);
|
||||||
|
|
||||||
|
Tensor weights_data(DT_FLOAT, TensorShape({2, 2, 2, 1}));
|
||||||
|
test::FillValues<float>(&weights_data,
|
||||||
|
{.1f, .2f, .3f, .4f, .1f, .2f, .3f, .4f});
|
||||||
|
Output weights_op =
|
||||||
|
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
|
||||||
|
|
||||||
|
Output conv_op = TConvOp(root.WithOpName("conv_op"), space_to_batch_op,
|
||||||
|
weights_op, {1, 1, 1, 1}, "VALID");
|
||||||
|
|
||||||
|
Tensor crops_data(DT_INT32, TensorShape({2, 2}));
|
||||||
|
test::FillValues<int>(&crops_data, {0, 1, 0, 1});
|
||||||
|
Output crops_op =
|
||||||
|
Const(root.WithOpName("crops_op"), Input::Initializer(crops_data));
|
||||||
|
|
||||||
|
Output batch_to_space_op = BatchToSpaceND(
|
||||||
|
root.WithOpName("output"), conv_op, block_shape_op, crops_op);
|
||||||
|
|
||||||
|
GraphDef original_graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(original_session->Create(original_graph_def));
|
||||||
|
std::vector<Tensor> original_outputs;
|
||||||
|
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
|
||||||
|
|
||||||
|
GraphDef modified_graph_def;
|
||||||
|
TF_ASSERT_OK(FlattenAtrousConv(original_graph_def, {{}, {"output"}},
|
||||||
|
&modified_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> modified_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(modified_session->Create(modified_graph_def));
|
||||||
|
std::vector<Tensor> modified_outputs;
|
||||||
|
TF_ASSERT_OK(modified_session->Run({}, {"output"}, {}, &modified_outputs));
|
||||||
|
|
||||||
|
EXPECT_EQ(3, modified_graph_def.node_size());
|
||||||
|
|
||||||
|
EXPECT_EQ("input_op", modified_graph_def.node(0).name());
|
||||||
|
EXPECT_EQ("weights_op", modified_graph_def.node(1).name());
|
||||||
|
EXPECT_EQ("output", modified_graph_def.node(2).name());
|
||||||
|
|
||||||
|
EXPECT_EQ("Const", modified_graph_def.node(0).op());
|
||||||
|
EXPECT_EQ("Const", modified_graph_def.node(1).op());
|
||||||
|
EXPECT_EQ(conv_op.node()->type_string(), modified_graph_def.node(2).op());
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(original_outputs[0], modified_outputs[0],
|
||||||
|
1e-6);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FlattenAtrousConvTest, TestFlattenAtrousConv2D) {
|
||||||
|
TestFlattenAtrousConv<::tensorflow::ops::Conv2D>();
|
||||||
|
}
|
||||||
|
TEST_F(FlattenAtrousConvTest, TestFlattenAtrousDepthwiseConv2dNative) {
|
||||||
|
TestFlattenAtrousConv<::tensorflow::ops::DepthwiseConv2dNative>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user