Explicitly set base_graph in tf.wrap_function.prune().
lift_to_graph.lift_to_graph(...) doesn't always succeed on inferring the base graph, so just set explicitly. PiperOrigin-RevId: 280476493 Change-Id: I1ecd649d04daa559e47d3d4d036d1a126e8fc9f8
This commit is contained in:
parent
1341ad6b0e
commit
08a428f7ea
@ -317,7 +317,8 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
lift_map = lift_to_graph.lift_to_graph(
|
||||
operation_fetches + tensor_fetches,
|
||||
pruned_graph,
|
||||
sources=flat_feeds + self.graph.internal_captures)
|
||||
sources=flat_feeds + self.graph.internal_captures,
|
||||
base_graph=self._func_graph)
|
||||
|
||||
# Note that we add the component tensors of any composite tensors to the
|
||||
# returned function's outputs list; the list must contain these component
|
||||
|
Loading…
Reference in New Issue
Block a user