Fix load_v1_in_v2 to support sparse tensor outputs.
PiperOrigin-RevId: 268195564
This commit is contained in:
parent
2a7d2c4295
commit
a271711455
|
@ -26,6 +26,7 @@ from tensorflow.python.eager import wrap_function
|
|||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import function_deserialization
|
||||
|
@ -127,12 +128,26 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
|||
signature_functions = {}
|
||||
for signature_key, signature_def in meta_graph_def.signature_def.items():
|
||||
if signature_def.inputs:
|
||||
input_names, input_specs = zip(*signature_def.inputs.items())
|
||||
original_input_names, input_specs = zip(*signature_def.inputs.items())
|
||||
else:
|
||||
input_names = []
|
||||
original_input_names = []
|
||||
input_specs = []
|
||||
# TODO(allenl): Support optional arguments
|
||||
feeds = [wrapped.graph.as_graph_element(inp.name) for inp in input_specs]
|
||||
feeds = [
|
||||
wrap_function._get_element_from_tensor_info(input_spec, wrapped.graph) # pylint: disable=protected-access
|
||||
for input_spec in input_specs
|
||||
]
|
||||
input_names = []
|
||||
for original_input_name, feed in zip(original_input_names, feeds):
|
||||
if isinstance(feed, sparse_tensor.SparseTensor):
|
||||
# We have to give explicit name for SparseTensor arguments, because
|
||||
# these are not present in the TensorInfo.
|
||||
indices_name = "%s_indices" % original_input_name
|
||||
values_name = "%s_values" % original_input_name
|
||||
dense_shape_name = "%s_dense_shape" % original_input_name
|
||||
input_names.extend([indices_name, values_name, dense_shape_name])
|
||||
else:
|
||||
input_names.append(original_input_name)
|
||||
fetches = {name: out for name, out in signature_def.outputs.items()}
|
||||
try:
|
||||
signature_fn = wrapped.prune(feeds=feeds, fetches=fetches)
|
||||
|
|
|
@ -529,6 +529,38 @@ class LoadTest(test.TestCase):
|
|||
forty_two = constant_op.constant([42], dtype=dtypes.int64)
|
||||
self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy())
|
||||
|
||||
def _model_with_sparse_input(self):
|
||||
"""Generate a graph with a SparseTensor input and serialize in V1 format."""
|
||||
export_graph = ops.Graph()
|
||||
with export_graph.as_default():
|
||||
in_sparse_placeholder = array_ops.sparse_placeholder(
|
||||
dtype=dtypes.int64, shape=[2, 2])
|
||||
out_sparse_tensor = sparse_tensor.SparseTensor(
|
||||
indices=in_sparse_placeholder.indices,
|
||||
values=in_sparse_placeholder.values,
|
||||
dense_shape=in_sparse_placeholder.dense_shape) * 2
|
||||
with session_lib.Session() as session:
|
||||
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
|
||||
simple_save.simple_save(
|
||||
session,
|
||||
path,
|
||||
inputs={"start": in_sparse_placeholder},
|
||||
outputs={"output": out_sparse_tensor})
|
||||
return path
|
||||
|
||||
def test_load_sparse_inputs(self):
|
||||
path = self._model_with_sparse_input()
|
||||
imported = load.load(path)
|
||||
imported_fn = imported.signatures["serving_default"]
|
||||
indices = constant_op.constant([[0, 0], [0, 1], [1, 1]], dtype=dtypes.int64)
|
||||
values = constant_op.constant([42, 43, 44], dtype=dtypes.int64)
|
||||
dense_shape = constant_op.constant([2, 2], dtype=dtypes.int64)
|
||||
result = imported_fn(
|
||||
start_indices=indices,
|
||||
start_values=values,
|
||||
start_dense_shape=dense_shape)
|
||||
self.assertAllEqual([84, 86, 88], result["output"].values.numpy())
|
||||
|
||||
def _model_with_defun(self):
|
||||
"""Generate a graph with a Defun and serialize in V1 format."""
|
||||
export_graph = ops.Graph()
|
||||
|
|
Loading…
Reference in New Issue