diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 4c5f281badd..c14bcac3be0 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -18,11 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.graph_editor import subgraph -from tensorflow.contrib.graph_editor import util -from tensorflow.python.framework import ops as tf_ops +from tensorflow.contrib.graph_editor import subgraph as _subgraph +from tensorflow.contrib.graph_editor import util as _util +from tensorflow.python.framework import ops as _tf_ops -__all__ = [ +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ "swap_ts", "reroute_ts", "swap_inputs", @@ -46,8 +48,8 @@ def _check_ts_compatibility(ts0, ts1): ValueError: if any pair of tensors (same index in ts0 and ts1) have a dtype or a shape which is not compatible. """ - ts0 = util.make_list_of_t(ts0) - ts1 = util.make_list_of_t(ts1) + ts0 = _util.make_list_of_t(ts0) + ts1 = _util.make_list_of_t(ts1) if len(ts0) != len(ts1): raise ValueError("ts0 and ts1 have different sizes: {} != {}".format( len(ts0), len(ts1))) @@ -176,13 +178,13 @@ def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None): converted to a list of `tf.Operation`. """ a2b, b2a = _RerouteMode.check(mode) - ts0 = util.make_list_of_t(ts0) - ts1 = util.make_list_of_t(ts1) + ts0 = _util.make_list_of_t(ts0) + ts1 = _util.make_list_of_t(ts1) _check_ts_compatibility(ts0, ts1) if cannot_modify is not None: - cannot_modify = frozenset(util.make_list_of_op(cannot_modify)) + cannot_modify = frozenset(_util.make_list_of_op(cannot_modify)) if can_modify is not None: - can_modify = frozenset(util.make_list_of_op(can_modify)) + can_modify = frozenset(_util.make_list_of_op(can_modify)) nb_update_inputs = 0 precomputed_consumers = [] # precompute consumers to avoid issue with repeated tensors: @@ -268,11 +270,11 @@ def _reroute_sgv_remap(sgv0, sgv1, mode): ValueError: if sgv0 and sgv1 do not belong to the same graph. """ a2b, b2a = _RerouteMode.check(mode) - if not isinstance(sgv0, subgraph.SubGraphView): + if not isinstance(sgv0, _subgraph.SubGraphView): raise TypeError("Expected a SubGraphView, got {}".format(type(sgv0))) - if not isinstance(sgv1, subgraph.SubGraphView): + if not isinstance(sgv1, _subgraph.SubGraphView): raise TypeError("Expected a SubGraphView, got {}".format(type(sgv1))) - util.check_graphs(sgv0, sgv1) + _util.check_graphs(sgv0, sgv1) sgv0_ = sgv0.copy() sgv1_ = sgv1.copy() # pylint: disable=protected-access @@ -327,13 +329,13 @@ def _reroute_sgv_inputs(sgv0, sgv1, mode): StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ - sgv0 = subgraph.make_view(sgv0) - sgv1 = subgraph.make_view(sgv1) - util.check_graphs(sgv0, sgv1) + sgv0 = _subgraph.make_view(sgv0) + sgv1 = _subgraph.make_view(sgv1) + _util.check_graphs(sgv0, sgv1) can_modify = sgv0.ops + sgv1.ops # also allow consumers of passthrough to be modified: - can_modify += util.get_consuming_ops(sgv0.passthroughs) - can_modify += util.get_consuming_ops(sgv1.passthroughs) + can_modify += _util.get_consuming_ops(sgv0.passthroughs) + can_modify += _util.get_consuming_ops(sgv1.passthroughs) _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify) _reroute_sgv_remap(sgv0, sgv1, mode) return sgv0, sgv1 @@ -357,9 +359,9 @@ def _reroute_sgv_outputs(sgv0, sgv1, mode): StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. """ - sgv0 = subgraph.make_view(sgv0) - sgv1 = subgraph.make_view(sgv1) - util.check_graphs(sgv0, sgv1) + sgv0 = _subgraph.make_view(sgv0) + sgv1 = _subgraph.make_view(sgv1) + _util.check_graphs(sgv0, sgv1) cannot_modify = sgv0.ops + sgv1.ops _reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify) return sgv0, sgv1 @@ -432,9 +434,9 @@ def remove_control_inputs(op, cops): TypeError: if op is not a `tf.Operation`. ValueError: if any cop in cops is not a control input of op. """ - if not isinstance(op, tf_ops.Operation): + if not isinstance(op, _tf_ops.Operation): raise TypeError("Expected a tf.Operation, got: {}", type(op)) - cops = util.make_list_of_op(cops, allow_graph=False) + cops = _util.make_list_of_op(cops, allow_graph=False) for cop in cops: if cop not in op.control_inputs: raise ValueError("{} is not a control_input of {}".format(op.name, @@ -457,9 +459,9 @@ def add_control_inputs(op, cops): TypeError: if op is not a tf.Operation ValueError: if any cop in cops is already a control input of op. """ - if not isinstance(op, tf_ops.Operation): + if not isinstance(op, _tf_ops.Operation): raise TypeError("Expected a tf.Operation, got: {}", type(op)) - cops = util.make_list_of_op(cops, allow_graph=False) + cops = _util.make_list_of_op(cops, allow_graph=False) for cop in cops: if cop in op.control_inputs: raise ValueError("{} is already a control_input of {}".format(op.name, @@ -468,3 +470,5 @@ def add_control_inputs(op, cops): op._control_inputs += cops op._recompute_node_def() # pylint: enable=protected-access + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/tools/docs/generate.py b/tensorflow/tools/docs/generate.py index 0a41fa473da..7107a285b5b 100644 --- a/tensorflow/tools/docs/generate.py +++ b/tensorflow/tools/docs/generate.py @@ -157,7 +157,6 @@ def extract(): 'contrib.graph_editor': [ 'edit', 'match', - 'reroute', 'subgraph', 'transform', 'select',