tensor tracer: making the topological sort deterministic.
PiperOrigin-RevId: 243330039
This commit is contained in:
parent
75d6568883
commit
6c6e798a15
@ -598,6 +598,7 @@ class TensorTracer(object):
|
||||
op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
|
||||
|
||||
frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
|
||||
frontier.sort(key=lambda op: op.name)
|
||||
while frontier:
|
||||
op = frontier.pop()
|
||||
# Remove the op from graph, and remove its outgoing edges.
|
||||
@ -609,7 +610,7 @@ class TensorTracer(object):
|
||||
# pylint: enable=protected-access
|
||||
for out_tensor in op.outputs:
|
||||
consumers += [consumer_op for consumer_op in out_tensor.consumers()]
|
||||
|
||||
consumers.sort(key=lambda op: op.name)
|
||||
for consumer in consumers:
|
||||
# For each deleted edge shift the bucket of the vertex.
|
||||
op_in_degree[consumer] -= 1
|
||||
@ -928,7 +929,9 @@ class TensorTracer(object):
|
||||
for i in range(0, len(tensor_list)):
|
||||
tensor = tensor_list[i]
|
||||
line = '%d "%s"'%(i, tensor.name)
|
||||
for consumer_op in tensor.consumers():
|
||||
consumers = tensor.consumers()
|
||||
consumers.sort(key=lambda op: op.name)
|
||||
for consumer_op in consumers:
|
||||
if consumer_op.name not in opname_idx_map:
|
||||
raise ValueError(
|
||||
'consumer_op %s is not in opname_idx_map'%consumer_op.name)
|
||||
|
Loading…
Reference in New Issue
Block a user