Fix transform for cyclic graph.

Improve collection name handling.
Added helper to retrieve corresponding tensor/op.
Change: 144824657
This commit is contained in:
A. Unique TensorFlower 2017-01-18 05:53:24 -08:00 committed by TensorFlower Gardener
parent 1c7ef3db08
commit e28da73a92
2 changed files with 105 additions and 5 deletions
tensorflow/contrib/graph_editor

View File

@ -26,13 +26,13 @@ from six import iteritems
from six import iterkeys
from six import string_types
from six import StringIO
from tensorflow.contrib.graph_editor import edit
from tensorflow.contrib.graph_editor import reroute
from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
from tensorflow.contrib.graph_editor import util
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.platform import tf_logging as logging
__all__ = [
"replace_t_with_placeholder_handler",
@ -87,17 +87,24 @@ def keep_t_if_possible_handler(info, t):
def assign_renamed_collections_handler(info, elem, elem_):
"""Add the transformed elem to the (renamed) collections of elem.
A collection is renamed only if is not a known key, as described in
`tf.GraphKeys`.
Args:
info: Transform._Info instance.
elem: the original element (`tf.Tensor` or `tf.Operation`)
elem_: the transformed element
"""
# TODO(fkp): handle known special cases
known_collection_names = util.get_predefined_collection_names()
for name, collection in iteritems(info.collections):
if elem not in collection:
continue
collection_name_ = info.transformer.new_name(name)
info.graph_.add_to_collection(collection_name_, elem_)
if name in known_collection_names:
transformed_name = name
else:
transformed_name = info.transformer.new_name(name)
info.graph_.add_to_collection(transformed_name, elem_)
def transform_op_if_inside_handler(info, op, keep_if_possible=True):
@ -150,6 +157,11 @@ def copy_op_handler(info, op, copy_shape=True):
# Transform inputs:
inputs_ = [info.transformer._transform_t(t) for t in op.inputs]
# Leave inputs empty if a graph cycle was found.
if None in inputs_:
info.cyclic_ops.append(op)
inputs_ = []
# Clone the node def:
node_def_ = deepcopy(op._node_def)
@ -239,7 +251,7 @@ class Transformer(object):
self.transformed_ts = {}
self.collections = dict((key, self.graph.get_collection(key))
for key in self.graph.get_all_collection_keys())
self.cyclic_ops = []
class ResultInfo(object):
""""Contains information about the result of a transform operation."""
@ -452,6 +464,17 @@ class Transformer(object):
for op in remaining_roots:
self._transform_op(op)
# Finalize cyclic ops:
for op in self._info.cyclic_ops:
logging.debug("Finalizing cyclic op: %s", op.name)
op_ = self._info.transformed_ops[op]
inputs_ = [self._info.transformed_ts[t] for t in op.inputs]
if None in inputs_:
raise ValueError("Could not find all the inputs of cyclic op: {}"
.format(op_.name))
for input_id, t_ in enumerate(inputs_):
op_._update_input(input_id, t_) # pylint: disable=protected-access
sgv_ = self._transform_sgv(sgv)
res_info = Transformer.ResultInfo(self._info)
@ -506,9 +529,13 @@ class Transformer(object):
Returns:
The transformed tensor.
"""
logging.debug("Transforming tensor: %s", t.name)
if t in self._info.transformed_ts:
return self._info.transformed_ts[t]
# Mark as None to detect cycle.
self._info.transformed_ts[t] = None
op, op_index = t.op, t.value_index
# If op is not in the subgraph:

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import collections
import re
from six import iteritems
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import array_ops as tf_array_ops
@ -465,3 +466,75 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
"""
return tf_array_ops.placeholder(
dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$")
def get_predefined_collection_names():
"""Return all the predefined collection names."""
return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys)
if not _INTERNAL_VARIABLE_RE.match(key)]
def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""):
"""Find corresponding op/tensor in a different graph.
Args:
target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph.
dst_graph: The graph in which the corresponding graph element must be found.
dst_scope: A scope which is prepended to the name to look for.
src_scope: A scope which is removed from the original of `target` name.
Returns:
The corresponding tf.Tensor` or a `tf.Operation`.
Raises:
ValueError: if `src_name` does not start with `src_scope`.
TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation`
KeyError: If the corresponding graph element cannot be found.
"""
src_name = target.name
if src_scope:
src_scope = scope_finalize(src_scope)
if not src_name.startswidth(src_scope):
raise ValueError("{} does not start with {}".format(src_name, src_scope))
src_name = src_name[len(src_scope):]
dst_name = src_name
if dst_scope:
dst_scope = scope_finalize(dst_scope)
dst_name = dst_scope + dst_name
if isinstance(target, tf_ops.Tensor):
return dst_graph.get_tensor_by_name(dst_name)
if isinstance(target, tf_ops.Operation):
return dst_graph.get_operation_by_name(dst_name)
raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target))
def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""):
"""Find corresponding ops/tensors in a different graph.
`targets` is a Python tree, that is, a nested structure of iterable
(list, tupple, dictionary) whose leaves are instances of
`tf.Tensor` or `tf.Operation`
Args:
targets: A Python tree containing `tf.Tensor` or `tf.Operation`
belonging to the original graph.
dst_graph: The graph in which the corresponding graph element must be found.
dst_scope: A scope which is prepended to the name to look for.
src_scope: A scope which is removed from the original of `top` name.
Returns:
A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`.
Raises:
ValueError: if `src_name` does not start with `src_scope`.
TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation`
KeyError: If the corresponding graph element cannot be found.
"""
def func(top):
return find_corresponding_elem(top, dst_graph, dst_scope, src_scope)
return transform_tree(targets, func)