Added a memory optimizer to grappler.
Change: 152184170
This commit is contained in:
parent
7284cd8615
commit
98f6dbe9fa
@ -117,6 +117,37 @@ cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "memory_optimizer",
|
||||||
|
srcs = ["memory_optimizer.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"memory_optimizer.h",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":graph_optimizer",
|
||||||
|
":graph_rewriter",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "memory_optimizer_test",
|
||||||
|
srcs = ["memory_optimizer_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":memory_optimizer",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:utils",
|
||||||
|
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "layout_optimizer",
|
name = "layout_optimizer",
|
||||||
srcs = ["layout_optimizer.cc"],
|
srcs = ["layout_optimizer.cc"],
|
||||||
|
83
tensorflow/core/grappler/optimizers/memory_optimizer.cc
Normal file
83
tensorflow/core/grappler/optimizers/memory_optimizer.cc
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||||
|
#include <unordered_set>
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
|
||||||
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
|
||||||
|
GraphDef* graph) {
|
||||||
|
string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
|
||||||
|
|
||||||
|
// Force the tensor to be copied to cpu.
|
||||||
|
NodeDef* swap_out_node = graph->add_node();
|
||||||
|
swap_out_node->set_name(strings::StrCat("swap_out_", tensor_to_swap));
|
||||||
|
swap_out_node->set_op("Identity");
|
||||||
|
swap_out_node->set_device("/CPU");
|
||||||
|
|
||||||
|
// Force the tensor to be restored to the device.
|
||||||
|
NodeDef* swap_in_node = graph->add_node();
|
||||||
|
swap_in_node->set_name(strings::StrCat("swap_in_", tensor_to_swap));
|
||||||
|
swap_in_node->set_op("Identity");
|
||||||
|
*swap_in_node->add_input() = swap_out_node->name();
|
||||||
|
|
||||||
|
// Colocate the swap_in_ node with the node itself.
|
||||||
|
string coloc_group = strings::StrCat("loc@", tensor_to_swap);
|
||||||
|
(*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
|
||||||
|
(*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
|
||||||
|
|
||||||
|
return std::make_pair(swap_out_node, swap_in_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
GraphDef* optimized_graph) {
|
||||||
|
*optimized_graph = item.graph;
|
||||||
|
|
||||||
|
for (auto& node : *optimized_graph->mutable_node()) {
|
||||||
|
if (node.attr().count("swap_to_host") == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap all the tensors that are marked with the 'swap_to_host' attribute.
|
||||||
|
for (int input_id : node.attr().at("swap_to_host").list().i()) {
|
||||||
|
std::pair<NodeDef*, NodeDef*> swap_nodes =
|
||||||
|
BuildSwapPair(&node, input_id, optimized_graph);
|
||||||
|
*swap_nodes.first->add_input() = node.input(input_id);
|
||||||
|
*node.mutable_input(input_id) = swap_nodes.second->name();
|
||||||
|
|
||||||
|
// TODO(bsteiner): Make sure the tensor isn't swapped back in right away
|
||||||
|
// by adding a control dependency to delay the execution of the swap.
|
||||||
|
// string trigger;
|
||||||
|
//*swap_nodes.second->add_input() = strings::StrCat("^", trigger);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemoryOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
const GraphDef& optimized_graph, double result) {
|
||||||
|
// Nothing to do for MemoryOptimizer.
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
42
tensorflow/core/grappler/optimizers/memory_optimizer.h
Normal file
42
tensorflow/core/grappler/optimizers/memory_optimizer.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
||||||
|
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
// Swap tensors in and out of device memory.
|
||||||
|
class MemoryOptimizer : public GraphOptimizer {
|
||||||
|
public:
|
||||||
|
MemoryOptimizer() {}
|
||||||
|
~MemoryOptimizer() override {}
|
||||||
|
|
||||||
|
string name() const override { return "memory_optimizer"; };
|
||||||
|
|
||||||
|
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
GraphDef* pruned_graph) override;
|
||||||
|
|
||||||
|
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
const GraphDef& pruned_graph, double result) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
74
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
Normal file
74
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class MemoryOptimizerTest : public ::testing::Test {};
|
||||||
|
|
||||||
|
TEST_F(MemoryOptimizerTest, SimpleSwapping) {
|
||||||
|
// Build a simple graph with an op that's marked for swapping.
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||||
|
Output b = ops::AddN(s.WithOpName("b"), {a});
|
||||||
|
Output c = ops::AddN(s.WithOpName("c"), {b});
|
||||||
|
Output d = ops::AddN(s.WithOpName("d"), {c});
|
||||||
|
Output e = ops::AddN(s.WithOpName("e"), {b, d});
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
EXPECT_EQ(5, item.graph.node_size());
|
||||||
|
EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name());
|
||||||
|
AttrValue& val =
|
||||||
|
(*item.graph.mutable_node(4)->mutable_attr())["swap_to_host"];
|
||||||
|
val.mutable_list()->add_i(0);
|
||||||
|
|
||||||
|
MemoryOptimizer optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
|
||||||
|
EXPECT_EQ(7, output.node_size());
|
||||||
|
const NodeDef& new_e = output.node(4);
|
||||||
|
EXPECT_EQ(NodeName(e.name()), new_e.name());
|
||||||
|
|
||||||
|
EXPECT_EQ(2, new_e.input_size());
|
||||||
|
EXPECT_EQ(NodeName(d.name()), new_e.input(1));
|
||||||
|
EXPECT_EQ("swap_in_e_0", new_e.input(0));
|
||||||
|
|
||||||
|
const NodeDef& swap_out = output.node(5);
|
||||||
|
EXPECT_EQ("swap_out_e_0", swap_out.name());
|
||||||
|
|
||||||
|
const NodeDef& swap_in = output.node(6);
|
||||||
|
EXPECT_EQ("swap_in_e_0", swap_in.name());
|
||||||
|
|
||||||
|
EXPECT_EQ(NodeName(b.name()), swap_out.input(0));
|
||||||
|
EXPECT_EQ(NodeName(swap_out.name()), swap_in.input(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace grappler
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user