diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py index 6d41d8768a3..1d37f1b64b8 100644 --- a/tensorflow/python/tpu/tensor_tracer.py +++ b/tensorflow/python/tpu/tensor_tracer.py @@ -540,54 +540,56 @@ class TensorTracer(object): cycle is found) and the second element is either the sorted list of nodes or the cycle of nodes found. """ + def _is_loop_edge(op): + """Returns true if the op is the end of a while-loop creating a cycle.""" + return op.type in ['NextIteration'] - def visit(op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops): - """Recursively visits all Ops in a graph. + def _in_op_degree(op): + """Returns the number of incoming edges to the given op. + The edge calculation skips the edges that come from 'NextIteration' ops. + NextIteration creates a cycle in the graph. We break cycles by treating + this op as 'sink' and ignoring all outgoing edges from it. Args: - op: the current Op being visited. - cycle: a cycle of Ops found. - permanently_marked_ops: the set of Ops that were already visited. - temporarily_marked_ops: the set of Ops that we have visited during - the current descent. - sorted_ops: the list of Ops sorted in topological order. + op: Tf.Operation + Returns: + the number of incoming edges. """ + count = 0 + for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]: + if not _is_loop_edge(op): + count += 1 + return count - if cycle: - return - if op in permanently_marked_ops: - return - if op in temporarily_marked_ops: - cycle = temporarily_marked_ops - return - temporarily_marked_ops.add(op) - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - for consumer_op in out_tensor.consumers(): - visit(consumer_op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - # pylint: disable=protected-access - for ctrl_output_op in op._control_outputs: - # pylint: enable=protected-access - visit(ctrl_output_op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - temporarily_marked_ops.remove(op) - permanently_marked_ops.add(op) - sorted_ops.insert(0, op) - - graph_cycle = set([]) sorted_ops = [] - permanently_marked_ops = set([]) - temporarily_marked_ops = set([]) - unsorted_ops = g.get_operations() - for op in unsorted_ops: - visit(op, graph_cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - if graph_cycle: - return (False, graph_cycle) + op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()} + + frontier = [op for (op, degree) in op_in_degree.items() if degree == 0] + while frontier: + op = frontier.pop() + # Remove the op from graph, and remove its outgoing edges. + sorted_ops.append(op) + if _is_loop_edge(op): + continue + # pylint: disable=protected-access + consumers = list(op._control_outputs) + # pylint: enable=protected-access + for out_tensor in op.outputs: + consumers += [consumer_op for consumer_op in out_tensor.consumers()] + + for consumer in consumers: + # For each deleted edge shift the bucket of the vertex. + op_in_degree[consumer] -= 1 + if op_in_degree[consumer] == 0: + frontier.append(consumer) + if op_in_degree[consumer] < 0: + raise ValueError('consumer:%s degree mismatch'%consumer.name) + + left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0]) + if left_ops: + return (False, left_ops) else: - assert len(unsorted_ops) == len(sorted_ops) + assert len(g.get_operations()) == len(sorted_ops) return (True, sorted_ops) @staticmethod