Add a heuristic to Grappler's memory optimizer to recompute elementwise ops

The current heuristic saves memory in simple conv->BN->relu->conv setups. It wastes computation and does not save memory for ResNet-like architectures (everything gets grouped together and recomputed just before gradients are executed).

It's also using a very simple list of ops to recompute. At the moment there is no advantage to this over just wrapping each layer in a Defun. However, there is a bit of infrastructure which will be re-used once smarter heuristics come around (namely finding trigger control dependencies and doing the re-writing).

And in the short term, even a few dumb heuristics should make things better for many networks (I just don't want to make this CL any more complicated than it already is).

PiperOrigin-RevId: 159026716
This commit is contained in:
A. Unique TensorFlower 2017-06-14 14:33:41 -07:00 committed by TensorFlower Gardener
parent a7c36173ca
commit f0a8bd95c7
8 changed files with 540 additions and 67 deletions

View File

@ -200,6 +200,7 @@ cc_library(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:topological_sort",
],
)

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include <algorithm>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@ -25,11 +28,325 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
// Prefix added to nodes which are recomputed.
const char* kRecomputedNodePrefix = "Recomputed";
const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger";
// Attribute which may be added to nodes to manually allow them to be
// recomputed.
const char* kRecomputeHint = "_recompute_hint";
const char* kRecomputationTargetNamePrefix = "gradients/";
// Ops which we wouldn't mind recomputing to save memory.
// TODO(allenl): Replace this list with a cost model.
std::unordered_set<string> GetCheapToRecomputeOps() {
std::unordered_set<string> cheap_ops = {
"Add", "AddN", "BiasAdd",
"Cast", "Fill", "FloorDiv",
"FloorMod", "FusedBatchNorm", "Mul",
"Neg", "RealDiv", "Reciprocal",
"Relu", "Reshape", "Rsqrt",
"Sqrt", "Square", "SquaredDifference",
"Sub", "Tile", "Transpose"};
return cheap_ops;
}
// Nodes whose inputs we may want to recompute (i.e. gradients).
// TODO(allenl): Rather than blindly recomputing gradient inputs, use a static
// schedule (grappler::EstimateEarliestExecutionTimes) to recompute only nodes
// whose outputs will sit around for a while.
bool IsTargetOp(const NodeDef& node) {
return node.name().find(kRecomputationTargetNamePrefix) == 0;
}
// Find recomputable ops which feed into target nodes.
std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
const NodeMap& node_map, const GraphDef* graph,
const std::function<bool(const NodeDef&)>& is_candidate) {
std::unordered_set<const NodeDef*> candidate_recompute_nodes;
for (const auto& node : graph->node()) {
if (!is_candidate(node)) {
continue;
}
bool has_target_output = false;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
// It only makes sense to recompute this if it feeds into a target
// node. We expand this to dependencies in GetOpGroupsToRecompute.
if (IsTargetOp(*output)) {
has_target_output = true;
break;
}
}
if (!has_target_output) {
continue;
}
bool has_target_input = false;
for (const string& input_name : node.input()) {
// Don't recompute nodes which depend on target nodes.
const NodeDef* input_node = node_map.GetNode(input_name);
if (IsTargetOp(*input_node)) {
has_target_input = true;
break;
}
}
if (has_target_input) {
continue;
}
candidate_recompute_nodes.insert(&node);
}
return candidate_recompute_nodes;
}
void connected_subgraph(const NodeMap& node_map, bool collect_inputs,
bool collect_outputs,
const std::function<bool(const NodeDef&)>& is_candidate,
std::unordered_set<const NodeDef*>* expanded_nodes) {
std::queue<const NodeDef*> to_visit;
for (const NodeDef* starting_node : *expanded_nodes) {
to_visit.push(starting_node);
}
expanded_nodes->clear();
while (!to_visit.empty()) {
const NodeDef* current_node = to_visit.front();
to_visit.pop();
if (!expanded_nodes->insert(current_node).second) {
// We already visited this node
continue;
}
if (collect_inputs) {
// Add inputs and outputs to this subgraph if they are candidates
for (const string& input_name_raw : current_node->input()) {
const NodeDef* input_node = node_map.GetNode(input_name_raw);
if (expanded_nodes->count(input_node) == 0 &&
is_candidate(*input_node)) {
to_visit.push(input_node);
}
}
}
if (collect_outputs) {
for (const NodeDef* output : node_map.GetOutputs(current_node->name())) {
if (expanded_nodes->count(output) == 0 && is_candidate(*output)) {
to_visit.push(output);
}
}
}
}
}
struct RecomputedSubGraph {
std::unordered_set<const NodeDef*> recomputed_source_nodes;
std::unordered_set<NodeDef*> target_nodes;
};
// Find groups of ops to recompute together based on `should_recompute`.
std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
const GraphDef* graph, const NodeMap& node_map,
const std::function<bool(const NodeDef&)>& should_recompute) {
std::unordered_set<const NodeDef*> visited_nodes;
std::vector<RecomputedSubGraph> subgraphs_to_recompute;
std::unordered_set<const NodeDef*> candidate_recompute_nodes =
FindCandidateRecomputeNodes(node_map, graph, should_recompute);
for (const NodeDef* recompute_node : candidate_recompute_nodes) {
if (visited_nodes.count(recompute_node) > 0) {
continue;
}
RecomputedSubGraph current_recomputation;
// Build out recomputation groups by expanding to inexpensive-to-recompute
// nodes which do not feed target nodes. The goal is to capture some
// intermediate activations within this graph.
std::unordered_set<const NodeDef*> unpruned_recompute_nodes;
unpruned_recompute_nodes.insert(recompute_node);
connected_subgraph(node_map,
true, // Collect inputs
true, // Collect outputs
should_recompute, &unpruned_recompute_nodes);
visited_nodes.insert(unpruned_recompute_nodes.begin(),
unpruned_recompute_nodes.end());
for (const NodeDef* recompute_node : unpruned_recompute_nodes) {
bool inserted_feed = false;
for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) {
if (IsTargetOp(*output)) {
current_recomputation.target_nodes.insert(output);
if (!inserted_feed) {
// Keep track of nodes which feed directly into a target node. These
// and nodes which feed into them will define the recomputed
// subgraph.
current_recomputation.recomputed_source_nodes.insert(
recompute_node);
inserted_feed = true;
}
}
}
}
// Recompute only nodes which eventually feed into a target node.
connected_subgraph(node_map,
true, // Collect inputs
false, // Collect outputs
[&unpruned_recompute_nodes](const NodeDef& node) {
return unpruned_recompute_nodes.count(&node) != 0;
},
&current_recomputation.recomputed_source_nodes);
if (current_recomputation.target_nodes.empty()) {
continue;
}
subgraphs_to_recompute.push_back(current_recomputation);
}
return subgraphs_to_recompute;
}
// Computes the maximum topological numbers of (1) target node components
// (gradient nodes being fed by the recomputation), and (2) child recompute node
// components for each recomputed node. We will not attach any control
// dependencies to a recomputation unless they have component numbers greater
// than this value (to prevent cycles).
std::unordered_map<const NodeDef*, int> GetMaxDownstreamComponents(
const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
const std::unordered_map<const NodeDef*, int>& components) {
std::unordered_map<const NodeDef*, int> recomputed_node_components;
// Start by setting component numbers to the maximum among target nodes.
for (const NodeDef* original_recompute_node : recomputed_source_nodes) {
int max_target_component = -1;
for (NodeDef* output :
node_map.GetOutputs(original_recompute_node->name())) {
if (target_nodes.count(output) != 0) {
int current_target_component = components.find(output)->second;
if (current_target_component > max_target_component) {
max_target_component = current_target_component;
}
}
}
if (max_target_component > -1) {
recomputed_node_components[original_recompute_node] =
max_target_component;
}
}
// Sort recomputed nodes topologically (based on the original graph) so we can
// efficiently assign to each node the maximum of its recomputed child
// components and its own targets.
std::vector<const NodeDef*> recomputed_source_nodes_topological(
recomputed_source_nodes.begin(), recomputed_source_nodes.end());
std::sort(recomputed_source_nodes_topological.begin(),
recomputed_source_nodes_topological.end(),
[&components](const NodeDef* first, const NodeDef* second) {
return components.find(first)->second <
components.find(second)->second;
});
for (const NodeDef* original_recompute_node :
recomputed_source_nodes_topological) {
int max_component;
auto recomputed_component_iterator =
recomputed_node_components.find(original_recompute_node);
if (recomputed_component_iterator != recomputed_node_components.end()) {
max_component = recomputed_component_iterator->second;
} else {
max_component = -1;
}
for (NodeDef* output :
node_map.GetOutputs(original_recompute_node->name())) {
if (recomputed_source_nodes.count(output) == 0) {
continue;
}
auto child_component_iterator = recomputed_node_components.find(output);
CHECK(child_component_iterator != recomputed_node_components.end());
int child_component = child_component_iterator->second;
if (child_component > max_component) {
max_component = child_component;
}
}
CHECK_GE(max_component, 0);
recomputed_node_components[original_recompute_node] = max_component;
}
return recomputed_node_components;
}
// Modifies `graph`, adding trigger nodes and returning a mapping from
// `recomputed_source_nodes` to trigger nodes which will not create loops in the
// graph (using the component numberings in `components` and
// `recomputed_node_max_feed_components`). The copied nodes (not the nodes in
// recomputed_source_nodes, which are the originals) eventually get these
// control dependencies.
std::unordered_map<const NodeDef*, const NodeDef*>
AddRecomputeControlDependencyNodes(
const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
const std::unordered_map<const NodeDef*, int>& components,
const std::unordered_map<const NodeDef*, int>&
recomputed_node_max_feed_components,
GraphDef* graph) {
// Sort recomputed nodes based on max downstream components.
std::vector<const NodeDef*> recomputed_source_nodes_topological(
recomputed_source_nodes.begin(), recomputed_source_nodes.end());
std::sort(recomputed_source_nodes_topological.begin(),
recomputed_source_nodes_topological.end(),
[&recomputed_node_max_feed_components](const NodeDef* first,
const NodeDef* second) {
int first_component =
recomputed_node_max_feed_components.find(first)->second;
int second_component =
recomputed_node_max_feed_components.find(second)->second;
return first_component > second_component
// Ensure a consistent ordering. This is necessary because
// we're working not with node component numbers (which are
// unique) but with the maximum across nodes they feed into
// (very much not unique).
|| (first_component == second_component &&
first->name() > second->name());
});
// Create merged control dependency nodes by sorting target inputs
// topologically and zipper merging with the sorted recomputed nodes.
std::vector<const NodeDef*> target_inputs_topological;
for (const NodeDef* target_node : target_nodes) {
for (const string& target_input_name_raw : target_node->input()) {
const NodeDef* target_input = node_map.GetNode(target_input_name_raw);
if (recomputed_source_nodes.count(target_input) != 0 ||
components.find(target_node)->second ==
components.find(target_input)->second) {
continue;
}
target_inputs_topological.push_back(target_input);
}
}
std::sort(target_inputs_topological.begin(), target_inputs_topological.end(),
[&components](const NodeDef* first, const NodeDef* second) {
return components.find(first)->second >
components.find(second)->second;
});
auto target_input_iterator = target_inputs_topological.begin();
NodeDef* current_trigger_node = nullptr;
std::unordered_map<const NodeDef*, const NodeDef*> triggers;
for (const NodeDef* original_recomputed_node :
recomputed_source_nodes_topological) {
NodeDef* new_trigger_node = graph->add_node();
new_trigger_node->set_name(AddPrefixToNodeName(
original_recomputed_node->name(), kRecomputeTriggerNodePrefix));
new_trigger_node->set_op("NoOp");
new_trigger_node->set_device(original_recomputed_node->device());
if (current_trigger_node != nullptr) {
*new_trigger_node->add_input() =
strings::StrCat("^", current_trigger_node->name());
}
current_trigger_node = new_trigger_node;
triggers[original_recomputed_node] = current_trigger_node;
for (;
target_input_iterator != target_inputs_topological.end() &&
components.find(*target_input_iterator)->second >
recomputed_node_max_feed_components.find(original_recomputed_node)
->second;
++target_input_iterator) {
*current_trigger_node->add_input() =
strings::StrCat("^", (*target_input_iterator)->name());
VLOG(2) << " Recomputation trigger " << current_trigger_node->name()
<< " depends on " << (*target_input_iterator)->name();
}
}
return triggers;
}
string RecomputedOrOriginalNodeName(
const std::unordered_set<string>& recomputed_node_names,
@ -42,14 +359,28 @@ string RecomputedOrOriginalNodeName(
}
}
// Helper function to recompute a sub-graph (recomputed_source_nodes). Edges
// from recomputed_source_nodes to target_nodes are changed to start from the
// recomputed nodes.
void RecomputeSubgraph(
const std::vector<const NodeDef*>& recomputed_source_nodes,
const string& recompute_trigger_node_name,
const std::vector<NodeDef*>& target_nodes, GraphDef* graph) {
const std::unordered_set<const NodeDef*>& recomputed_source_nodes,
const std::unordered_set<NodeDef*>& target_nodes, const NodeMap& node_map,
const std::unordered_map<const NodeDef*, int>& components,
GraphDef* graph) {
std::unordered_set<string> recomputed_node_names;
for (const NodeDef* to_recompute : recomputed_source_nodes) {
recomputed_node_names.insert(to_recompute->name());
VLOG(1) << "Recomputing a " << recomputed_source_nodes.size()
<< " node subgraph";
std::unordered_map<const NodeDef*, int> recomputed_node_components =
GetMaxDownstreamComponents(recomputed_source_nodes, target_nodes,
node_map, components);
for (const NodeDef* original_node : recomputed_source_nodes) {
VLOG(2) << " " << original_node->name();
recomputed_node_names.insert(original_node->name());
}
std::unordered_map<const NodeDef*, const NodeDef*> triggers =
AddRecomputeControlDependencyNodes(recomputed_source_nodes, target_nodes,
node_map, components,
recomputed_node_components, graph);
// Create the recomputed sub-graph
for (const NodeDef* original_node : recomputed_source_nodes) {
NodeDef* copied_node = graph->add_node();
@ -64,10 +395,10 @@ void RecomputeSubgraph(
*copied_node->add_input() = RecomputedOrOriginalNodeName(
recomputed_node_names, original_input_name);
}
// Set control dependencies on the recomputed nodes so that they are not run
// until the specified trigger runs.
// Each recomputed node gets a control dependency to prevent it from being
// recomputed immediately.
*copied_node->add_input() =
strings::StrCat("^", recompute_trigger_node_name);
strings::StrCat("^", triggers[original_node]->name());
}
// Set the inputs of nodes in the target subgraph to the recomputed nodes
// where applicable.
@ -79,6 +410,52 @@ void RecomputeSubgraph(
}
}
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
GraphDef* graph) {
// The topological numberings and NodeMap will be stale as soon as we start
// modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
// looks up nodes which were in the original graph, and preserves the graph
// topology it's interested in.
// We don't use the results of this topological sort until later, but this
// call invalidates all NodeDef pointers, so it needs to be done before we
// start collecting those.
TopologicalSort(graph);
NodeMap node_map(graph);
std::vector<RecomputedSubGraph> recomputed_subgraphs;
if (optimization_level == RewriterConfig::HEURISTICS) {
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
// the cheap forward ops get grouped into a single subgraph which must
// execute before gradients start executing (unless layers are manually
// separated by identity ops).
std::unordered_set<string> cheap_to_recompute_ops =
GetCheapToRecomputeOps();
recomputed_subgraphs = GetOpGroupsToRecompute(
graph, node_map, [&cheap_to_recompute_ops](const NodeDef& node) {
return !IsTargetOp(node) &&
(cheap_to_recompute_ops.count(node.op()) > 0 ||
node.attr().count(kRecomputeHint) > 0);
});
} else { // optimization_level == RewriterConfig::MANUAL
recomputed_subgraphs =
GetOpGroupsToRecompute(graph, node_map, [](const NodeDef& node) {
return !IsTargetOp(node) && node.attr().count(kRecomputeHint) > 0;
});
}
if (!recomputed_subgraphs.empty()) {
std::unordered_map<const NodeDef*, int> topological_numbering;
for (int node_number = 0; node_number < graph->node().size();
++node_number) {
topological_numbering[graph->mutable_node(node_number)] =
graph->node().size() - node_number - 1;
}
// Duplicate the indicated sub-graphs and set up control dependencies
for (const RecomputedSubGraph& subgraph : recomputed_subgraphs) {
RecomputeSubgraph(subgraph.recomputed_source_nodes, subgraph.target_nodes,
node_map, topological_numbering, graph);
}
}
}
std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
GraphDef* graph) {
string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
@ -205,6 +582,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;
RecomputationRewritingPass(optimization_level_, optimized_graph);
// Figure out what needs to be swapped;
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
for (auto& node : *optimized_graph->mutable_node()) {

View File

@ -16,9 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
#include <vector>
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
@ -26,7 +25,8 @@ namespace grappler {
// Swap tensors in and out of device memory.
class MemoryOptimizer : public GraphOptimizer {
public:
MemoryOptimizer() {}
explicit MemoryOptimizer(RewriterConfig::MemOptType optimization_level)
: optimization_level_(optimization_level) {}
~MemoryOptimizer() override {}
string name() const override { return "memory_optimizer"; };
@ -36,15 +36,10 @@ class MemoryOptimizer : public GraphOptimizer {
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& pruned_graph, double result) override;
};
// Helper function to recompute a sub-graph (recomputed_source_nodes) on a
// trigger. Edges from recomputed_source_nodes to target_nodes are changed to
// start from the recomputed nodes.
void RecomputeSubgraph(
const std::vector<const NodeDef*>& recomputed_source_nodes,
const string& recompute_trigger_node_name,
const std::vector<NodeDef*>& target_nodes, GraphDef* graph);
private:
RewriterConfig::MemOptType optimization_level_;
};
} // end namespace grappler
} // end namespace tensorflow

