diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 8c1634f9c48..d23c812e738 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -296,7 +296,7 @@ class WrappedFunction(function.ConcreteFunction): fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) - # Turn composite/sparse tensors into dense Tensors. + # Expand composite tensors into their component dense Tensors. tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) for f in (flat_feeds + tensor_fetches + operation_fetches): @@ -311,9 +311,9 @@ class WrappedFunction(function.ConcreteFunction): pruned_graph, sources=flat_feeds + internal_captures) - # Note that we deliberately add the component tensors of any SparseTensors - # to the returned function's outputs list; the list must contain these - # component tensors, or the function's sparse outputs won't work properly. + # Note that we add the component tensors of any composite tensors to the + # returned function's outputs list; the list must contain these component + # tensors, or the function's sparse outputs won't work properly. pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) @@ -334,16 +334,17 @@ class WrappedFunction(function.ConcreteFunction): pruned_graph.variables = self.graph.variables def _structured_output_mapping(fetched): - """`nest.map_structure()` callback.""" - if isinstance(fetched, sparse_tensor.SparseTensor): - return _lift_sparse_tensor(fetched, lift_map) + """callback for `nest.map_structure()`""" lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None return lifted + # expand_composites=True here causes composite tensors to be expanded + # into their component dense Tensors, mapped to the new graph, and then + # reconstituted into their original composite form. pruned_graph.structured_outputs = nest.map_structure( - _structured_output_mapping, fetches) + _structured_output_mapping, fetches, expand_composites=True) pruned_graph.structured_input_signature = input_signature pruned_fn = WrappedFunction( pruned_graph, variable_holder=self._variable_holder)