[convergence_tools]: changing the topological sort function.

PiperOrigin-RevId: 238548383
This commit is contained in:
A. Unique TensorFlower 2019-03-14 16:57:12 -07:00 committed by TensorFlower Gardener
parent f4f8d500c3
commit ab95f9cc90

View File

@ -540,54 +540,56 @@ class TensorTracer(object):
cycle is found) and the second element is either the sorted
list of nodes or the cycle of nodes found.
"""
def _is_loop_edge(op):
"""Returns true if the op is the end of a while-loop creating a cycle."""
return op.type in ['NextIteration']
def visit(op, cycle, permanently_marked_ops,
temporarily_marked_ops, sorted_ops):
"""Recursively visits all Ops in a graph.
def _in_op_degree(op):
"""Returns the number of incoming edges to the given op.
The edge calculation skips the edges that come from 'NextIteration' ops.
NextIteration creates a cycle in the graph. We break cycles by treating
this op as 'sink' and ignoring all outgoing edges from it.
Args:
op: the current Op being visited.
cycle: a cycle of Ops found.
permanently_marked_ops: the set of Ops that were already visited.
temporarily_marked_ops: the set of Ops that we have visited during
the current descent.
sorted_ops: the list of Ops sorted in topological order.
op: Tf.Operation
Returns:
the number of incoming edges.
"""
count = 0
for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
if not _is_loop_edge(op):
count += 1
return count
if cycle:
return
if op in permanently_marked_ops:
return
if op in temporarily_marked_ops:
cycle = temporarily_marked_ops
return
temporarily_marked_ops.add(op)
for i in range(len(op.outputs)):
out_tensor = op.outputs[i]
for consumer_op in out_tensor.consumers():
visit(consumer_op, cycle, permanently_marked_ops,
temporarily_marked_ops, sorted_ops)
# pylint: disable=protected-access
for ctrl_output_op in op._control_outputs:
# pylint: enable=protected-access
visit(ctrl_output_op, cycle, permanently_marked_ops,
temporarily_marked_ops, sorted_ops)
temporarily_marked_ops.remove(op)
permanently_marked_ops.add(op)
sorted_ops.insert(0, op)
graph_cycle = set([])
sorted_ops = []
permanently_marked_ops = set([])
temporarily_marked_ops = set([])
unsorted_ops = g.get_operations()
for op in unsorted_ops:
visit(op, graph_cycle, permanently_marked_ops,
temporarily_marked_ops, sorted_ops)
if graph_cycle:
return (False, graph_cycle)
op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
while frontier:
op = frontier.pop()
# Remove the op from graph, and remove its outgoing edges.
sorted_ops.append(op)
if _is_loop_edge(op):
continue
# pylint: disable=protected-access
consumers = list(op._control_outputs)
# pylint: enable=protected-access
for out_tensor in op.outputs:
consumers += [consumer_op for consumer_op in out_tensor.consumers()]
for consumer in consumers:
# For each deleted edge shift the bucket of the vertex.
op_in_degree[consumer] -= 1
if op_in_degree[consumer] == 0:
frontier.append(consumer)
if op_in_degree[consumer] < 0:
raise ValueError('consumer:%s degree mismatch'%consumer.name)
left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
if left_ops:
return (False, left_ops)
else:
assert len(unsorted_ops) == len(sorted_ops)
assert len(g.get_operations()) == len(sorted_ops)
return (True, sorted_ops)
@staticmethod