View File

@ -35,84 +35,103 @@ TEST_F(RecomputeSubgraphTest, SimpleSubgraph) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1.f, {2, 3, 4});
Output b = ops::AddN(s.WithOpName("b"), {a}); // Recomputed
Output c = ops::AddN(s.WithOpName("c"), {b});
Output d = ops::AddN(s.WithOpName("d"), {c});
Output e = ops::AddN(s.WithOpName("e"), {d, b});
Output f = ops::AddN(s.WithOpName("f"), {e, a});
Output b = ops::Identity(s.WithOpName("b"), a); // Recomputed
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::AddN(s.WithOpName("gradients/d"), {c});
Output e = ops::AddN(s.WithOpName("gradients/e"), {d, b});
Output f = ops::AddN(s.WithOpName("gradients/f"), {e, a});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
EXPECT_EQ(6, item.graph.node_size());
NodeMap pre_transform_node_map(&item.graph);
std::vector<const NodeDef*> recomputed_source_nodes;
recomputed_source_nodes.push_back(pre_transform_node_map.GetNode(b.name()));
std::vector<NodeDef*> target_nodes;
target_nodes.push_back(pre_transform_node_map.GetNode(e.name()));
RecomputeSubgraph(recomputed_source_nodes, d.name(), target_nodes,
&item.graph);
NodeMap post_transform_node_map(&item.graph);
EXPECT_EQ(7, item.graph.node_size());
(*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
.set_i(0);
MemoryOptimizer optimizer(RewriterConfig::MANUAL);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
NodeMap post_transform_node_map(&output);
EXPECT_EQ(8, output.node_size());
NodeDef* transformed_e = post_transform_node_map.GetNode(e.name());
EXPECT_EQ(2, transformed_e->input_size());
EXPECT_EQ("d", transformed_e->input(0));
EXPECT_EQ("gradients/d", transformed_e->input(0));
EXPECT_EQ("Recomputed/b", transformed_e->input(1));
NodeDef* recomputed_b = post_transform_node_map.GetNode("Recomputed/b");
EXPECT_EQ(2, recomputed_b->input_size());
EXPECT_EQ("a", recomputed_b->input(0));
EXPECT_EQ("^d", recomputed_b->input(1).substr(0, 2));
EXPECT_EQ("^RecomputeTrigger/b", recomputed_b->input(1));
NodeDef* recompute_trigger =
post_transform_node_map.GetNode("RecomputeTrigger/b");
EXPECT_EQ(1, recompute_trigger->input_size());
EXPECT_EQ("^gradients/d", recompute_trigger->input(0));
}
TEST_F(RecomputeSubgraphTest, MultiNode) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("Conv"), 1.f, {2, 3, 4});
Output b = ops::AddN(s.WithOpName("BN"), {a}); // Recomputed
Output c = ops::AddN(s.WithOpName("ReLU"), {b}); // Recomputed
Output d = ops::AddN(s.WithOpName("Conv1"), {c});
Output b = ops::Identity(s.WithOpName("BN"), a); // Recomputed
Output c = ops::Identity(s.WithOpName("ReLU"), b); // Recomputed
Output d = ops::Identity(s.WithOpName("Conv1"), c);
Output trigger = ops::Const(s.WithOpName("BN1Grad"), 0.f, {2, 3, 4});
Output e = ops::AddN(s.WithOpName("Conv1Grad"), {trigger, c});
Output f = ops::AddN(s.WithOpName("ReLUGrad"), {e, c});
Output g = ops::AddN(s.WithOpName("BNGrad"), {f, a});
Output h = ops::AddN(s.WithOpName("ConvGrad"), {g});
// The "gradients/" prefix means the heuristic will pick these up as
// candidates to have their inputs recomputed.
Output trigger = ops::AddN(s.WithOpName("gradients/BN1Grad"), {d});
Output e = ops::AddN(s.WithOpName("gradients/Conv1Grad"), {trigger, c});
Output f = ops::AddN(s.WithOpName("gradients/ReLUGrad"), {e, c});
Output g = ops::AddN(s.WithOpName("gradients/BNGrad"), {f, a});
Output h = ops::AddN(s.WithOpName("gradients/ConvGrad"), {g});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
EXPECT_EQ(9, item.graph.node_size());
NodeMap pre_transform_node_map(&item.graph);
std::vector<const NodeDef*> recomputed_source_nodes;
recomputed_source_nodes.push_back(pre_transform_node_map.GetNode(b.name()));
recomputed_source_nodes.push_back(pre_transform_node_map.GetNode(c.name()));
std::vector<NodeDef*> target_nodes;
target_nodes.push_back(pre_transform_node_map.GetNode(e.name()));
target_nodes.push_back(pre_transform_node_map.GetNode(f.name()));
target_nodes.push_back(pre_transform_node_map.GetNode(g.name()));
RecomputeSubgraph(recomputed_source_nodes, trigger.name(), target_nodes,
&item.graph);
NodeMap post_transform_node_map(&item.graph);
EXPECT_EQ(11, item.graph.node_size());
// Set op types so that the heuristic will pick these nodes up to be
// recomputed
pre_transform_node_map.GetNode("BN")->set_op("FusedBatchNorm");
pre_transform_node_map.GetNode("ReLU")->set_op("Relu");
MemoryOptimizer optimizer(RewriterConfig::HEURISTICS);
GraphDef first_pass_output;
Status first_pass_status =
optimizer.Optimize(nullptr, item, &first_pass_output);
TF_EXPECT_OK(first_pass_status);
NodeMap post_transform_node_map(&first_pass_output);
EXPECT_EQ(13, first_pass_output.node_size());
NodeDef* transformed_e = post_transform_node_map.GetNode(e.name());
EXPECT_EQ(2, transformed_e->input_size());
EXPECT_EQ("BN1Grad", transformed_e->input(0));
EXPECT_EQ("gradients/BN1Grad", transformed_e->input(0));
EXPECT_EQ("Recomputed/ReLU", transformed_e->input(1));
NodeDef* transformed_f = post_transform_node_map.GetNode(f.name());
EXPECT_EQ(2, transformed_f->input_size());
EXPECT_EQ("Conv1Grad", transformed_f->input(0));
EXPECT_EQ("gradients/Conv1Grad", transformed_f->input(0));
EXPECT_EQ("Recomputed/ReLU", transformed_f->input(1));
NodeDef* transformed_g = post_transform_node_map.GetNode(g.name());
EXPECT_EQ(2, transformed_g->input_size());
EXPECT_EQ("ReLUGrad", transformed_g->input(0));
EXPECT_EQ("gradients/ReLUGrad", transformed_g->input(0));
EXPECT_EQ("Conv", transformed_g->input(1));
NodeDef* recomputed_b = post_transform_node_map.GetNode("Recomputed/BN");
EXPECT_EQ(2, recomputed_b->input_size());
EXPECT_EQ("Conv", recomputed_b->input(0));
EXPECT_EQ("^BN1Grad", recomputed_b->input(1).substr(0, 8));
EXPECT_EQ("^RecomputeTrigger/BN", recomputed_b->input(1));
NodeDef* recompute_trigger_b =
post_transform_node_map.GetNode("RecomputeTrigger/BN");
EXPECT_EQ(1, recompute_trigger_b->input_size());
EXPECT_EQ("^RecomputeTrigger/ReLU", recompute_trigger_b->input(0));
NodeDef* recomputed_c = post_transform_node_map.GetNode("Recomputed/ReLU");
EXPECT_EQ(2, recomputed_c->input_size());
EXPECT_EQ("Recomputed/BN", recomputed_c->input(0));
EXPECT_EQ("^BN1Grad", recomputed_c->input(1).substr(0, 8));
EXPECT_EQ("^RecomputeTrigger/ReLU", recomputed_c->input(1));
NodeDef* recompute_trigger_c =
post_transform_node_map.GetNode("RecomputeTrigger/ReLU");
EXPECT_EQ(1, recompute_trigger_c->input_size());
EXPECT_EQ("^gradients/BN1Grad", recompute_trigger_c->input(0));
}
class MemoryOptimizerTest : public ::testing::Test {
@ -150,7 +169,7 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
VirtualCluster cluster(CreateVirtualCluster());
MemoryOptimizer optimizer;
MemoryOptimizer optimizer(RewriterConfig::MANUAL);
GraphDef output;
Status status = optimizer.Optimize(&cluster, item, &output);
TF_EXPECT_OK(status);

View File

@ -41,7 +41,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
graph_optimizer.reset(new LayoutOptimizer());
}
if (optimizer == "memory") {
graph_optimizer.reset(new MemoryOptimizer());
graph_optimizer.reset(new MemoryOptimizer(RewriterConfig::MANUAL));
}
if (optimizer == "autoparallel") {
graph_optimizer.reset(
@ -66,8 +66,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
}
if (cfg_.memory_optimization() > 0) {
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new MemoryOptimizer()));
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
new MemoryOptimizer(cfg_.memory_optimization())));
}
if (cfg_.auto_parallel().enable()) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
@ -114,7 +114,8 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
return cfg.optimize_tensor_layout() || cfg.constant_folding() ||
cfg.auto_parallel().enable() || !cfg.optimizers().empty();
cfg.auto_parallel().enable() || cfg.memory_optimization() > 0 ||
!cfg.optimizers().empty();
}
Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg,

