From a2717114558ea58ed7ff710a2257a535e3fdadf4 Mon Sep 17 00:00:00 2001 From: Vojtech Bardiovsky Date: Tue, 10 Sep 2019 04:21:30 -0700 Subject: [PATCH] Fix load_v1_in_v2 to support sparse tensor outputs. PiperOrigin-RevId: 268195564 --- .../python/saved_model/load_v1_in_v2.py | 21 ++++++++++-- .../python/saved_model/load_v1_in_v2_test.py | 32 +++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py index 4ddd18bc6f3..ec1647cb949 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2.py +++ b/tensorflow/python/saved_model/load_v1_in_v2.py @@ -26,6 +26,7 @@ from tensorflow.python.eager import wrap_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import function_deserialization @@ -127,12 +128,26 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): signature_functions = {} for signature_key, signature_def in meta_graph_def.signature_def.items(): if signature_def.inputs: - input_names, input_specs = zip(*signature_def.inputs.items()) + original_input_names, input_specs = zip(*signature_def.inputs.items()) else: - input_names = [] + original_input_names = [] input_specs = [] # TODO(allenl): Support optional arguments - feeds = [wrapped.graph.as_graph_element(inp.name) for inp in input_specs] + feeds = [ + wrap_function._get_element_from_tensor_info(input_spec, wrapped.graph) # pylint: disable=protected-access + for input_spec in input_specs + ] + input_names = [] + for original_input_name, feed in zip(original_input_names, feeds): + if isinstance(feed, sparse_tensor.SparseTensor): + # We have to give explicit name for SparseTensor arguments, because + # these are not present in the TensorInfo. + indices_name = "%s_indices" % original_input_name + values_name = "%s_values" % original_input_name + dense_shape_name = "%s_dense_shape" % original_input_name + input_names.extend([indices_name, values_name, dense_shape_name]) + else: + input_names.append(original_input_name) fetches = {name: out for name, out in signature_def.outputs.items()} try: signature_fn = wrapped.prune(feeds=feeds, fetches=fetches) diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index 906b8198335..f02ab14b21c 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -529,6 +529,38 @@ class LoadTest(test.TestCase): forty_two = constant_op.constant([42], dtype=dtypes.int64) self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy()) + def _model_with_sparse_input(self): + """Generate a graph with a SparseTensor input and serialize in V1 format.""" + export_graph = ops.Graph() + with export_graph.as_default(): + in_sparse_placeholder = array_ops.sparse_placeholder( + dtype=dtypes.int64, shape=[2, 2]) + out_sparse_tensor = sparse_tensor.SparseTensor( + indices=in_sparse_placeholder.indices, + values=in_sparse_placeholder.values, + dense_shape=in_sparse_placeholder.dense_shape) * 2 + with session_lib.Session() as session: + path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) + simple_save.simple_save( + session, + path, + inputs={"start": in_sparse_placeholder}, + outputs={"output": out_sparse_tensor}) + return path + + def test_load_sparse_inputs(self): + path = self._model_with_sparse_input() + imported = load.load(path) + imported_fn = imported.signatures["serving_default"] + indices = constant_op.constant([[0, 0], [0, 1], [1, 1]], dtype=dtypes.int64) + values = constant_op.constant([42, 43, 44], dtype=dtypes.int64) + dense_shape = constant_op.constant([2, 2], dtype=dtypes.int64) + result = imported_fn( + start_indices=indices, + start_values=values, + start_dense_shape=dense_shape) + self.assertAllEqual([84, 86, 88], result["output"].values.numpy()) + def _model_with_defun(self): """Generate a graph with a Defun and serialize in V1 format.""" export_graph = ops.Graph()