diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 17a09378d69..ef731fae4a9 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -245,39 +245,43 @@ class WrappedFunction(function.ConcreteFunction): tensor_fetches = [] tensor_infos = [] - def _fetch_preprocesing_callback(f): + def _fetch_preprocesing_callback(fetch): """Extract out lists of ops, tensors, and tensor type info. - Turns TensorInfos into Tensors in the original fetches structure. + Turns TensorInfos into Tensors in the original `fetches` structure. + Also extracts ops from `fetches`. Args: - f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string - identifying a Tensor or Operation. + fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or + string identifying a Tensor or Operation. Returns: - `f` converted to a Tensor. + `fetch` converted to a Tensor. """ - if isinstance(f, ops.Operation): - operation_fetches.append(f) - return f - elif isinstance(f, meta_graph_pb2.TensorInfo): - tensor_infos.append(f) - decoded = _get_element_from_tensor_info(f, self._func_graph) + if isinstance(fetch, ops.Operation): + operation_fetches.append(fetch) + return fetch + elif isinstance(fetch, meta_graph_pb2.TensorInfo): + tensor_infos.append(fetch) + decoded = _get_element_from_tensor_info(fetch, self._func_graph) if tensor_util.is_tensor(decoded): tensor_fetches.append(decoded) else: operation_fetches.append(decoded) return decoded - elif isinstance(f, ops.Tensor): - tensor_fetches.append(f) - return f + elif isinstance(fetch, ops.Tensor): + tensor_fetches.append(fetch) + return fetch else: - graph_element = self.graph.as_graph_element(f) + graph_element = self.graph.as_graph_element(fetch) return _fetch_preprocesing_callback(graph_element) fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) - for f in flat_feeds + tensor_fetches + operation_fetches: + # Expand composite tensors into their component dense Tensors. + tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) + + for f in (flat_feeds + tensor_fetches + operation_fetches): if f.graph is not self._func_graph: raise ValueError("Can only prune function whose feeds and fetches " "are from this graph (%s). Input %s is from graph %s" % @@ -288,6 +292,10 @@ class WrappedFunction(function.ConcreteFunction): operation_fetches + tensor_fetches, pruned_graph, sources=flat_feeds + internal_captures) + + # Note that we add the component tensors of any composite tensors to the + # returned function's outputs list; the list must contain these component + # tensors, or the function's sparse outputs won't work properly. pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) @@ -308,13 +316,17 @@ class WrappedFunction(function.ConcreteFunction): pruned_graph.variables = self.graph.variables def _structured_output_mapping(fetched): + """callback for `nest.map_structure()`""" lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None return lifted + # expand_composites=True here causes composite tensors to be expanded + # into their component dense Tensors, mapped to the new graph, and then + # reconstituted into their original composite form. pruned_graph.structured_outputs = nest.map_structure( - _structured_output_mapping, fetches) + _structured_output_mapping, fetches, expand_composites=True) pruned_graph.structured_input_signature = input_signature pruned_fn = WrappedFunction( pruned_graph, variable_holder=self._variable_holder) diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index 8c64413a42c..387efef5426 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -29,6 +29,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.framework import versions @@ -489,5 +490,29 @@ class LoadTest(test.TestCase): root = load.load(path) self.assertFalse(root.variables[0].trainable) + def _model_with_sparse_output(self): + """Generate a graph with a SparseTensor output and serialize in V1 format""" + export_graph = ops.Graph() + with export_graph.as_default(): + in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1]) + out_sparse_tensor = sparse_tensor.SparseTensor( + indices=[[0]], values=in_placeholder, dense_shape=[1]) * 2 + with session_lib.Session() as session: + path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) + simple_save.simple_save( + session, + path, + inputs={"start": in_placeholder}, + outputs={"output": out_sparse_tensor}) + return path + + def test_load_sparse_outputs(self): + path = self._model_with_sparse_output() + imported = load.load(path) + imported_fn = imported.signatures["serving_default"] + forty_two = constant_op.constant([42], dtype=dtypes.int64) + self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy()) + + if __name__ == "__main__": test.main()