Added a memory optimizer to grappler.
Change: 152184170
This commit is contained in:
parent
7284cd8615
commit
98f6dbe9fa
@ -117,6 +117,37 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "memory_optimizer",
|
||||
srcs = ["memory_optimizer.cc"],
|
||||
hdrs = [
|
||||
"memory_optimizer.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_optimizer",
|
||||
":graph_rewriter",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "memory_optimizer_test",
|
||||
srcs = ["memory_optimizer_test.cc"],
|
||||
deps = [
|
||||
":memory_optimizer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "layout_optimizer",
|
||||
srcs = ["layout_optimizer.cc"],
|
||||
|
83
tensorflow/core/grappler/optimizers/memory_optimizer.cc
Normal file
83
tensorflow/core/grappler/optimizers/memory_optimizer.cc
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
|
||||
GraphDef* graph) {
|
||||
string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
|
||||
|
||||
// Force the tensor to be copied to cpu.
|
||||
NodeDef* swap_out_node = graph->add_node();
|
||||
swap_out_node->set_name(strings::StrCat("swap_out_", tensor_to_swap));
|
||||
swap_out_node->set_op("Identity");
|
||||
swap_out_node->set_device("/CPU");
|
||||
|
||||
// Force the tensor to be restored to the device.
|
||||
NodeDef* swap_in_node = graph->add_node();
|
||||
swap_in_node->set_name(strings::StrCat("swap_in_", tensor_to_swap));
|
||||
swap_in_node->set_op("Identity");
|
||||
*swap_in_node->add_input() = swap_out_node->name();
|
||||
|
||||
// Colocate the swap_in_ node with the node itself.
|
||||
string coloc_group = strings::StrCat("loc@", tensor_to_swap);
|
||||
(*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
|
||||
(*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
|
||||
|
||||
return std::make_pair(swap_out_node, swap_in_node);
|
||||
}
|
||||
|
||||
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
*optimized_graph = item.graph;
|
||||
|
||||
for (auto& node : *optimized_graph->mutable_node()) {
|
||||
if (node.attr().count("swap_to_host") == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Swap all the tensors that are marked with the 'swap_to_host' attribute.
|
||||
for (int input_id : node.attr().at("swap_to_host").list().i()) {
|
||||
std::pair<NodeDef*, NodeDef*> swap_nodes =
|
||||
BuildSwapPair(&node, input_id, optimized_graph);
|
||||
*swap_nodes.first->add_input() = node.input(input_id);
|
||||
*node.mutable_input(input_id) = swap_nodes.second->name();
|
||||
|
||||
// TODO(bsteiner): Make sure the tensor isn't swapped back in right away
|
||||
// by adding a control dependency to delay the execution of the swap.
|
||||
// string trigger;
|
||||
//*swap_nodes.second->add_input() = strings::StrCat("^", trigger);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MemoryOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimized_graph, double result) {
|
||||
// Nothing to do for MemoryOptimizer.
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
42
tensorflow/core/grappler/optimizers/memory_optimizer.h
Normal file
42
tensorflow/core/grappler/optimizers/memory_optimizer.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
||||
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Swap tensors in and out of device memory.
|
||||
class MemoryOptimizer : public GraphOptimizer {
|
||||
public:
|
||||
MemoryOptimizer() {}
|
||||
~MemoryOptimizer() override {}
|
||||
|
||||
string name() const override { return "memory_optimizer"; };
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* pruned_graph) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& pruned_graph, double result) override;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
74
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
Normal file
74
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class MemoryOptimizerTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(MemoryOptimizerTest, SimpleSwapping) {
|
||||
// Build a simple graph with an op that's marked for swapping.
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::AddN(s.WithOpName("b"), {a});
|
||||
Output c = ops::AddN(s.WithOpName("c"), {b});
|
||||
Output d = ops::AddN(s.WithOpName("d"), {c});
|
||||
Output e = ops::AddN(s.WithOpName("e"), {b, d});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
EXPECT_EQ(5, item.graph.node_size());
|
||||
EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name());
|
||||
AttrValue& val =
|
||||
(*item.graph.mutable_node(4)->mutable_attr())["swap_to_host"];
|
||||
val.mutable_list()->add_i(0);
|
||||
|
||||
MemoryOptimizer optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(7, output.node_size());
|
||||
const NodeDef& new_e = output.node(4);
|
||||
EXPECT_EQ(NodeName(e.name()), new_e.name());
|
||||
|
||||
EXPECT_EQ(2, new_e.input_size());
|
||||
EXPECT_EQ(NodeName(d.name()), new_e.input(1));
|
||||
EXPECT_EQ("swap_in_e_0", new_e.input(0));
|
||||
|
||||
const NodeDef& swap_out = output.node(5);
|
||||
EXPECT_EQ("swap_out_e_0", swap_out.name());
|
||||
|
||||
const NodeDef& swap_in = output.node(6);
|
||||
EXPECT_EQ("swap_in_e_0", swap_in.name());
|
||||
|
||||
EXPECT_EQ(NodeName(b.name()), swap_out.input(0));
|
||||
EXPECT_EQ(NodeName(swap_out.name()), swap_in.input(0));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user