tensor tracer: making the topological sort deterministic.

PiperOrigin-RevId: 243330039
This commit is contained in:
A. Unique TensorFlower 2019-04-12 14:10:44 -07:00 committed by TensorFlower Gardener
parent 75d6568883
commit 6c6e798a15

View File

@ -598,6 +598,7 @@ class TensorTracer(object):
op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
frontier.sort(key=lambda op: op.name)
while frontier:
op = frontier.pop()
# Remove the op from graph, and remove its outgoing edges.
@ -609,7 +610,7 @@ class TensorTracer(object):
# pylint: enable=protected-access
for out_tensor in op.outputs:
consumers += [consumer_op for consumer_op in out_tensor.consumers()]
consumers.sort(key=lambda op: op.name)
for consumer in consumers:
# For each deleted edge shift the bucket of the vertex.
op_in_degree[consumer] -= 1
@ -928,7 +929,9 @@ class TensorTracer(object):
for i in range(0, len(tensor_list)):
tensor = tensor_list[i]
line = '%d "%s"'%(i, tensor.name)
for consumer_op in tensor.consumers():
consumers = tensor.consumers()
consumers.sort(key=lambda op: op.name)
for consumer_op in consumers:
if consumer_op.name not in opname_idx_map:
raise ValueError(
'consumer_op %s is not in opname_idx_map'%consumer_op.name)