Add heuristics to trigger swapping
PiperOrigin-RevId: 174376046
This commit is contained in:
parent
9dce7b9405
commit
ccd413a0d8
@ -66,6 +66,30 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_view",
|
||||
srcs = ["graph_view.cc"],
|
||||
hdrs = ["graph_view.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "graph_view_test",
|
||||
srcs = ["graph_view_test.cc"],
|
||||
deps = [
|
||||
":graph_view",
|
||||
":grappler_item",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grappler_item",
|
||||
srcs = [
|
||||
|
93
tensorflow/core/grappler/graph_view.cc
Normal file
93
tensorflow/core/grappler/graph_view.cc
Normal file
@ -0,0 +1,93 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
|
||||
for (int i = 0; i < graph_->node_size(); i++) {
|
||||
auto node = graph_->mutable_node(i);
|
||||
auto rslt = nodes_.insert(std::make_pair(node->name(), node));
|
||||
// Check that the graph doesn't contain multiple nodes with the same name.
|
||||
CHECK(rslt.second);
|
||||
}
|
||||
for (NodeDef& node : *graph_->mutable_node()) {
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
InputPort input;
|
||||
input.node = &node;
|
||||
input.port_id = i;
|
||||
|
||||
OutputPort fanin;
|
||||
string fanin_name = ParseNodeName(node.input(i), &fanin.port_id);
|
||||
fanin.node = nodes_[fanin_name];
|
||||
fanouts_[fanin].insert(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NodeDef* GraphView::GetNode(const string& node_name) const {
|
||||
auto it = nodes_.find(node_name);
|
||||
if (it == nodes_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
GraphView::InputPort GraphView::GetInputPort(const string& node_name,
|
||||
int port_id) const {
|
||||
InputPort result;
|
||||
result.node = GetNode(node_name);
|
||||
// TODO(bsteiner): verify that the node has at least port_id input ports
|
||||
result.port_id = port_id;
|
||||
return result;
|
||||
}
|
||||
|
||||
GraphView::OutputPort GraphView::GetOutputPort(const string& node_name,
|
||||
int port_id) const {
|
||||
OutputPort result;
|
||||
result.node = GetNode(node_name);
|
||||
// TODO(bsteiner): verify that the node has at least port_id output ports
|
||||
result.port_id = port_id;
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
|
||||
GraphView::GetFanout(const GraphView::OutputPort& port) const {
|
||||
auto it = fanouts_.find(port);
|
||||
if (it == fanouts_.end()) {
|
||||
return empty_set_;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
const GraphView::OutputPort GraphView::GetFanin(
|
||||
const GraphView::InputPort& port) const {
|
||||
OutputPort fanin;
|
||||
string fanin_name =
|
||||
ParseNodeName(port.node->input(port.port_id), &fanin.port_id);
|
||||
auto it = nodes_.find(fanin_name);
|
||||
if (it == nodes_.end()) {
|
||||
fanin.node = nullptr;
|
||||
} else {
|
||||
fanin.node = it->second;
|
||||
}
|
||||
return fanin;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
69
tensorflow/core/grappler/graph_view.h
Normal file
69
tensorflow/core/grappler/graph_view.h
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_GRAPPLER_GRAPH_VIEW_H_
|
||||
#define TENSORFLOW_GRAPPLER_GRAPH_VIEW_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// A utility class to simplify the traversal of a GraphDef.
|
||||
class GraphView {
|
||||
public:
|
||||
struct Port {
|
||||
NodeDef* node;
|
||||
int port_id;
|
||||
|
||||
bool operator==(const Port& other) const {
|
||||
return node == other.node && port_id == other.port_id;
|
||||
}
|
||||
};
|
||||
struct InputPort : public Port {};
|
||||
struct OutputPort : public Port {};
|
||||
|
||||
struct HashPort {
|
||||
std::size_t operator()(const Port& port) const {
|
||||
return reinterpret_cast<std::size_t>(port.node) + port.port_id;
|
||||
}
|
||||
};
|
||||
|
||||
explicit GraphView(GraphDef* graph);
|
||||
NodeDef* GetNode(const string& node_name) const;
|
||||
InputPort GetInputPort(const string& node_name, int port_id) const;
|
||||
OutputPort GetOutputPort(const string& node_name, int port_id) const;
|
||||
|
||||
const std::unordered_set<InputPort, HashPort>& GetFanout(
|
||||
const OutputPort& port) const;
|
||||
const OutputPort GetFanin(const InputPort& port) const;
|
||||
|
||||
private:
|
||||
GraphDef* graph_;
|
||||
std::unordered_map<string, NodeDef*> nodes_;
|
||||
std::unordered_set<InputPort, HashPort> empty_set_;
|
||||
std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>,
|
||||
HashPort>
|
||||
fanouts_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_GRAPPLER_GRAPH_VIEW_H_
|
66
tensorflow/core/grappler/graph_view_test.cc
Normal file
66
tensorflow/core/grappler/graph_view_test.cc
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class GraphViewTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(GraphViewTest, BasicGraph) {
|
||||
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
|
||||
GrapplerItem item;
|
||||
CHECK(fake_input.NextItem(&item));
|
||||
|
||||
std::cout << item.graph.DebugString() << std::endl;
|
||||
|
||||
GraphView graph(&item.graph);
|
||||
|
||||
GraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
||||
EXPECT_EQ("AddN", input.node->name());
|
||||
EXPECT_EQ(0, input.port_id);
|
||||
GraphView::OutputPort fanin = graph.GetFanin(input);
|
||||
EXPECT_EQ("Square", fanin.node->name());
|
||||
EXPECT_EQ(0, fanin.port_id);
|
||||
|
||||
input = graph.GetInputPort("AddN", 1);
|
||||
EXPECT_EQ("AddN", input.node->name());
|
||||
EXPECT_EQ(1, input.port_id);
|
||||
fanin = graph.GetFanin(input);
|
||||
EXPECT_EQ("Square_1", fanin.node->name());
|
||||
EXPECT_EQ(0, fanin.port_id);
|
||||
|
||||
GraphView::OutputPort output = graph.GetOutputPort("AddN", 0);
|
||||
EXPECT_EQ("AddN", output.node->name());
|
||||
EXPECT_EQ(0, output.port_id);
|
||||
EXPECT_EQ(2, graph.GetFanout(output).size());
|
||||
for (auto fanout : graph.GetFanout(output)) {
|
||||
if (fanout.node->name() == "AddN_2" || fanout.node->name() == "AddN_3") {
|
||||
EXPECT_EQ(0, fanout.port_id);
|
||||
} else {
|
||||
// Invalid fanout
|
||||
EXPECT_FALSE(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -235,9 +235,11 @@ cc_library(
|
||||
":graph_rewriter",
|
||||
":static_schedule",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_memory",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
],
|
||||
|
@ -24,7 +24,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_memory.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
|
||||
@ -430,14 +432,16 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
||||
[&recomputation_targets_name_prefix](const NodeDef& node) {
|
||||
// Nodes whose inputs we may want to recompute. Typically targets will
|
||||
// be gradients (recomputation_targets_name_prefix="gradients/"),
|
||||
// although the prefix is configurable since gradients may be created in
|
||||
// a name scope.
|
||||
// although the prefix is configurable since gradients may be created
|
||||
// in a name scope.
|
||||
// TODO(allenl): Use a static schedule
|
||||
// (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
|
||||
// whose outputs will sit around for a while.
|
||||
return node.name().find(recomputation_targets_name_prefix) == 0;
|
||||
};
|
||||
if (optimization_level == RewriterConfig::HEURISTICS) {
|
||||
|
||||
if (optimization_level == RewriterConfig::RECOMPUTATION_HEURISTICS ||
|
||||
optimization_level == RewriterConfig::HEURISTICS) {
|
||||
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
|
||||
// the cheap forward ops get grouped into a single subgraph which must
|
||||
// execute before gradients start executing (unless layers are manually
|
||||
@ -601,6 +605,81 @@ static const NodeDef* FindSwapTrigger(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void IdentifySwappingCandidates(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
GraphMemory memory(item);
|
||||
const std::unordered_map<string, DeviceProperties>& devices =
|
||||
cluster->GetDevices();
|
||||
if (!memory.InferStatically(devices).ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto& device : devices) {
|
||||
const string& name = device.first;
|
||||
const DeviceProperties& prop = device.second;
|
||||
if (prop.type() != "GPU") {
|
||||
continue;
|
||||
}
|
||||
if (prop.memory_size() <= 0) {
|
||||
continue;
|
||||
}
|
||||
const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
|
||||
if (mem_usage.used_memory <= prop.memory_size()) {
|
||||
continue;
|
||||
}
|
||||
int64 required_savings = mem_usage.used_memory - prop.memory_size();
|
||||
// TODO(bsteiner): sort the tensors by how long they're live.
|
||||
|
||||
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
|
||||
if (!EstimateEarliestExecutionTimes(item, cluster, &execution_times).ok()) {
|
||||
return;
|
||||
}
|
||||
GraphView graph(optimized_graph);
|
||||
for (const auto& live_tensor : mem_usage.live_tensors) {
|
||||
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
|
||||
Costs::Duration(1e6)) {
|
||||
// Not enough time to swap.
|
||||
continue;
|
||||
}
|
||||
if (live_tensor.memory_used <= 1024) {
|
||||
// Don't bother with small tensors.
|
||||
continue;
|
||||
}
|
||||
Costs::NanoSeconds execution_time(-1);
|
||||
GraphView::InputPort fanout_to_swap;
|
||||
GraphView::OutputPort port =
|
||||
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
|
||||
for (GraphView::InputPort input : graph.GetFanout(port)) {
|
||||
auto it = execution_times.find(input.node);
|
||||
if (it != execution_times.end()) {
|
||||
if (it->second > execution_time) {
|
||||
fanout_to_swap = input;
|
||||
execution_time = it->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Annotate the fanout to request the tensor to be swapped if it's not
|
||||
// already been done.
|
||||
AttrValue& val = (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
|
||||
bool found = false;
|
||||
for (int port_id : val.list().i()) {
|
||||
if (port_id == fanout_to_swap.port_id) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
val.mutable_list()->add_i(fanout_to_swap.port_id);
|
||||
required_savings -= live_tensor.memory_used;
|
||||
if (required_savings < 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
*optimized_graph = item.graph;
|
||||
@ -609,6 +688,10 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
recomputation_targets_name_prefix_,
|
||||
optimized_graph, item);
|
||||
|
||||
if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS) {
|
||||
IdentifySwappingCandidates(cluster, item, optimized_graph);
|
||||
}
|
||||
|
||||
// Figure out what needs to be swapped;
|
||||
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
|
||||
for (auto& node : *optimized_graph->mutable_node()) {
|
||||
|
@ -153,7 +153,7 @@ TEST_F(RecomputeSubgraphTest, MultiNode) {
|
||||
pre_transform_node_map.GetNode("BN")->set_op("FusedBatchNorm");
|
||||
pre_transform_node_map.GetNode("ReLU")->set_op("Relu");
|
||||
|
||||
MemoryOptimizer optimizer(RewriterConfig::HEURISTICS);
|
||||
MemoryOptimizer optimizer(RewriterConfig::RECOMPUTATION_HEURISTICS);
|
||||
GraphDef first_pass_output;
|
||||
Status first_pass_status =
|
||||
optimizer.Optimize(nullptr, item, &first_pass_output);
|
||||
|
@ -46,9 +46,12 @@ message RewriterConfig {
|
||||
// Driven by manual op-level annotations.
|
||||
MANUAL = 2;
|
||||
// Driven by heuristics. The behavior of these heuristics is subject to
|
||||
// change. Currently includes an experimental recomputation
|
||||
// heuristic. Manual annotations are respected, but additional nodes are
|
||||
// change. Currently includes an experimental recomputation and swapping
|
||||
// heuristics. Manual annotations are respected, but additional nodes are
|
||||
// selected automatically.
|
||||
SWAPPING_HEURISTICS = 4;
|
||||
RECOMPUTATION_HEURISTICS = 5;
|
||||
// Use any combination of swapping and recomputation heuristics.
|
||||
HEURISTICS = 3;
|
||||
}
|
||||
// Configures memory optimization passes through the meta-optimizer. Has no
|
||||
|
@ -129,8 +129,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
|
||||
disable_model_pruning=True,
|
||||
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS),
|
||||
original_metagraph)
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.
|
||||
RECOMPUTATION_HEURISTICS), original_metagraph)
|
||||
self.assertGreater(
|
||||
len(rewritten_graph_def.node),
|
||||
len(original_metagraph.graph_def.node))
|
||||
@ -152,7 +152,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
|
||||
disable_model_pruning=True,
|
||||
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS,
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.
|
||||
RECOMPUTATION_HEURISTICS,
|
||||
memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
|
||||
original_metagraph)
|
||||
self.assertGreater(
|
||||
|
Loading…
Reference in New Issue
Block a user