Never recompute Tensors which are feeds in Grappler's memory optimizer.

I don't think anyone has run into this yet, but it would lead to incredibly hard
to debug issues (sometimes incorrect gradients with a correct forward pass).

PiperOrigin-RevId: 161094335
This commit is contained in:
Allen Lavoie 2017-07-06 10:08:38 -07:00 committed by TensorFlower Gardener
parent 4de7361c43
commit d1567d35e7
2 changed files with 41 additions and 6 deletions

View File

@ -417,7 +417,7 @@ void RecomputeSubgraph(
}
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
GraphDef* graph) {
GraphDef* graph, const GrapplerItem& item) {
// The topological numberings and NodeMap will be stale as soon as we start
// modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
// looks up nodes which were in the original graph, and preserves the graph
@ -428,6 +428,12 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
TopologicalSort(graph);
NodeMap node_map(graph);
std::vector<RecomputedSubGraph> recomputed_subgraphs;
// Do not recompute nodes which are fed, since the recomputed node would not
// take on the fed value (i.e. gradients would be incorrect).
std::unordered_set<string> feeds;
for (const auto& feed : item.feed) {
feeds.insert(NodeName(feed.first));
}
if (optimization_level == RewriterConfig::HEURISTICS) {
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
// the cheap forward ops get grouped into a single subgraph which must
@ -436,15 +442,17 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
std::unordered_set<string> cheap_to_recompute_ops =
GetCheapToRecomputeOps();
recomputed_subgraphs = GetOpGroupsToRecompute(
graph, node_map, [&cheap_to_recompute_ops](const NodeDef& node) {
return !IsTargetOp(node) &&
graph, node_map,
[&cheap_to_recompute_ops, &feeds](const NodeDef& node) {
return !IsTargetOp(node) && feeds.count(node.name()) == 0 &&
(cheap_to_recompute_ops.count(node.op()) > 0 ||
node.attr().count(kRecomputeHint) > 0);
});
} else { // optimization_level == RewriterConfig::MANUAL
recomputed_subgraphs =
GetOpGroupsToRecompute(graph, node_map, [](const NodeDef& node) {
return !IsTargetOp(node) && node.attr().count(kRecomputeHint) > 0;
GetOpGroupsToRecompute(graph, node_map, [&feeds](const NodeDef& node) {
return !IsTargetOp(node) && feeds.count(node.name()) == 0 &&
node.attr().count(kRecomputeHint) > 0;
});
}
if (!recomputed_subgraphs.empty()) {
@ -588,7 +596,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;
RecomputationRewritingPass(optimization_level_, optimized_graph);
RecomputationRewritingPass(optimization_level_, optimized_graph, item);
// Figure out what needs to be swapped;
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;

View File

@ -69,6 +69,33 @@ TEST_F(RecomputeSubgraphTest, SimpleSubgraph) {
EXPECT_EQ("^gradients/d", recompute_trigger->input(0));
}
TEST_F(RecomputeSubgraphTest, NoFeedsRecomputed) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT);
Output b = ops::Identity(s.WithOpName("b"), a); // Would be recomputed, but
// for being fed
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::AddN(s.WithOpName("gradients/d"), {c});
Output e = ops::AddN(s.WithOpName("gradients/e"), {d, b});
Output f = ops::AddN(s.WithOpName("gradients/f"), {e, a});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.feed.emplace_back("b", Tensor());
EXPECT_EQ(6, item.graph.node_size());
NodeMap pre_transform_node_map(&item.graph);
(*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
.set_i(0);
MemoryOptimizer optimizer(RewriterConfig::MANUAL);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(6, output.node_size());
}
TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();