Fix documentation and guide for graph_editor
-Seal and expose reroute -Remove unavailable aliases from guide Change: 147429891
This commit is contained in:
parent
4a215b750b
commit
d96c3b7d40
tensorflow
@ -18,11 +18,13 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib.graph_editor import subgraph
|
from tensorflow.contrib.graph_editor import subgraph as _subgraph
|
||||||
from tensorflow.contrib.graph_editor import util
|
from tensorflow.contrib.graph_editor import util as _util
|
||||||
from tensorflow.python.framework import ops as tf_ops
|
from tensorflow.python.framework import ops as _tf_ops
|
||||||
|
|
||||||
__all__ = [
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
_allowed_symbols = [
|
||||||
"swap_ts",
|
"swap_ts",
|
||||||
"reroute_ts",
|
"reroute_ts",
|
||||||
"swap_inputs",
|
"swap_inputs",
|
||||||
@ -46,8 +48,8 @@ def _check_ts_compatibility(ts0, ts1):
|
|||||||
ValueError: if any pair of tensors (same index in ts0 and ts1) have
|
ValueError: if any pair of tensors (same index in ts0 and ts1) have
|
||||||
a dtype or a shape which is not compatible.
|
a dtype or a shape which is not compatible.
|
||||||
"""
|
"""
|
||||||
ts0 = util.make_list_of_t(ts0)
|
ts0 = _util.make_list_of_t(ts0)
|
||||||
ts1 = util.make_list_of_t(ts1)
|
ts1 = _util.make_list_of_t(ts1)
|
||||||
if len(ts0) != len(ts1):
|
if len(ts0) != len(ts1):
|
||||||
raise ValueError("ts0 and ts1 have different sizes: {} != {}".format(
|
raise ValueError("ts0 and ts1 have different sizes: {} != {}".format(
|
||||||
len(ts0), len(ts1)))
|
len(ts0), len(ts1)))
|
||||||
@ -176,13 +178,13 @@ def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None):
|
|||||||
converted to a list of `tf.Operation`.
|
converted to a list of `tf.Operation`.
|
||||||
"""
|
"""
|
||||||
a2b, b2a = _RerouteMode.check(mode)
|
a2b, b2a = _RerouteMode.check(mode)
|
||||||
ts0 = util.make_list_of_t(ts0)
|
ts0 = _util.make_list_of_t(ts0)
|
||||||
ts1 = util.make_list_of_t(ts1)
|
ts1 = _util.make_list_of_t(ts1)
|
||||||
_check_ts_compatibility(ts0, ts1)
|
_check_ts_compatibility(ts0, ts1)
|
||||||
if cannot_modify is not None:
|
if cannot_modify is not None:
|
||||||
cannot_modify = frozenset(util.make_list_of_op(cannot_modify))
|
cannot_modify = frozenset(_util.make_list_of_op(cannot_modify))
|
||||||
if can_modify is not None:
|
if can_modify is not None:
|
||||||
can_modify = frozenset(util.make_list_of_op(can_modify))
|
can_modify = frozenset(_util.make_list_of_op(can_modify))
|
||||||
nb_update_inputs = 0
|
nb_update_inputs = 0
|
||||||
precomputed_consumers = []
|
precomputed_consumers = []
|
||||||
# precompute consumers to avoid issue with repeated tensors:
|
# precompute consumers to avoid issue with repeated tensors:
|
||||||
@ -268,11 +270,11 @@ def _reroute_sgv_remap(sgv0, sgv1, mode):
|
|||||||
ValueError: if sgv0 and sgv1 do not belong to the same graph.
|
ValueError: if sgv0 and sgv1 do not belong to the same graph.
|
||||||
"""
|
"""
|
||||||
a2b, b2a = _RerouteMode.check(mode)
|
a2b, b2a = _RerouteMode.check(mode)
|
||||||
if not isinstance(sgv0, subgraph.SubGraphView):
|
if not isinstance(sgv0, _subgraph.SubGraphView):
|
||||||
raise TypeError("Expected a SubGraphView, got {}".format(type(sgv0)))
|
raise TypeError("Expected a SubGraphView, got {}".format(type(sgv0)))
|
||||||
if not isinstance(sgv1, subgraph.SubGraphView):
|
if not isinstance(sgv1, _subgraph.SubGraphView):
|
||||||
raise TypeError("Expected a SubGraphView, got {}".format(type(sgv1)))
|
raise TypeError("Expected a SubGraphView, got {}".format(type(sgv1)))
|
||||||
util.check_graphs(sgv0, sgv1)
|
_util.check_graphs(sgv0, sgv1)
|
||||||
sgv0_ = sgv0.copy()
|
sgv0_ = sgv0.copy()
|
||||||
sgv1_ = sgv1.copy()
|
sgv1_ = sgv1.copy()
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
@ -327,13 +329,13 @@ def _reroute_sgv_inputs(sgv0, sgv1, mode):
|
|||||||
StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
|
StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
|
||||||
the same rules than the function subgraph.make_view.
|
the same rules than the function subgraph.make_view.
|
||||||
"""
|
"""
|
||||||
sgv0 = subgraph.make_view(sgv0)
|
sgv0 = _subgraph.make_view(sgv0)
|
||||||
sgv1 = subgraph.make_view(sgv1)
|
sgv1 = _subgraph.make_view(sgv1)
|
||||||
util.check_graphs(sgv0, sgv1)
|
_util.check_graphs(sgv0, sgv1)
|
||||||
can_modify = sgv0.ops + sgv1.ops
|
can_modify = sgv0.ops + sgv1.ops
|
||||||
# also allow consumers of passthrough to be modified:
|
# also allow consumers of passthrough to be modified:
|
||||||
can_modify += util.get_consuming_ops(sgv0.passthroughs)
|
can_modify += _util.get_consuming_ops(sgv0.passthroughs)
|
||||||
can_modify += util.get_consuming_ops(sgv1.passthroughs)
|
can_modify += _util.get_consuming_ops(sgv1.passthroughs)
|
||||||
_reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify)
|
_reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify)
|
||||||
_reroute_sgv_remap(sgv0, sgv1, mode)
|
_reroute_sgv_remap(sgv0, sgv1, mode)
|
||||||
return sgv0, sgv1
|
return sgv0, sgv1
|
||||||
@ -357,9 +359,9 @@ def _reroute_sgv_outputs(sgv0, sgv1, mode):
|
|||||||
StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
|
StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
|
||||||
the same rules than the function subgraph.make_view.
|
the same rules than the function subgraph.make_view.
|
||||||
"""
|
"""
|
||||||
sgv0 = subgraph.make_view(sgv0)
|
sgv0 = _subgraph.make_view(sgv0)
|
||||||
sgv1 = subgraph.make_view(sgv1)
|
sgv1 = _subgraph.make_view(sgv1)
|
||||||
util.check_graphs(sgv0, sgv1)
|
_util.check_graphs(sgv0, sgv1)
|
||||||
cannot_modify = sgv0.ops + sgv1.ops
|
cannot_modify = sgv0.ops + sgv1.ops
|
||||||
_reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify)
|
_reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify)
|
||||||
return sgv0, sgv1
|
return sgv0, sgv1
|
||||||
@ -432,9 +434,9 @@ def remove_control_inputs(op, cops):
|
|||||||
TypeError: if op is not a `tf.Operation`.
|
TypeError: if op is not a `tf.Operation`.
|
||||||
ValueError: if any cop in cops is not a control input of op.
|
ValueError: if any cop in cops is not a control input of op.
|
||||||
"""
|
"""
|
||||||
if not isinstance(op, tf_ops.Operation):
|
if not isinstance(op, _tf_ops.Operation):
|
||||||
raise TypeError("Expected a tf.Operation, got: {}", type(op))
|
raise TypeError("Expected a tf.Operation, got: {}", type(op))
|
||||||
cops = util.make_list_of_op(cops, allow_graph=False)
|
cops = _util.make_list_of_op(cops, allow_graph=False)
|
||||||
for cop in cops:
|
for cop in cops:
|
||||||
if cop not in op.control_inputs:
|
if cop not in op.control_inputs:
|
||||||
raise ValueError("{} is not a control_input of {}".format(op.name,
|
raise ValueError("{} is not a control_input of {}".format(op.name,
|
||||||
@ -457,9 +459,9 @@ def add_control_inputs(op, cops):
|
|||||||
TypeError: if op is not a tf.Operation
|
TypeError: if op is not a tf.Operation
|
||||||
ValueError: if any cop in cops is already a control input of op.
|
ValueError: if any cop in cops is already a control input of op.
|
||||||
"""
|
"""
|
||||||
if not isinstance(op, tf_ops.Operation):
|
if not isinstance(op, _tf_ops.Operation):
|
||||||
raise TypeError("Expected a tf.Operation, got: {}", type(op))
|
raise TypeError("Expected a tf.Operation, got: {}", type(op))
|
||||||
cops = util.make_list_of_op(cops, allow_graph=False)
|
cops = _util.make_list_of_op(cops, allow_graph=False)
|
||||||
for cop in cops:
|
for cop in cops:
|
||||||
if cop in op.control_inputs:
|
if cop in op.control_inputs:
|
||||||
raise ValueError("{} is already a control_input of {}".format(op.name,
|
raise ValueError("{} is already a control_input of {}".format(op.name,
|
||||||
@ -468,3 +470,5 @@ def add_control_inputs(op, cops):
|
|||||||
op._control_inputs += cops
|
op._control_inputs += cops
|
||||||
op._recompute_node_def()
|
op._recompute_node_def()
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
remove_undocumented(__name__, _allowed_symbols)
|
||||||
|
@ -157,7 +157,6 @@ def extract():
|
|||||||
'contrib.graph_editor': [
|
'contrib.graph_editor': [
|
||||||
'edit',
|
'edit',
|
||||||
'match',
|
'match',
|
||||||
'reroute',
|
|
||||||
'subgraph',
|
'subgraph',
|
||||||
'transform',
|
'transform',
|
||||||
'select',
|
'select',
|
||||||
|
Loading…
Reference in New Issue
Block a user