Set control outputs in wrap_function.prune to make sure they run

PiperOrigin-RevId: 239094610
This commit is contained in:
Allen Lavoie 2019-03-18 17:21:34 -07:00 committed by TensorFlower Gardener
parent ba3e07c51e
commit 9c234792c0
2 changed files with 20 additions and 0 deletions

View File

@ -188,6 +188,8 @@ class WrappedFunction(function.ConcreteFunction):
lift_map[original_fetch] = lift_map[identity_fetch]
pruned_graph.outputs.extend(
lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor))
pruned_graph.control_outputs.extend(
[lift_map[operation] for operation in operation_fetches])
if not tensor_fetches:
pruned_graph.outputs.append(lift_map[sink_tensor])
for external_capture, internal_capture in self.graph.captures.items():

View File

@ -245,6 +245,24 @@ class WrapFunctionTest(test.TestCase):
self.assertEqual(0, v0.numpy())
self.assertEqual(0, v1.numpy())
def test_operation_returned(self):
v = variables.Variable(0)
def f():
v.assign(1, read_value=False, name='assign_to_v')
f_wrapped = wrap_function.wrap_function(f, [])
operation_to_fetch = f_wrapped.graph.get_operation_by_name('assign_to_v')
f_pruned = f_wrapped.prune(
[], operation_to_fetch)
self.assertEqual(
['assign_to_v'],
[operation.name for operation in f_pruned.graph.control_outputs])
self.assertEqual(0, v.numpy())
f_pruned()
self.assertEqual(1, v.numpy())
def test_function_from_graph_def(self):
@def_function.function
def make_graph_def(x):