Add heuristics to trigger swapping

PiperOrigin-RevId: 174376046
This commit is contained in:
Benoit Steiner 2017-11-02 13:56:49 -07:00 committed by TensorFlower Gardener
parent 9dce7b9405
commit ccd413a0d8
9 changed files with 350 additions and 9 deletions

View File

@ -66,6 +66,30 @@ tf_cuda_library(
],
)
cc_library(
name = "graph_view",
srcs = ["graph_view.cc"],
hdrs = ["graph_view.h"],
visibility = ["//visibility:public"],
deps = [
":utils",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "graph_view_test",
srcs = ["graph_view_test.cc"],
deps = [
":graph_view",
":grappler_item",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
],
)
cc_library(
name = "grappler_item",
srcs = [

View File

@ -0,0 +1,93 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);
auto rslt = nodes_.insert(std::make_pair(node->name(), node));
// Check that the graph doesn't contain multiple nodes with the same name.
CHECK(rslt.second);
}
for (NodeDef& node : *graph_->mutable_node()) {
for (int i = 0; i < node.input_size(); ++i) {
InputPort input;
input.node = &node;
input.port_id = i;
OutputPort fanin;
string fanin_name = ParseNodeName(node.input(i), &fanin.port_id);
fanin.node = nodes_[fanin_name];
fanouts_[fanin].insert(input);
}
}
}
NodeDef* GraphView::GetNode(const string& node_name) const {
auto it = nodes_.find(node_name);
if (it == nodes_.end()) {
return nullptr;
}
return it->second;
}
GraphView::InputPort GraphView::GetInputPort(const string& node_name,
int port_id) const {
InputPort result;
result.node = GetNode(node_name);
// TODO(bsteiner): verify that the node has at least port_id input ports
result.port_id = port_id;
return result;
}
GraphView::OutputPort GraphView::GetOutputPort(const string& node_name,
int port_id) const {
OutputPort result;
result.node = GetNode(node_name);
// TODO(bsteiner): verify that the node has at least port_id output ports
result.port_id = port_id;
return result;
}
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
GraphView::GetFanout(const GraphView::OutputPort& port) const {
auto it = fanouts_.find(port);
if (it == fanouts_.end()) {
return empty_set_;
}
return it->second;
}
const GraphView::OutputPort GraphView::GetFanin(
const GraphView::InputPort& port) const {
OutputPort fanin;
string fanin_name =
ParseNodeName(port.node->input(port.port_id), &fanin.port_id);
auto it = nodes_.find(fanin_name);
if (it == nodes_.end()) {
fanin.node = nullptr;
} else {
fanin.node = it->second;
}
return fanin;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,69 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_GRAPPLER_GRAPH_VIEW_H_
#define TENSORFLOW_GRAPPLER_GRAPH_VIEW_H_
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
public:
struct Port {
NodeDef* node;
int port_id;
bool operator==(const Port& other) const {
return node == other.node && port_id == other.port_id;
}
};
struct InputPort : public Port {};
struct OutputPort : public Port {};
struct HashPort {
std::size_t operator()(const Port& port) const {
return reinterpret_cast<std::size_t>(port.node) + port.port_id;
}
};
explicit GraphView(GraphDef* graph);
NodeDef* GetNode(const string& node_name) const;
InputPort GetInputPort(const string& node_name, int port_id) const;
OutputPort GetOutputPort(const string& node_name, int port_id) const;
const std::unordered_set<InputPort, HashPort>& GetFanout(
const OutputPort& port) const;
const OutputPort GetFanin(const InputPort& port) const;
private:
GraphDef* graph_;
std::unordered_map<string, NodeDef*> nodes_;
std::unordered_set<InputPort, HashPort> empty_set_;
std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>,
HashPort>
fanouts_;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_GRAPPLER_GRAPH_VIEW_H_

View File

@ -0,0 +1,66 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
class GraphViewTest : public ::testing::Test {};
TEST_F(GraphViewTest, BasicGraph) {
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
std::cout << item.graph.DebugString() << std::endl;
GraphView graph(&item.graph);
GraphView::InputPort input = graph.GetInputPort("AddN", 0);
EXPECT_EQ("AddN", input.node->name());
EXPECT_EQ(0, input.port_id);
GraphView::OutputPort fanin = graph.GetFanin(input);
EXPECT_EQ("Square", fanin.node->name());
EXPECT_EQ(0, fanin.port_id);
input = graph.GetInputPort("AddN", 1);
EXPECT_EQ("AddN", input.node->name());
EXPECT_EQ(1, input.port_id);
fanin = graph.GetFanin(input);
EXPECT_EQ("Square_1", fanin.node->name());
EXPECT_EQ(0, fanin.port_id);
GraphView::OutputPort output = graph.GetOutputPort("AddN", 0);
EXPECT_EQ("AddN", output.node->name());
EXPECT_EQ(0, output.port_id);
EXPECT_EQ(2, graph.GetFanout(output).size());
for (auto fanout : graph.GetFanout(output)) {
if (fanout.node->name() == "AddN_2" || fanout.node->name() == "AddN_3") {
EXPECT_EQ(0, fanout.port_id);
} else {
// Invalid fanout
EXPECT_FALSE(true);
}
}
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -235,9 +235,11 @@ cc_library(
":graph_rewriter",
":static_schedule",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_memory",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:topological_sort",
],

View File

@ -24,7 +24,9 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
@ -430,14 +432,16 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
[&recomputation_targets_name_prefix](const NodeDef& node) {
// Nodes whose inputs we may want to recompute. Typically targets will
// be gradients (recomputation_targets_name_prefix="gradients/"),
// although the prefix is configurable since gradients may be created in
// a name scope.
// although the prefix is configurable since gradients may be created
// in a name scope.
// TODO(allenl): Use a static schedule
// (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
// whose outputs will sit around for a while.
return node.name().find(recomputation_targets_name_prefix) == 0;
};
if (optimization_level == RewriterConfig::HEURISTICS) {
if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
optimization_level == RewriterConfig::HEURISTICS) {
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
// the cheap forward ops get grouped into a single subgraph which must
// execute before gradients start executing (unless layers are manually
@ -601,6 +605,81 @@ static const NodeDef* FindSwapTrigger(
return nullptr;
}
static void IdentifySwappingCandidates(Cluster* cluster,
const GrapplerItem& item,
GraphDef* optimized_graph) {
GraphMemory memory(item);
const std::unordered_map<string, DeviceProperties>& devices =
cluster->GetDevices();
if (!memory.InferStatically(devices).ok()) {
return;
}
for (const auto& device : devices) {
const string& name = device.first;
const DeviceProperties& prop = device.second;
if (prop.type() != "GPU") {
continue;
}
if (prop.memory_size() <= 0) {
continue;
}
const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
if (mem_usage.used_memory <= prop.memory_size()) {
continue;
}
int64 required_savings = mem_usage.used_memory - prop.memory_size();
// TODO(bsteiner): sort the tensors by how long they're live.
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
if (!EstimateEarliestExecutionTimes(item, cluster, &execution_times).ok()) {
return;
}
GraphView graph(optimized_graph);
for (const auto& live_tensor : mem_usage.live_tensors) {
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
Costs::Duration(1e6)) {
// Not enough time to swap.
continue;
}
if (live_tensor.memory_used <= 1024) {
// Don't bother with small tensors.
continue;
}
Costs::NanoSeconds execution_time(-1);
GraphView::InputPort fanout_to_swap;
GraphView::OutputPort port =
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
for (GraphView::InputPort input : graph.GetFanout(port)) {
auto it = execution_times.find(input.node);
if (it != execution_times.end()) {
if (it->second > execution_time) {
fanout_to_swap = input;
execution_time = it->second;
}
}
}
// Annotate the fanout to request the tensor to be swapped if it's not
// already been done.
AttrValue& val = (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
bool found = false;
for (int port_id : val.list().i()) {
if (port_id == fanout_to_swap.port_id) {
found = true;
break;
}
}
if (!found) {
val.mutable_list()->add_i(fanout_to_swap.port_id);
required_savings -= live_tensor.memory_used;
if (required_savings < 0) {
break;
}
}
}
}
}
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;
@ -609,6 +688,10 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
recomputation_targets_name_prefix_,
optimized_graph, item);
if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS) {
IdentifySwappingCandidates(cluster, item, optimized_graph);
}
// Figure out what needs to be swapped;
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
for (auto& node : *optimized_graph->mutable_node()) {

View File

@ -153,7 +153,7 @@ TEST_F(RecomputeSubgraphTest, MultiNode) {
pre_transform_node_map.GetNode("BN")->set_op("FusedBatchNorm");
pre_transform_node_map.GetNode("ReLU")->set_op("Relu");
MemoryOptimizer optimizer(RewriterConfig::HEURISTICS);
MemoryOptimizer optimizer(RewriterConfig::RECOMPUTATION_HEURISTICS);
GraphDef first_pass_output;
Status first_pass_status =
optimizer.Optimize(nullptr, item, &first_pass_output);

View File

@ -46,9 +46,12 @@ message RewriterConfig {
// Driven by manual op-level annotations.
MANUAL = 2;
// Driven by heuristics. The behavior of these heuristics is subject to
// change. Currently includes an experimental recomputation
// heuristic. Manual annotations are respected, but additional nodes are
// change. Currently includes an experimental recomputation and swapping
// heuristics. Manual annotations are respected, but additional nodes are
// selected automatically.
SWAPPING_HEURISTICS = 4;
RECOMPUTATION_HEURISTICS = 5;
// Use any combination of swapping and recomputation heuristics.
HEURISTICS = 3;
}
// Configures memory optimization passes through the meta-optimizer. Has no

View File

@ -129,8 +129,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS),
original_metagraph)
memory_optimization=rewriter_config_pb2.RewriterConfig.
RECOMPUTATION_HEURISTICS), original_metagraph)
self.assertGreater(
len(rewritten_graph_def.node),
len(original_metagraph.graph_def.node))
@ -152,7 +152,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS,
memory_optimization=rewriter_config_pb2.RewriterConfig.
RECOMPUTATION_HEURISTICS,
memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
original_metagraph)
self.assertGreater(