When converting the layout of Conv2DBackpropInput, we need to permute one of
its inputs, which is a constant node. We permute a copy of this node, instead of the original node, because the original node may be used as input to other nodes. This kind of sharing of const node could arise if the graph is pre-optimized by common subexpression elimination, which is part of the L1 optimizations in TensorFlow. PiperOrigin-RevId: 158037552
This commit is contained in:
parent
88bdb6fca2
commit
cc411f9387
@ -245,20 +245,20 @@ class NodeProcessor {
|
||||
virtual Status AddLayoutTransposeToInputs() {
|
||||
std::vector<int> input_pos = GetInputPos();
|
||||
for (const auto& pos : input_pos) {
|
||||
string node_name_NHWCToNCHW = strings::StrCat(
|
||||
kTransposeNHWCToNCHW, "-", node_->name(), "-", node_->input(pos));
|
||||
string base_name = strings::StrCat(node_->name(), "-", node_->input(pos));
|
||||
string node_name =
|
||||
AddPrefixToNodeName(base_name, kTransposeNHWCToNCHW, "-");
|
||||
auto input_node = node_map_->GetNode(node_->input(pos));
|
||||
int output_pos = NodePosition(node_->input(pos));
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
|
||||
AddNodeTranspose(
|
||||
node_name_NHWCToNCHW, node_->input(pos), node_->attr().at("T").type(),
|
||||
node_name, node_->input(pos), node_->attr().at("T").type(),
|
||||
input_node->attr().at("_output_shapes").list().shape(output_pos),
|
||||
true);
|
||||
node_map_->UpdateOutput(node_->input(pos), node_->name(),
|
||||
node_name_NHWCToNCHW);
|
||||
node_map_->AddOutput(node_name_NHWCToNCHW, node_->name());
|
||||
*node_->mutable_input(pos) = node_name_NHWCToNCHW;
|
||||
node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name);
|
||||
node_map_->AddOutput(node_name, node_->name());
|
||||
*node_->mutable_input(pos) = node_name;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -266,9 +266,10 @@ class NodeProcessor {
|
||||
virtual Status AddLayoutTransposeToOutputs() {
|
||||
auto outputs = node_map_->GetOutputs(node_->name());
|
||||
for (const auto& output : outputs) {
|
||||
string node_name_NCHWToNHWC = strings::StrCat(
|
||||
kTransposeNCHWToNHWC, "-", node_->name(), "-", output->name());
|
||||
// TODO (yaozhang): handle the rare case where node A is connected to more
|
||||
string base_name = strings::StrCat(node_->name(), "-", output->name());
|
||||
string node_name =
|
||||
AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-");
|
||||
// TODO(yaozhang): handle the rare case where node A is connected to more
|
||||
// than one input of node B.
|
||||
auto it = std::find_if(output->mutable_input()->begin(),
|
||||
output->mutable_input()->end(),
|
||||
@ -290,13 +291,12 @@ class NodeProcessor {
|
||||
}
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
|
||||
AddNodeTranspose(
|
||||
node_name_NCHWToNHWC, node_->name(), node_->attr().at("T").type(),
|
||||
node_->attr().at("_output_shapes").list().shape(0), false);
|
||||
*it = node_name_NCHWToNHWC;
|
||||
node_map_->UpdateOutput(node_->name(), output->name(),
|
||||
node_name_NCHWToNHWC);
|
||||
node_map_->AddOutput(node_name_NCHWToNHWC, output->name());
|
||||
AddNodeTranspose(node_name, node_->name(), node_->attr().at("T").type(),
|
||||
node_->attr().at("_output_shapes").list().shape(0),
|
||||
false);
|
||||
*it = node_name;
|
||||
node_map_->UpdateOutput(node_->name(), output->name(), node_name);
|
||||
node_map_->AddOutput(node_name, output->name());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -468,7 +468,13 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor {
|
||||
|
||||
Status CustomizedProcessing() override {
|
||||
NodeDef* node = node_map_->GetNode(node_->input(0));
|
||||
return UpdateAttrValue(node);
|
||||
NodeDef* added_node = graph_->add_node();
|
||||
*added_node = *node;
|
||||
string node_name =
|
||||
AddPrefixToNodeName(node->name(), "LayoutOptimizer", "-");
|
||||
added_node->set_name(node_name);
|
||||
node_map_->AddNode(node_name, added_node);
|
||||
return UpdateAttrValue(added_node);
|
||||
}
|
||||
};
|
||||
|
||||
@ -621,9 +627,11 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
|
||||
|
||||
Status CustomizedProcessing() override {
|
||||
if (is_4d_with_vector_) {
|
||||
string suffix = strings::StrCat("-", node_->name(), "-", node_->input(1));
|
||||
string reshape_node_name = strings::StrCat(kReshapeNHWCToNCHW, suffix);
|
||||
string shape_const_node_name = strings::StrCat(kReshapeConst, suffix);
|
||||
string base_name = strings::StrCat(node_->name(), "-", node_->input(1));
|
||||
string reshape_node_name =
|
||||
AddPrefixToNodeName(base_name, kReshapeNHWCToNCHW, "-");
|
||||
string shape_const_node_name =
|
||||
AddPrefixToNodeName(base_name, kReshapeConst, "-");
|
||||
auto input_node = node_map_->GetNode(node_->input(1));
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
|
||||
int vector_size =
|
||||
@ -710,15 +718,15 @@ class SliceProcessor : public AgnosticNodeProcessor {
|
||||
Status CustomizedProcessing() override {
|
||||
// Skip the first input, which is the data to be sliced.
|
||||
for (int i = 1; i < node_->input_size(); i++) {
|
||||
string node_name_NHWCToNCHW =
|
||||
strings::StrCat(kPermVecNHWCToNCHW, "-", node_->name(), "-input", i);
|
||||
string base_name = strings::StrCat(node_->name(), "-input", i);
|
||||
string node_name =
|
||||
AddPrefixToNodeName(base_name, kPermVecNHWCToNCHW, "-");
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "Index"));
|
||||
AddNodePermVec(node_name_NHWCToNCHW, node_->input(i),
|
||||
AddNodePermVec(node_name, node_->input(i),
|
||||
node_->attr().at("Index").type(), true);
|
||||
node_map_->UpdateOutput(node_->input(i), node_->name(),
|
||||
node_name_NHWCToNCHW);
|
||||
node_map_->AddOutput(node_name_NHWCToNCHW, node_->name());
|
||||
*node_->mutable_input(i) = node_name_NHWCToNCHW;
|
||||
node_map_->UpdateOutput(node_->input(i), node_->name(), node_name);
|
||||
node_map_->AddOutput(node_name, node_->name());
|
||||
*node_->mutable_input(i) = node_name;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -795,7 +803,7 @@ class SliceProcessorConcatOffset : public AgnosticNodeProcessor {
|
||||
"input 1 of ConcatOffset"));
|
||||
}
|
||||
// Need to process if the channel is at dimension 3, which indicates the
|
||||
// NHWC format is being used. As mutiple Slice nodes may share the same
|
||||
// NHWC format is being used. As multiple Slice nodes may share the same
|
||||
// ConcatOffset node, the NHWC to NCHW conversion may have already
|
||||
// been performed when processing other Slice nodes.
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*axis_node, "value"));
|
||||
|
@ -36,8 +36,8 @@ void AddOutputShape(Node* node, const TensorShape& shape) {
|
||||
|
||||
class LayoutOptimizerTest : public ::testing::Test {
|
||||
protected:
|
||||
Output SimpleConv(tensorflow::Scope* s, int input_size, int filter_size,
|
||||
const string& padding) {
|
||||
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
||||
const string& padding) {
|
||||
int batch_size = 128;
|
||||
int input_height = input_size;
|
||||
int input_width = input_size;
|
||||
@ -65,11 +65,80 @@ class LayoutOptimizerTest : public ::testing::Test {
|
||||
AddOutputShape(conv.node(), input_shape);
|
||||
return conv;
|
||||
}
|
||||
|
||||
Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
|
||||
int filter_size, const string& padding) {
|
||||
int batch_size = 128;
|
||||
int input_height = input_size;
|
||||
int input_width = input_size;
|
||||
int input_depth = 3;
|
||||
int filter_count = 2;
|
||||
int stride = 1;
|
||||
TensorShape input_sizes_shape({4});
|
||||
Tensor input_data(DT_INT32, input_sizes_shape);
|
||||
test::FillValues<int>(&input_data,
|
||||
{batch_size, input_height, input_width, input_depth});
|
||||
Output input_sizes =
|
||||
ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
|
||||
AddOutputShape(input_sizes.node(), input_sizes_shape);
|
||||
|
||||
TensorShape filter_shape(
|
||||
{filter_size, filter_size, input_depth, filter_count});
|
||||
Tensor filter_data(DT_FLOAT, filter_shape);
|
||||
test::FillIota<float>(&filter_data, 1.0f);
|
||||
Output filter =
|
||||
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
|
||||
AddOutputShape(filter.node(), filter_shape);
|
||||
|
||||
int output_height = input_height;
|
||||
int output_width = input_width;
|
||||
TensorShape output_shape(
|
||||
{batch_size, output_height, output_width, filter_count});
|
||||
Tensor output_data(DT_FLOAT, output_shape);
|
||||
test::FillIota<float>(&output_data, 1.0f);
|
||||
Output output =
|
||||
ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
|
||||
AddOutputShape(output.node(), output_shape);
|
||||
|
||||
Output conv_backprop_input = ops::Conv2DBackpropInput(
|
||||
s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
|
||||
{1, stride, stride, 1}, padding);
|
||||
TensorShape input_shape(
|
||||
{batch_size, input_height, input_width, input_depth});
|
||||
AddOutputShape(conv_backprop_input.node(), input_shape);
|
||||
return conv_backprop_input;
|
||||
}
|
||||
|
||||
Tensor GetAttrValue(const NodeDef& node) {
|
||||
Tensor tensor;
|
||||
CHECK(tensor.FromProto(node.attr().at({"value"}).tensor()));
|
||||
return tensor;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME");
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
LayoutOptimizer optimizer;
|
||||
optimizer.set_num_gpus(1);
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
NodeMap node_map(&output);
|
||||
auto input_sizes_node = node_map.GetNode(
|
||||
AddPrefixToNodeName("InputSizes", "LayoutOptimizer", "-"));
|
||||
CHECK(input_sizes_node);
|
||||
auto input_sizes = GetAttrValue(*input_sizes_node);
|
||||
Tensor input_sizes_expected(DT_INT32, {4});
|
||||
test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7});
|
||||
test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
|
||||
}
|
||||
|
||||
TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto conv = SimpleConv(&s, 2, 1, "SAME");
|
||||
auto conv = SimpleConv2D(&s, 2, 1, "SAME");
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
@ -84,7 +153,7 @@ TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
|
||||
|
||||
TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto conv = SimpleConv(&s, 2, 1, "SAME");
|
||||
auto conv = SimpleConv2D(&s, 2, 1, "SAME");
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
@ -99,7 +168,7 @@ TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
|
||||
|
||||
TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto conv = SimpleConv(&s, 2, 2, "VALID");
|
||||
auto conv = SimpleConv2D(&s, 2, 2, "VALID");
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
@ -114,7 +183,7 @@ TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
|
||||
|
||||
TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto conv = SimpleConv(&s, 2, 2, "SAME");
|
||||
auto conv = SimpleConv2D(&s, 2, 2, "SAME");
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
@ -129,7 +198,7 @@ TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
|
||||
|
||||
TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto conv = SimpleConv(&s, 2, 3, "VALID");
|
||||
auto conv = SimpleConv2D(&s, 2, 3, "VALID");
|
||||
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
@ -98,13 +98,18 @@ int NodePosition(const string& name) {
|
||||
return position;
|
||||
}
|
||||
|
||||
string AddPrefixToNodeName(const string& name, const string& prefix) {
|
||||
string AddPrefixToNodeName(const string& name, const string& prefix,
|
||||
const string& delimiter) {
|
||||
if (!name.empty()) {
|
||||
if (name[0] == '^') {
|
||||
return strings::StrCat("^", prefix, "/", name.substr(1));
|
||||
return strings::StrCat("^", prefix, delimiter, name.substr(1));
|
||||
}
|
||||
}
|
||||
return strings::StrCat(prefix, "/", name);
|
||||
return strings::StrCat(prefix, delimiter, name);
|
||||
}
|
||||
|
||||
string AddPrefixToNodeName(const string& name, const string& prefix) {
|
||||
return AddPrefixToNodeName(name, prefix, "/");
|
||||
}
|
||||
|
||||
bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
|
||||
|
@ -60,7 +60,11 @@ int NodePosition(const string& name);
|
||||
// Returns the node name and position in a single call.
|
||||
string ParseNodeName(const string& name, int* position);
|
||||
|
||||
// Add a prefix to a node name
|
||||
// Add a prefix to a node name with a custom delimiter.
|
||||
string AddPrefixToNodeName(const string& name, const string& prefix,
|
||||
const string& delimiter);
|
||||
|
||||
// Add a prefix to a node name.
|
||||
string AddPrefixToNodeName(const string& name, const string& prefix);
|
||||
|
||||
// Executes a 'fn' in the 'thread_pool'. The method waits for the configured
|
||||
|
Loading…
Reference in New Issue
Block a user