Merge pull request #28886 from frreiss:issue-sparse-import

PiperOrigin-RevId: 249843058
This commit is contained in:
TensorFlower Gardener 2019-05-24 09:15:31 -07:00
commit eb26bc8fb0
2 changed files with 54 additions and 17 deletions

View File

@ -245,39 +245,43 @@ class WrappedFunction(function.ConcreteFunction):
tensor_fetches = [] tensor_fetches = []
tensor_infos = [] tensor_infos = []
def _fetch_preprocesing_callback(f): def _fetch_preprocesing_callback(fetch):
"""Extract out lists of ops, tensors, and tensor type info. """Extract out lists of ops, tensors, and tensor type info.
Turns TensorInfos into Tensors in the original fetches structure. Turns TensorInfos into Tensors in the original `fetches` structure.
Also extracts ops from `fetches`.
Args: Args:
f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or
identifying a Tensor or Operation. string identifying a Tensor or Operation.
Returns: Returns:
`f` converted to a Tensor. `fetch` converted to a Tensor.
""" """
if isinstance(f, ops.Operation): if isinstance(fetch, ops.Operation):
operation_fetches.append(f) operation_fetches.append(fetch)
return f return fetch
elif isinstance(f, meta_graph_pb2.TensorInfo): elif isinstance(fetch, meta_graph_pb2.TensorInfo):
tensor_infos.append(f) tensor_infos.append(fetch)
decoded = _get_element_from_tensor_info(f, self._func_graph) decoded = _get_element_from_tensor_info(fetch, self._func_graph)
if tensor_util.is_tensor(decoded): if tensor_util.is_tensor(decoded):
tensor_fetches.append(decoded) tensor_fetches.append(decoded)
else: else:
operation_fetches.append(decoded) operation_fetches.append(decoded)
return decoded return decoded
elif isinstance(f, ops.Tensor): elif isinstance(fetch, ops.Tensor):
tensor_fetches.append(f) tensor_fetches.append(fetch)
return f return fetch
else: else:
graph_element = self.graph.as_graph_element(f) graph_element = self.graph.as_graph_element(fetch)
return _fetch_preprocesing_callback(graph_element) return _fetch_preprocesing_callback(graph_element)
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
for f in flat_feeds + tensor_fetches + operation_fetches: # Expand composite tensors into their component dense Tensors.
tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
for f in (flat_feeds + tensor_fetches + operation_fetches):
if f.graph is not self._func_graph: if f.graph is not self._func_graph:
raise ValueError("Can only prune function whose feeds and fetches " raise ValueError("Can only prune function whose feeds and fetches "
"are from this graph (%s). Input %s is from graph %s" % "are from this graph (%s). Input %s is from graph %s" %
@ -288,6 +292,10 @@ class WrappedFunction(function.ConcreteFunction):
operation_fetches + tensor_fetches, operation_fetches + tensor_fetches,
pruned_graph, pruned_graph,
sources=flat_feeds + internal_captures) sources=flat_feeds + internal_captures)
# Note that we add the component tensors of any composite tensors to the
# returned function's outputs list; the list must contain these component
# tensors, or the function's sparse outputs won't work properly.
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
pruned_graph.control_outputs.extend( pruned_graph.control_outputs.extend(
[lift_map[operation] for operation in operation_fetches]) [lift_map[operation] for operation in operation_fetches])
@ -308,13 +316,17 @@ class WrappedFunction(function.ConcreteFunction):
pruned_graph.variables = self.graph.variables pruned_graph.variables = self.graph.variables
def _structured_output_mapping(fetched): def _structured_output_mapping(fetched):
"""callback for `nest.map_structure()`"""
lifted = lift_map[fetched] lifted = lift_map[fetched]
if isinstance(lifted, ops.Operation): if isinstance(lifted, ops.Operation):
return None return None
return lifted return lifted
# expand_composites=True here causes composite tensors to be expanded
# into their component dense Tensors, mapped to the new graph, and then
# reconstituted into their original composite form.
pruned_graph.structured_outputs = nest.map_structure( pruned_graph.structured_outputs = nest.map_structure(
_structured_output_mapping, fetches) _structured_output_mapping, fetches, expand_composites=True)
pruned_graph.structured_input_signature = input_signature pruned_graph.structured_input_signature = input_signature
pruned_fn = WrappedFunction( pruned_fn = WrappedFunction(
pruned_graph, variable_holder=self._variable_holder) pruned_graph, variable_holder=self._variable_holder)

View File

@ -29,6 +29,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions from tensorflow.python.framework import versions
@ -489,5 +490,29 @@ class LoadTest(test.TestCase):
root = load.load(path) root = load.load(path)
self.assertFalse(root.variables[0].trainable) self.assertFalse(root.variables[0].trainable)
def _model_with_sparse_output(self):
"""Generate a graph with a SparseTensor output and serialize in V1 format"""
export_graph = ops.Graph()
with export_graph.as_default():
in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1])
out_sparse_tensor = sparse_tensor.SparseTensor(
indices=[[0]], values=in_placeholder, dense_shape=[1]) * 2
with session_lib.Session() as session:
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
simple_save.simple_save(
session,
path,
inputs={"start": in_placeholder},
outputs={"output": out_sparse_tensor})
return path
def test_load_sparse_outputs(self):
path = self._model_with_sparse_output()
imported = load.load(path)
imported_fn = imported.signatures["serving_default"]
forty_two = constant_op.constant([42], dtype=dtypes.int64)
self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy())
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()