Fix load_v1_in_v2 to support sparse tensor outputs.

PiperOrigin-RevId: 268195564
This commit is contained in:
Vojtech Bardiovsky 2019-09-10 04:21:30 -07:00 committed by TensorFlower Gardener
parent 2a7d2c4295
commit a271711455
2 changed files with 50 additions and 3 deletions

View File

@ -26,6 +26,7 @@ from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import function_deserialization
@ -127,12 +128,26 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
signature_functions = {}
for signature_key, signature_def in meta_graph_def.signature_def.items():
if signature_def.inputs:
input_names, input_specs = zip(*signature_def.inputs.items())
original_input_names, input_specs = zip(*signature_def.inputs.items())
else:
input_names = []
original_input_names = []
input_specs = []
# TODO(allenl): Support optional arguments
feeds = [wrapped.graph.as_graph_element(inp.name) for inp in input_specs]
feeds = [
wrap_function._get_element_from_tensor_info(input_spec, wrapped.graph) # pylint: disable=protected-access
for input_spec in input_specs
]
input_names = []
for original_input_name, feed in zip(original_input_names, feeds):
if isinstance(feed, sparse_tensor.SparseTensor):
# We have to give explicit name for SparseTensor arguments, because
# these are not present in the TensorInfo.
indices_name = "%s_indices" % original_input_name
values_name = "%s_values" % original_input_name
dense_shape_name = "%s_dense_shape" % original_input_name
input_names.extend([indices_name, values_name, dense_shape_name])
else:
input_names.append(original_input_name)
fetches = {name: out for name, out in signature_def.outputs.items()}
try:
signature_fn = wrapped.prune(feeds=feeds, fetches=fetches)

View File

@ -529,6 +529,38 @@ class LoadTest(test.TestCase):
forty_two = constant_op.constant([42], dtype=dtypes.int64)
self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy())
def _model_with_sparse_input(self):
"""Generate a graph with a SparseTensor input and serialize in V1 format."""
export_graph = ops.Graph()
with export_graph.as_default():
in_sparse_placeholder = array_ops.sparse_placeholder(
dtype=dtypes.int64, shape=[2, 2])
out_sparse_tensor = sparse_tensor.SparseTensor(
indices=in_sparse_placeholder.indices,
values=in_sparse_placeholder.values,
dense_shape=in_sparse_placeholder.dense_shape) * 2
with session_lib.Session() as session:
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
simple_save.simple_save(
session,
path,
inputs={"start": in_sparse_placeholder},
outputs={"output": out_sparse_tensor})
return path
def test_load_sparse_inputs(self):
path = self._model_with_sparse_input()
imported = load.load(path)
imported_fn = imported.signatures["serving_default"]
indices = constant_op.constant([[0, 0], [0, 1], [1, 1]], dtype=dtypes.int64)
values = constant_op.constant([42, 43, 44], dtype=dtypes.int64)
dense_shape = constant_op.constant([2, 2], dtype=dtypes.int64)
result = imported_fn(
start_indices=indices,
start_values=values,
start_dense_shape=dense_shape)
self.assertAllEqual([84, 86, 88], result["output"].values.numpy())
def _model_with_defun(self):
"""Generate a graph with a Defun and serialize in V1 format."""
export_graph = ops.Graph()