From d1567d35e78b8e614ee05b90071773f24b251b96 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 6 Jul 2017 10:08:38 -0700 Subject: [PATCH] Never recompute Tensors which are feeds in Grappler's memory optimizer. I don't think anyone has run into this yet, but it would lead to incredibly hard to debug issues (sometimes incorrect gradients with a correct forward pass). PiperOrigin-RevId: 161094335 --- .../grappler/optimizers/memory_optimizer.cc | 20 +++++++++----- .../optimizers/memory_optimizer_test.cc | 27 +++++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 462cfb928f6..23479d84b00 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -417,7 +417,7 @@ void RecomputeSubgraph( } void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, - GraphDef* graph) { + GraphDef* graph, const GrapplerItem& item) { // The topological numberings and NodeMap will be stale as soon as we start // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only // looks up nodes which were in the original graph, and preserves the graph @@ -428,6 +428,12 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, TopologicalSort(graph); NodeMap node_map(graph); std::vector recomputed_subgraphs; + // Do not recompute nodes which are fed, since the recomputed node would not + // take on the fed value (i.e. gradients would be incorrect). + std::unordered_set feeds; + for (const auto& feed : item.feed) { + feeds.insert(NodeName(feed.first)); + } if (optimization_level == RewriterConfig::HEURISTICS) { // TODO(allenl): Handle ResNet-like architectures better. Right now all of // the cheap forward ops get grouped into a single subgraph which must @@ -436,15 +442,17 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, std::unordered_set cheap_to_recompute_ops = GetCheapToRecomputeOps(); recomputed_subgraphs = GetOpGroupsToRecompute( - graph, node_map, [&cheap_to_recompute_ops](const NodeDef& node) { - return !IsTargetOp(node) && + graph, node_map, + [&cheap_to_recompute_ops, &feeds](const NodeDef& node) { + return !IsTargetOp(node) && feeds.count(node.name()) == 0 && (cheap_to_recompute_ops.count(node.op()) > 0 || node.attr().count(kRecomputeHint) > 0); }); } else { // optimization_level == RewriterConfig::MANUAL recomputed_subgraphs = - GetOpGroupsToRecompute(graph, node_map, [](const NodeDef& node) { - return !IsTargetOp(node) && node.attr().count(kRecomputeHint) > 0; + GetOpGroupsToRecompute(graph, node_map, [&feeds](const NodeDef& node) { + return !IsTargetOp(node) && feeds.count(node.name()) == 0 && + node.attr().count(kRecomputeHint) > 0; }); } if (!recomputed_subgraphs.empty()) { @@ -588,7 +596,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { *optimized_graph = item.graph; - RecomputationRewritingPass(optimization_level_, optimized_graph); + RecomputationRewritingPass(optimization_level_, optimized_graph, item); // Figure out what needs to be swapped; std::unordered_map nodes_to_swap; diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index eb7a8eb343a..0d5d302f4ad 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -69,6 +69,33 @@ TEST_F(RecomputeSubgraphTest, SimpleSubgraph) { EXPECT_EQ("^gradients/d", recompute_trigger->input(0)); } +TEST_F(RecomputeSubgraphTest, NoFeedsRecomputed) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT); + Output b = ops::Identity(s.WithOpName("b"), a); // Would be recomputed, but + // for being fed + Output c = ops::Identity(s.WithOpName("c"), b); + Output d = ops::AddN(s.WithOpName("gradients/d"), {c}); + Output e = ops::AddN(s.WithOpName("gradients/e"), {d, b}); + Output f = ops::AddN(s.WithOpName("gradients/f"), {e, a}); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.feed.emplace_back("b", Tensor()); + EXPECT_EQ(6, item.graph.node_size()); + NodeMap pre_transform_node_map(&item.graph); + (*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"] + .set_i(0); + + MemoryOptimizer optimizer(RewriterConfig::MANUAL); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + + TF_EXPECT_OK(status); + EXPECT_EQ(6, output.node_size()); +} + TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) { tensorflow::Scope s = tensorflow::Scope::NewRootScope();