From 98f6dbe9fa3b7cab41c01cc0d780d05eecb0c4a2 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 4 Apr 2017 13:57:59 -0800 Subject: [PATCH] Added a memory optimizer to grappler. Change: 152184170 --- tensorflow/core/grappler/optimizers/BUILD | 31 +++++++ .../grappler/optimizers/memory_optimizer.cc | 83 +++++++++++++++++++ .../grappler/optimizers/memory_optimizer.h | 42 ++++++++++ .../optimizers/memory_optimizer_test.cc | 74 +++++++++++++++++ 4 files changed, 230 insertions(+) create mode 100644 tensorflow/core/grappler/optimizers/memory_optimizer.cc create mode 100644 tensorflow/core/grappler/optimizers/memory_optimizer.h create mode 100644 tensorflow/core/grappler/optimizers/memory_optimizer_test.cc diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index d09a3c4e304..4e41c2bb129 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -117,6 +117,37 @@ cc_test( ], ) +cc_library( + name = "memory_optimizer", + srcs = ["memory_optimizer.cc"], + hdrs = [ + "memory_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + ":graph_rewriter", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + ], +) + +cc_test( + name = "memory_optimizer_test", + srcs = ["memory_optimizer_test.cc"], + deps = [ + ":memory_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + cc_library( name = "layout_optimizer", srcs = ["layout_optimizer.cc"], diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc new file mode 100644 index 00000000000..24c6ab12efc --- /dev/null +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/memory_optimizer.h" +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/graph_rewriter.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { + +std::pair BuildSwapPair(NodeDef* node, int input_to_swap, + GraphDef* graph) { + string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap); + + // Force the tensor to be copied to cpu. + NodeDef* swap_out_node = graph->add_node(); + swap_out_node->set_name(strings::StrCat("swap_out_", tensor_to_swap)); + swap_out_node->set_op("Identity"); + swap_out_node->set_device("/CPU"); + + // Force the tensor to be restored to the device. + NodeDef* swap_in_node = graph->add_node(); + swap_in_node->set_name(strings::StrCat("swap_in_", tensor_to_swap)); + swap_in_node->set_op("Identity"); + *swap_in_node->add_input() = swap_out_node->name(); + + // Colocate the swap_in_ node with the node itself. + string coloc_group = strings::StrCat("loc@", tensor_to_swap); + (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); + (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); + + return std::make_pair(swap_out_node, swap_in_node); +} + +Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + + for (auto& node : *optimized_graph->mutable_node()) { + if (node.attr().count("swap_to_host") == 0) { + continue; + } + + // Swap all the tensors that are marked with the 'swap_to_host' attribute. + for (int input_id : node.attr().at("swap_to_host").list().i()) { + std::pair swap_nodes = + BuildSwapPair(&node, input_id, optimized_graph); + *swap_nodes.first->add_input() = node.input(input_id); + *node.mutable_input(input_id) = swap_nodes.second->name(); + + // TODO(bsteiner): Make sure the tensor isn't swapped back in right away + // by adding a control dependency to delay the execution of the swap. + // string trigger; + //*swap_nodes.second->add_input() = strings::StrCat("^", trigger); + } + } + + return Status::OK(); +} + +void MemoryOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) { + // Nothing to do for MemoryOptimizer. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.h b/tensorflow/core/grappler/optimizers/memory_optimizer.h new file mode 100644 index 00000000000..463067738b3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// Swap tensors in and out of device memory. +class MemoryOptimizer : public GraphOptimizer { + public: + MemoryOptimizer() {} + ~MemoryOptimizer() override {} + + string name() const override { return "memory_optimizer"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* pruned_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& pruned_graph, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc new file mode 100644 index 00000000000..9defa72cffb --- /dev/null +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/memory_optimizer.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class MemoryOptimizerTest : public ::testing::Test {}; + +TEST_F(MemoryOptimizerTest, SimpleSwapping) { + // Build a simple graph with an op that's marked for swapping. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); + Output b = ops::AddN(s.WithOpName("b"), {a}); + Output c = ops::AddN(s.WithOpName("c"), {b}); + Output d = ops::AddN(s.WithOpName("d"), {c}); + Output e = ops::AddN(s.WithOpName("e"), {b, d}); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + EXPECT_EQ(5, item.graph.node_size()); + EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name()); + AttrValue& val = + (*item.graph.mutable_node(4)->mutable_attr())["swap_to_host"]; + val.mutable_list()->add_i(0); + + MemoryOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(7, output.node_size()); + const NodeDef& new_e = output.node(4); + EXPECT_EQ(NodeName(e.name()), new_e.name()); + + EXPECT_EQ(2, new_e.input_size()); + EXPECT_EQ(NodeName(d.name()), new_e.input(1)); + EXPECT_EQ("swap_in_e_0", new_e.input(0)); + + const NodeDef& swap_out = output.node(5); + EXPECT_EQ("swap_out_e_0", swap_out.name()); + + const NodeDef& swap_in = output.node(6); + EXPECT_EQ("swap_in_e_0", swap_in.name()); + + EXPECT_EQ(NodeName(b.name()), swap_out.input(0)); + EXPECT_EQ(NodeName(swap_out.name()), swap_in.input(0)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow