Allow importing of V1 models that output SparseTensors
Initial implementation Fix typo Fix minor bugs and finish up test case
This commit is contained in:
parent
d102214520
commit
65aacb43a6
@ -192,6 +192,36 @@ def _lift_unlifted_variables(graph, variable_holder):
|
|||||||
mutable_collection[index] = lifted_variables.get(current, current)
|
mutable_collection[index] = lifted_variables.get(current, current)
|
||||||
|
|
||||||
|
|
||||||
|
def _sparse_to_dense(sparse_tensor_list):
|
||||||
|
"""
|
||||||
|
Extract out and return the dense components (elements, indices, shape) of an
|
||||||
|
iterable of `SparseTensor`s.
|
||||||
|
"""
|
||||||
|
ret = []
|
||||||
|
for s in sparse_tensor_list:
|
||||||
|
ret.append(s.indices)
|
||||||
|
ret.append(s.values)
|
||||||
|
ret.append(s.dense_shape)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _lift_sparse_tensor(orig_sparse_tensor, lift_map):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
orig_sparse_tensor: SparseTensors object whose underlying dense Tensors
|
||||||
|
reside in a different graph
|
||||||
|
lift_map: Map (as returned by `lift_to_graph`) from tensors in the other
|
||||||
|
graph to tensors in the current graph.
|
||||||
|
Returns:
|
||||||
|
A new copy of `orig_sparse_tensor` whose underlying dense tensors are in
|
||||||
|
the current graph
|
||||||
|
"""
|
||||||
|
return sparse_tensor.SparseTensor(
|
||||||
|
indices=lift_map[orig_sparse_tensor.indices],
|
||||||
|
values=lift_map[orig_sparse_tensor.values],
|
||||||
|
dense_shape=lift_map[orig_sparse_tensor.dense_shape]
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(allenl): make this trackable
|
# TODO(allenl): make this trackable
|
||||||
class WrappedFunction(function.ConcreteFunction):
|
class WrappedFunction(function.ConcreteFunction):
|
||||||
"""Wraps a tf V1 piece of code in a function."""
|
"""Wraps a tf V1 piece of code in a function."""
|
||||||
@ -243,12 +273,14 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
|
|
||||||
operation_fetches = []
|
operation_fetches = []
|
||||||
tensor_fetches = []
|
tensor_fetches = []
|
||||||
|
sparse_tensor_fetches = []
|
||||||
tensor_infos = []
|
tensor_infos = []
|
||||||
|
|
||||||
def _fetch_preprocesing_callback(f):
|
def _fetch_preprocesing_callback(f):
|
||||||
"""Extract out lists of ops, tensors, and tensor type info.
|
"""Extract out lists of ops, tensors, and tensor type info.
|
||||||
|
|
||||||
Turns TensorInfos into Tensors in the original fetches structure.
|
Turns TensorInfos into Tensors in the original `fetches` structure.
|
||||||
|
Also extracts sparse tensors and ops from `fetches`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
|
f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
|
||||||
@ -263,7 +295,9 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
elif isinstance(f, meta_graph_pb2.TensorInfo):
|
elif isinstance(f, meta_graph_pb2.TensorInfo):
|
||||||
tensor_infos.append(f)
|
tensor_infos.append(f)
|
||||||
decoded = _get_element_from_tensor_info(f, self._func_graph)
|
decoded = _get_element_from_tensor_info(f, self._func_graph)
|
||||||
if tensor_util.is_tensor(decoded):
|
if isinstance(decoded, sparse_tensor.SparseTensor):
|
||||||
|
sparse_tensor_fetches.append(decoded)
|
||||||
|
elif tensor_util.is_tensor(decoded):
|
||||||
tensor_fetches.append(decoded)
|
tensor_fetches.append(decoded)
|
||||||
else:
|
else:
|
||||||
operation_fetches.append(decoded)
|
operation_fetches.append(decoded)
|
||||||
@ -277,7 +311,8 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
|
|
||||||
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
|
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
|
||||||
|
|
||||||
for f in flat_feeds + tensor_fetches + operation_fetches:
|
for f in flat_feeds + tensor_fetches + operation_fetches \
|
||||||
|
+ sparse_tensor_fetches:
|
||||||
if f.graph is not self._func_graph:
|
if f.graph is not self._func_graph:
|
||||||
raise ValueError("Can only prune function whose feeds and fetches "
|
raise ValueError("Can only prune function whose feeds and fetches "
|
||||||
"are from this graph (%s). Input %s is from graph %s" %
|
"are from this graph (%s). Input %s is from graph %s" %
|
||||||
@ -285,10 +320,17 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
with self._func_graph.as_default():
|
with self._func_graph.as_default():
|
||||||
pruned_graph = func_graph.FuncGraph(name)
|
pruned_graph = func_graph.FuncGraph(name)
|
||||||
lift_map = lift_to_graph.lift_to_graph(
|
lift_map = lift_to_graph.lift_to_graph(
|
||||||
operation_fetches + tensor_fetches,
|
operation_fetches + tensor_fetches
|
||||||
|
+ _sparse_to_dense(sparse_tensor_fetches),
|
||||||
pruned_graph,
|
pruned_graph,
|
||||||
sources=flat_feeds + internal_captures)
|
sources=flat_feeds + internal_captures)
|
||||||
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
|
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
|
||||||
|
for f in sparse_tensor_fetches:
|
||||||
|
# Outputs list can only contain dense tensors, but it must contain any
|
||||||
|
# tensors that are part of an output SparseTensor.
|
||||||
|
f_lifted = _lift_sparse_tensor(f, lift_map)
|
||||||
|
pruned_graph.outputs.extend([f_lifted.indices, f_lifted.values,
|
||||||
|
f_lifted.dense_shape])
|
||||||
pruned_graph.control_outputs.extend(
|
pruned_graph.control_outputs.extend(
|
||||||
[lift_map[operation] for operation in operation_fetches])
|
[lift_map[operation] for operation in operation_fetches])
|
||||||
for external_capture, internal_capture in self.graph.captures.items():
|
for external_capture, internal_capture in self.graph.captures.items():
|
||||||
@ -308,6 +350,9 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
pruned_graph.variables = self.graph.variables
|
pruned_graph.variables = self.graph.variables
|
||||||
|
|
||||||
def _structured_output_mapping(fetched):
|
def _structured_output_mapping(fetched):
|
||||||
|
"""`nest.map_structure()` callback."""
|
||||||
|
if isinstance(fetched, sparse_tensor.SparseTensor):
|
||||||
|
return _lift_sparse_tensor(fetched, lift_map)
|
||||||
lifted = lift_map[fetched]
|
lifted = lift_map[fetched]
|
||||||
if isinstance(lifted, ops.Operation):
|
if isinstance(lifted, ops.Operation):
|
||||||
return None
|
return None
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.eager import test
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.framework import versions
|
from tensorflow.python.framework import versions
|
||||||
@ -489,5 +490,29 @@ class LoadTest(test.TestCase):
|
|||||||
root = load.load(path)
|
root = load.load(path)
|
||||||
self.assertFalse(root.variables[0].trainable)
|
self.assertFalse(root.variables[0].trainable)
|
||||||
|
|
||||||
|
def _model_with_sparse_output(self):
|
||||||
|
"""Generate a graph with a SparseTensor output and serialize in V1 format"""
|
||||||
|
export_graph = ops.Graph()
|
||||||
|
with export_graph.as_default():
|
||||||
|
in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1])
|
||||||
|
out_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0]],
|
||||||
|
values=in_placeholder,
|
||||||
|
dense_shape=[1]) * 2
|
||||||
|
with session_lib.Session() as session:
|
||||||
|
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
|
||||||
|
simple_save.simple_save(
|
||||||
|
session,
|
||||||
|
path,
|
||||||
|
inputs={"start": in_placeholder},
|
||||||
|
outputs={"output": out_sparse_tensor})
|
||||||
|
return path
|
||||||
|
|
||||||
|
def test_load_sparse_outputs(self):
|
||||||
|
path = self._model_with_sparse_output()
|
||||||
|
imported = load.load(path)
|
||||||
|
imported_fn = imported.signatures["serving_default"]
|
||||||
|
forty_two = constant_op.constant([42], dtype=dtypes.int64)
|
||||||
|
self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user