Merge pull request #28886 from frreiss:issue-sparse-import
PiperOrigin-RevId: 249843058
This commit is contained in:
commit
eb26bc8fb0
@ -245,39 +245,43 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
tensor_fetches = []
|
||||
tensor_infos = []
|
||||
|
||||
def _fetch_preprocesing_callback(f):
|
||||
def _fetch_preprocesing_callback(fetch):
|
||||
"""Extract out lists of ops, tensors, and tensor type info.
|
||||
|
||||
Turns TensorInfos into Tensors in the original fetches structure.
|
||||
Turns TensorInfos into Tensors in the original `fetches` structure.
|
||||
Also extracts ops from `fetches`.
|
||||
|
||||
Args:
|
||||
f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
|
||||
identifying a Tensor or Operation.
|
||||
fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or
|
||||
string identifying a Tensor or Operation.
|
||||
|
||||
Returns:
|
||||
`f` converted to a Tensor.
|
||||
`fetch` converted to a Tensor.
|
||||
"""
|
||||
if isinstance(f, ops.Operation):
|
||||
operation_fetches.append(f)
|
||||
return f
|
||||
elif isinstance(f, meta_graph_pb2.TensorInfo):
|
||||
tensor_infos.append(f)
|
||||
decoded = _get_element_from_tensor_info(f, self._func_graph)
|
||||
if isinstance(fetch, ops.Operation):
|
||||
operation_fetches.append(fetch)
|
||||
return fetch
|
||||
elif isinstance(fetch, meta_graph_pb2.TensorInfo):
|
||||
tensor_infos.append(fetch)
|
||||
decoded = _get_element_from_tensor_info(fetch, self._func_graph)
|
||||
if tensor_util.is_tensor(decoded):
|
||||
tensor_fetches.append(decoded)
|
||||
else:
|
||||
operation_fetches.append(decoded)
|
||||
return decoded
|
||||
elif isinstance(f, ops.Tensor):
|
||||
tensor_fetches.append(f)
|
||||
return f
|
||||
elif isinstance(fetch, ops.Tensor):
|
||||
tensor_fetches.append(fetch)
|
||||
return fetch
|
||||
else:
|
||||
graph_element = self.graph.as_graph_element(f)
|
||||
graph_element = self.graph.as_graph_element(fetch)
|
||||
return _fetch_preprocesing_callback(graph_element)
|
||||
|
||||
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
|
||||
|
||||
for f in flat_feeds + tensor_fetches + operation_fetches:
|
||||
# Expand composite tensors into their component dense Tensors.
|
||||
tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
|
||||
|
||||
for f in (flat_feeds + tensor_fetches + operation_fetches):
|
||||
if f.graph is not self._func_graph:
|
||||
raise ValueError("Can only prune function whose feeds and fetches "
|
||||
"are from this graph (%s). Input %s is from graph %s" %
|
||||
@ -288,6 +292,10 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
operation_fetches + tensor_fetches,
|
||||
pruned_graph,
|
||||
sources=flat_feeds + internal_captures)
|
||||
|
||||
# Note that we add the component tensors of any composite tensors to the
|
||||
# returned function's outputs list; the list must contain these component
|
||||
# tensors, or the function's sparse outputs won't work properly.
|
||||
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
|
||||
pruned_graph.control_outputs.extend(
|
||||
[lift_map[operation] for operation in operation_fetches])
|
||||
@ -308,13 +316,17 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
pruned_graph.variables = self.graph.variables
|
||||
|
||||
def _structured_output_mapping(fetched):
|
||||
"""callback for `nest.map_structure()`"""
|
||||
lifted = lift_map[fetched]
|
||||
if isinstance(lifted, ops.Operation):
|
||||
return None
|
||||
return lifted
|
||||
|
||||
# expand_composites=True here causes composite tensors to be expanded
|
||||
# into their component dense Tensors, mapped to the new graph, and then
|
||||
# reconstituted into their original composite form.
|
||||
pruned_graph.structured_outputs = nest.map_structure(
|
||||
_structured_output_mapping, fetches)
|
||||
_structured_output_mapping, fetches, expand_composites=True)
|
||||
pruned_graph.structured_input_signature = input_signature
|
||||
pruned_fn = WrappedFunction(
|
||||
pruned_graph, variable_holder=self._variable_holder)
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import versions
|
||||
@ -489,5 +490,29 @@ class LoadTest(test.TestCase):
|
||||
root = load.load(path)
|
||||
self.assertFalse(root.variables[0].trainable)
|
||||
|
||||
def _model_with_sparse_output(self):
|
||||
"""Generate a graph with a SparseTensor output and serialize in V1 format"""
|
||||
export_graph = ops.Graph()
|
||||
with export_graph.as_default():
|
||||
in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1])
|
||||
out_sparse_tensor = sparse_tensor.SparseTensor(
|
||||
indices=[[0]], values=in_placeholder, dense_shape=[1]) * 2
|
||||
with session_lib.Session() as session:
|
||||
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
|
||||
simple_save.simple_save(
|
||||
session,
|
||||
path,
|
||||
inputs={"start": in_placeholder},
|
||||
outputs={"output": out_sparse_tensor})
|
||||
return path
|
||||
|
||||
def test_load_sparse_outputs(self):
|
||||
path = self._model_with_sparse_output()
|
||||
imported = load.load(path)
|
||||
imported_fn = imported.signatures["serving_default"]
|
||||
forty_two = constant_op.constant([42], dtype=dtypes.int64)
|
||||
self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user