diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py
index 6d41d8768a3..1d37f1b64b8 100644
--- a/tensorflow/python/tpu/tensor_tracer.py
+++ b/tensorflow/python/tpu/tensor_tracer.py
@@ -540,54 +540,56 @@ class TensorTracer(object):
        cycle is found) and the second element is either the sorted
        list of nodes or the cycle of nodes found.
     """
+    def _is_loop_edge(op):
+      """Returns true if the op is the end of a while-loop creating a cycle."""
+      return op.type in ['NextIteration']
 
-    def visit(op, cycle, permanently_marked_ops,
-              temporarily_marked_ops, sorted_ops):
-      """Recursively visits all Ops in a graph.
+    def _in_op_degree(op):
+      """Returns the number of incoming edges to the given op.
 
+      The edge calculation skips the edges that come from 'NextIteration' ops.
+      NextIteration creates a cycle in the graph. We break cycles by treating
+      this op as 'sink' and ignoring all outgoing edges from it.
       Args:
-         op: the current Op being visited.
-         cycle: a cycle of Ops found.
-         permanently_marked_ops: the set of Ops that were already visited.
-         temporarily_marked_ops: the set of Ops that we have visited during
-                                 the current descent.
-         sorted_ops: the list of Ops sorted in topological order.
+        op: Tf.Operation
+      Returns:
+        the number of incoming edges.
       """
+      count = 0
+      for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
+        if not _is_loop_edge(op):
+          count += 1
+      return count
 
-      if cycle:
-        return
-      if op in permanently_marked_ops:
-        return
-      if op in temporarily_marked_ops:
-        cycle = temporarily_marked_ops
-        return
-      temporarily_marked_ops.add(op)
-      for i in range(len(op.outputs)):
-        out_tensor = op.outputs[i]
-        for consumer_op in out_tensor.consumers():
-          visit(consumer_op, cycle, permanently_marked_ops,
-                temporarily_marked_ops, sorted_ops)
-      # pylint: disable=protected-access
-      for ctrl_output_op in op._control_outputs:
-        # pylint: enable=protected-access
-        visit(ctrl_output_op, cycle, permanently_marked_ops,
-              temporarily_marked_ops, sorted_ops)
-      temporarily_marked_ops.remove(op)
-      permanently_marked_ops.add(op)
-      sorted_ops.insert(0, op)
-
-    graph_cycle = set([])
     sorted_ops = []
-    permanently_marked_ops = set([])
-    temporarily_marked_ops = set([])
-    unsorted_ops = g.get_operations()
-    for op in unsorted_ops:
-      visit(op, graph_cycle, permanently_marked_ops,
-            temporarily_marked_ops, sorted_ops)
-    if graph_cycle:
-      return (False, graph_cycle)
+    op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
+
+    frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
+    while frontier:
+      op = frontier.pop()
+      # Remove the op from graph, and remove its outgoing edges.
+      sorted_ops.append(op)
+      if _is_loop_edge(op):
+        continue
+      # pylint: disable=protected-access
+      consumers = list(op._control_outputs)
+      # pylint: enable=protected-access
+      for out_tensor in op.outputs:
+        consumers += [consumer_op for consumer_op in out_tensor.consumers()]
+
+      for consumer in consumers:
+        # For each deleted edge shift the bucket of the vertex.
+        op_in_degree[consumer] -= 1
+        if op_in_degree[consumer] == 0:
+          frontier.append(consumer)
+        if op_in_degree[consumer] < 0:
+          raise ValueError('consumer:%s degree mismatch'%consumer.name)
+
+    left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
+    if left_ops:
+      return (False, left_ops)
     else:
-      assert len(unsorted_ops) == len(sorted_ops)
+      assert len(g.get_operations()) == len(sorted_ops)
       return (True, sorted_ops)
 
   @staticmethod