Merge pull request #28886 from frreiss:issue-sparse-import

PiperOrigin-RevId: 249843058
This commit is contained in:
TensorFlower Gardener 2019-05-24 09:15:31 -07:00
commit eb26bc8fb0
2 changed files with 54 additions and 17 deletions

View File

@ -245,39 +245,43 @@ class WrappedFunction(function.ConcreteFunction):
tensor_fetches = []
tensor_infos = []
def _fetch_preprocesing_callback(f):
def _fetch_preprocesing_callback(fetch):
"""Extract out lists of ops, tensors, and tensor type info.
Turns TensorInfos into Tensors in the original fetches structure.
Turns TensorInfos into Tensors in the original `fetches` structure.
Also extracts ops from `fetches`.
Args:
f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
identifying a Tensor or Operation.
fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or
string identifying a Tensor or Operation.
Returns:
`f` converted to a Tensor.
`fetch` converted to a Tensor.
"""
if isinstance(f, ops.Operation):
operation_fetches.append(f)
return f
elif isinstance(f, meta_graph_pb2.TensorInfo):
tensor_infos.append(f)
decoded = _get_element_from_tensor_info(f, self._func_graph)
if isinstance(fetch, ops.Operation):
operation_fetches.append(fetch)
return fetch
elif isinstance(fetch, meta_graph_pb2.TensorInfo):
tensor_infos.append(fetch)
decoded = _get_element_from_tensor_info(fetch, self._func_graph)
if tensor_util.is_tensor(decoded):
tensor_fetches.append(decoded)
else:
operation_fetches.append(decoded)
return decoded
elif isinstance(f, ops.Tensor):
tensor_fetches.append(f)
return f
elif isinstance(fetch, ops.Tensor):
tensor_fetches.append(fetch)
return fetch
else:
graph_element = self.graph.as_graph_element(f)
graph_element = self.graph.as_graph_element(fetch)
return _fetch_preprocesing_callback(graph_element)
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
for f in flat_feeds + tensor_fetches + operation_fetches:
# Expand composite tensors into their component dense Tensors.
tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
for f in (flat_feeds + tensor_fetches + operation_fetches):
if f.graph is not self._func_graph:
raise ValueError("Can only prune function whose feeds and fetches "
"are from this graph (%s). Input %s is from graph %s" %
@ -288,6 +292,10 @@ class WrappedFunction(function.ConcreteFunction):
operation_fetches + tensor_fetches,
pruned_graph,
sources=flat_feeds + internal_captures)
# Note that we add the component tensors of any composite tensors to the
# returned function's outputs list; the list must contain these component
# tensors, or the function's sparse outputs won't work properly.
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
pruned_graph.control_outputs.extend(
[lift_map[operation] for operation in operation_fetches])
@ -308,13 +316,17 @@ class WrappedFunction(function.ConcreteFunction):
pruned_graph.variables = self.graph.variables
def _structured_output_mapping(fetched):
"""callback for `nest.map_structure()`"""
lifted = lift_map[fetched]
if isinstance(lifted, ops.Operation):
return None
return lifted
# expand_composites=True here causes composite tensors to be expanded
# into their component dense Tensors, mapped to the new graph, and then
# reconstituted into their original composite form.
pruned_graph.structured_outputs = nest.map_structure(
_structured_output_mapping, fetches)
_structured_output_mapping, fetches, expand_composites=True)
pruned_graph.structured_input_signature = input_signature
pruned_fn = WrappedFunction(
pruned_graph, variable_holder=self._variable_holder)

View File

@ -29,6 +29,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
@ -489,5 +490,29 @@ class LoadTest(test.TestCase):
root = load.load(path)
self.assertFalse(root.variables[0].trainable)
def _model_with_sparse_output(self):
"""Generate a graph with a SparseTensor output and serialize in V1 format"""
export_graph = ops.Graph()
with export_graph.as_default():
in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1])
out_sparse_tensor = sparse_tensor.SparseTensor(
indices=[[0]], values=in_placeholder, dense_shape=[1]) * 2
with session_lib.Session() as session:
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
simple_save.simple_save(
session,
path,
inputs={"start": in_placeholder},
outputs={"output": out_sparse_tensor})
return path
def test_load_sparse_outputs(self):
path = self._model_with_sparse_output()
imported = load.load(path)
imported_fn = imported.signatures["serving_default"]
forty_two = constant_op.constant([42], dtype=dtypes.int64)
self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy())
if __name__ == "__main__":
test.main()