Fix transform for cyclic graph.
Improve collection name handling. Added helper to retrieve corresponding tensor/op. Change: 144824657
This commit is contained in:
parent
1c7ef3db08
commit
e28da73a92
tensorflow/contrib/graph_editor
@ -26,13 +26,13 @@ from six import iteritems
|
||||
from six import iterkeys
|
||||
from six import string_types
|
||||
from six import StringIO
|
||||
|
||||
from tensorflow.contrib.graph_editor import edit
|
||||
from tensorflow.contrib.graph_editor import reroute
|
||||
from tensorflow.contrib.graph_editor import select
|
||||
from tensorflow.contrib.graph_editor import subgraph
|
||||
from tensorflow.contrib.graph_editor import util
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
__all__ = [
|
||||
"replace_t_with_placeholder_handler",
|
||||
@ -87,17 +87,24 @@ def keep_t_if_possible_handler(info, t):
|
||||
def assign_renamed_collections_handler(info, elem, elem_):
|
||||
"""Add the transformed elem to the (renamed) collections of elem.
|
||||
|
||||
A collection is renamed only if is not a known key, as described in
|
||||
`tf.GraphKeys`.
|
||||
|
||||
Args:
|
||||
info: Transform._Info instance.
|
||||
elem: the original element (`tf.Tensor` or `tf.Operation`)
|
||||
elem_: the transformed element
|
||||
"""
|
||||
# TODO(fkp): handle known special cases
|
||||
known_collection_names = util.get_predefined_collection_names()
|
||||
for name, collection in iteritems(info.collections):
|
||||
if elem not in collection:
|
||||
continue
|
||||
collection_name_ = info.transformer.new_name(name)
|
||||
info.graph_.add_to_collection(collection_name_, elem_)
|
||||
|
||||
if name in known_collection_names:
|
||||
transformed_name = name
|
||||
else:
|
||||
transformed_name = info.transformer.new_name(name)
|
||||
info.graph_.add_to_collection(transformed_name, elem_)
|
||||
|
||||
|
||||
def transform_op_if_inside_handler(info, op, keep_if_possible=True):
|
||||
@ -150,6 +157,11 @@ def copy_op_handler(info, op, copy_shape=True):
|
||||
# Transform inputs:
|
||||
inputs_ = [info.transformer._transform_t(t) for t in op.inputs]
|
||||
|
||||
# Leave inputs empty if a graph cycle was found.
|
||||
if None in inputs_:
|
||||
info.cyclic_ops.append(op)
|
||||
inputs_ = []
|
||||
|
||||
# Clone the node def:
|
||||
node_def_ = deepcopy(op._node_def)
|
||||
|
||||
@ -239,7 +251,7 @@ class Transformer(object):
|
||||
self.transformed_ts = {}
|
||||
self.collections = dict((key, self.graph.get_collection(key))
|
||||
for key in self.graph.get_all_collection_keys())
|
||||
|
||||
self.cyclic_ops = []
|
||||
|
||||
class ResultInfo(object):
|
||||
""""Contains information about the result of a transform operation."""
|
||||
@ -452,6 +464,17 @@ class Transformer(object):
|
||||
for op in remaining_roots:
|
||||
self._transform_op(op)
|
||||
|
||||
# Finalize cyclic ops:
|
||||
for op in self._info.cyclic_ops:
|
||||
logging.debug("Finalizing cyclic op: %s", op.name)
|
||||
op_ = self._info.transformed_ops[op]
|
||||
inputs_ = [self._info.transformed_ts[t] for t in op.inputs]
|
||||
if None in inputs_:
|
||||
raise ValueError("Could not find all the inputs of cyclic op: {}"
|
||||
.format(op_.name))
|
||||
for input_id, t_ in enumerate(inputs_):
|
||||
op_._update_input(input_id, t_) # pylint: disable=protected-access
|
||||
|
||||
sgv_ = self._transform_sgv(sgv)
|
||||
|
||||
res_info = Transformer.ResultInfo(self._info)
|
||||
@ -506,9 +529,13 @@ class Transformer(object):
|
||||
Returns:
|
||||
The transformed tensor.
|
||||
"""
|
||||
logging.debug("Transforming tensor: %s", t.name)
|
||||
if t in self._info.transformed_ts:
|
||||
return self._info.transformed_ts[t]
|
||||
|
||||
# Mark as None to detect cycle.
|
||||
self._info.transformed_ts[t] = None
|
||||
|
||||
op, op_index = t.op, t.value_index
|
||||
|
||||
# If op is not in the subgraph:
|
||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
from six import iteritems
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.ops import array_ops as tf_array_ops
|
||||
@ -465,3 +466,75 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
|
||||
"""
|
||||
return tf_array_ops.placeholder(
|
||||
dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
|
||||
|
||||
|
||||
_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$")
|
||||
|
||||
|
||||
def get_predefined_collection_names():
|
||||
"""Return all the predefined collection names."""
|
||||
return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys)
|
||||
if not _INTERNAL_VARIABLE_RE.match(key)]
|
||||
|
||||
|
||||
def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""):
|
||||
"""Find corresponding op/tensor in a different graph.
|
||||
|
||||
Args:
|
||||
target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph.
|
||||
dst_graph: The graph in which the corresponding graph element must be found.
|
||||
dst_scope: A scope which is prepended to the name to look for.
|
||||
src_scope: A scope which is removed from the original of `target` name.
|
||||
|
||||
Returns:
|
||||
The corresponding tf.Tensor` or a `tf.Operation`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `src_name` does not start with `src_scope`.
|
||||
TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation`
|
||||
KeyError: If the corresponding graph element cannot be found.
|
||||
"""
|
||||
src_name = target.name
|
||||
if src_scope:
|
||||
src_scope = scope_finalize(src_scope)
|
||||
if not src_name.startswidth(src_scope):
|
||||
raise ValueError("{} does not start with {}".format(src_name, src_scope))
|
||||
src_name = src_name[len(src_scope):]
|
||||
|
||||
dst_name = src_name
|
||||
if dst_scope:
|
||||
dst_scope = scope_finalize(dst_scope)
|
||||
dst_name = dst_scope + dst_name
|
||||
|
||||
if isinstance(target, tf_ops.Tensor):
|
||||
return dst_graph.get_tensor_by_name(dst_name)
|
||||
if isinstance(target, tf_ops.Operation):
|
||||
return dst_graph.get_operation_by_name(dst_name)
|
||||
raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target))
|
||||
|
||||
|
||||
def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""):
|
||||
"""Find corresponding ops/tensors in a different graph.
|
||||
|
||||
`targets` is a Python tree, that is, a nested structure of iterable
|
||||
(list, tupple, dictionary) whose leaves are instances of
|
||||
`tf.Tensor` or `tf.Operation`
|
||||
|
||||
Args:
|
||||
targets: A Python tree containing `tf.Tensor` or `tf.Operation`
|
||||
belonging to the original graph.
|
||||
dst_graph: The graph in which the corresponding graph element must be found.
|
||||
dst_scope: A scope which is prepended to the name to look for.
|
||||
src_scope: A scope which is removed from the original of `top` name.
|
||||
|
||||
Returns:
|
||||
A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `src_name` does not start with `src_scope`.
|
||||
TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation`
|
||||
KeyError: If the corresponding graph element cannot be found.
|
||||
"""
|
||||
def func(top):
|
||||
return find_corresponding_elem(top, dst_graph, dst_scope, src_scope)
|
||||
return transform_tree(targets, func)
|
||||
|
Loading…
Reference in New Issue
Block a user