Fix a bug in lift_to_graph.py.

For example, if a source tensor is SOME_OP:0, the SOME_OP will be added to source_ops in line 325, and also be added to op_map in line 203. However if SOME_OP has another output SOME_OP:1, it will skip the _copy_non_source(), so SOME_OP:1 won't be added to op_map. As the result, in line 349 if mutation.old_graph_tensor is SOME_OP:1, it will crash.

With this cl, all tensors including SOME_OP:1 will be added to op_map.

PiperOrigin-RevId: 315861055
Change-Id: I7aec62416051d90c37dec11a7a883452d1960c23
This commit is contained in:
A. Unique TensorFlower 2020-06-11 02:28:15 -07:00 committed by TensorFlower Gardener
parent ce4ae2b312
commit 41de1e90a3
2 changed files with 6 additions and 36 deletions

View File

@ -134,9 +134,7 @@ def _copy_non_source(op, graph, op_map, base_graph):
name=op.name)
op_map[op] = copied_op
for i, o in enumerate(op.outputs):
# The tensor should already be in the map if it's a source.
if o not in op_map:
op_map[o] = copied_op.outputs[i]
op_map[o] = copied_op.outputs[i]
return ([mutation._replace(copied_op=copied_op)
for mutation in input_mutations],
@ -309,10 +307,12 @@ def lift_to_graph(tensors,
with graph.as_default():
for i in variable_init_tensors:
op_map[i] = i
source_ops = set()
# Add the sources in the same order as the original graph.
for s in internal_captures:
if s in sources:
sources.remove(s)
source_ops.add(s.op)
_copy_source(
s=s,
graph=graph,
@ -321,6 +321,7 @@ def lift_to_graph(tensors,
inverse_captures=inverse_captures,
base_graph=base_graph)
for s in sources:
source_ops.add(s.op)
_copy_source(
s=s,
graph=graph,
@ -332,6 +333,8 @@ def lift_to_graph(tensors,
input_mutations = []
control_mutations = []
for op in reversed(ops_to_copy):
if op in source_ops or op in op_map:
continue
new_input_mutations, new_control_mutations = _copy_non_source(
op=op, graph=graph, op_map=op_map, base_graph=base_graph)
input_mutations.extend(new_input_mutations)

View File

@ -24,7 +24,6 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops as framework_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat
@ -88,38 +87,6 @@ class LiftToGraphTest(test.TestCase):
self.assertItemsEqual(op.colocation_groups(), # Expect default self-ref.
[compat.as_bytes('loc:@%s' % op.name)])
def testMixedSourceAndNonSourceInOpOutputs(self):
@def_function.function
def fn():
input_range = math_ops.range(0, 4)
split0, split1 = array_ops.split(input_range, 2, name='split')
concat = array_ops.concat([split0, split1], axis=0, name='concat')
return concat
concrete_fn = fn.get_concrete_function()
out = concrete_fn.graph.outputs[0]
g = func_graph.FuncGraph('lifted')
# split:0 is source, but split:1 is non source.
op_map = lift_to_graph.lift_to_graph(
[out],
g,
sources=[concrete_fn.graph.get_operation_by_name('split').outputs[0]])
for old, new in op_map.items():
if old.name == 'split':
self.assertEqual('split_1', new.name)
elif old.name == 'split:0':
self.assertEqual('split:0', new.name)
elif old.name == 'split:1':
self.assertEqual('split_1:1', new.name)
elif old.name == 'concat':
new_concat_inputs = []
for input_tensor in new.inputs:
new_concat_inputs.append(input_tensor.name)
# The inputs of 'concat' used to be ['split:0', 'split:1'].
self.assertContainsSubset(['split:0', 'split_1:1'], new_concat_inputs)
if __name__ == '__main__':
test.main()