tfdbg bug fix: avoid hanging while watching Enter and NextIteration ops in while loops
Change: 147488620
This commit is contained in:
parent
3e8728dbbc
commit
2b2f26b088
@ -231,10 +231,12 @@ Status DebugNodeInserter::InsertNodes(
|
||||
|
||||
// Add control edges from the debug nodes to the destination node
|
||||
// to ensure that the debug nodes are executed before the destination
|
||||
// node.
|
||||
// node. Skip Enter and NextIteration ops to avoid hanging.
|
||||
for (Node* debug_node : debug_nodes) {
|
||||
graph->AddEdge(debug_node, Graph::kControlSlot, edge->dst(),
|
||||
Graph::kControlSlot);
|
||||
if (!src_node->IsEnter() && !src_node->IsNextIteration()) {
|
||||
graph->AddEdge(debug_node, Graph::kControlSlot, edge->dst(),
|
||||
Graph::kControlSlot);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -410,6 +410,53 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(10, len(dump.watch_key_to_data(watch_keys[0])))
|
||||
self.assertEqual([], dump.watch_key_to_data("foo"))
|
||||
|
||||
def testDebugWhileLoopWatchingWholeGraphWorks(self):
|
||||
with session.Session() as sess:
|
||||
loop_body = lambda i: math_ops.add(i, 2)
|
||||
loop_cond = lambda i: math_ops.less(i, 16)
|
||||
|
||||
i = constant_op.constant(10, name="i")
|
||||
loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(run_options,
|
||||
sess.graph,
|
||||
debug_urls=self._debug_urls())
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
self.assertEqual(
|
||||
16, sess.run(loop, options=run_options, run_metadata=run_metadata))
|
||||
|
||||
dump = debug_data.DebugDumpDir(
|
||||
self._dump_root, partition_graphs=run_metadata.partition_graphs)
|
||||
|
||||
self.assertEqual(
|
||||
[[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity"))
|
||||
self.assertEqual(
|
||||
[[12], [14], [16]],
|
||||
dump.get_tensors("while/NextIteration", 0, "DebugIdentity"))
|
||||
|
||||
def testDebugCondWatchingWholeGraphWorks(self):
|
||||
with session.Session() as sess:
|
||||
x = variables.Variable(10.0, name="x")
|
||||
y = variables.Variable(20.0, name="y")
|
||||
cond = control_flow_ops.cond(
|
||||
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(run_options,
|
||||
sess.graph,
|
||||
debug_urls=self._debug_urls())
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
self.assertEqual(
|
||||
21, sess.run(cond, options=run_options, run_metadata=run_metadata))
|
||||
|
||||
dump = debug_data.DebugDumpDir(
|
||||
self._dump_root, partition_graphs=run_metadata.partition_graphs)
|
||||
self.assertAllClose(
|
||||
[21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity"))
|
||||
|
||||
def testFindNodesWithBadTensorValues(self):
|
||||
with session.Session() as sess:
|
||||
u_name = "testFindNodesWithBadTensorValues/u"
|
||||
|
Loading…
Reference in New Issue
Block a user