Explicitly set base_graph in tf.wrap_function.prune().

lift_to_graph.lift_to_graph(...) doesn't always succeed on inferring the base graph, so just set explicitly.

PiperOrigin-RevId: 280476493
Change-Id: I1ecd649d04daa559e47d3d4d036d1a126e8fc9f8
This commit is contained in:
A. Unique TensorFlower 2019-11-14 11:43:13 -08:00 committed by TensorFlower Gardener
parent 1341ad6b0e
commit 08a428f7ea

View File

@ -317,7 +317,8 @@ class WrappedFunction(function.ConcreteFunction):
lift_map = lift_to_graph.lift_to_graph(
operation_fetches + tensor_fetches,
pruned_graph,
sources=flat_feeds + self.graph.internal_captures)
sources=flat_feeds + self.graph.internal_captures,
base_graph=self._func_graph)
# Note that we add the component tensors of any composite tensors to the
# returned function's outputs list; the list must contain these component