Address review comments an d fix a few pep8 linter warnings

This commit is contained in:
frreiss 2019-05-22 16:34:28 -07:00
parent 65aacb43a6
commit 4c7cf10b6c

View File

@ -192,28 +192,15 @@ def _lift_unlifted_variables(graph, variable_holder):
mutable_collection[index] = lifted_variables.get(current, current) mutable_collection[index] = lifted_variables.get(current, current)
def _sparse_to_dense(sparse_tensor_list):
"""
Extract out and return the dense components (elements, indices, shape) of an
iterable of `SparseTensor`s.
"""
ret = []
for s in sparse_tensor_list:
ret.append(s.indices)
ret.append(s.values)
ret.append(s.dense_shape)
return ret
def _lift_sparse_tensor(orig_sparse_tensor, lift_map): def _lift_sparse_tensor(orig_sparse_tensor, lift_map):
""" """
Args: Args:
orig_sparse_tensor: SparseTensors object whose underlying dense Tensors orig_sparse_tensor: SparseTensors object whose underlying dense Tensors
reside in a different graph reside in a different graph
lift_map: Map (as returned by `lift_to_graph`) from tensors in the other lift_map: Map (as returned by `lift_to_graph`) from tensors in the other
graph to tensors in the current graph. graph to tensors in the current graph.
Returns: Returns:
A new copy of `orig_sparse_tensor` whose underlying dense tensors are in A new copy of `orig_sparse_tensor` whose underlying dense tensors are in
the current graph the current graph
""" """
return sparse_tensor.SparseTensor( return sparse_tensor.SparseTensor(
@ -221,7 +208,8 @@ def _lift_sparse_tensor(orig_sparse_tensor, lift_map):
values=lift_map[orig_sparse_tensor.values], values=lift_map[orig_sparse_tensor.values],
dense_shape=lift_map[orig_sparse_tensor.dense_shape] dense_shape=lift_map[orig_sparse_tensor.dense_shape]
) )
# TODO(allenl): make this trackable # TODO(allenl): make this trackable
class WrappedFunction(function.ConcreteFunction): class WrappedFunction(function.ConcreteFunction):
"""Wraps a tf V1 piece of code in a function.""" """Wraps a tf V1 piece of code in a function."""
@ -273,46 +261,45 @@ class WrappedFunction(function.ConcreteFunction):
operation_fetches = [] operation_fetches = []
tensor_fetches = [] tensor_fetches = []
sparse_tensor_fetches = []
tensor_infos = [] tensor_infos = []
def _fetch_preprocesing_callback(f): def _fetch_preprocesing_callback(fetch):
"""Extract out lists of ops, tensors, and tensor type info. """Extract out lists of ops, tensors, and tensor type info.
Turns TensorInfos into Tensors in the original `fetches` structure. Turns TensorInfos into Tensors in the original `fetches` structure.
Also extracts sparse tensors and ops from `fetches`. Also extracts ops from `fetches`.
Args: Args:
f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
identifying a Tensor or Operation. identifying a Tensor or Operation.
Returns: Returns:
`f` converted to a Tensor. `fetch` converted to a Tensor.
""" """
if isinstance(f, ops.Operation): if isinstance(fetch, ops.Operation):
operation_fetches.append(f) operation_fetches.append(fetch)
return f return fetch
elif isinstance(f, meta_graph_pb2.TensorInfo): elif isinstance(fetch, meta_graph_pb2.TensorInfo):
tensor_infos.append(f) tensor_infos.append(fetch)
decoded = _get_element_from_tensor_info(f, self._func_graph) decoded = _get_element_from_tensor_info(fetch, self._func_graph)
if isinstance(decoded, sparse_tensor.SparseTensor): if tensor_util.is_tensor(decoded):
sparse_tensor_fetches.append(decoded)
elif tensor_util.is_tensor(decoded):
tensor_fetches.append(decoded) tensor_fetches.append(decoded)
else: else:
operation_fetches.append(decoded) operation_fetches.append(decoded)
return decoded return decoded
elif isinstance(f, ops.Tensor): elif isinstance(fetch, ops.Tensor):
tensor_fetches.append(f) tensor_fetches.append(fetch)
return f return fetch
else: else:
graph_element = self.graph.as_graph_element(f) graph_element = self.graph.as_graph_element(fetch)
return _fetch_preprocesing_callback(graph_element) return _fetch_preprocesing_callback(graph_element)
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches) fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
for f in flat_feeds + tensor_fetches + operation_fetches \ # Turn composite/sparse tensors into dense Tensors.
+ sparse_tensor_fetches: tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
for f in (flat_feeds + tensor_fetches + operation_fetches):
if f.graph is not self._func_graph: if f.graph is not self._func_graph:
raise ValueError("Can only prune function whose feeds and fetches " raise ValueError("Can only prune function whose feeds and fetches "
"are from this graph (%s). Input %s is from graph %s" % "are from this graph (%s). Input %s is from graph %s" %
@ -320,17 +307,14 @@ class WrappedFunction(function.ConcreteFunction):
with self._func_graph.as_default(): with self._func_graph.as_default():
pruned_graph = func_graph.FuncGraph(name) pruned_graph = func_graph.FuncGraph(name)
lift_map = lift_to_graph.lift_to_graph( lift_map = lift_to_graph.lift_to_graph(
operation_fetches + tensor_fetches operation_fetches + tensor_fetches,
+ _sparse_to_dense(sparse_tensor_fetches),
pruned_graph, pruned_graph,
sources=flat_feeds + internal_captures) sources=flat_feeds + internal_captures)
# Note that we deliberately add the component tensors of any SparseTensors
# to the returned function's outputs list; the list must contain these
# component tensors, or the function's sparse outputs won't work properly.
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
for f in sparse_tensor_fetches:
# Outputs list can only contain dense tensors, but it must contain any
# tensors that are part of an output SparseTensor.
f_lifted = _lift_sparse_tensor(f, lift_map)
pruned_graph.outputs.extend([f_lifted.indices, f_lifted.values,
f_lifted.dense_shape])
pruned_graph.control_outputs.extend( pruned_graph.control_outputs.extend(
[lift_map[operation] for operation in operation_fetches]) [lift_map[operation] for operation in operation_fetches])
for external_capture, internal_capture in self.graph.captures.items(): for external_capture, internal_capture in self.graph.captures.items():