View File

@ -15,21 +15,40 @@ message RewriterConfig {
// Graph rewriting is experimental and subject to change, not covered by any
// API stability guarantees.
// Configuration options for the meta-optimizer. Unless otherwise noted, these
// configuration options do not apply to explicitly triggered optimization
// passes in the optimizers field.
bool optimize_tensor_layout = 1;
bool disable_model_pruning = 2;
bool constant_folding = 3;
enum MemOptType {
// Fully disabled
// Disabled in the meta-optimizer.
NO_MEM_OPT = 0;
// Driven by manual annotations
// Driven by manual op-level annotations.
MANUAL = 1;
// Driven by heuristics. The behavior of these heuristics is subject to
// change. Currently includes an experimental recomputation heuristic.
HEURISTICS = 2;
}
// Configures memory optimization passes through the meta-optimizer. Has no
// effect on manually requested memory optimization passes in the optimizers
// field.
MemOptType memory_optimization = 4;
// Configures AutoParallel optimization passes either through the
// meta-optimizer or when manually specified through the optimizers field.
AutoParallelOptions auto_parallel = 5;
// If non-empty, will use this as an alternative way to specify a list of
// optimizations to turn on and the order of the optimizations.
// optimizations to turn on and the order of the optimizations (replacing the
// meta-optimizer).
//
// Of the RewriterConfig options, only the AutoParallel configuration options
// (the auto_parallel field) apply to manually requested optimization passes
// ("autoparallel"). Memory optimization passes ("memory") invoked here are
// not configurable (in contrast to memory optimization passes through the
// meta-optimizer) and act only on manual op annotations.
repeated string optimizers = 100;
}

