Fix a bug in lift_to_graph.py.
For example, if a source tensor is SOME_OP:0, the SOME_OP will be added to source_ops in line 325, and also be added to op_map in line 203. However if SOME_OP has another output SOME_OP:1, it will skip the _copy_non_source(), so SOME_OP:1 won't be added to op_map. As the result, in line 349 if mutation.old_graph_tensor is SOME_OP:1, it will crash. With this cl, all tensors including SOME_OP:1 will be added to op_map. PiperOrigin-RevId: 315861055 Change-Id: I7aec62416051d90c37dec11a7a883452d1960c23
This commit is contained in:
parent
ce4ae2b312
commit
41de1e90a3
tensorflow/python/eager
@ -134,9 +134,7 @@ def _copy_non_source(op, graph, op_map, base_graph):
|
||||
name=op.name)
|
||||
op_map[op] = copied_op
|
||||
for i, o in enumerate(op.outputs):
|
||||
# The tensor should already be in the map if it's a source.
|
||||
if o not in op_map:
|
||||
op_map[o] = copied_op.outputs[i]
|
||||
op_map[o] = copied_op.outputs[i]
|
||||
|
||||
return ([mutation._replace(copied_op=copied_op)
|
||||
for mutation in input_mutations],
|
||||
@ -309,10 +307,12 @@ def lift_to_graph(tensors,
|
||||
with graph.as_default():
|
||||
for i in variable_init_tensors:
|
||||
op_map[i] = i
|
||||
source_ops = set()
|
||||
# Add the sources in the same order as the original graph.
|
||||
for s in internal_captures:
|
||||
if s in sources:
|
||||
sources.remove(s)
|
||||
source_ops.add(s.op)
|
||||
_copy_source(
|
||||
s=s,
|
||||
graph=graph,
|
||||
@ -321,6 +321,7 @@ def lift_to_graph(tensors,
|
||||
inverse_captures=inverse_captures,
|
||||
base_graph=base_graph)
|
||||
for s in sources:
|
||||
source_ops.add(s.op)
|
||||
_copy_source(
|
||||
s=s,
|
||||
graph=graph,
|
||||
@ -332,6 +333,8 @@ def lift_to_graph(tensors,
|
||||
input_mutations = []
|
||||
control_mutations = []
|
||||
for op in reversed(ops_to_copy):
|
||||
if op in source_ops or op in op_map:
|
||||
continue
|
||||
new_input_mutations, new_control_mutations = _copy_non_source(
|
||||
op=op, graph=graph, op_map=op_map, base_graph=base_graph)
|
||||
input_mutations.extend(new_input_mutations)
|
||||
|
@ -24,7 +24,6 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops as framework_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.util import compat
|
||||
@ -88,38 +87,6 @@ class LiftToGraphTest(test.TestCase):
|
||||
self.assertItemsEqual(op.colocation_groups(), # Expect default self-ref.
|
||||
[compat.as_bytes('loc:@%s' % op.name)])
|
||||
|
||||
def testMixedSourceAndNonSourceInOpOutputs(self):
|
||||
|
||||
@def_function.function
|
||||
def fn():
|
||||
input_range = math_ops.range(0, 4)
|
||||
split0, split1 = array_ops.split(input_range, 2, name='split')
|
||||
concat = array_ops.concat([split0, split1], axis=0, name='concat')
|
||||
return concat
|
||||
|
||||
concrete_fn = fn.get_concrete_function()
|
||||
out = concrete_fn.graph.outputs[0]
|
||||
|
||||
g = func_graph.FuncGraph('lifted')
|
||||
# split:0 is source, but split:1 is non source.
|
||||
op_map = lift_to_graph.lift_to_graph(
|
||||
[out],
|
||||
g,
|
||||
sources=[concrete_fn.graph.get_operation_by_name('split').outputs[0]])
|
||||
for old, new in op_map.items():
|
||||
if old.name == 'split':
|
||||
self.assertEqual('split_1', new.name)
|
||||
elif old.name == 'split:0':
|
||||
self.assertEqual('split:0', new.name)
|
||||
elif old.name == 'split:1':
|
||||
self.assertEqual('split_1:1', new.name)
|
||||
elif old.name == 'concat':
|
||||
new_concat_inputs = []
|
||||
for input_tensor in new.inputs:
|
||||
new_concat_inputs.append(input_tensor.name)
|
||||
# The inputs of 'concat' used to be ['split:0', 'split:1'].
|
||||
self.assertContainsSubset(['split:0', 'split_1:1'], new_concat_inputs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user