diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 17a09378d69..eba251aaa16 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -192,6 +192,36 @@ def _lift_unlifted_variables(graph, variable_holder): mutable_collection[index] = lifted_variables.get(current, current) +def _sparse_to_dense(sparse_tensor_list): + """ + Extract out and return the dense components (elements, indices, shape) of an + iterable of `SparseTensor`s. + """ + ret = [] + for s in sparse_tensor_list: + ret.append(s.indices) + ret.append(s.values) + ret.append(s.dense_shape) + return ret + + +def _lift_sparse_tensor(orig_sparse_tensor, lift_map): + """ + Args: + orig_sparse_tensor: SparseTensors object whose underlying dense Tensors + reside in a different graph + lift_map: Map (as returned by `lift_to_graph`) from tensors in the other + graph to tensors in the current graph. + Returns: + A new copy of `orig_sparse_tensor` whose underlying dense tensors are in + the current graph + """ + return sparse_tensor.SparseTensor( + indices=lift_map[orig_sparse_tensor.indices], + values=lift_map[orig_sparse_tensor.values], + dense_shape=lift_map[orig_sparse_tensor.dense_shape] + ) + # TODO(allenl): make this trackable class WrappedFunction(function.ConcreteFunction): """Wraps a tf V1 piece of code in a function.""" @@ -243,12 +273,14 @@ class WrappedFunction(function.ConcreteFunction): operation_fetches = [] tensor_fetches = [] + sparse_tensor_fetches = [] tensor_infos = [] def _fetch_preprocesing_callback(f): """Extract out lists of ops, tensors, and tensor type info. - Turns TensorInfos into Tensors in the original fetches structure. + Turns TensorInfos into Tensors in the original `fetches` structure. + Also extracts sparse tensors and ops from `fetches`. Args: f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string @@ -263,7 +295,9 @@ class WrappedFunction(function.ConcreteFunction): elif isinstance(f, meta_graph_pb2.TensorInfo): tensor_infos.append(f) decoded = _get_element_from_tensor_info(f, self._func_graph) - if tensor_util.is_tensor(decoded): + if isinstance(decoded, sparse_tensor.SparseTensor): + sparse_tensor_fetches.append(decoded) + elif tensor_util.is_tensor(decoded): tensor_fetches.append(decoded) else: operation_fetches.append(decoded) @@ -277,7 +311,8 @@ class WrappedFunction(function.ConcreteFunction): fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) - for f in flat_feeds + tensor_fetches + operation_fetches: + for f in flat_feeds + tensor_fetches + operation_fetches \ + + sparse_tensor_fetches: if f.graph is not self._func_graph: raise ValueError("Can only prune function whose feeds and fetches " "are from this graph (%s). Input %s is from graph %s" % @@ -285,10 +320,17 @@ class WrappedFunction(function.ConcreteFunction): with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph(name) lift_map = lift_to_graph.lift_to_graph( - operation_fetches + tensor_fetches, + operation_fetches + tensor_fetches + + _sparse_to_dense(sparse_tensor_fetches), pruned_graph, sources=flat_feeds + internal_captures) pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) + for f in sparse_tensor_fetches: + # Outputs list can only contain dense tensors, but it must contain any + # tensors that are part of an output SparseTensor. + f_lifted = _lift_sparse_tensor(f, lift_map) + pruned_graph.outputs.extend([f_lifted.indices, f_lifted.values, + f_lifted.dense_shape]) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) for external_capture, internal_capture in self.graph.captures.items(): @@ -308,6 +350,9 @@ class WrappedFunction(function.ConcreteFunction): pruned_graph.variables = self.graph.variables def _structured_output_mapping(fetched): + """`nest.map_structure()` callback.""" + if isinstance(fetched, sparse_tensor.SparseTensor): + return _lift_sparse_tensor(fetched, lift_map) lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index 8c64413a42c..9ccef5f5e3f 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -29,6 +29,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.framework import versions @@ -489,5 +490,29 @@ class LoadTest(test.TestCase): root = load.load(path) self.assertFalse(root.variables[0].trainable) + def _model_with_sparse_output(self): + """Generate a graph with a SparseTensor output and serialize in V1 format""" + export_graph = ops.Graph() + with export_graph.as_default(): + in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1]) + out_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0]], + values=in_placeholder, + dense_shape=[1]) * 2 + with session_lib.Session() as session: + path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) + simple_save.simple_save( + session, + path, + inputs={"start": in_placeholder}, + outputs={"output": out_sparse_tensor}) + return path + + def test_load_sparse_outputs(self): + path = self._model_with_sparse_output() + imported = load.load(path) + imported_fn = imported.signatures["serving_default"] + forty_two = constant_op.constant([42], dtype=dtypes.int64) + self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy()) + if __name__ == "__main__": test.main()