Simplify logic per review comments

This commit is contained in:
frreiss 2019-05-23 17:19:13 -07:00
parent 5060dd212e
commit 088e496ba8

View File

@ -296,7 +296,7 @@ class WrappedFunction(function.ConcreteFunction):
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
# Turn composite/sparse tensors into dense Tensors.
# Expand composite tensors into their component dense Tensors.
tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
for f in (flat_feeds + tensor_fetches + operation_fetches):
@ -311,9 +311,9 @@ class WrappedFunction(function.ConcreteFunction):
pruned_graph,
sources=flat_feeds + internal_captures)
# Note that we deliberately add the component tensors of any SparseTensors
# to the returned function's outputs list; the list must contain these
# component tensors, or the function's sparse outputs won't work properly.
# Note that we add the component tensors of any composite tensors to the
# returned function's outputs list; the list must contain these component
# tensors, or the function's sparse outputs won't work properly.
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
pruned_graph.control_outputs.extend(
[lift_map[operation] for operation in operation_fetches])
@ -334,16 +334,17 @@ class WrappedFunction(function.ConcreteFunction):
pruned_graph.variables = self.graph.variables
def _structured_output_mapping(fetched):
"""`nest.map_structure()` callback."""
if isinstance(fetched, sparse_tensor.SparseTensor):
return _lift_sparse_tensor(fetched, lift_map)
"""callback for `nest.map_structure()`"""
lifted = lift_map[fetched]
if isinstance(lifted, ops.Operation):
return None
return lifted
# expand_composites=True here causes composite tensors to be expanded
# into their component dense Tensors, mapped to the new graph, and then
# reconstituted into their original composite form.
pruned_graph.structured_outputs = nest.map_structure(
_structured_output_mapping, fetches)
_structured_output_mapping, fetches, expand_composites=True)
pruned_graph.structured_input_signature = input_signature
pruned_fn = WrappedFunction(
pruned_graph, variable_holder=self._variable_holder)