Simplify logic per review comments
This commit is contained in:
parent
5060dd212e
commit
088e496ba8
@ -296,7 +296,7 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
|
|
||||||
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
|
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
|
||||||
|
|
||||||
# Turn composite/sparse tensors into dense Tensors.
|
# Expand composite tensors into their component dense Tensors.
|
||||||
tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
|
tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
|
||||||
|
|
||||||
for f in (flat_feeds + tensor_fetches + operation_fetches):
|
for f in (flat_feeds + tensor_fetches + operation_fetches):
|
||||||
@ -311,9 +311,9 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
pruned_graph,
|
pruned_graph,
|
||||||
sources=flat_feeds + internal_captures)
|
sources=flat_feeds + internal_captures)
|
||||||
|
|
||||||
# Note that we deliberately add the component tensors of any SparseTensors
|
# Note that we add the component tensors of any composite tensors to the
|
||||||
# to the returned function's outputs list; the list must contain these
|
# returned function's outputs list; the list must contain these component
|
||||||
# component tensors, or the function's sparse outputs won't work properly.
|
# tensors, or the function's sparse outputs won't work properly.
|
||||||
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
|
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
|
||||||
pruned_graph.control_outputs.extend(
|
pruned_graph.control_outputs.extend(
|
||||||
[lift_map[operation] for operation in operation_fetches])
|
[lift_map[operation] for operation in operation_fetches])
|
||||||
@ -334,16 +334,17 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
pruned_graph.variables = self.graph.variables
|
pruned_graph.variables = self.graph.variables
|
||||||
|
|
||||||
def _structured_output_mapping(fetched):
|
def _structured_output_mapping(fetched):
|
||||||
"""`nest.map_structure()` callback."""
|
"""callback for `nest.map_structure()`"""
|
||||||
if isinstance(fetched, sparse_tensor.SparseTensor):
|
|
||||||
return _lift_sparse_tensor(fetched, lift_map)
|
|
||||||
lifted = lift_map[fetched]
|
lifted = lift_map[fetched]
|
||||||
if isinstance(lifted, ops.Operation):
|
if isinstance(lifted, ops.Operation):
|
||||||
return None
|
return None
|
||||||
return lifted
|
return lifted
|
||||||
|
|
||||||
|
# expand_composites=True here causes composite tensors to be expanded
|
||||||
|
# into their component dense Tensors, mapped to the new graph, and then
|
||||||
|
# reconstituted into their original composite form.
|
||||||
pruned_graph.structured_outputs = nest.map_structure(
|
pruned_graph.structured_outputs = nest.map_structure(
|
||||||
_structured_output_mapping, fetches)
|
_structured_output_mapping, fetches, expand_composites=True)
|
||||||
pruned_graph.structured_input_signature = input_signature
|
pruned_graph.structured_input_signature = input_signature
|
||||||
pruned_fn = WrappedFunction(
|
pruned_fn = WrappedFunction(
|
||||||
pruned_graph, variable_holder=self._variable_holder)
|
pruned_graph, variable_holder=self._variable_holder)
|
||||||
|
Loading…
Reference in New Issue
Block a user