Implemented memory swapping heuristics for GPU
PiperOrigin-RevId: 180968225
This commit is contained in:
parent
3a3feb207d
commit
ca6f0dd19b
tensorflow/core/grappler
@ -32,7 +32,17 @@ Status GraphMemory::InferStatically(
|
||||
const std::unordered_map<string, DeviceProperties>& devices) {
|
||||
VirtualCluster cluster(devices);
|
||||
TF_RETURN_IF_ERROR(cluster.Provision());
|
||||
return InferDynamically(&cluster);
|
||||
TF_RETURN_IF_ERROR(cluster.Initialize(item_));
|
||||
RunMetadata metadata;
|
||||
Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata);
|
||||
// The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
|
||||
// that the model would run out of memory. We still get the metadata we need
|
||||
// out of the simulation, so we just ignore this error.
|
||||
if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
|
||||
return s;
|
||||
}
|
||||
InferFromTrace(metadata.step_stats());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphMemory::InferDynamically(Cluster* cluster) {
|
||||
|
@ -29,8 +29,8 @@ namespace grappler {
|
||||
class GraphView {
|
||||
public:
|
||||
struct Port {
|
||||
NodeDef* node;
|
||||
int port_id;
|
||||
NodeDef* node = nullptr;
|
||||
int port_id = -1;
|
||||
|
||||
bool operator==(const Port& other) const {
|
||||
return node == other.node && port_id == other.port_id;
|
||||
|
@ -568,9 +568,12 @@ static const NodeDef* FindSwapTrigger(
|
||||
max_trigger_time -= swap_info.time_to_swap;
|
||||
|
||||
std::map<Costs::NanoSeconds, const NodeDef*> candidates;
|
||||
std::set<string> already_processed;
|
||||
|
||||
while (!possible_inputs.empty()) {
|
||||
const string input_node_name = *possible_inputs.begin();
|
||||
possible_inputs.erase(possible_inputs.begin());
|
||||
already_processed.insert(input_node_name);
|
||||
auto it1 = name_map.find(input_node_name);
|
||||
if (it1 == name_map.end()) {
|
||||
return nullptr;
|
||||
@ -579,7 +582,7 @@ static const NodeDef* FindSwapTrigger(
|
||||
// Don't jump over frames, since adding a control dependency from one frame
|
||||
// to the next isn't supported. Don't go through branches, since we don't
|
||||
// know whether they'll be executed or not.
|
||||
if (IsNextIteration(*input_node) || IsSwitch(*input_node) ||
|
||||
if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
|
||||
IsMerge(*input_node)) {
|
||||
continue;
|
||||
}
|
||||
@ -591,7 +594,10 @@ static const NodeDef* FindSwapTrigger(
|
||||
candidates[it2->second] = input_node;
|
||||
} else {
|
||||
for (const string& fanin : input_node->input()) {
|
||||
possible_inputs.insert(NodeName(fanin));
|
||||
string name = NodeName(fanin);
|
||||
if (already_processed.find(name) == already_processed.end()) {
|
||||
possible_inputs.insert(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -611,7 +617,9 @@ static void IdentifySwappingCandidates(Cluster* cluster,
|
||||
GraphMemory memory(item);
|
||||
const std::unordered_map<string, DeviceProperties>& devices =
|
||||
cluster->GetDevices();
|
||||
if (!memory.InferStatically(devices).ok()) {
|
||||
Status s = memory.InferStatically(devices);
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Failed to infer memory usage: " << s.error_message();
|
||||
return;
|
||||
}
|
||||
|
||||
@ -622,24 +630,36 @@ static void IdentifySwappingCandidates(Cluster* cluster,
|
||||
continue;
|
||||
}
|
||||
if (prop.memory_size() <= 0) {
|
||||
VLOG(1) << "Peak memory usage unknown for device " << name;
|
||||
continue;
|
||||
}
|
||||
const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
|
||||
|
||||
if (mem_usage.used_memory <= prop.memory_size()) {
|
||||
continue;
|
||||
}
|
||||
int64 required_savings = mem_usage.used_memory - prop.memory_size();
|
||||
// TODO(bsteiner): sort the tensors by how long they're live.
|
||||
|
||||
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
|
||||
if (!EstimateEarliestExecutionTimes(item, cluster, &execution_times).ok()) {
|
||||
return;
|
||||
std::unordered_map<string, Costs::NanoSeconds> execution_times;
|
||||
{
|
||||
std::unordered_map<const NodeDef*, Costs::NanoSeconds>
|
||||
tmp_execution_times;
|
||||
if (!EstimateEarliestExecutionTimes(item, cluster, &tmp_execution_times)
|
||||
.ok()) {
|
||||
return;
|
||||
}
|
||||
for (const auto& exec_time : tmp_execution_times) {
|
||||
execution_times.emplace(exec_time.first->name(), exec_time.second);
|
||||
}
|
||||
}
|
||||
|
||||
GraphView graph(optimized_graph);
|
||||
for (const auto& live_tensor : mem_usage.live_tensors) {
|
||||
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
|
||||
Costs::Duration(1e6)) {
|
||||
// Not enough time to swap.
|
||||
VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
|
||||
continue;
|
||||
}
|
||||
if (live_tensor.memory_used <= 1024) {
|
||||
@ -651,7 +671,7 @@ static void IdentifySwappingCandidates(Cluster* cluster,
|
||||
GraphView::OutputPort port =
|
||||
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
|
||||
for (GraphView::InputPort input : graph.GetFanout(port)) {
|
||||
auto it = execution_times.find(input.node);
|
||||
auto it = execution_times.find(input.node->name());
|
||||
if (it != execution_times.end()) {
|
||||
if (it->second > execution_time) {
|
||||
fanout_to_swap = input;
|
||||
@ -661,15 +681,23 @@ static void IdentifySwappingCandidates(Cluster* cluster,
|
||||
}
|
||||
// Annotate the fanout to request the tensor to be swapped if it's not
|
||||
// already been done.
|
||||
AttrValue& val = (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
|
||||
bool found = false;
|
||||
for (int port_id : val.list().i()) {
|
||||
if (port_id == fanout_to_swap.port_id) {
|
||||
found = true;
|
||||
break;
|
||||
if (!fanout_to_swap.node) {
|
||||
continue;
|
||||
}
|
||||
auto it = fanout_to_swap.node->attr().find("_swap_to_host");
|
||||
if (it != fanout_to_swap.node->attr().end()) {
|
||||
const AttrValue& val = it->second;
|
||||
for (int port_id : val.list().i()) {
|
||||
if (port_id == fanout_to_swap.port_id) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
AttrValue& val =
|
||||
(*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
|
||||
val.mutable_list()->add_i(fanout_to_swap.port_id);
|
||||
required_savings -= live_tensor.memory_used;
|
||||
if (required_savings < 0) {
|
||||
@ -688,7 +716,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
recomputation_targets_name_prefix_,
|
||||
optimized_graph, item);
|
||||
|
||||
if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS) {
|
||||
if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS &&
|
||||
cluster != nullptr) {
|
||||
IdentifySwappingCandidates(cluster, item, optimized_graph);
|
||||
}
|
||||
|
||||
|
@ -201,8 +201,16 @@ class MemoryOptimizerTest : public ::testing::Test {
|
||||
cpu_device.set_frequency(1000);
|
||||
cpu_device.set_num_cores(4);
|
||||
cpu_device.set_bandwidth(32);
|
||||
DeviceProperties gpu_device;
|
||||
gpu_device.set_type("GPU");
|
||||
gpu_device.set_frequency(1000);
|
||||
gpu_device.set_num_cores(24);
|
||||
gpu_device.set_bandwidth(128);
|
||||
gpu_device.set_memory_size(1024 * 1024);
|
||||
gpu_device.mutable_environment()->insert({"architecture", "6"});
|
||||
std::unordered_map<string, DeviceProperties> devices;
|
||||
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
|
||||
devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
|
||||
return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices));
|
||||
}
|
||||
};
|
||||
@ -252,6 +260,42 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
|
||||
EXPECT_EQ("^c", swap_in.input(1));
|
||||
}
|
||||
|
||||
TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
|
||||
{128, 128, 8}, DT_FLOAT);
|
||||
Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
|
||||
Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
|
||||
Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a});
|
||||
Output axis = ops::Const(s.WithOpName("axis"), 0);
|
||||
Output e =
|
||||
ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"e"};
|
||||
|
||||
std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
|
||||
|
||||
MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(cluster.get(), item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
for (const auto& node : output.node()) {
|
||||
if (node.name() == "e") {
|
||||
EXPECT_TRUE(node.attr().count("_swap_to_host") > 0);
|
||||
const AttrValue& val = node.attr().at("_swap_to_host");
|
||||
EXPECT_TRUE(val.has_list());
|
||||
std::set<int> inputs_to_swap;
|
||||
for (int64 input_id : val.list().i()) {
|
||||
inputs_to_swap.insert(input_id);
|
||||
}
|
||||
EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user