tfdbg bug fix: avoid hanging while watching Enter and NextIteration ops in while loops

Change: 147488620
This commit is contained in:
Shanqing Cai 2017-02-14 10:21:02 -08:00 committed by TensorFlower Gardener
parent 3e8728dbbc
commit 2b2f26b088
2 changed files with 52 additions and 3 deletions

View File

@ -231,10 +231,12 @@ Status DebugNodeInserter::InsertNodes(
// Add control edges from the debug nodes to the destination node
// to ensure that the debug nodes are executed before the destination
// node.
// node. Skip Enter and NextIteration ops to avoid hanging.
for (Node* debug_node : debug_nodes) {
graph->AddEdge(debug_node, Graph::kControlSlot, edge->dst(),
Graph::kControlSlot);
if (!src_node->IsEnter() && !src_node->IsNextIteration()) {
graph->AddEdge(debug_node, Graph::kControlSlot, edge->dst(),
Graph::kControlSlot);
}
}
}
}

View File

@ -410,6 +410,53 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertEqual(10, len(dump.watch_key_to_data(watch_keys[0])))
self.assertEqual([], dump.watch_key_to_data("foo"))
def testDebugWhileLoopWatchingWholeGraphWorks(self):
with session.Session() as sess:
loop_body = lambda i: math_ops.add(i, 2)
loop_cond = lambda i: math_ops.less(i, 16)
i = constant_op.constant(10, name="i")
loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(run_options,
sess.graph,
debug_urls=self._debug_urls())
run_metadata = config_pb2.RunMetadata()
self.assertEqual(
16, sess.run(loop, options=run_options, run_metadata=run_metadata))
dump = debug_data.DebugDumpDir(
self._dump_root, partition_graphs=run_metadata.partition_graphs)
self.assertEqual(
[[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity"))
self.assertEqual(
[[12], [14], [16]],
dump.get_tensors("while/NextIteration", 0, "DebugIdentity"))
def testDebugCondWatchingWholeGraphWorks(self):
with session.Session() as sess:
x = variables.Variable(10.0, name="x")
y = variables.Variable(20.0, name="y")
cond = control_flow_ops.cond(
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
sess.run(variables.global_variables_initializer())
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(run_options,
sess.graph,
debug_urls=self._debug_urls())
run_metadata = config_pb2.RunMetadata()
self.assertEqual(
21, sess.run(cond, options=run_options, run_metadata=run_metadata))
dump = debug_data.DebugDumpDir(
self._dump_root, partition_graphs=run_metadata.partition_graphs)
self.assertAllClose(
[21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity"))
def testFindNodesWithBadTensorValues(self):
with session.Session() as sess:
u_name = "testFindNodesWithBadTensorValues/u"