[convergence_tools]: changing the topological sort function.
PiperOrigin-RevId: 238548383
This commit is contained in:
parent
f4f8d500c3
commit
ab95f9cc90
@ -540,54 +540,56 @@ class TensorTracer(object):
|
||||
cycle is found) and the second element is either the sorted
|
||||
list of nodes or the cycle of nodes found.
|
||||
"""
|
||||
def _is_loop_edge(op):
|
||||
"""Returns true if the op is the end of a while-loop creating a cycle."""
|
||||
return op.type in ['NextIteration']
|
||||
|
||||
def visit(op, cycle, permanently_marked_ops,
|
||||
temporarily_marked_ops, sorted_ops):
|
||||
"""Recursively visits all Ops in a graph.
|
||||
def _in_op_degree(op):
|
||||
"""Returns the number of incoming edges to the given op.
|
||||
|
||||
The edge calculation skips the edges that come from 'NextIteration' ops.
|
||||
NextIteration creates a cycle in the graph. We break cycles by treating
|
||||
this op as 'sink' and ignoring all outgoing edges from it.
|
||||
Args:
|
||||
op: the current Op being visited.
|
||||
cycle: a cycle of Ops found.
|
||||
permanently_marked_ops: the set of Ops that were already visited.
|
||||
temporarily_marked_ops: the set of Ops that we have visited during
|
||||
the current descent.
|
||||
sorted_ops: the list of Ops sorted in topological order.
|
||||
op: Tf.Operation
|
||||
Returns:
|
||||
the number of incoming edges.
|
||||
"""
|
||||
count = 0
|
||||
for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
|
||||
if not _is_loop_edge(op):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
if cycle:
|
||||
return
|
||||
if op in permanently_marked_ops:
|
||||
return
|
||||
if op in temporarily_marked_ops:
|
||||
cycle = temporarily_marked_ops
|
||||
return
|
||||
temporarily_marked_ops.add(op)
|
||||
for i in range(len(op.outputs)):
|
||||
out_tensor = op.outputs[i]
|
||||
for consumer_op in out_tensor.consumers():
|
||||
visit(consumer_op, cycle, permanently_marked_ops,
|
||||
temporarily_marked_ops, sorted_ops)
|
||||
# pylint: disable=protected-access
|
||||
for ctrl_output_op in op._control_outputs:
|
||||
# pylint: enable=protected-access
|
||||
visit(ctrl_output_op, cycle, permanently_marked_ops,
|
||||
temporarily_marked_ops, sorted_ops)
|
||||
temporarily_marked_ops.remove(op)
|
||||
permanently_marked_ops.add(op)
|
||||
sorted_ops.insert(0, op)
|
||||
|
||||
graph_cycle = set([])
|
||||
sorted_ops = []
|
||||
permanently_marked_ops = set([])
|
||||
temporarily_marked_ops = set([])
|
||||
unsorted_ops = g.get_operations()
|
||||
for op in unsorted_ops:
|
||||
visit(op, graph_cycle, permanently_marked_ops,
|
||||
temporarily_marked_ops, sorted_ops)
|
||||
if graph_cycle:
|
||||
return (False, graph_cycle)
|
||||
op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
|
||||
|
||||
frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
|
||||
while frontier:
|
||||
op = frontier.pop()
|
||||
# Remove the op from graph, and remove its outgoing edges.
|
||||
sorted_ops.append(op)
|
||||
if _is_loop_edge(op):
|
||||
continue
|
||||
# pylint: disable=protected-access
|
||||
consumers = list(op._control_outputs)
|
||||
# pylint: enable=protected-access
|
||||
for out_tensor in op.outputs:
|
||||
consumers += [consumer_op for consumer_op in out_tensor.consumers()]
|
||||
|
||||
for consumer in consumers:
|
||||
# For each deleted edge shift the bucket of the vertex.
|
||||
op_in_degree[consumer] -= 1
|
||||
if op_in_degree[consumer] == 0:
|
||||
frontier.append(consumer)
|
||||
if op_in_degree[consumer] < 0:
|
||||
raise ValueError('consumer:%s degree mismatch'%consumer.name)
|
||||
|
||||
left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
|
||||
if left_ops:
|
||||
return (False, left_ops)
|
||||
else:
|
||||
assert len(unsorted_ops) == len(sorted_ops)
|
||||
assert len(g.get_operations()) == len(sorted_ops)
|
||||
return (True, sorted_ops)
|
||||
|
||||
@staticmethod
|
||||
|
Loading…
Reference in New Issue
Block a user