diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index eba251aaa16..b85526a4d32 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -192,28 +192,15 @@ def _lift_unlifted_variables(graph, variable_holder): mutable_collection[index] = lifted_variables.get(current, current) -def _sparse_to_dense(sparse_tensor_list): - """ - Extract out and return the dense components (elements, indices, shape) of an - iterable of `SparseTensor`s. - """ - ret = [] - for s in sparse_tensor_list: - ret.append(s.indices) - ret.append(s.values) - ret.append(s.dense_shape) - return ret - - def _lift_sparse_tensor(orig_sparse_tensor, lift_map): """ Args: - orig_sparse_tensor: SparseTensors object whose underlying dense Tensors + orig_sparse_tensor: SparseTensors object whose underlying dense Tensors reside in a different graph lift_map: Map (as returned by `lift_to_graph`) from tensors in the other graph to tensors in the current graph. Returns: - A new copy of `orig_sparse_tensor` whose underlying dense tensors are in + A new copy of `orig_sparse_tensor` whose underlying dense tensors are in the current graph """ return sparse_tensor.SparseTensor( @@ -221,7 +208,8 @@ def _lift_sparse_tensor(orig_sparse_tensor, lift_map): values=lift_map[orig_sparse_tensor.values], dense_shape=lift_map[orig_sparse_tensor.dense_shape] ) - + + # TODO(allenl): make this trackable class WrappedFunction(function.ConcreteFunction): """Wraps a tf V1 piece of code in a function.""" @@ -273,46 +261,45 @@ class WrappedFunction(function.ConcreteFunction): operation_fetches = [] tensor_fetches = [] - sparse_tensor_fetches = [] tensor_infos = [] - def _fetch_preprocesing_callback(f): + def _fetch_preprocesing_callback(fetch): """Extract out lists of ops, tensors, and tensor type info. Turns TensorInfos into Tensors in the original `fetches` structure. - Also extracts sparse tensors and ops from `fetches`. + Also extracts ops from `fetches`. Args: - f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string + fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string identifying a Tensor or Operation. Returns: - `f` converted to a Tensor. + `fetch` converted to a Tensor. """ - if isinstance(f, ops.Operation): - operation_fetches.append(f) - return f - elif isinstance(f, meta_graph_pb2.TensorInfo): - tensor_infos.append(f) - decoded = _get_element_from_tensor_info(f, self._func_graph) - if isinstance(decoded, sparse_tensor.SparseTensor): - sparse_tensor_fetches.append(decoded) - elif tensor_util.is_tensor(decoded): + if isinstance(fetch, ops.Operation): + operation_fetches.append(fetch) + return fetch + elif isinstance(fetch, meta_graph_pb2.TensorInfo): + tensor_infos.append(fetch) + decoded = _get_element_from_tensor_info(fetch, self._func_graph) + if tensor_util.is_tensor(decoded): tensor_fetches.append(decoded) else: operation_fetches.append(decoded) return decoded - elif isinstance(f, ops.Tensor): - tensor_fetches.append(f) - return f + elif isinstance(fetch, ops.Tensor): + tensor_fetches.append(fetch) + return fetch else: - graph_element = self.graph.as_graph_element(f) + graph_element = self.graph.as_graph_element(fetch) return _fetch_preprocesing_callback(graph_element) fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) - for f in flat_feeds + tensor_fetches + operation_fetches \ - + sparse_tensor_fetches: + # Turn composite/sparse tensors into dense Tensors. + tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) + + for f in (flat_feeds + tensor_fetches + operation_fetches): if f.graph is not self._func_graph: raise ValueError("Can only prune function whose feeds and fetches " "are from this graph (%s). Input %s is from graph %s" % @@ -320,17 +307,14 @@ class WrappedFunction(function.ConcreteFunction): with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph(name) lift_map = lift_to_graph.lift_to_graph( - operation_fetches + tensor_fetches - + _sparse_to_dense(sparse_tensor_fetches), + operation_fetches + tensor_fetches, pruned_graph, sources=flat_feeds + internal_captures) + + # Note that we deliberately add the component tensors of any SparseTensors + # to the returned function's outputs list; the list must contain these + # component tensors, or the function's sparse outputs won't work properly. pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) - for f in sparse_tensor_fetches: - # Outputs list can only contain dense tensors, but it must contain any - # tensors that are part of an output SparseTensor. - f_lifted = _lift_sparse_tensor(f, lift_map) - pruned_graph.outputs.extend([f_lifted.indices, f_lifted.values, - f_lifted.dense_shape]) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) for external_capture, internal_capture in self.graph.captures.items():