Set control outputs in wrap_function.prune to make sure they run
PiperOrigin-RevId: 239094610
This commit is contained in:
parent
ba3e07c51e
commit
9c234792c0
@ -188,6 +188,8 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
lift_map[original_fetch] = lift_map[identity_fetch]
|
||||
pruned_graph.outputs.extend(
|
||||
lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor))
|
||||
pruned_graph.control_outputs.extend(
|
||||
[lift_map[operation] for operation in operation_fetches])
|
||||
if not tensor_fetches:
|
||||
pruned_graph.outputs.append(lift_map[sink_tensor])
|
||||
for external_capture, internal_capture in self.graph.captures.items():
|
||||
|
@ -245,6 +245,24 @@ class WrapFunctionTest(test.TestCase):
|
||||
self.assertEqual(0, v0.numpy())
|
||||
self.assertEqual(0, v1.numpy())
|
||||
|
||||
def test_operation_returned(self):
|
||||
|
||||
v = variables.Variable(0)
|
||||
|
||||
def f():
|
||||
v.assign(1, read_value=False, name='assign_to_v')
|
||||
|
||||
f_wrapped = wrap_function.wrap_function(f, [])
|
||||
operation_to_fetch = f_wrapped.graph.get_operation_by_name('assign_to_v')
|
||||
f_pruned = f_wrapped.prune(
|
||||
[], operation_to_fetch)
|
||||
self.assertEqual(
|
||||
['assign_to_v'],
|
||||
[operation.name for operation in f_pruned.graph.control_outputs])
|
||||
self.assertEqual(0, v.numpy())
|
||||
f_pruned()
|
||||
self.assertEqual(1, v.numpy())
|
||||
|
||||
def test_function_from_graph_def(self):
|
||||
@def_function.function
|
||||
def make_graph_def(x):
|
||||
|
Loading…
Reference in New Issue
Block a user