Allow importing of V1 models that output SparseTensors

Initial implementation

Fix typo

Fix minor bugs and finish up test case
This commit is contained in:
frreiss 2019-05-16 16:40:43 -07:00
parent d102214520
commit 65aacb43a6
2 changed files with 74 additions and 4 deletions

View File

@ -192,6 +192,36 @@ def _lift_unlifted_variables(graph, variable_holder):
mutable_collection[index] = lifted_variables.get(current, current) mutable_collection[index] = lifted_variables.get(current, current)
def _sparse_to_dense(sparse_tensor_list):
"""
Extract out and return the dense components (elements, indices, shape) of an
iterable of `SparseTensor`s.
"""
ret = []
for s in sparse_tensor_list:
ret.append(s.indices)
ret.append(s.values)
ret.append(s.dense_shape)
return ret
def _lift_sparse_tensor(orig_sparse_tensor, lift_map):
"""
Args:
orig_sparse_tensor: SparseTensors object whose underlying dense Tensors
reside in a different graph
lift_map: Map (as returned by `lift_to_graph`) from tensors in the other
graph to tensors in the current graph.
Returns:
A new copy of `orig_sparse_tensor` whose underlying dense tensors are in
the current graph
"""
return sparse_tensor.SparseTensor(
indices=lift_map[orig_sparse_tensor.indices],
values=lift_map[orig_sparse_tensor.values],
dense_shape=lift_map[orig_sparse_tensor.dense_shape]
)
# TODO(allenl): make this trackable # TODO(allenl): make this trackable
class WrappedFunction(function.ConcreteFunction): class WrappedFunction(function.ConcreteFunction):
"""Wraps a tf V1 piece of code in a function.""" """Wraps a tf V1 piece of code in a function."""
@ -243,12 +273,14 @@ class WrappedFunction(function.ConcreteFunction):
operation_fetches = [] operation_fetches = []
tensor_fetches = [] tensor_fetches = []
sparse_tensor_fetches = []
tensor_infos = [] tensor_infos = []
def _fetch_preprocesing_callback(f): def _fetch_preprocesing_callback(f):
"""Extract out lists of ops, tensors, and tensor type info. """Extract out lists of ops, tensors, and tensor type info.
Turns TensorInfos into Tensors in the original fetches structure. Turns TensorInfos into Tensors in the original `fetches` structure.
Also extracts sparse tensors and ops from `fetches`.
Args: Args:
f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
@ -263,7 +295,9 @@ class WrappedFunction(function.ConcreteFunction):
elif isinstance(f, meta_graph_pb2.TensorInfo): elif isinstance(f, meta_graph_pb2.TensorInfo):
tensor_infos.append(f) tensor_infos.append(f)
decoded = _get_element_from_tensor_info(f, self._func_graph) decoded = _get_element_from_tensor_info(f, self._func_graph)
if tensor_util.is_tensor(decoded): if isinstance(decoded, sparse_tensor.SparseTensor):
sparse_tensor_fetches.append(decoded)
elif tensor_util.is_tensor(decoded):
tensor_fetches.append(decoded) tensor_fetches.append(decoded)
else: else:
operation_fetches.append(decoded) operation_fetches.append(decoded)
@ -277,7 +311,8 @@ class WrappedFunction(function.ConcreteFunction):
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
for f in flat_feeds + tensor_fetches + operation_fetches: for f in flat_feeds + tensor_fetches + operation_fetches \
+ sparse_tensor_fetches:
if f.graph is not self._func_graph: if f.graph is not self._func_graph:
raise ValueError("Can only prune function whose feeds and fetches " raise ValueError("Can only prune function whose feeds and fetches "
"are from this graph (%s). Input %s is from graph %s" % "are from this graph (%s). Input %s is from graph %s" %
@ -285,10 +320,17 @@ class WrappedFunction(function.ConcreteFunction):
with self._func_graph.as_default(): with self._func_graph.as_default():
pruned_graph = func_graph.FuncGraph(name) pruned_graph = func_graph.FuncGraph(name)
lift_map = lift_to_graph.lift_to_graph( lift_map = lift_to_graph.lift_to_graph(
operation_fetches + tensor_fetches, operation_fetches + tensor_fetches
+ _sparse_to_dense(sparse_tensor_fetches),
pruned_graph, pruned_graph,
sources=flat_feeds + internal_captures) sources=flat_feeds + internal_captures)
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
for f in sparse_tensor_fetches:
# Outputs list can only contain dense tensors, but it must contain any
# tensors that are part of an output SparseTensor.
f_lifted = _lift_sparse_tensor(f, lift_map)
pruned_graph.outputs.extend([f_lifted.indices, f_lifted.values,
f_lifted.dense_shape])
pruned_graph.control_outputs.extend( pruned_graph.control_outputs.extend(
[lift_map[operation] for operation in operation_fetches]) [lift_map[operation] for operation in operation_fetches])
for external_capture, internal_capture in self.graph.captures.items(): for external_capture, internal_capture in self.graph.captures.items():
@ -308,6 +350,9 @@ class WrappedFunction(function.ConcreteFunction):
pruned_graph.variables = self.graph.variables pruned_graph.variables = self.graph.variables
def _structured_output_mapping(fetched): def _structured_output_mapping(fetched):
"""`nest.map_structure()` callback."""
if isinstance(fetched, sparse_tensor.SparseTensor):
return _lift_sparse_tensor(fetched, lift_map)
lifted = lift_map[fetched] lifted = lift_map[fetched]
if isinstance(lifted, ops.Operation): if isinstance(lifted, ops.Operation):
return None return None

View File

@ -29,6 +29,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions from tensorflow.python.framework import versions
@ -489,5 +490,29 @@ class LoadTest(test.TestCase):
root = load.load(path) root = load.load(path)
self.assertFalse(root.variables[0].trainable) self.assertFalse(root.variables[0].trainable)
def _model_with_sparse_output(self):
"""Generate a graph with a SparseTensor output and serialize in V1 format"""
export_graph = ops.Graph()
with export_graph.as_default():
in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1])
out_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0]],
values=in_placeholder,
dense_shape=[1]) * 2
with session_lib.Session() as session:
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
simple_save.simple_save(
session,
path,
inputs={"start": in_placeholder},
outputs={"output": out_sparse_tensor})
return path
def test_load_sparse_outputs(self):
path = self._model_with_sparse_output()
imported = load.load(path)
imported_fn = imported.signatures["serving_default"]
forty_two = constant_op.constant([42], dtype=dtypes.int64)
self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy())
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()