Implemented memory swapping heuristics for GPU
PiperOrigin-RevId: 180968225
This commit is contained in:
parent
3a3feb207d
commit
ca6f0dd19b
@ -32,7 +32,17 @@ Status GraphMemory::InferStatically(
|
|||||||
const std::unordered_map<string, DeviceProperties>& devices) {
|
const std::unordered_map<string, DeviceProperties>& devices) {
|
||||||
VirtualCluster cluster(devices);
|
VirtualCluster cluster(devices);
|
||||||
TF_RETURN_IF_ERROR(cluster.Provision());
|
TF_RETURN_IF_ERROR(cluster.Provision());
|
||||||
return InferDynamically(&cluster);
|
TF_RETURN_IF_ERROR(cluster.Initialize(item_));
|
||||||
|
RunMetadata metadata;
|
||||||
|
Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata);
|
||||||
|
// The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
|
||||||
|
// that the model would run out of memory. We still get the metadata we need
|
||||||
|
// out of the simulation, so we just ignore this error.
|
||||||
|
if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
InferFromTrace(metadata.step_stats());
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphMemory::InferDynamically(Cluster* cluster) {
|
Status GraphMemory::InferDynamically(Cluster* cluster) {
|
||||||
|
@ -29,8 +29,8 @@ namespace grappler {
|
|||||||
class GraphView {
|
class GraphView {
|
||||||
public:
|
public:
|
||||||
struct Port {
|
struct Port {
|
||||||
NodeDef* node;
|
NodeDef* node = nullptr;
|
||||||
int port_id;
|
int port_id = -1;
|
||||||
|
|
||||||
bool operator==(const Port& other) const {
|
bool operator==(const Port& other) const {
|
||||||
return node == other.node && port_id == other.port_id;
|
return node == other.node && port_id == other.port_id;
|
||||||
|
@ -568,9 +568,12 @@ static const NodeDef* FindSwapTrigger(
|
|||||||
max_trigger_time -= swap_info.time_to_swap;
|
max_trigger_time -= swap_info.time_to_swap;
|
||||||
|
|
||||||
std::map<Costs::NanoSeconds, const NodeDef*> candidates;
|
std::map<Costs::NanoSeconds, const NodeDef*> candidates;
|
||||||
|
std::set<string> already_processed;
|
||||||
|
|
||||||
while (!possible_inputs.empty()) {
|
while (!possible_inputs.empty()) {
|
||||||
const string input_node_name = *possible_inputs.begin();
|
const string input_node_name = *possible_inputs.begin();
|
||||||
possible_inputs.erase(possible_inputs.begin());
|
possible_inputs.erase(possible_inputs.begin());
|
||||||
|
already_processed.insert(input_node_name);
|
||||||
auto it1 = name_map.find(input_node_name);
|
auto it1 = name_map.find(input_node_name);
|
||||||
if (it1 == name_map.end()) {
|
if (it1 == name_map.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -579,7 +582,7 @@ static const NodeDef* FindSwapTrigger(
|
|||||||
// Don't jump over frames, since adding a control dependency from one frame
|
// Don't jump over frames, since adding a control dependency from one frame
|
||||||
// to the next isn't supported. Don't go through branches, since we don't
|
// to the next isn't supported. Don't go through branches, since we don't
|
||||||
// know whether they'll be executed or not.
|
// know whether they'll be executed or not.
|
||||||
if (IsNextIteration(*input_node) || IsSwitch(*input_node) ||
|
if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
|
||||||
IsMerge(*input_node)) {
|
IsMerge(*input_node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -591,7 +594,10 @@ static const NodeDef* FindSwapTrigger(
|
|||||||
candidates[it2->second] = input_node;
|
candidates[it2->second] = input_node;
|
||||||
} else {
|
} else {
|
||||||
for (const string& fanin : input_node->input()) {
|
for (const string& fanin : input_node->input()) {
|
||||||
possible_inputs.insert(NodeName(fanin));
|
string name = NodeName(fanin);
|
||||||
|
if (already_processed.find(name) == already_processed.end()) {
|
||||||
|
possible_inputs.insert(name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -611,7 +617,9 @@ static void IdentifySwappingCandidates(Cluster* cluster,
|
|||||||
GraphMemory memory(item);
|
GraphMemory memory(item);
|
||||||
const std::unordered_map<string, DeviceProperties>& devices =
|
const std::unordered_map<string, DeviceProperties>& devices =
|
||||||
cluster->GetDevices();
|
cluster->GetDevices();
|
||||||
if (!memory.InferStatically(devices).ok()) {
|
Status s = memory.InferStatically(devices);
|
||||||
|
if (!s.ok()) {
|
||||||
|
VLOG(1) << "Failed to infer memory usage: " << s.error_message();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -622,24 +630,36 @@ static void IdentifySwappingCandidates(Cluster* cluster,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (prop.memory_size() <= 0) {
|
if (prop.memory_size() <= 0) {
|
||||||
|
VLOG(1) << "Peak memory usage unknown for device " << name;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
|
const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
|
||||||
|
|
||||||
if (mem_usage.used_memory <= prop.memory_size()) {
|
if (mem_usage.used_memory <= prop.memory_size()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
int64 required_savings = mem_usage.used_memory - prop.memory_size();
|
int64 required_savings = mem_usage.used_memory - prop.memory_size();
|
||||||
// TODO(bsteiner): sort the tensors by how long they're live.
|
// TODO(bsteiner): sort the tensors by how long they're live.
|
||||||
|
|
||||||
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
|
std::unordered_map<string, Costs::NanoSeconds> execution_times;
|
||||||
if (!EstimateEarliestExecutionTimes(item, cluster, &execution_times).ok()) {
|
{
|
||||||
return;
|
std::unordered_map<const NodeDef*, Costs::NanoSeconds>
|
||||||
|
tmp_execution_times;
|
||||||
|
if (!EstimateEarliestExecutionTimes(item, cluster, &tmp_execution_times)
|
||||||
|
.ok()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const auto& exec_time : tmp_execution_times) {
|
||||||
|
execution_times.emplace(exec_time.first->name(), exec_time.second);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphView graph(optimized_graph);
|
GraphView graph(optimized_graph);
|
||||||
for (const auto& live_tensor : mem_usage.live_tensors) {
|
for (const auto& live_tensor : mem_usage.live_tensors) {
|
||||||
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
|
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
|
||||||
Costs::Duration(1e6)) {
|
Costs::Duration(1e6)) {
|
||||||
// Not enough time to swap.
|
// Not enough time to swap.
|
||||||
|
VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (live_tensor.memory_used <= 1024) {
|
if (live_tensor.memory_used <= 1024) {
|
||||||
@ -651,7 +671,7 @@ static void IdentifySwappingCandidates(Cluster* cluster,
|
|||||||
GraphView::OutputPort port =
|
GraphView::OutputPort port =
|
||||||
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
|
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
|
||||||
for (GraphView::InputPort input : graph.GetFanout(port)) {
|
for (GraphView::InputPort input : graph.GetFanout(port)) {
|
||||||
auto it = execution_times.find(input.node);
|
auto it = execution_times.find(input.node->name());
|
||||||
if (it != execution_times.end()) {
|
if (it != execution_times.end()) {
|
||||||
if (it->second > execution_time) {
|
if (it->second > execution_time) {
|
||||||
fanout_to_swap = input;
|
fanout_to_swap = input;
|
||||||
@ -661,15 +681,23 @@ static void IdentifySwappingCandidates(Cluster* cluster,
|
|||||||
}
|
}
|
||||||
// Annotate the fanout to request the tensor to be swapped if it's not
|
// Annotate the fanout to request the tensor to be swapped if it's not
|
||||||
// already been done.
|
// already been done.
|
||||||
AttrValue& val = (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
|
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (int port_id : val.list().i()) {
|
if (!fanout_to_swap.node) {
|
||||||
if (port_id == fanout_to_swap.port_id) {
|
continue;
|
||||||
found = true;
|
}
|
||||||
break;
|
auto it = fanout_to_swap.node->attr().find("_swap_to_host");
|
||||||
|
if (it != fanout_to_swap.node->attr().end()) {
|
||||||
|
const AttrValue& val = it->second;
|
||||||
|
for (int port_id : val.list().i()) {
|
||||||
|
if (port_id == fanout_to_swap.port_id) {
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!found) {
|
if (!found) {
|
||||||
|
AttrValue& val =
|
||||||
|
(*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
|
||||||
val.mutable_list()->add_i(fanout_to_swap.port_id);
|
val.mutable_list()->add_i(fanout_to_swap.port_id);
|
||||||
required_savings -= live_tensor.memory_used;
|
required_savings -= live_tensor.memory_used;
|
||||||
if (required_savings < 0) {
|
if (required_savings < 0) {
|
||||||
@ -688,7 +716,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
recomputation_targets_name_prefix_,
|
recomputation_targets_name_prefix_,
|
||||||
optimized_graph, item);
|
optimized_graph, item);
|
||||||
|
|
||||||
if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS) {
|
if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS &&
|
||||||
|
cluster != nullptr) {
|
||||||
IdentifySwappingCandidates(cluster, item, optimized_graph);
|
IdentifySwappingCandidates(cluster, item, optimized_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -201,8 +201,16 @@ class MemoryOptimizerTest : public ::testing::Test {
|
|||||||
cpu_device.set_frequency(1000);
|
cpu_device.set_frequency(1000);
|
||||||
cpu_device.set_num_cores(4);
|
cpu_device.set_num_cores(4);
|
||||||
cpu_device.set_bandwidth(32);
|
cpu_device.set_bandwidth(32);
|
||||||
|
DeviceProperties gpu_device;
|
||||||
|
gpu_device.set_type("GPU");
|
||||||
|
gpu_device.set_frequency(1000);
|
||||||
|
gpu_device.set_num_cores(24);
|
||||||
|
gpu_device.set_bandwidth(128);
|
||||||
|
gpu_device.set_memory_size(1024 * 1024);
|
||||||
|
gpu_device.mutable_environment()->insert({"architecture", "6"});
|
||||||
std::unordered_map<string, DeviceProperties> devices;
|
std::unordered_map<string, DeviceProperties> devices;
|
||||||
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
|
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
|
||||||
|
devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
|
||||||
return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices));
|
return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -252,6 +260,42 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
|
|||||||
EXPECT_EQ("^c", swap_in.input(1));
|
EXPECT_EQ("^c", swap_in.input(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
|
||||||
|
{128, 128, 8}, DT_FLOAT);
|
||||||
|
Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
|
||||||
|
Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
|
||||||
|
Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a});
|
||||||
|
Output axis = ops::Const(s.WithOpName("axis"), 0);
|
||||||
|
Output e =
|
||||||
|
ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
item.fetch = {"e"};
|
||||||
|
|
||||||
|
std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
|
||||||
|
|
||||||
|
MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(cluster.get(), item, &output);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
|
||||||
|
for (const auto& node : output.node()) {
|
||||||
|
if (node.name() == "e") {
|
||||||
|
EXPECT_TRUE(node.attr().count("_swap_to_host") > 0);
|
||||||
|
const AttrValue& val = node.attr().at("_swap_to_host");
|
||||||
|
EXPECT_TRUE(val.has_list());
|
||||||
|
std::set<int> inputs_to_swap;
|
||||||
|
for (int64 input_id : val.list().i()) {
|
||||||
|
inputs_to_swap.insert(input_id);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user