View File

@ -3800,7 +3800,13 @@ py_test(
":client_testlib",
":framework_for_generated_wrappers",
":math_ops",
":nn",
":random_seed",
":session",
":tf_optimizer",
":training",
":variable_scope",
":variables",
"//tensorflow/core:protos_all_py",
"//third_party/py/numpy",
],

View File

@ -18,16 +18,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import training as train
class MemoryOptimizerTest(test.TestCase):
class MemoryOptimizerSwapTest(test.TestCase):
"""Tests the Grappler memory optimizer."""
def testNoSwapping(self):
@ -85,5 +92,51 @@ class MemoryOptimizerTest(test.TestCase):
self.assertEqual('c', node.input[1])
class MemoryOptimizerRecomputeTest(test.TestCase):
def _RunGraphWithConfig(self, config, batch_size=14, image_dim=12):
"""Run a simple layered graph with conv, an intermediate op, and a ReLU."""
graph = ops.Graph()
with graph.as_default():
random_seed.set_random_seed(1)
current_activation = variable_scope.get_variable(
name='start', shape=[batch_size, image_dim, image_dim, 5])
conv_filter = variable_scope.get_variable(
name='filter', shape=[5, 5, 5, 5])
for layer_number in range(10):
with variable_scope.variable_scope('layer_{}'.format(layer_number)):
after_conv = nn.conv2d(current_activation, conv_filter, [1, 1, 1, 1],
'SAME')
current_activation = 2. * after_conv
current_activation = nn.relu(current_activation)
loss = math_ops.reduce_mean(current_activation)
optimizer = train.AdamOptimizer(0.001)
train_op = optimizer.minimize(loss)
init_op = variables.global_variables_initializer()
with session.Session(config=config, graph=graph) as sess:
sess.run(init_op)
sess.run(train_op)
sess.run(train_op)
return sess.run(loss)
def _GetMemoryOptimizerConfig(self):
rewrite_options = rewriter_config_pb2.RewriterConfig(
memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS)
graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options)
return config_pb2.ConfigProto(graph_options=graph_options)
def testRecomputationRewritingNoErrors(self):
"""Tests that there are no errors when we request a memory optimizer pass.
Does not test that the memory optimizer actually runs. See
core/grappler/optimizers/memory_optimizer_test.cc for a functional test of
the graph rewriting.
"""
original_loss = self._RunGraphWithConfig(config_pb2.ConfigProto())
memory_optimized_loss = self._RunGraphWithConfig(
config=self._GetMemoryOptimizerConfig())
self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)
if __name__ == '__main__':
test.main()