From 65aacb43a6665d99c1226c345531fb795cf8503d Mon Sep 17 00:00:00 2001 From: frreiss Date: Thu, 16 May 2019 16:40:43 -0700 Subject: [PATCH 1/5] Allow importing of V1 models that output SparseTensors Initial implementation Fix typo Fix minor bugs and finish up test case --- tensorflow/python/eager/wrap_function.py | 53 +++++++++++++++++-- .../python/saved_model/load_v1_in_v2_test.py | 25 +++++++++ 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 17a09378d69..eba251aaa16 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -192,6 +192,36 @@ def _lift_unlifted_variables(graph, variable_holder): mutable_collection[index] = lifted_variables.get(current, current) +def _sparse_to_dense(sparse_tensor_list): + """ + Extract out and return the dense components (elements, indices, shape) of an + iterable of `SparseTensor`s. + """ + ret = [] + for s in sparse_tensor_list: + ret.append(s.indices) + ret.append(s.values) + ret.append(s.dense_shape) + return ret + + +def _lift_sparse_tensor(orig_sparse_tensor, lift_map): + """ + Args: + orig_sparse_tensor: SparseTensors object whose underlying dense Tensors + reside in a different graph + lift_map: Map (as returned by `lift_to_graph`) from tensors in the other + graph to tensors in the current graph. + Returns: + A new copy of `orig_sparse_tensor` whose underlying dense tensors are in + the current graph + """ + return sparse_tensor.SparseTensor( + indices=lift_map[orig_sparse_tensor.indices], + values=lift_map[orig_sparse_tensor.values], + dense_shape=lift_map[orig_sparse_tensor.dense_shape] + ) + # TODO(allenl): make this trackable class WrappedFunction(function.ConcreteFunction): """Wraps a tf V1 piece of code in a function.""" @@ -243,12 +273,14 @@ class WrappedFunction(function.ConcreteFunction): operation_fetches = [] tensor_fetches = [] + sparse_tensor_fetches = [] tensor_infos = [] def _fetch_preprocesing_callback(f): """Extract out lists of ops, tensors, and tensor type info. - Turns TensorInfos into Tensors in the original fetches structure. + Turns TensorInfos into Tensors in the original `fetches` structure. + Also extracts sparse tensors and ops from `fetches`. Args: f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string @@ -263,7 +295,9 @@ class WrappedFunction(function.ConcreteFunction): elif isinstance(f, meta_graph_pb2.TensorInfo): tensor_infos.append(f) decoded = _get_element_from_tensor_info(f, self._func_graph) - if tensor_util.is_tensor(decoded): + if isinstance(decoded, sparse_tensor.SparseTensor): + sparse_tensor_fetches.append(decoded) + elif tensor_util.is_tensor(decoded): tensor_fetches.append(decoded) else: operation_fetches.append(decoded) @@ -277,7 +311,8 @@ class WrappedFunction(function.ConcreteFunction): fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) - for f in flat_feeds + tensor_fetches + operation_fetches: + for f in flat_feeds + tensor_fetches + operation_fetches \ + + sparse_tensor_fetches: if f.graph is not self._func_graph: raise ValueError("Can only prune function whose feeds and fetches " "are from this graph (%s). Input %s is from graph %s" % @@ -285,10 +320,17 @@ class WrappedFunction(function.ConcreteFunction): with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph(name) lift_map = lift_to_graph.lift_to_graph( - operation_fetches + tensor_fetches, + operation_fetches + tensor_fetches + + _sparse_to_dense(sparse_tensor_fetches), pruned_graph, sources=flat_feeds + internal_captures) pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) + for f in sparse_tensor_fetches: + # Outputs list can only contain dense tensors, but it must contain any + # tensors that are part of an output SparseTensor. + f_lifted = _lift_sparse_tensor(f, lift_map) + pruned_graph.outputs.extend([f_lifted.indices, f_lifted.values, + f_lifted.dense_shape]) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) for external_capture, internal_capture in self.graph.captures.items(): @@ -308,6 +350,9 @@ class WrappedFunction(function.ConcreteFunction): pruned_graph.variables = self.graph.variables def _structured_output_mapping(fetched): + """`nest.map_structure()` callback.""" + if isinstance(fetched, sparse_tensor.SparseTensor): + return _lift_sparse_tensor(fetched, lift_map) lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index 8c64413a42c..9ccef5f5e3f 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -29,6 +29,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.framework import versions @@ -489,5 +490,29 @@ class LoadTest(test.TestCase): root = load.load(path) self.assertFalse(root.variables[0].trainable) + def _model_with_sparse_output(self): + """Generate a graph with a SparseTensor output and serialize in V1 format""" + export_graph = ops.Graph() + with export_graph.as_default(): + in_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[1]) + out_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0]], + values=in_placeholder, + dense_shape=[1]) * 2 + with session_lib.Session() as session: + path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) + simple_save.simple_save( + session, + path, + inputs={"start": in_placeholder}, + outputs={"output": out_sparse_tensor}) + return path + + def test_load_sparse_outputs(self): + path = self._model_with_sparse_output() + imported = load.load(path) + imported_fn = imported.signatures["serving_default"] + forty_two = constant_op.constant([42], dtype=dtypes.int64) + self.assertEqual([84], imported_fn(forty_two)["output"].values.numpy()) + if __name__ == "__main__": test.main() From 4c7cf10b6cc888d5ee821a992ecdc0b12be4a6dc Mon Sep 17 00:00:00 2001 From: frreiss Date: Wed, 22 May 2019 16:34:28 -0700 Subject: [PATCH 2/5] Address review comments an d fix a few pep8 linter warnings --- tensorflow/python/eager/wrap_function.py | 72 +++++++++--------------- 1 file changed, 28 insertions(+), 44 deletions(-) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index eba251aaa16..b85526a4d32 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -192,28 +192,15 @@ def _lift_unlifted_variables(graph, variable_holder): mutable_collection[index] = lifted_variables.get(current, current) -def _sparse_to_dense(sparse_tensor_list): - """ - Extract out and return the dense components (elements, indices, shape) of an - iterable of `SparseTensor`s. - """ - ret = [] - for s in sparse_tensor_list: - ret.append(s.indices) - ret.append(s.values) - ret.append(s.dense_shape) - return ret - - def _lift_sparse_tensor(orig_sparse_tensor, lift_map): """ Args: - orig_sparse_tensor: SparseTensors object whose underlying dense Tensors + orig_sparse_tensor: SparseTensors object whose underlying dense Tensors reside in a different graph lift_map: Map (as returned by `lift_to_graph`) from tensors in the other graph to tensors in the current graph. Returns: - A new copy of `orig_sparse_tensor` whose underlying dense tensors are in + A new copy of `orig_sparse_tensor` whose underlying dense tensors are in the current graph """ return sparse_tensor.SparseTensor( @@ -221,7 +208,8 @@ def _lift_sparse_tensor(orig_sparse_tensor, lift_map): values=lift_map[orig_sparse_tensor.values], dense_shape=lift_map[orig_sparse_tensor.dense_shape] ) - + + # TODO(allenl): make this trackable class WrappedFunction(function.ConcreteFunction): """Wraps a tf V1 piece of code in a function.""" @@ -273,46 +261,45 @@ class WrappedFunction(function.ConcreteFunction): operation_fetches = [] tensor_fetches = [] - sparse_tensor_fetches = [] tensor_infos = [] - def _fetch_preprocesing_callback(f): + def _fetch_preprocesing_callback(fetch): """Extract out lists of ops, tensors, and tensor type info. Turns TensorInfos into Tensors in the original `fetches` structure. - Also extracts sparse tensors and ops from `fetches`. + Also extracts ops from `fetches`. Args: - f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string + fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string identifying a Tensor or Operation. Returns: - `f` converted to a Tensor. + `fetch` converted to a Tensor. """ - if isinstance(f, ops.Operation): - operation_fetches.append(f) - return f - elif isinstance(f, meta_graph_pb2.TensorInfo): - tensor_infos.append(f) - decoded = _get_element_from_tensor_info(f, self._func_graph) - if isinstance(decoded, sparse_tensor.SparseTensor): - sparse_tensor_fetches.append(decoded) - elif tensor_util.is_tensor(decoded): + if isinstance(fetch, ops.Operation): + operation_fetches.append(fetch) + return fetch + elif isinstance(fetch, meta_graph_pb2.TensorInfo): + tensor_infos.append(fetch) + decoded = _get_element_from_tensor_info(fetch, self._func_graph) + if tensor_util.is_tensor(decoded): tensor_fetches.append(decoded) else: operation_fetches.append(decoded) return decoded - elif isinstance(f, ops.Tensor): - tensor_fetches.append(f) - return f + elif isinstance(fetch, ops.Tensor): + tensor_fetches.append(fetch) + return fetch else: - graph_element = self.graph.as_graph_element(f) + graph_element = self.graph.as_graph_element(fetch) return _fetch_preprocesing_callback(graph_element) fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) - for f in flat_feeds + tensor_fetches + operation_fetches \ - + sparse_tensor_fetches: + # Turn composite/sparse tensors into dense Tensors. + tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) + + for f in (flat_feeds + tensor_fetches + operation_fetches): if f.graph is not self._func_graph: raise ValueError("Can only prune function whose feeds and fetches " "are from this graph (%s). Input %s is from graph %s" % @@ -320,17 +307,14 @@ class WrappedFunction(function.ConcreteFunction): with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph(name) lift_map = lift_to_graph.lift_to_graph( - operation_fetches + tensor_fetches - + _sparse_to_dense(sparse_tensor_fetches), + operation_fetches + tensor_fetches, pruned_graph, sources=flat_feeds + internal_captures) + + # Note that we deliberately add the component tensors of any SparseTensors + # to the returned function's outputs list; the list must contain these + # component tensors, or the function's sparse outputs won't work properly. pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) - for f in sparse_tensor_fetches: - # Outputs list can only contain dense tensors, but it must contain any - # tensors that are part of an output SparseTensor. - f_lifted = _lift_sparse_tensor(f, lift_map) - pruned_graph.outputs.extend([f_lifted.indices, f_lifted.values, - f_lifted.dense_shape]) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) for external_capture, internal_capture in self.graph.captures.items(): From 5060dd212ef82465e24e0de6da3735913be55100 Mon Sep 17 00:00:00 2001 From: frreiss Date: Wed, 22 May 2019 17:13:36 -0700 Subject: [PATCH 3/5] Fix pylint warning --- tensorflow/python/eager/wrap_function.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index b85526a4d32..8c1634f9c48 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -270,8 +270,8 @@ class WrappedFunction(function.ConcreteFunction): Also extracts ops from `fetches`. Args: - fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string - identifying a Tensor or Operation. + fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or + string identifying a Tensor or Operation. Returns: `fetch` converted to a Tensor. From 088e496ba8e0fedba232c1c180e9ade79dac57cb Mon Sep 17 00:00:00 2001 From: frreiss Date: Thu, 23 May 2019 17:19:13 -0700 Subject: [PATCH 4/5] Simplify logic per review comments --- tensorflow/python/eager/wrap_function.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 8c1634f9c48..d23c812e738 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -296,7 +296,7 @@ class WrappedFunction(function.ConcreteFunction): fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) - # Turn composite/sparse tensors into dense Tensors. + # Expand composite tensors into their component dense Tensors. tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) for f in (flat_feeds + tensor_fetches + operation_fetches): @@ -311,9 +311,9 @@ class WrappedFunction(function.ConcreteFunction): pruned_graph, sources=flat_feeds + internal_captures) - # Note that we deliberately add the component tensors of any SparseTensors - # to the returned function's outputs list; the list must contain these - # component tensors, or the function's sparse outputs won't work properly. + # Note that we add the component tensors of any composite tensors to the + # returned function's outputs list; the list must contain these component + # tensors, or the function's sparse outputs won't work properly. pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) @@ -334,16 +334,17 @@ class WrappedFunction(function.ConcreteFunction): pruned_graph.variables = self.graph.variables def _structured_output_mapping(fetched): - """`nest.map_structure()` callback.""" - if isinstance(fetched, sparse_tensor.SparseTensor): - return _lift_sparse_tensor(fetched, lift_map) + """callback for `nest.map_structure()`""" lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None return lifted + # expand_composites=True here causes composite tensors to be expanded + # into their component dense Tensors, mapped to the new graph, and then + # reconstituted into their original composite form. pruned_graph.structured_outputs = nest.map_structure( - _structured_output_mapping, fetches) + _structured_output_mapping, fetches, expand_composites=True) pruned_graph.structured_input_signature = input_signature pruned_fn = WrappedFunction( pruned_graph, variable_holder=self._variable_holder) From 3986d0e353e446474be9383b88ec356ef8b192b9 Mon Sep 17 00:00:00 2001 From: frreiss Date: Thu, 23 May 2019 17:22:44 -0700 Subject: [PATCH 5/5] Remove dead code --- tensorflow/python/eager/wrap_function.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index d23c812e738..ef731fae4a9 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -192,24 +192,6 @@ def _lift_unlifted_variables(graph, variable_holder): mutable_collection[index] = lifted_variables.get(current, current) -def _lift_sparse_tensor(orig_sparse_tensor, lift_map): - """ - Args: - orig_sparse_tensor: SparseTensors object whose underlying dense Tensors - reside in a different graph - lift_map: Map (as returned by `lift_to_graph`) from tensors in the other - graph to tensors in the current graph. - Returns: - A new copy of `orig_sparse_tensor` whose underlying dense tensors are in - the current graph - """ - return sparse_tensor.SparseTensor( - indices=lift_map[orig_sparse_tensor.indices], - values=lift_map[orig_sparse_tensor.values], - dense_shape=lift_map[orig_sparse_tensor.dense_shape] - ) - - # TODO(allenl): make this trackable class WrappedFunction(function.ConcreteFunction): """Wraps a tf V1 piece of code in a function."""