Never recompute Tensors which are feeds in Grappler's memory optimizer.
I don't think anyone has run into this yet, but it would lead to incredibly hard to debug issues (sometimes incorrect gradients with a correct forward pass). PiperOrigin-RevId: 161094335
This commit is contained in:
parent
4de7361c43
commit
d1567d35e7
@ -417,7 +417,7 @@ void RecomputeSubgraph(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
||||||
GraphDef* graph) {
|
GraphDef* graph, const GrapplerItem& item) {
|
||||||
// The topological numberings and NodeMap will be stale as soon as we start
|
// The topological numberings and NodeMap will be stale as soon as we start
|
||||||
// modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
|
// modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only
|
||||||
// looks up nodes which were in the original graph, and preserves the graph
|
// looks up nodes which were in the original graph, and preserves the graph
|
||||||
@ -428,6 +428,12 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
|||||||
TopologicalSort(graph);
|
TopologicalSort(graph);
|
||||||
NodeMap node_map(graph);
|
NodeMap node_map(graph);
|
||||||
std::vector<RecomputedSubGraph> recomputed_subgraphs;
|
std::vector<RecomputedSubGraph> recomputed_subgraphs;
|
||||||
|
// Do not recompute nodes which are fed, since the recomputed node would not
|
||||||
|
// take on the fed value (i.e. gradients would be incorrect).
|
||||||
|
std::unordered_set<string> feeds;
|
||||||
|
for (const auto& feed : item.feed) {
|
||||||
|
feeds.insert(NodeName(feed.first));
|
||||||
|
}
|
||||||
if (optimization_level == RewriterConfig::HEURISTICS) {
|
if (optimization_level == RewriterConfig::HEURISTICS) {
|
||||||
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
|
// TODO(allenl): Handle ResNet-like architectures better. Right now all of
|
||||||
// the cheap forward ops get grouped into a single subgraph which must
|
// the cheap forward ops get grouped into a single subgraph which must
|
||||||
@ -436,15 +442,17 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
|||||||
std::unordered_set<string> cheap_to_recompute_ops =
|
std::unordered_set<string> cheap_to_recompute_ops =
|
||||||
GetCheapToRecomputeOps();
|
GetCheapToRecomputeOps();
|
||||||
recomputed_subgraphs = GetOpGroupsToRecompute(
|
recomputed_subgraphs = GetOpGroupsToRecompute(
|
||||||
graph, node_map, [&cheap_to_recompute_ops](const NodeDef& node) {
|
graph, node_map,
|
||||||
return !IsTargetOp(node) &&
|
[&cheap_to_recompute_ops, &feeds](const NodeDef& node) {
|
||||||
|
return !IsTargetOp(node) && feeds.count(node.name()) == 0 &&
|
||||||
(cheap_to_recompute_ops.count(node.op()) > 0 ||
|
(cheap_to_recompute_ops.count(node.op()) > 0 ||
|
||||||
node.attr().count(kRecomputeHint) > 0);
|
node.attr().count(kRecomputeHint) > 0);
|
||||||
});
|
});
|
||||||
} else { // optimization_level == RewriterConfig::MANUAL
|
} else { // optimization_level == RewriterConfig::MANUAL
|
||||||
recomputed_subgraphs =
|
recomputed_subgraphs =
|
||||||
GetOpGroupsToRecompute(graph, node_map, [](const NodeDef& node) {
|
GetOpGroupsToRecompute(graph, node_map, [&feeds](const NodeDef& node) {
|
||||||
return !IsTargetOp(node) && node.attr().count(kRecomputeHint) > 0;
|
return !IsTargetOp(node) && feeds.count(node.name()) == 0 &&
|
||||||
|
node.attr().count(kRecomputeHint) > 0;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if (!recomputed_subgraphs.empty()) {
|
if (!recomputed_subgraphs.empty()) {
|
||||||
@ -588,7 +596,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
*optimized_graph = item.graph;
|
*optimized_graph = item.graph;
|
||||||
|
|
||||||
RecomputationRewritingPass(optimization_level_, optimized_graph);
|
RecomputationRewritingPass(optimization_level_, optimized_graph, item);
|
||||||
|
|
||||||
// Figure out what needs to be swapped;
|
// Figure out what needs to be swapped;
|
||||||
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
|
std::unordered_map<NodeDef*, SwapInfo> nodes_to_swap;
|
||||||
|
@ -69,6 +69,33 @@ TEST_F(RecomputeSubgraphTest, SimpleSubgraph) {
|
|||||||
EXPECT_EQ("^gradients/d", recompute_trigger->input(0));
|
EXPECT_EQ("^gradients/d", recompute_trigger->input(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(RecomputeSubgraphTest, NoFeedsRecomputed) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
Output a = ops::Variable(s.WithOpName("a"), {2, 3, 4}, DT_FLOAT);
|
||||||
|
Output b = ops::Identity(s.WithOpName("b"), a); // Would be recomputed, but
|
||||||
|
// for being fed
|
||||||
|
Output c = ops::Identity(s.WithOpName("c"), b);
|
||||||
|
Output d = ops::AddN(s.WithOpName("gradients/d"), {c});
|
||||||
|
Output e = ops::AddN(s.WithOpName("gradients/e"), {d, b});
|
||||||
|
Output f = ops::AddN(s.WithOpName("gradients/f"), {e, a});
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
item.feed.emplace_back("b", Tensor());
|
||||||
|
EXPECT_EQ(6, item.graph.node_size());
|
||||||
|
NodeMap pre_transform_node_map(&item.graph);
|
||||||
|
(*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"]
|
||||||
|
.set_i(0);
|
||||||
|
|
||||||
|
MemoryOptimizer optimizer(RewriterConfig::MANUAL);
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||||
|
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
EXPECT_EQ(6, output.node_size());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
|
TEST_F(RecomputeSubgraphTest, TwoInputSubgraphs) {
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user