Implemented memory swapping heuristics for GPU

PiperOrigin-RevId: 180968225
This commit is contained in:
Benoit Steiner 2018-01-05 13:35:03 -08:00 committed by TensorFlower Gardener
parent 3a3feb207d
commit ca6f0dd19b
4 changed files with 99 additions and 16 deletions

View File

@ -32,7 +32,17 @@ Status GraphMemory::InferStatically(
const std::unordered_map<string, DeviceProperties>& devices) {
VirtualCluster cluster(devices);
TF_RETURN_IF_ERROR(cluster.Provision());
return InferDynamically(&cluster);
TF_RETURN_IF_ERROR(cluster.Initialize(item_));
RunMetadata metadata;
Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata);
// The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
// that the model would run out of memory. We still get the metadata we need
// out of the simulation, so we just ignore this error.
if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
return s;
}
InferFromTrace(metadata.step_stats());
return Status::OK();
}
Status GraphMemory::InferDynamically(Cluster* cluster) {

View File

@ -29,8 +29,8 @@ namespace grappler {
class GraphView {
public:
struct Port {
NodeDef* node;
int port_id;
NodeDef* node = nullptr;
int port_id = -1;
bool operator==(const Port& other) const {
return node == other.node && port_id == other.port_id;

View File

@ -568,9 +568,12 @@ static const NodeDef* FindSwapTrigger(
max_trigger_time -= swap_info.time_to_swap;
std::map<Costs::NanoSeconds, const NodeDef*> candidates;
std::set<string> already_processed;
while (!possible_inputs.empty()) {
const string input_node_name = *possible_inputs.begin();
possible_inputs.erase(possible_inputs.begin());
already_processed.insert(input_node_name);
auto it1 = name_map.find(input_node_name);
if (it1 == name_map.end()) {
return nullptr;
@ -579,7 +582,7 @@ static const NodeDef* FindSwapTrigger(
// Don't jump over frames, since adding a control dependency from one frame
// to the next isn't supported. Don't go through branches, since we don't
// know whether they'll be executed or not.
if (IsNextIteration(*input_node) || IsSwitch(*input_node) ||
if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
IsMerge(*input_node)) {
continue;
}
@ -591,7 +594,10 @@ static const NodeDef* FindSwapTrigger(
candidates[it2->second] = input_node;
} else {
for (const string& fanin : input_node->input()) {
possible_inputs.insert(NodeName(fanin));
string name = NodeName(fanin);
if (already_processed.find(name) == already_processed.end()) {
possible_inputs.insert(name);
}
}
}
}
@ -611,7 +617,9 @@ static void IdentifySwappingCandidates(Cluster* cluster,
GraphMemory memory(item);
const std::unordered_map<string, DeviceProperties>& devices =
cluster->GetDevices();
if (!memory.InferStatically(devices).ok()) {
Status s = memory.InferStatically(devices);
if (!s.ok()) {
VLOG(1) << "Failed to infer memory usage: " << s.error_message();
return;
}
@ -622,24 +630,36 @@ static void IdentifySwappingCandidates(Cluster* cluster,
continue;
}
if (prop.memory_size() <= 0) {
VLOG(1) << "Peak memory usage unknown for device " << name;
continue;
}
const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
if (mem_usage.used_memory <= prop.memory_size()) {
continue;
}
int64 required_savings = mem_usage.used_memory - prop.memory_size();
// TODO(bsteiner): sort the tensors by how long they're live.
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
if (!EstimateEarliestExecutionTimes(item, cluster, &execution_times).ok()) {
return;
std::unordered_map<string, Costs::NanoSeconds> execution_times;
{
std::unordered_map<const NodeDef*, Costs::NanoSeconds>
tmp_execution_times;
if (!EstimateEarliestExecutionTimes(item, cluster, &tmp_execution_times)
.ok()) {
return;
}
for (const auto& exec_time : tmp_execution_times) {
execution_times.emplace(exec_time.first->name(), exec_time.second);
}
}
GraphView graph(optimized_graph);
for (const auto& live_tensor : mem_usage.live_tensors) {
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
Costs::Duration(1e6)) {
// Not enough time to swap.
VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
continue;
}
if (live_tensor.memory_used <= 1024) {
@ -651,7 +671,7 @@ static void IdentifySwappingCandidates(Cluster* cluster,
GraphView::OutputPort port =
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
for (GraphView::InputPort input : graph.GetFanout(port)) {
auto it = execution_times.find(input.node);
auto it = execution_times.find(input.node->name());
if (it != execution_times.end()) {
if (it->second > execution_time) {
fanout_to_swap = input;
@ -661,15 +681,23 @@ static void IdentifySwappingCandidates(Cluster* cluster,
}
// Annotate the fanout to request the tensor to be swapped if it's not
// already been done.
AttrValue& val = (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
bool found = false;
for (int port_id : val.list().i()) {
if (port_id == fanout_to_swap.port_id) {
found = true;
break;
if (!fanout_to_swap.node) {
continue;
}
auto it = fanout_to_swap.node->attr().find("_swap_to_host");
if (it != fanout_to_swap.node->attr().end()) {
const AttrValue& val = it->second;
for (int port_id : val.list().i()) {
if (port_id == fanout_to_swap.port_id) {
found = true;
break;
}
}
}
if (!found) {
AttrValue& val =
(*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
val.mutable_list()->add_i(fanout_to_swap.port_id);
required_savings -= live_tensor.memory_used;
if (required_savings < 0) {
@ -688,7 +716,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
recomputation_targets_name_prefix_,
optimized_graph, item);
if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS) {
if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS &&
cluster != nullptr) {
IdentifySwappingCandidates(cluster, item, optimized_graph);
}

View File

@ -201,8 +201,16 @@ class MemoryOptimizerTest : public ::testing::Test {
cpu_device.set_frequency(1000);
cpu_device.set_num_cores(4);
cpu_device.set_bandwidth(32);
DeviceProperties gpu_device;
gpu_device.set_type("GPU");
gpu_device.set_frequency(1000);
gpu_device.set_num_cores(24);
gpu_device.set_bandwidth(128);
gpu_device.set_memory_size(1024 * 1024);
gpu_device.mutable_environment()->insert({"architecture", "6"});
std::unordered_map<string, DeviceProperties> devices;
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices));
}
};
@ -252,6 +260,42 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
EXPECT_EQ("^c", swap_in.input(1));
}
TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
{128, 128, 8}, DT_FLOAT);
Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a});
Output axis = ops::Const(s.WithOpName("axis"), 0);
Output e =
ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"e"};
std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
GraphDef output;
Status status = optimizer.Optimize(cluster.get(), item, &output);
TF_EXPECT_OK(status);
for (const auto& node : output.node()) {
if (node.name() == "e") {
EXPECT_TRUE(node.attr().count("_swap_to_host") > 0);
const AttrValue& val = node.attr().at("_swap_to_host");
EXPECT_TRUE(val.has_list());
std::set<int> inputs_to_swap;
for (int64 input_id : val.list().i()) {
inputs_to_swap.insert(input_id);
}
EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap);
}
}
}
} // namespace
} // namespace grappler
} // namespace tensorflow