Allow importing of V1 models that output SparseTensors

Initial implementation

Fix typo

Fix minor bugs and finish up test case
This commit is contained in:
frreiss 2019-05-16 16:40:43 -07:00
parent d102214520
commit 65aacb43a6
2 changed files with 74 additions and 4 deletions

View File

@ -192,6 +192,36 @@ def _lift_unlifted_variables(graph, variable_holder):
mutable_collection[index] = lifted_variables.get(current, current)
def _sparse_to_dense(sparse_tensor_list):
"""
Extract out and return the dense components (elements, indices, shape) of an
iterable of `SparseTensor`s.
"""
ret = []
for s in sparse_tensor_list:
ret.append(s.indices)
ret.append(s.values)
ret.append(s.dense_shape)
return ret
def _lift_sparse_tensor(orig_sparse_tensor, lift_map):
"""
Args:
orig_sparse_tensor: SparseTensors object whose underlying dense Tensors
reside in a different graph
lift_map: Map (as returned by `lift_to_graph`) from tensors in the other
graph to tensors in the current graph.
Returns:
A new copy of `orig_sparse_tensor` whose underlying dense tensors are in
the current graph
"""
return sparse_tensor.SparseTensor(
indices=lift_map[orig_sparse_tensor.indices],
values=lift_map[orig_sparse_tensor.values],
dense_shape=lift_map[orig_sparse_tensor.dense_shape]
)
# TODO(allenl): make this trackable
class WrappedFunction(function.ConcreteFunction):
"""Wraps a tf V1 piece of code in a function."""
@ -243,12 +273,14 @@ class WrappedFunction(function.ConcreteFunction):
operation_fetches = []
tensor_fetches = []
sparse_tensor_fetches = []
tensor_infos = []
def _fetch_preprocesing_callback(f):
"""Extract out lists of ops, tensors, and tensor type info.
Turns TensorInfos into Tensors in the original fetches structure.
Turns TensorInfos into Tensors in the original `fetches` structure.
Also extracts sparse tensors and ops from `fetches`.
Args:
f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
@ -263,7 +295,9 @@ class WrappedFunction(function.ConcreteFunction):
elif isinstance(f, meta_graph_pb2.TensorInfo):
tensor_infos.append(f)
decoded = _get_element_from_tensor_info(f, self._func_graph)
if tensor_util.is_tensor(decoded):
if isinstance(decoded, sparse_tensor.SparseTensor):
sparse_tensor_fetches.append(decoded)
elif tensor_util.is_tensor(decoded):
tensor_fetches.append(decoded)
else:
operation_fetches.append(decoded)
@ -277,7 +311,8 @@ class WrappedFunction(function.ConcreteFunction):
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
for f in flat_feeds + tensor_fetches + operation_fetches:
for f in flat_feeds + tensor_fetches + operation_fetches \
+ sparse_tensor_fetches:
if f.graph is not self._func_graph:
raise ValueError("Can only prune function whose feeds and fetches "
"are from this graph (%s). Input %s is from graph %s" %
@ -285,10 +320,17 @@ class WrappedFunction(function.ConcreteFunction):
with self._func_graph.as_default():
pruned_graph = func_graph.FuncGraph(name)
lift_map = lift_to_graph.lift_to_graph(
operation_fetches + tensor_fetches,
operation_fetches + tensor_fetches
+ _sparse_to_dense(sparse_tensor_fetches),
pruned_graph,
sources=flat_feeds + internal_captures)
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
for f in sparse_tensor_fetches:
# Outputs list can only contain dense tensors, but it must contain any
# tensors that are part of an output SparseTensor.
f_lifted = _lift_sparse_tensor(f, lift_map)
pruned_graph.outputs.extend([f_lifted.indices, f_lifted.values,
f_lifted.dense_shape])
pruned_graph.control_outputs.extend(
[lift_map[operation] for operation in operation_fetches])
for external_capture, internal_capture in self.graph.captures.items():
@ -308,6 +350,9 @@ class WrappedFunction(function.ConcreteFunction):
pruned_graph.variables = self.graph.variables
def _structured_output_mapping(fetched):
"""`nest.map_structure()` callback."""
if isinstance(fetched, sparse_tensor.SparseTensor):
return _lift_sparse_tensor(fetched, lift_map)
lifted = lift_map[fetched]
if isinstance(lifted, ops.Operation):
return None

View File

@ -29,6 +29,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
@ -489,5 +490,29 @@ class LoadTest(test.TestCase):
root = load.load(path)
self.assertFalse(root.variables[0].trainable)
def _model_with_sparse_output(self):
"""Generate a graph with a SparseTensor output and serialize in V1 format"""
export_graph = ops.Graph()
with export_graph.as_default():
in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1])
out_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0]],
values=in_placeholder,
dense_shape=[1]) * 2
with session_lib.Session() as session:
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
simple_save.simple_save(
session,
path,
inputs={"start": in_placeholder},
outputs={"output": out_sparse_tensor})
return path
def test_load_sparse_outputs(self):
path = self._model_with_sparse_output()
imported = load.load(path)
imported_fn = imported.signatures["serving_default"]
forty_two = constant_op.constant([42], dtype=dtypes.int64)
self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy())
if __name__ == "__main__":
test.main()