Address review comments an d fix a few pep8 linter warnings
This commit is contained in:
parent
65aacb43a6
commit
4c7cf10b6c
@ -192,19 +192,6 @@ def _lift_unlifted_variables(graph, variable_holder):
|
||||
mutable_collection[index] = lifted_variables.get(current, current)
|
||||
|
||||
|
||||
def _sparse_to_dense(sparse_tensor_list):
|
||||
"""
|
||||
Extract out and return the dense components (elements, indices, shape) of an
|
||||
iterable of `SparseTensor`s.
|
||||
"""
|
||||
ret = []
|
||||
for s in sparse_tensor_list:
|
||||
ret.append(s.indices)
|
||||
ret.append(s.values)
|
||||
ret.append(s.dense_shape)
|
||||
return ret
|
||||
|
||||
|
||||
def _lift_sparse_tensor(orig_sparse_tensor, lift_map):
|
||||
"""
|
||||
Args:
|
||||
@ -222,6 +209,7 @@ def _lift_sparse_tensor(orig_sparse_tensor, lift_map):
|
||||
dense_shape=lift_map[orig_sparse_tensor.dense_shape]
|
||||
)
|
||||
|
||||
|
||||
# TODO(allenl): make this trackable
|
||||
class WrappedFunction(function.ConcreteFunction):
|
||||
"""Wraps a tf V1 piece of code in a function."""
|
||||
@ -273,46 +261,45 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
|
||||
operation_fetches = []
|
||||
tensor_fetches = []
|
||||
sparse_tensor_fetches = []
|
||||
tensor_infos = []
|
||||
|
||||
def _fetch_preprocesing_callback(f):
|
||||
def _fetch_preprocesing_callback(fetch):
|
||||
"""Extract out lists of ops, tensors, and tensor type info.
|
||||
|
||||
Turns TensorInfos into Tensors in the original `fetches` structure.
|
||||
Also extracts sparse tensors and ops from `fetches`.
|
||||
Also extracts ops from `fetches`.
|
||||
|
||||
Args:
|
||||
f: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
|
||||
fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string
|
||||
identifying a Tensor or Operation.
|
||||
|
||||
Returns:
|
||||
`f` converted to a Tensor.
|
||||
`fetch` converted to a Tensor.
|
||||
"""
|
||||
if isinstance(f, ops.Operation):
|
||||
operation_fetches.append(f)
|
||||
return f
|
||||
elif isinstance(f, meta_graph_pb2.TensorInfo):
|
||||
tensor_infos.append(f)
|
||||
decoded = _get_element_from_tensor_info(f, self._func_graph)
|
||||
if isinstance(decoded, sparse_tensor.SparseTensor):
|
||||
sparse_tensor_fetches.append(decoded)
|
||||
elif tensor_util.is_tensor(decoded):
|
||||
if isinstance(fetch, ops.Operation):
|
||||
operation_fetches.append(fetch)
|
||||
return fetch
|
||||
elif isinstance(fetch, meta_graph_pb2.TensorInfo):
|
||||
tensor_infos.append(fetch)
|
||||
decoded = _get_element_from_tensor_info(fetch, self._func_graph)
|
||||
if tensor_util.is_tensor(decoded):
|
||||
tensor_fetches.append(decoded)
|
||||
else:
|
||||
operation_fetches.append(decoded)
|
||||
return decoded
|
||||
elif isinstance(f, ops.Tensor):
|
||||
tensor_fetches.append(f)
|
||||
return f
|
||||
elif isinstance(fetch, ops.Tensor):
|
||||
tensor_fetches.append(fetch)
|
||||
return fetch
|
||||
else:
|
||||
graph_element = self.graph.as_graph_element(f)
|
||||
graph_element = self.graph.as_graph_element(fetch)
|
||||
return _fetch_preprocesing_callback(graph_element)
|
||||
|
||||
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
|
||||
|
||||
for f in flat_feeds + tensor_fetches + operation_fetches \
|
||||
+ sparse_tensor_fetches:
|
||||
# Turn composite/sparse tensors into dense Tensors.
|
||||
tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
|
||||
|
||||
for f in (flat_feeds + tensor_fetches + operation_fetches):
|
||||
if f.graph is not self._func_graph:
|
||||
raise ValueError("Can only prune function whose feeds and fetches "
|
||||
"are from this graph (%s). Input %s is from graph %s" %
|
||||
@ -320,17 +307,14 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
with self._func_graph.as_default():
|
||||
pruned_graph = func_graph.FuncGraph(name)
|
||||
lift_map = lift_to_graph.lift_to_graph(
|
||||
operation_fetches + tensor_fetches
|
||||
+ _sparse_to_dense(sparse_tensor_fetches),
|
||||
operation_fetches + tensor_fetches,
|
||||
pruned_graph,
|
||||
sources=flat_feeds + internal_captures)
|
||||
|
||||
# Note that we deliberately add the component tensors of any SparseTensors
|
||||
# to the returned function's outputs list; the list must contain these
|
||||
# component tensors, or the function's sparse outputs won't work properly.
|
||||
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
|
||||
for f in sparse_tensor_fetches:
|
||||
# Outputs list can only contain dense tensors, but it must contain any
|
||||
# tensors that are part of an output SparseTensor.
|
||||
f_lifted = _lift_sparse_tensor(f, lift_map)
|
||||
pruned_graph.outputs.extend([f_lifted.indices, f_lifted.values,
|
||||
f_lifted.dense_shape])
|
||||
pruned_graph.control_outputs.extend(
|
||||
[lift_map[operation] for operation in operation_fetches])
|
||||
for external_capture, internal_capture in self.graph.captures.items():
|
||||
|
Loading…
Reference in New Issue
Block a user