Add a heuristic to Grappler's memory optimizer to recompute elementwise ops
The current heuristic saves memory in simple conv->BN->relu->conv setups. It wastes computation and does not save memory for ResNet-like architectures (everything gets grouped together and recomputed just before gradients are executed). It's also using a very simple list of ops to recompute. At the moment there is no advantage to this over just wrapping each layer in a Defun. However, there is a bit of infrastructure which will be re-used once smarter heuristics come around (namely finding trigger control dependencies and doing the re-writing). And in the short term, even a few dumb heuristics should make things better for many networks (I just don't want to make this CL any more complicated than it already is). PiperOrigin-RevId: 159026716
This commit is contained in:
parent
a7c36173ca
commit
f0a8bd95c7
@ -200,6 +200,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
@ -25,11 +28,325 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
|
||||
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Prefix added to nodes which are recomputed.
|
||||
const char* kRecomputedNodePrefix = "Recomputed";
|
||||
const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
|
||||
// Attribute which may be added to nodes to manually allow them to be
|
||||
// recomputed.
|
||||
const char* kRecomputeHint = "_recompute_hint";
|
||||
const char* kRecomputationTargetNamePrefix = "gradients/";
|
||||
|
||||
// Ops which we wouldn't mind recomputing to save memory.
|
||||
// TODO(allenl): Replace this list with a cost model.
|
||||
std::unordered_set<string> GetCheapToRecomputeOps() {
|
||||
std::unordered_set<string> cheap_ops = {
|
||||
"Add", "AddN", "BiasAdd",
|
||||
"Cast", "Fill", "FloorDiv",
|
||||
"FloorMod", "FusedBatchNorm", "Mul",
|
||||
"Neg", "RealDiv", "Reciprocal",
|
||||
"Relu", "Reshape", "Rsqrt",
|
||||
"Sqrt", "Square", "SquaredDifference",
|
||||
"Sub", "Tile", "Transpose"};
|
||||
return cheap_ops;
|
||||
}
|
||||
|
||||
// Nodes whose inputs we may want to recompute (i.e. gradients).
|
||||
// TODO(allenl): Rather than blindly recomputing gradient inputs, use a static
|
||||
// schedule (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
|
||||
// whose outputs will sit around for a while.
|
||||
bool IsTargetOp(const NodeDef& node) {
|
||||
return node.name().find(kRecomputationTargetNamePrefix) == 0;
|
||||
}
|
||||
|
||||
// Find recomputable ops which feed into target nodes.
|
||||
std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
|
||||
const NodeMap& node_map, const GraphDef* graph,
|
||||
const std::function<bool(const NodeDef&)>& is_candidate) {
|
||||
std::unordered_set<const NodeDef*> candidate_recompute_nodes;
|
||||
for (const auto& node : graph->node()) {
|
||||
if (!is_candidate(node)) {
|
||||
continue;
|
||||
}
|
||||
bool has_target_output = false;
|
||||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||
// It only makes sense to recompute this if it feeds into a target
|
||||
// node. We expand this to dependencies in GetOpGroupsToRecompute.
|
||||
if (IsTargetOp(*output)) {
|
||||
has_target_output = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!has_target_output) {
|
||||
continue;
|
||||
}
|
||||
bool has_target_input = false;
|
||||
for (const string& input_name : node.input()) {
|
||||
// Don't recompute nodes which depend on target nodes.
|
||||
const NodeDef* input_node = node_map.GetNode(input_name);
|
||||
if (IsTargetOp(*input_node)) {
|
||||
has_target_input = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_target_input) {
|
||||
continue;
|
||||
}
|
||||
candidate_recompute_nodes.insert(&node);
|
||||
}
|
||||
return candidate_recompute_nodes;
|
||||
}
|
||||
|
||||
void connected_subgraph(const NodeMap& node_map, bool collect_inputs,
|
||||
bool collect_outputs,
|
||||
const std::function<bool(const NodeDef&)>& is_candidate,
|
||||
std::unordered_set<const NodeDef*>* expanded_nodes) {
|
||||
std::queue<const NodeDef*> to_visit;
|
||||
for (const NodeDef* starting_node : *expanded_nodes) {
|
||||
to_visit.push(starting_node);
|
||||
}
|
||||
expanded_nodes->clear();
|
||||
while (!to_visit.empty()) {
|
||||
const NodeDef* current_node = to_visit.front();
|
||||
to_visit.pop();
|
||||
if (!expanded_nodes->insert(current_node).second) {
|
||||
// We already visited this node
|
||||
continue;
|
||||
}
|
||||
if (collect_inputs) {
|
||||
// Add inputs and outputs to this subgraph if they are candidates
|
||||
for (const string& input_name_raw : current_node->input()) {
|
||||
const NodeDef* input_node = node_map.GetNode(input_name_raw);
|
||||
if (expanded_nodes->count(input_node) == 0 &&
|
||||
is_candidate(*input_node)) {
|
||||
to_visit.push(input_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (collect_outputs) {
|
||||
for (const NodeDef* output : node_map.GetOutputs(current_node->name())) {
|
||||
if (expanded_nodes->count(output) == 0 && is_candidate(*output)) {
|
||||
to_visit.push(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RecomputedSubGraph {
|
||||
std::unordered_set<const NodeDef*> recomputed_source_nodes;
|
||||
std::unordered_set<NodeDef*> target_nodes;
|
||||
};
|
||||
|
||||
// Find groups of ops to recompute together based on `should_recompute`.
|
||||
std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
|
||||
const GraphDef* graph, const NodeMap& node_map,
|
||||
const std::function<bool(const NodeDef&)>& should_recompute) {
|
||||
std::unordered_set<const NodeDef*> visited_nodes;
|
||||
std::vector<RecomputedSubGraph> subgraphs_to_recompute;
|
||||
std::unordered_set<const NodeDef*> candidate_recompute_nodes =
|
||||
FindCandidateRecomputeNodes(node_map, graph, should_recompute);
|
||||
for (const NodeDef* recompute_node : candidate_recompute_nodes) {
|
||||
if (visited_nodes.count(recompute_node) > 0) {
|
||||
continue;
|
||||
}
|
||||
RecomputedSubGraph current_recomputation;
|
||||
// Build out recomputation groups by expanding to inexpensive-to-recompute
|
||||
// nodes which do not feed target nodes. The goal is to capture some
|
||||
// intermediate activations within this graph.
|
||||
std::unordered_set<const NodeDef*> unpruned_recompute_nodes;
|
||||
unpruned_recompute_nodes.insert(recompute_node);
|
||||
connected_subgraph(node_map,
|
||||
true, // Collect inputs
|
||||
true, // Collect outputs
|
||||
should_recompute, &unpruned_recompute_nodes);
|
||||
visited_nodes.insert(unpruned_recompute_nodes.begin(),
|
||||
unpruned_recompute_nodes.end());
|
||||
for (const NodeDef* recompute_node : unpruned_recompute_nodes) {
|
||||
bool inserted_feed = false;
|
||||
for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) {
|
||||
if (IsTargetOp(*output)) {
|
||||
current_recomputation.target_nodes.insert(output);
|
||||
if (!inserted_feed) {
|
||||
// Keep track of nodes which feed directly into a target node. These
|
||||
// and nodes which feed into them will define the recomputed
|
||||
// subgraph.
|
||||
current_recomputation.recomputed_source_nodes.insert(
|
||||
recompute_node);
|
||||
inserted_feed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Recompute only nodes which eventually feed into a target node.
|
||||
connected_subgraph(node_map,
|
||||
true, // Collect inputs
|
||||
false, // Collect outputs
|
||||
[&unpruned_recompute_nodes](const NodeDef& node) {
|
||||
return unpruned_recompute_nodes.count(&node) != 0;
|
||||
},
|
||||
¤t_recomputation.recomputed_source_nodes);
|
||||
if (current_recomputation.target_nodes.empty()) {
|
||||
continue;
|
||||
}
|
||||
subgraphs_to_recompute.push_back(current_recomputation);
|
||||
}
|
||||
return subgraphs_to_recompute;
|
||||
}
|
||||
|
||||
// Computes the maximum topological numbers of (1) target node components
|
||||
// (gradient nodes being fed by the recomputation), and (2) child recompute node
|
||||
// components for each recomputed node. We will not attach any control
|
||||
// dependencies to a recomputation unless they have component numbers greater
|
||||
// than this value (to prevent cycles).
|
||||
std::unordered_map<const NodeDef*, int> GetMaxDownstreamComponents(
|
||||
const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
|
||||
const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
|
||||
const std::unordered_map<const NodeDef*, int>& components) {
|
||||
std::unordered_map<const NodeDef*, int> recomputed_node_components;
|
||||
// Start by setting component numbers to the maximum among target nodes.
|
||||
for (const NodeDef* original_recompute_node : recomputed_source_nodes) {
|
||||
int max_target_component = -1;
|
||||
for (NodeDef* output :
|
||||
node_map.GetOutputs(original_recompute_node->name())) {
|
||||
if (target_nodes.count(output) != 0) {
|
||||
int current_target_component = components.find(output)->second;
|
||||
if (current_target_component > max_target_component) {
|
||||
max_target_component = current_target_component;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (max_target_component > -1) {
|
||||
recomputed_node_components[original_recompute_node] =
|
||||
max_target_component;
|
||||
}
|
||||
}
|
||||
// Sort recomputed nodes topologically (based on the original graph) so we can
|
||||
// efficiently assign to each node the maximum of its recomputed child
|
||||
// components and its own targets.
|
||||
std::vector<const NodeDef*> recomputed_source_nodes_topological(
|
||||
recomputed_source_nodes.begin(), recomputed_source_nodes.end());
|
||||
std::sort(recomputed_source_nodes_topological.begin(),
|
||||
recomputed_source_nodes_topological.end(),
|
||||
[&components](const NodeDef* first, const NodeDef* second) {
|
||||
return components.find(first)->second <
|
||||
components.find(second)->second;
|
||||
});
|
||||
for (const NodeDef* original_recompute_node :
|
||||
recomputed_source_nodes_topological) {
|
||||
int max_component;
|
||||
auto recomputed_component_iterator =
|
||||
recomputed_node_components.find(original_recompute_node);
|
||||
if (recomputed_component_iterator != recomputed_node_components.end()) {
|
||||
max_component = recomputed_component_iterator->second;
|
||||
} else {
|
||||
max_component = -1;
|
||||
}
|
||||
for (NodeDef* output :
|
||||
node_map.GetOutputs(original_recompute_node->name())) {
|
||||
if (recomputed_source_nodes.count(output) == 0) {
|
||||
continue;
|
||||
}
|
||||
auto child_component_iterator = recomputed_node_components.find(output);
|
||||
CHECK(child_component_iterator != recomputed_node_components.end());
|
||||
int child_component = child_component_iterator->second;
|
||||
if (child_component > max_component) {
|
||||
max_component = child_component;
|
||||
}
|
||||
}
|
||||
CHECK_GE(max_component, 0);
|
||||
recomputed_node_components[original_recompute_node] = max_component;
|
||||
}
|
||||
return recomputed_node_components;
|
||||
}
|
||||
|
||||
// Modifies `graph`, adding trigger nodes and returning a mapping from
|
||||
// `recomputed_source_nodes` to trigger nodes which will not create loops in the
|
||||
// graph (using the component numberings in `components` and
|
||||
// `recomputed_node_max_feed_components`). The copied nodes (not the nodes in
|
||||
// recomputed_source_nodes, which are the originals) eventually get these
|
||||
// control dependencies.
|
||||
std::unordered_map<const NodeDef*, const NodeDef*>
|
||||
AddRecomputeControlDependencyNodes(
|
||||
const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
|
||||
const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
|
||||
const std::unordered_map<const NodeDef*, int>& components,
|
||||
const std::unordered_map<const NodeDef*, int>&
|
||||
recomputed_node_max_feed_components,
|
||||
GraphDef* graph) {
|
||||
// Sort recomputed nodes based on max downstream components.
|
||||
std::vector<const NodeDef*> recomputed_source_nodes_topological(
|
||||
recomputed_source_nodes.begin(), recomputed_source_nodes.end());
|
||||
std::sort(recomputed_source_nodes_topological.begin(),
|
||||
recomputed_source_nodes_topological.end(),
|
||||
[&recomputed_node_max_feed_components](const NodeDef* first,
|
||||
const NodeDef* second) {
|
||||
int first_component =
|
||||
recomputed_node_max_feed_components.find(first)->second;
|
||||
int second_component =
|
||||
recomputed_node_max_feed_components.find(second)->second;
|
||||
return first_component > second_component
|
||||
// Ensure a consistent ordering. This is necessary because
|
||||
// we're working not with node component numbers (which are
|
||||
// unique) but with the maximum across nodes they feed into
|
||||
// (very much not unique).
|
||||
|| (first_component == second_component &&
|
||||
first->name() > second->name());
|
||||
});
|
||||
// Create merged control dependency nodes by sorting target inputs
|
||||
// topologically and zipper merging with the sorted recomputed nodes.
|
||||
std::vector<const NodeDef*> target_inputs_topological;
|
||||
for (const NodeDef* target_node : target_nodes) {
|
||||
for (const string& target_input_name_raw : target_node->input()) {
|
||||
const NodeDef* target_input = node_map.GetNode(target_input_name_raw);
|
||||
if (recomputed_source_nodes.count(target_input) != 0 ||
|
||||
components.find(target_node)->second ==
|
||||
components.find(target_input)->second) {
|
||||
continue;
|
||||
}
|
||||
target_inputs_topological.push_back(target_input);
|
||||
}
|
||||
}
|
||||
std::sort(target_inputs_topological.begin(), target_inputs_topological.end(),
|
||||
[&components](const NodeDef* first, const NodeDef* second) {
|
||||
return components.find(first)->second >
|
||||
components.find(second)->second;
|
||||
});
|
||||
auto target_input_iterator = target_inputs_topological.begin();
|
||||
NodeDef* current_trigger_node = nullptr;
|
||||
std::unordered_map<const NodeDef*, const NodeDef*> triggers;
|
||||
for (const NodeDef* original_recomputed_node :
|
||||
recomputed_source_nodes_topological) {
|
||||
NodeDef* new_trigger_node = graph->add_node();
|
||||
new_trigger_node->set_name(AddPrefixToNodeName(
|
||||
original_recomputed_node->name(), kRecomputeTriggerNodePrefix));
|
||||
new_trigger_node->set_op("NoOp");
|
||||
new_trigger_node->set_device(original_recomputed_node->device());
|
||||
if (current_trigger_node != nullptr) {
|
||||
*new_trigger_node->add_input() =
|
||||
strings::StrCat("^", current_trigger_node->name());
|
||||
}
|
||||
current_trigger_node = new_trigger_node;
|
||||
triggers[original_recomputed_node] = current_trigger_node;
|
||||
for (;
|
||||
target_input_iterator != target_inputs_topological.end() &&
|
||||
components.find(*target_input_iterator)->second >
|
||||
recomputed_node_max_feed_components.find(original_recomputed_node)
|
||||
->second;
|
||||
++target_input_iterator) {
|
||||
*current_trigger_node->add_input() =
|
||||
strings::StrCat("^", (*target_input_iterator)->name());
|
||||
VLOG(2) << " Recomputation trigger " << current_trigger_node->name()
|
||||
<< " depends on " << (*target_input_iterator)->name();
|
||||
}
|
||||
}
|
||||
return triggers;
|
||||
}
|
||||
|
||||
string RecomputedOrOriginalNodeName(
|
||||
const std::unordered_set<string>& recomputed_node_names,
|
||||
@ -42,14 +359,28 @@ string RecomputedOrOriginalNodeName(
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to recompute a sub-graph (recomputed_source_nodes). Edges
|
||||
// from recomputed_source_nodes to target_nodes are changed to start from the
|
||||
// recomputed nodes.
|
||||
void RecomputeSubgraph(
|
||||
const std::vector<const NodeDef*>& recomputed_source_nodes,
|
||||
const string& recompute_trigger_node_name,
|
||||
const std::vector<NodeDef*>& target_nodes, GraphDef* graph) {
|
||||
const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
|
||||
const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
|
||||
const std::unordered_map<const NodeDef*, int>& components,
|
||||
GraphDef* graph) {
|
||||
std::unordered_set<string> recomputed_node_names;
|
||||
for (const NodeDef* to_recompute : recomputed_source_nodes) {
|
||||
recomputed_node_names.insert(to_recompute->name());
|
||||
VLOG(1) << "Recomputing a " << recomputed_source_nodes.size()
|
||||
<< " node subgraph";
|
||||
std::unordered_map<const NodeDef*, int> recomputed_node_components =
|
||||
GetMaxDownstreamComponents(recomputed_source_nodes, target_nodes,
|
||||
node_map, components);
|
||||
for (const NodeDef* original_node : recomputed_source_nodes) {
|
||||
VLOG(2) << " " << original_node->name();
|
||||
recomputed_node_names.insert(original_node->name());
|
||||
}
|
||||
std::unordered_map<const NodeDef*, const NodeDef*> triggers =
|
||||
AddRecomputeControlDependencyNodes(recomputed_source_nodes, target_nodes,
|
||||
node_map, components,
|
||||
recomputed_node_components, graph);
|
||||
// Create the recomputed sub-graph
|
||||
for (const NodeDef* original_node : recomputed_source_nodes) {
|
||||
NodeDef* copied_node = graph->add_node();
|
||||
@ -64,10 +395,10 @@ void RecomputeSubgraph(
|
||||
*copied_node->add_input() = RecomputedOrOriginalNodeName(
|
||||
recomputed_node_names, original_input_name);
|
||||
}
|
||||
// Set control dependencies on the recomputed nodes so that they are not run
|
||||
// until the specified trigger runs.
|
||||
// Each recomputed node gets a control dependency to prevent it from being
|
||||
// recomputed immediately.
|
||||
*copied_node->add_input() =
|
||||
strings::StrCat("^", recompute_trigger_node_name);
|
||||
strings::StrCat("^", triggers[original_node]->name());
|
||||
}
|
||||
// Set the inputs of nodes in the target subgraph to the recomputed nodes
|
||||
// where applicable.
|
||||
@ -79,6 +410,52 @@ void RecomputeSubgraph(
|
||||
}
|
||||
}
|
||||
|
||||
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
||||
GraphDef* graph) {
|
||||
// The topological numberings and NodeMap will be stale as soon as we start
|
||||
// modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
|
||||
// looks up nodes which were in the original graph, and preserves the graph
|
||||
// topology it's interested in.
|
||||
// We don't use the results of this topological sort until later, but this
|
||||
// call invalidates all NodeDef pointers, so it needs to be done before we
|
||||
// start collecting those.
|
||||
TopologicalSort(graph);
|
||||
NodeMap node_map(graph);
|
||||
std::vector<RecomputedSubGraph> recomputed_subgraphs;
|
||||
if (optimization_level == RewriterConfig::HEURISTICS) {
|
||||
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
|
||||
// the cheap forward ops get grouped into a single subgraph which must
|
||||
// execute before gradients start executing (unless layers are manually
|
||||
// separated by identity ops).
|
||||
std::unordered_set<string> cheap_to_recompute_ops =
|
||||
GetCheapToRecomputeOps();
|
||||
recomputed_subgraphs = GetOpGroupsToRecompute(
|
||||
graph, node_map, [&cheap_to_recompute_ops](const NodeDef& node) {
|
||||
return !IsTargetOp(node) &&
|
||||
(cheap_to_recompute_ops.count(node.op()) > 0 ||
|
||||
node.attr().count(kRecomputeHint) > 0);
|
||||
});
|
||||
} else { // optimization_level == RewriterConfig::MANUAL
|
||||
recomputed_subgraphs =
|
||||
GetOpGroupsToRecompute(graph, node_map, [](const NodeDef& node) {
|
||||
return !IsTargetOp(node) && node.attr().count(kRecomputeHint) > 0;
|
||||
});
|
||||
}
|
||||
if (!recomputed_subgraphs.empty()) {
|
||||
std::unordered_map<const NodeDef*, int> topological_numbering;
|
||||
for (int node_number = 0; node_number < graph->node().size();
|
||||
++node_number) {
|
||||
topological_numbering[graph->mutable_node(node_number)] =
|
||||
graph->node().size() - node_number - 1;
|
||||
}
|
||||
// Duplicate the indicated sub-graphs and set up control dependencies
|
||||
for (const RecomputedSubGraph& subgraph : recomputed_subgraphs) {
|
||||
RecomputeSubgraph(subgraph.recomputed_source_nodes, subgraph.target_nodes,
|
||||
node_map, topological_numbering, graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
|
||||
GraphDef* graph) {
|
||||
string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
|
||||
@ -205,6 +582,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
*optimized_graph = item.graph;
|
||||
|
||||
RecomputationRewritingPass(optimization_level_, optimized_graph);
|
||||
|
||||
// Figure out what needs to be swapped;
|
||||
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
|
||||
for (auto& node : *optimized_graph->mutable_node()) {
|
||||
|
@ -16,9 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
||||
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -26,7 +25,8 @@ namespace grappler {
|
||||
// Swap tensors in and out of device memory.
|
||||
class MemoryOptimizer : public GraphOptimizer {
|
||||
public:
|
||||
MemoryOptimizer() {}
|
||||
explicit MemoryOptimizer(RewriterConfig::MemOptType optimization_level)
|
||||
: optimization_level_(optimization_level) {}
|
||||
~MemoryOptimizer() override {}
|
||||
|
||||
string name() const override { return "memory_optimizer"; };
|
||||
@ -36,15 +36,10 @@ class MemoryOptimizer : public GraphOptimizer {
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& pruned_graph, double result) override;
|
||||
};
|
||||
|
||||
// Helper function to recompute a sub-graph (recomputed_source_nodes) on a
|
||||
// trigger. Edges from recomputed_source_nodes to target_nodes are changed to
|
||||
// start from the recomputed nodes.
|
||||
void RecomputeSubgraph(
|
||||
const std::vector<const NodeDef*>& recomputed_source_nodes,
|
||||
const string& recompute_trigger_node_name,
|
||||
const std::vector<NodeDef*>& target_nodes, GraphDef* graph);
|
||||
private:
|
||||
RewriterConfig::MemOptType optimization_level_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -35,84 +35,103 @@ TEST_F(RecomputeSubgraphTest, SimpleSubgraph) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 1.f, {2, 3, 4});
|
||||
Output b = ops::AddN(s.WithOpName("b"), {a}); // Recomputed
|
||||
Output c = ops::AddN(s.WithOpName("c"), {b});
|
||||
Output d = ops::AddN(s.WithOpName("d"), {c});
|
||||
Output e = ops::AddN(s.WithOpName("e"), {d, b});
|
||||
Output f = ops::AddN(s.WithOpName("f"), {e, a});
|
||||
Output b = ops::Identity(s.WithOpName("b"), a); // Recomputed
|
||||
Output c = ops::Identity(s.WithOpName("c"), b);
|
||||
Output d = ops::AddN(s.WithOpName("gradients/d"), {c});
|
||||
Output e = ops::AddN(s.WithOpName("gradients/e"), {d, b});
|
||||
Output f = ops::AddN(s.WithOpName("gradients/f"), {e, a});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
EXPECT_EQ(6, item.graph.node_size());
|
||||
NodeMap pre_transform_node_map(&item.graph);
|
||||
std::vector<const NodeDef*> recomputed_source_nodes;
|
||||
recomputed_source_nodes.push_back(pre_transform_node_map.GetNode(b.name()));
|
||||
std::vector<NodeDef*> target_nodes;
|
||||
target_nodes.push_back(pre_transform_node_map.GetNode(e.name()));
|
||||
RecomputeSubgraph(recomputed_source_nodes, d.name(), target_nodes,
|
||||
&item.graph);
|
||||
NodeMap post_transform_node_map(&item.graph);
|
||||
EXPECT_EQ(7, item.graph.node_size());
|
||||
(*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
|
||||
.set_i(0);
|
||||
|
||||
MemoryOptimizer optimizer(RewriterConfig::MANUAL);
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
|
||||
TF_EXPECT_OK(status);
|
||||
NodeMap post_transform_node_map(&output);
|
||||
EXPECT_EQ(8, output.node_size());
|
||||
NodeDef* transformed_e = post_transform_node_map.GetNode(e.name());
|
||||
EXPECT_EQ(2, transformed_e->input_size());
|
||||
EXPECT_EQ("d", transformed_e->input(0));
|
||||
EXPECT_EQ("gradients/d", transformed_e->input(0));
|
||||
EXPECT_EQ("Recomputed/b", transformed_e->input(1));
|
||||
NodeDef* recomputed_b = post_transform_node_map.GetNode("Recomputed/b");
|
||||
EXPECT_EQ(2, recomputed_b->input_size());
|
||||
EXPECT_EQ("a", recomputed_b->input(0));
|
||||
EXPECT_EQ("^d", recomputed_b->input(1).substr(0, 2));
|
||||
EXPECT_EQ("^RecomputeTrigger/b", recomputed_b->input(1));
|
||||
NodeDef* recompute_trigger =
|
||||
post_transform_node_map.GetNode("RecomputeTrigger/b");
|
||||
EXPECT_EQ(1, recompute_trigger->input_size());
|
||||
EXPECT_EQ("^gradients/d", recompute_trigger->input(0));
|
||||
}
|
||||
|
||||
TEST_F(RecomputeSubgraphTest, MultiNode) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("Conv"), 1.f, {2, 3, 4});
|
||||
Output b = ops::AddN(s.WithOpName("BN"), {a}); // Recomputed
|
||||
Output c = ops::AddN(s.WithOpName("ReLU"), {b}); // Recomputed
|
||||
Output d = ops::AddN(s.WithOpName("Conv1"), {c});
|
||||
Output b = ops::Identity(s.WithOpName("BN"), a); // Recomputed
|
||||
Output c = ops::Identity(s.WithOpName("ReLU"), b); // Recomputed
|
||||
Output d = ops::Identity(s.WithOpName("Conv1"), c);
|
||||
|
||||
Output trigger = ops::Const(s.WithOpName("BN1Grad"), 0.f, {2, 3, 4});
|
||||
Output e = ops::AddN(s.WithOpName("Conv1Grad"), {trigger, c});
|
||||
Output f = ops::AddN(s.WithOpName("ReLUGrad"), {e, c});
|
||||
Output g = ops::AddN(s.WithOpName("BNGrad"), {f, a});
|
||||
Output h = ops::AddN(s.WithOpName("ConvGrad"), {g});
|
||||
// The "gradients/" prefix means the heuristic will pick these up as
|
||||
// candidates to have their inputs recomputed.
|
||||
Output trigger = ops::AddN(s.WithOpName("gradients/BN1Grad"), {d});
|
||||
Output e = ops::AddN(s.WithOpName("gradients/Conv1Grad"), {trigger, c});
|
||||
Output f = ops::AddN(s.WithOpName("gradients/ReLUGrad"), {e, c});
|
||||
Output g = ops::AddN(s.WithOpName("gradients/BNGrad"), {f, a});
|
||||
Output h = ops::AddN(s.WithOpName("gradients/ConvGrad"), {g});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
EXPECT_EQ(9, item.graph.node_size());
|
||||
NodeMap pre_transform_node_map(&item.graph);
|
||||
std::vector<const NodeDef*> recomputed_source_nodes;
|
||||
recomputed_source_nodes.push_back(pre_transform_node_map.GetNode(b.name()));
|
||||
recomputed_source_nodes.push_back(pre_transform_node_map.GetNode(c.name()));
|
||||
std::vector<NodeDef*> target_nodes;
|
||||
target_nodes.push_back(pre_transform_node_map.GetNode(e.name()));
|
||||
target_nodes.push_back(pre_transform_node_map.GetNode(f.name()));
|
||||
target_nodes.push_back(pre_transform_node_map.GetNode(g.name()));
|
||||
RecomputeSubgraph(recomputed_source_nodes, trigger.name(), target_nodes,
|
||||
&item.graph);
|
||||
NodeMap post_transform_node_map(&item.graph);
|
||||
EXPECT_EQ(11, item.graph.node_size());
|
||||
// Set op types so that the heuristic will pick these nodes up to be
|
||||
// recomputed
|
||||
pre_transform_node_map.GetNode("BN")->set_op("FusedBatchNorm");
|
||||
pre_transform_node_map.GetNode("ReLU")->set_op("Relu");
|
||||
|
||||
MemoryOptimizer optimizer(RewriterConfig::HEURISTICS);
|
||||
GraphDef first_pass_output;
|
||||
Status first_pass_status =
|
||||
optimizer.Optimize(nullptr, item, &first_pass_output);
|
||||
TF_EXPECT_OK(first_pass_status);
|
||||
|
||||
NodeMap post_transform_node_map(&first_pass_output);
|
||||
EXPECT_EQ(13, first_pass_output.node_size());
|
||||
NodeDef* transformed_e = post_transform_node_map.GetNode(e.name());
|
||||
EXPECT_EQ(2, transformed_e->input_size());
|
||||
EXPECT_EQ("BN1Grad", transformed_e->input(0));
|
||||
EXPECT_EQ("gradients/BN1Grad", transformed_e->input(0));
|
||||
EXPECT_EQ("Recomputed/ReLU", transformed_e->input(1));
|
||||
NodeDef* transformed_f = post_transform_node_map.GetNode(f.name());
|
||||
EXPECT_EQ(2, transformed_f->input_size());
|
||||
EXPECT_EQ("Conv1Grad", transformed_f->input(0));
|
||||
EXPECT_EQ("gradients/Conv1Grad", transformed_f->input(0));
|
||||
EXPECT_EQ("Recomputed/ReLU", transformed_f->input(1));
|
||||
NodeDef* transformed_g = post_transform_node_map.GetNode(g.name());
|
||||
EXPECT_EQ(2, transformed_g->input_size());
|
||||
EXPECT_EQ("ReLUGrad", transformed_g->input(0));
|
||||
EXPECT_EQ("gradients/ReLUGrad", transformed_g->input(0));
|
||||
EXPECT_EQ("Conv", transformed_g->input(1));
|
||||
|
||||
NodeDef* recomputed_b = post_transform_node_map.GetNode("Recomputed/BN");
|
||||
EXPECT_EQ(2, recomputed_b->input_size());
|
||||
EXPECT_EQ("Conv", recomputed_b->input(0));
|
||||
EXPECT_EQ("^BN1Grad", recomputed_b->input(1).substr(0, 8));
|
||||
EXPECT_EQ("^RecomputeTrigger/BN", recomputed_b->input(1));
|
||||
NodeDef* recompute_trigger_b =
|
||||
post_transform_node_map.GetNode("RecomputeTrigger/BN");
|
||||
EXPECT_EQ(1, recompute_trigger_b->input_size());
|
||||
EXPECT_EQ("^RecomputeTrigger/ReLU", recompute_trigger_b->input(0));
|
||||
|
||||
NodeDef* recomputed_c = post_transform_node_map.GetNode("Recomputed/ReLU");
|
||||
EXPECT_EQ(2, recomputed_c->input_size());
|
||||
EXPECT_EQ("Recomputed/BN", recomputed_c->input(0));
|
||||
EXPECT_EQ("^BN1Grad", recomputed_c->input(1).substr(0, 8));
|
||||
EXPECT_EQ("^RecomputeTrigger/ReLU", recomputed_c->input(1));
|
||||
NodeDef* recompute_trigger_c =
|
||||
post_transform_node_map.GetNode("RecomputeTrigger/ReLU");
|
||||
EXPECT_EQ(1, recompute_trigger_c->input_size());
|
||||
EXPECT_EQ("^gradients/BN1Grad", recompute_trigger_c->input(0));
|
||||
}
|
||||
|
||||
class MemoryOptimizerTest : public ::testing::Test {
|
||||
@ -150,7 +169,7 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
|
||||
|
||||
VirtualCluster cluster(CreateVirtualCluster());
|
||||
|
||||
MemoryOptimizer optimizer;
|
||||
MemoryOptimizer optimizer(RewriterConfig::MANUAL);
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(&cluster, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
@ -41,7 +41,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
|
||||
graph_optimizer.reset(new LayoutOptimizer());
|
||||
}
|
||||
if (optimizer == "memory") {
|
||||
graph_optimizer.reset(new MemoryOptimizer());
|
||||
graph_optimizer.reset(new MemoryOptimizer(RewriterConfig::MANUAL));
|
||||
}
|
||||
if (optimizer == "autoparallel") {
|
||||
graph_optimizer.reset(
|
||||
@ -66,8 +66,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
|
||||
}
|
||||
if (cfg_.memory_optimization() > 0) {
|
||||
optimizers.push_back(
|
||||
std::unique_ptr<GraphOptimizer>(new MemoryOptimizer()));
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||
new MemoryOptimizer(cfg_.memory_optimization())));
|
||||
}
|
||||
if (cfg_.auto_parallel().enable()) {
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||
@ -114,7 +114,8 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
|
||||
return cfg.optimize_tensor_layout() || cfg.constant_folding() ||
|
||||
cfg.auto_parallel().enable() || !cfg.optimizers().empty();
|
||||
cfg.auto_parallel().enable() || cfg.memory_optimization() > 0 ||
|
||||
!cfg.optimizers().empty();
|
||||
}
|
||||
|
||||
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,
|
||||
|
@ -15,21 +15,40 @@ message RewriterConfig {
|
||||
// Graph rewriting is experimental and subject to change, not covered by any
|
||||
// API stability guarantees.
|
||||
|
||||
// Configuration options for the meta-optimizer. Unless otherwise noted, these
|
||||
// configuration options do not apply to explicitly triggered optimization
|
||||
// passes in the optimizers field.
|
||||
|
||||
bool optimize_tensor_layout = 1;
|
||||
bool disable_model_pruning = 2;
|
||||
bool constant_folding = 3;
|
||||
|
||||
enum MemOptType {
|
||||
// Fully disabled
|
||||
// Disabled in the meta-optimizer.
|
||||
NO_MEM_OPT = 0;
|
||||
// Driven by manual annotations
|
||||
// Driven by manual op-level annotations.
|
||||
MANUAL = 1;
|
||||
// Driven by heuristics. The behavior of these heuristics is subject to
|
||||
// change. Currently includes an experimental recomputation heuristic.
|
||||
HEURISTICS = 2;
|
||||
}
|
||||
// Configures memory optimization passes through the meta-optimizer. Has no
|
||||
// effect on manually requested memory optimization passes in the optimizers
|
||||
// field.
|
||||
MemOptType memory_optimization = 4;
|
||||
|
||||
// Configures AutoParallel optimization passes either through the
|
||||
// meta-optimizer or when manually specified through the optimizers field.
|
||||
AutoParallelOptions auto_parallel = 5;
|
||||
|
||||
// If non-empty, will use this as an alternative way to specify a list of
|
||||
// optimizations to turn on and the order of the optimizations.
|
||||
// optimizations to turn on and the order of the optimizations (replacing the
|
||||
// meta-optimizer).
|
||||
//
|
||||
// Of the RewriterConfig options, only the AutoParallel configuration options
|
||||
// (the auto_parallel field) apply to manually requested optimization passes
|
||||
// ("autoparallel"). Memory optimization passes ("memory") invoked here are
|
||||
// not configurable (in contrast to memory optimization passes through the
|
||||
// meta-optimizer) and act only on manual op annotations.
|
||||
repeated string optimizers = 100;
|
||||
}
|
||||
|
@ -3800,7 +3800,13 @@ py_test(
|
||||
":client_testlib",
|
||||
":framework_for_generated_wrappers",
|
||||
":math_ops",
|
||||
":nn",
|
||||
":random_seed",
|
||||
":session",
|
||||
":tf_optimizer",
|
||||
":training",
|
||||
":variable_scope",
|
||||
":variables",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
@ -18,16 +18,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import training as train
|
||||
|
||||
|
||||
class MemoryOptimizerTest(test.TestCase):
|
||||
class MemoryOptimizerSwapTest(test.TestCase):
|
||||
"""Tests the Grappler memory optimizer."""
|
||||
|
||||
def testNoSwapping(self):
|
||||
@ -85,5 +92,51 @@ class MemoryOptimizerTest(test.TestCase):
|
||||
self.assertEqual('c', node.input[1])
|
||||
|
||||
|
||||
class MemoryOptimizerRecomputeTest(test.TestCase):
|
||||
|
||||
def _RunGraphWithConfig(self, config, batch_size=14, image_dim=12):
|
||||
"""Run a simple layered graph with conv, an intermediate op, and a ReLU."""
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
random_seed.set_random_seed(1)
|
||||
current_activation = variable_scope.get_variable(
|
||||
name='start', shape=[batch_size, image_dim, image_dim, 5])
|
||||
conv_filter = variable_scope.get_variable(
|
||||
name='filter', shape=[5, 5, 5, 5])
|
||||
for layer_number in range(10):
|
||||
with variable_scope.variable_scope('layer_{}'.format(layer_number)):
|
||||
after_conv = nn.conv2d(current_activation, conv_filter, [1, 1, 1, 1],
|
||||
'SAME')
|
||||
current_activation = 2. * after_conv
|
||||
current_activation = nn.relu(current_activation)
|
||||
loss = math_ops.reduce_mean(current_activation)
|
||||
optimizer = train.AdamOptimizer(0.001)
|
||||
train_op = optimizer.minimize(loss)
|
||||
init_op = variables.global_variables_initializer()
|
||||
with session.Session(config=config, graph=graph) as sess:
|
||||
sess.run(init_op)
|
||||
sess.run(train_op)
|
||||
sess.run(train_op)
|
||||
return sess.run(loss)
|
||||
|
||||
def _GetMemoryOptimizerConfig(self):
|
||||
rewrite_options = rewriter_config_pb2.RewriterConfig(
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS)
|
||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options)
|
||||
return config_pb2.ConfigProto(graph_options=graph_options)
|
||||
|
||||
def testRecomputationRewritingNoErrors(self):
|
||||
"""Tests that there are no errors when we request a memory optimizer pass.
|
||||
|
||||
Does not test that the memory optimizer actually runs. See
|
||||
core/grappler/optimizers/memory_optimizer_test.cc for a functional test of
|
||||
the graph rewriting.
|
||||
"""
|
||||
original_loss = self._RunGraphWithConfig(config_pb2.ConfigProto())
|
||||
memory_optimized_loss = self._RunGraphWithConfig(
|
||||
config=self._GetMemoryOptimizerConfig())
|
||||
self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user