Added a memory optimizer to grappler.

Change: 152184170
This commit is contained in:
Benoit Steiner 2017-04-04 13:57:59 -08:00 committed by TensorFlower Gardener
parent 7284cd8615
commit 98f6dbe9fa
4 changed files with 230 additions and 0 deletions

View File

@ -117,6 +117,37 @@ cc_test(
], ],
) )
cc_library(
name = "memory_optimizer",
srcs = ["memory_optimizer.cc"],
hdrs = [
"memory_optimizer.h",
],
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
":graph_rewriter",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
],
)
cc_test(
name = "memory_optimizer_test",
srcs = ["memory_optimizer_test.cc"],
deps = [
":memory_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
],
)
cc_library( cc_library(
name = "layout_optimizer", name = "layout_optimizer",
srcs = ["layout_optimizer.cc"], srcs = ["layout_optimizer.cc"],

View File

@ -0,0 +1,83 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include <unordered_set>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
GraphDef* graph) {
string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
// Force the tensor to be copied to cpu.
NodeDef* swap_out_node = graph->add_node();
swap_out_node->set_name(strings::StrCat("swap_out_", tensor_to_swap));
swap_out_node->set_op("Identity");
swap_out_node->set_device("/CPU");
// Force the tensor to be restored to the device.
NodeDef* swap_in_node = graph->add_node();
swap_in_node->set_name(strings::StrCat("swap_in_", tensor_to_swap));
swap_in_node->set_op("Identity");
*swap_in_node->add_input() = swap_out_node->name();
// Colocate the swap_in_ node with the node itself.
string coloc_group = strings::StrCat("loc@", tensor_to_swap);
(*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
(*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
return std::make_pair(swap_out_node, swap_in_node);
}
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;
for (auto& node : *optimized_graph->mutable_node()) {
if (node.attr().count("swap_to_host") == 0) {
continue;
}
// Swap all the tensors that are marked with the 'swap_to_host' attribute.
for (int input_id : node.attr().at("swap_to_host").list().i()) {
std::pair<NodeDef*, NodeDef*> swap_nodes =
BuildSwapPair(&node, input_id, optimized_graph);
*swap_nodes.first->add_input() = node.input(input_id);
*node.mutable_input(input_id) = swap_nodes.second->name();
// TODO(bsteiner): Make sure the tensor isn't swapped back in right away
// by adding a control dependency to delay the execution of the swap.
// string trigger;
//*swap_nodes.second->add_input() = strings::StrCat("^", trigger);
}
}
return Status::OK();
}
void MemoryOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) {
// Nothing to do for MemoryOptimizer.
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,42 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
namespace tensorflow {
namespace grappler {
// Swap tensors in and out of device memory.
class MemoryOptimizer : public GraphOptimizer {
public:
MemoryOptimizer() {}
~MemoryOptimizer() override {}
string name() const override { return "memory_optimizer"; };
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* pruned_graph) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& pruned_graph, double result) override;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_

View File

@ -0,0 +1,74 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
class MemoryOptimizerTest : public ::testing::Test {};
TEST_F(MemoryOptimizerTest, SimpleSwapping) {
// Build a simple graph with an op that's marked for swapping.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a});
Output c = ops::AddN(s.WithOpName("c"), {b});
Output d = ops::AddN(s.WithOpName("d"), {c});
Output e = ops::AddN(s.WithOpName("e"), {b, d});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
EXPECT_EQ(5, item.graph.node_size());
EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name());
AttrValue& val =
(*item.graph.mutable_node(4)->mutable_attr())["swap_to_host"];
val.mutable_list()->add_i(0);
MemoryOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(7, output.node_size());
const NodeDef& new_e = output.node(4);
EXPECT_EQ(NodeName(e.name()), new_e.name());
EXPECT_EQ(2, new_e.input_size());
EXPECT_EQ(NodeName(d.name()), new_e.input(1));
EXPECT_EQ("swap_in_e_0", new_e.input(0));
const NodeDef& swap_out = output.node(5);
EXPECT_EQ("swap_out_e_0", swap_out.name());
const NodeDef& swap_in = output.node(6);
EXPECT_EQ("swap_in_e_0", swap_in.name());
EXPECT_EQ(NodeName(b.name()), swap_out.input(0));
EXPECT_EQ(NodeName(swap_out.name()), swap_in.input(0));
}
} // namespace
} // namespace grappler
} // namespace tensorflow