commit
43df96aeac
tensorflow
contrib
__init__.py
bayesflow
copy_graph
crf
cudnn_rnn
deprecated
distributions
framework
graph_editor
image
input_pipeline
integrate
layers
learn
__init__.py
python/learn
linalg
linear_optimizer
lookup
losses
metrics
nn
opt
rnn
seq2seq
stat_summarizer
tfprof
util
python/framework
@ -58,3 +58,7 @@ from tensorflow.contrib import training
|
||||
from tensorflow.contrib import util
|
||||
from tensorflow.contrib.ndlstm import python as ndlstm
|
||||
from tensorflow.contrib.specs import python as specs
|
||||
|
||||
del absolute_import
|
||||
del division
|
||||
del print_function
|
||||
|
@ -30,3 +30,13 @@ from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor
|
||||
from tensorflow.contrib.bayesflow.python.ops import stochastic_variables
|
||||
from tensorflow.contrib.bayesflow.python.ops import variational_inference
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
|
||||
_allowed_symbols = ['entropy', 'monte_carlo',
|
||||
'special_math', 'stochastic_gradient_estimators',
|
||||
'stochastic_graph', 'stochastic_tensor',
|
||||
'stochastic_variables', 'variational_inference']
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -20,4 +20,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.copy_graph.python.util import copy_elements
|
||||
from tensorflow.contrib.copy_graph.python.util.copy_elements import *
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
remove_undocumented(__name__, doc_string_modules=[copy_elements])
|
||||
|
@ -37,3 +37,7 @@ from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score
|
||||
from tensorflow.contrib.crf.python.ops.crf import crf_unary_score
|
||||
from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell
|
||||
from tensorflow.contrib.crf.python.ops.crf import viterbi_decode
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
remove_undocumented(__name__)
|
||||
|
@ -23,3 +23,6 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanh
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import RNNParamsSaveable
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
@ -95,3 +95,10 @@ from tensorflow.python.ops.logging_ops import merge_all_summaries
|
||||
from tensorflow.python.ops.logging_ops import merge_summary
|
||||
from tensorflow.python.ops.logging_ops import scalar_summary
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
_allowed_symbols = ['audio_summary', 'histogram_summary',
|
||||
'image_summary', 'merge_all_summaries',
|
||||
'merge_summary', 'scalar_summary']
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -134,3 +134,12 @@ from tensorflow.contrib.distributions.python.ops.uniform import *
|
||||
from tensorflow.contrib.distributions.python.ops.wishart import *
|
||||
|
||||
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ['bijector',
|
||||
'ConditionalDistribution',
|
||||
'ConditionalTransformedDistribution',
|
||||
'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED']
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -16,6 +16,7 @@
|
||||
"""Framework utilities.
|
||||
|
||||
@@assert_same_float_dtype
|
||||
@@assert_scalar
|
||||
@@assert_scalar_int
|
||||
@@convert_to_tensor_or_sparse_tensor
|
||||
@@get_graph_from_inputs
|
||||
@ -24,6 +25,7 @@
|
||||
@@is_strictly_increasing
|
||||
@@is_tensor
|
||||
@@reduce_sum_n
|
||||
@@remove_squeezable_dimensions
|
||||
@@with_shape
|
||||
@@with_same_shape
|
||||
|
||||
@ -47,6 +49,7 @@
|
||||
@@assign_from_values
|
||||
@@assign_from_values_fn
|
||||
@@create_global_step
|
||||
@@filter_variables
|
||||
@@get_global_step
|
||||
@@get_or_create_global_step
|
||||
@@get_local_variables
|
||||
@ -74,12 +77,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.framework.python.framework import *
|
||||
from tensorflow.contrib.framework.python.ops import *
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
|
||||
remove_undocumented(__name__)
|
||||
|
@ -108,111 +108,12 @@ which to operate must always be given explicitly. This is the reason why
|
||||
*connect* or *bypass*.
|
||||
* transform: the Transformer class, which enables transforming
|
||||
(or simply copying) a subgraph into another one.
|
||||
|
||||
## Module: util
|
||||
|
||||
@@make_list_of_op
|
||||
@@get_tensors
|
||||
@@make_list_of_t
|
||||
@@get_generating_ops
|
||||
@@get_consuming_ops
|
||||
@@ControlOutputs
|
||||
@@placeholder_name
|
||||
@@make_placeholder_from_tensor
|
||||
@@make_placeholder_from_dtype_and_shape
|
||||
|
||||
## Module: select
|
||||
|
||||
@@filter_ts
|
||||
@@filter_ts_from_regex
|
||||
@@filter_ops
|
||||
@@filter_ops_from_regex
|
||||
@@get_name_scope_ops
|
||||
@@check_cios
|
||||
@@get_ops_ios
|
||||
@@compute_boundary_ts
|
||||
@@get_within_boundary_ops
|
||||
@@get_forward_walk_ops
|
||||
@@get_backward_walk_ops
|
||||
@@get_walks_intersection_ops
|
||||
@@get_walks_union_ops
|
||||
@@select_ops
|
||||
@@select_ts
|
||||
@@select_ops_and_ts
|
||||
|
||||
## Module: subgraph
|
||||
|
||||
@@SubGraphView
|
||||
@@make_view
|
||||
@@make_view_from_scope
|
||||
|
||||
## Module: reroute
|
||||
|
||||
@@swap_ts
|
||||
@@reroute_a2b_ts
|
||||
@@reroute_b2a_ts
|
||||
@@swap_inputs
|
||||
@@reroute_a2b_inputs
|
||||
@@reroute_b2a_inputs
|
||||
@@swap_outputs
|
||||
@@reroute_a2b_outputs
|
||||
@@reroute_b2a_outputs
|
||||
@@swap
|
||||
@@reroute_a2b
|
||||
@@reroute_b2a
|
||||
@@remove_control_inputs
|
||||
@@add_control_inputs
|
||||
|
||||
## Module: edit
|
||||
|
||||
@@detach_control_inputs
|
||||
@@detach_control_outputs
|
||||
@@detach_inputs
|
||||
@@detach_outputs
|
||||
@@detach
|
||||
@@connect
|
||||
@@bypass
|
||||
|
||||
## Module: transform
|
||||
|
||||
@@replace_t_with_placeholder_handler
|
||||
@@keep_t_if_possible_handler
|
||||
@@assign_renamed_collections_handler
|
||||
@@transform_op_if_inside_handler
|
||||
@@copy_op_handler
|
||||
@@transform_op_in_place
|
||||
@@Transformer
|
||||
@@copy
|
||||
@@copy_with_input_replacements
|
||||
@@graph_replace
|
||||
|
||||
## Module: match
|
||||
|
||||
@@op_type
|
||||
@@OpMatcher
|
||||
|
||||
## Useful aliases
|
||||
|
||||
@@ph
|
||||
@@sgv
|
||||
@@sgv_scope
|
||||
@@ts
|
||||
@@ops
|
||||
@@matcher
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.graph_editor import edit
|
||||
from tensorflow.contrib.graph_editor import match
|
||||
from tensorflow.contrib.graph_editor import reroute
|
||||
from tensorflow.contrib.graph_editor import select
|
||||
from tensorflow.contrib.graph_editor import subgraph
|
||||
from tensorflow.contrib.graph_editor import transform
|
||||
from tensorflow.contrib.graph_editor import util
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.graph_editor.edit import *
|
||||
from tensorflow.contrib.graph_editor.match import *
|
||||
@ -224,9 +125,14 @@ from tensorflow.contrib.graph_editor.util import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
# some useful aliases
|
||||
ph = util.make_placeholder_from_dtype_and_shape
|
||||
sgv = subgraph.make_view
|
||||
sgv_scope = subgraph.make_view_from_scope
|
||||
ts = select.select_ts
|
||||
ops = select.select_ops
|
||||
matcher = match.OpMatcher
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.contrib.graph_editor import subgraph as _subgraph
|
||||
from tensorflow.contrib.graph_editor import util as _util
|
||||
# pylint: enable=g-bad-import-order
|
||||
ph = _util.make_placeholder_from_dtype_and_shape
|
||||
sgv = _subgraph.make_view
|
||||
sgv_scope = _subgraph.make_view_from_scope
|
||||
|
||||
del absolute_import
|
||||
del division
|
||||
del print_function
|
||||
|
@ -194,7 +194,7 @@ def connect(sgv0, sgv1, disconnect_first=False):
|
||||
if disconnect_first:
|
||||
detach_outputs(sgv0)
|
||||
sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs)
|
||||
reroute.reroute_a2b_inputs(sgv0_outputs, sgv1)
|
||||
reroute.reroute_inputs(sgv0_outputs, sgv1)
|
||||
return sgv0, sgv1
|
||||
|
||||
|
||||
@ -217,5 +217,5 @@ def bypass(sgv):
|
||||
sgv = subgraph.make_view(sgv)
|
||||
sgv_inputs = list(sgv.inputs)
|
||||
sgv, detached_inputs = detach_inputs(sgv)
|
||||
reroute.reroute_a2b_ts(sgv_inputs, sgv.outputs)
|
||||
reroute.reroute_ts(sgv_inputs, sgv.outputs)
|
||||
return sgv, detached_inputs
|
||||
|
@ -24,17 +24,13 @@ from tensorflow.python.framework import ops as tf_ops
|
||||
|
||||
__all__ = [
|
||||
"swap_ts",
|
||||
"reroute_a2b_ts",
|
||||
"reroute_b2a_ts",
|
||||
"reroute_ts",
|
||||
"swap_inputs",
|
||||
"reroute_a2b_inputs",
|
||||
"reroute_b2a_inputs",
|
||||
"reroute_inputs",
|
||||
"swap_outputs",
|
||||
"reroute_a2b_outputs",
|
||||
"reroute_b2a_outputs",
|
||||
"swap",
|
||||
"reroute_a2b",
|
||||
"reroute_b2a",
|
||||
"reroute_outputs",
|
||||
"swap_ios",
|
||||
"reroute_ios",
|
||||
"remove_control_inputs",
|
||||
"add_control_inputs",
|
||||
]
|
||||
@ -232,7 +228,7 @@ def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None):
|
||||
return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify)
|
||||
|
||||
|
||||
def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None):
|
||||
def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None):
|
||||
"""For each tensor's pair, replace the end of t1 by the end of t0.
|
||||
|
||||
B0 B1 B0 B1
|
||||
@ -258,33 +254,6 @@ def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None):
|
||||
return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
|
||||
|
||||
|
||||
def reroute_b2a_ts(ts0, ts1, can_modify=None, cannot_modify=None):
|
||||
r"""For each tensor's pair, replace the end of t0 by the end of t1.
|
||||
|
||||
B0 B1 B0 B1
|
||||
| | => \|
|
||||
A0 A1 A0 A1
|
||||
|
||||
The end of the tensors in ts0 are left dangling.
|
||||
|
||||
Args:
|
||||
ts0: an object convertible to a list of `tf.Tensor`.
|
||||
ts1: an object convertible to a list of `tf.Tensor`.
|
||||
can_modify: iterable of operations which can be modified. Any operation
|
||||
outside within_ops will be left untouched by this function.
|
||||
cannot_modify: iterable of operations which cannot be modified.
|
||||
Any operation within cannot_modify will be left untouched by this
|
||||
function.
|
||||
Returns:
|
||||
The number of individual modifications made by the function.
|
||||
Raises:
|
||||
TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
|
||||
TypeError: if can_modify or cannot_modify is not None and cannot be
|
||||
converted to a list of tf.Operation.
|
||||
"""
|
||||
return _reroute_ts(ts0, ts1, _RerouteMode.b2a, can_modify, cannot_modify)
|
||||
|
||||
|
||||
def _reroute_sgv_remap(sgv0, sgv1, mode):
|
||||
"""Remap in place the inputs of two subgraph views to mimic the reroute.
|
||||
|
||||
@ -425,46 +394,31 @@ def swap_inputs(sgv0, sgv1):
|
||||
return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.swap)
|
||||
|
||||
|
||||
def reroute_a2b_inputs(sgv0, sgv1):
|
||||
def reroute_inputs(sgv0, sgv1):
|
||||
"""Re-route all the inputs of sgv0 to sgv1 (see reroute_inputs)."""
|
||||
return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b)
|
||||
|
||||
|
||||
def reroute_b2a_inputs(sgv0, sgv1):
|
||||
"""Re-route all the inputs of sgv1 to sgv0 (see reroute_inputs)."""
|
||||
return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.b2a)
|
||||
|
||||
|
||||
def swap_outputs(sgv0, sgv1):
|
||||
"""Swap all the outputs of sgv0 and sgv1 (see _reroute_outputs)."""
|
||||
return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap)
|
||||
|
||||
|
||||
def reroute_a2b_outputs(sgv0, sgv1):
|
||||
def reroute_outputs(sgv0, sgv1):
|
||||
"""Re-route all the outputs of sgv0 to sgv1 (see _reroute_outputs)."""
|
||||
return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b)
|
||||
|
||||
|
||||
def reroute_b2a_outputs(sgv0, sgv1):
|
||||
"""Re-route all the outputs of sgv1 to sgv0 (see _reroute_outputs)."""
|
||||
return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.b2a)
|
||||
|
||||
|
||||
def swap(sgv0, sgv1):
|
||||
def swap_ios(sgv0, sgv1):
|
||||
"""Swap the inputs and outputs of sgv1 to sgv0 (see _reroute)."""
|
||||
return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap)
|
||||
|
||||
|
||||
def reroute_a2b(sgv0, sgv1):
|
||||
def reroute_ios(sgv0, sgv1):
|
||||
"""Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute)."""
|
||||
return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b)
|
||||
|
||||
|
||||
def reroute_b2a(sgv0, sgv1):
|
||||
"""Re-route the inputs and outputs of sgv1 to sgv0 (see _reroute)."""
|
||||
return _reroute_sgv(sgv0, sgv1, _RerouteMode.b2a)
|
||||
|
||||
|
||||
def remove_control_inputs(op, cops):
|
||||
"""Remove the control inputs cops from co.
|
||||
|
||||
|
@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Various ways of selecting operations and tensors in a graph.
|
||||
"""
|
||||
"""Various ways of selecting operations and tensors in a graph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -28,6 +27,8 @@ from tensorflow.contrib.graph_editor import util
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
|
||||
__all__ = [
|
||||
"can_be_regex",
|
||||
"make_regex",
|
||||
"filter_ts",
|
||||
"filter_ts_from_regex",
|
||||
"filter_ops",
|
||||
|
@ -50,11 +50,11 @@ class EditTest(test.TestCase):
|
||||
def test_detach(self):
|
||||
"""Test for ge.detach."""
|
||||
sgv = ge.sgv(self.c.op, self.a.op)
|
||||
control_outputs = ge.util.ControlOutputs(self.graph)
|
||||
control_outputs = ge.ControlOutputs(self.graph)
|
||||
ge.detach(sgv, control_ios=control_outputs)
|
||||
# make sure the detached graph is as expected.
|
||||
self.assertTrue(
|
||||
ge.matcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op))
|
||||
ge.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op))
|
||||
|
||||
def test_connect(self):
|
||||
"""Test for ge.connect."""
|
||||
@ -66,13 +66,14 @@ class EditTest(test.TestCase):
|
||||
sgv = ge.sgv(x.op, y.op, z.op)
|
||||
ge.connect(sgv, ge.sgv(self.e.op).remap_inputs([0]))
|
||||
self.assertTrue(
|
||||
ge.matcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op))
|
||||
ge.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op))
|
||||
|
||||
def test_bypass(self):
|
||||
"""Test for ge.bypass."""
|
||||
ge.bypass(ge.sgv(self.f.op).remap_inputs([0]))
|
||||
self.assertTrue(
|
||||
ge.matcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")(self.h.op))
|
||||
ge.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")(
|
||||
self.h.op))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -42,20 +42,20 @@ class MatchTest(test.TestCase):
|
||||
self.h = math_ops.add(self.f, self.g, name="h")
|
||||
|
||||
def test_simple_match(self):
|
||||
self.assertTrue(ge.matcher("^.*/f$")(self.f.op))
|
||||
self.assertTrue(ge.OpMatcher("^.*/f$")(self.f.op))
|
||||
self.assertTrue(
|
||||
ge.matcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op))
|
||||
self.assertTrue(ge.matcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op))
|
||||
ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op))
|
||||
self.assertTrue(ge.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op))
|
||||
self.assertTrue(
|
||||
ge.matcher("^.*/f$").input_ops(
|
||||
ge.OpMatcher("^.*/f$").input_ops(
|
||||
ge.match.op_type("Add"), ge.match.op_type("Const"))(self.f.op))
|
||||
self.assertTrue(
|
||||
ge.matcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")
|
||||
.output_ops(ge.matcher("^.*/h$")
|
||||
ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")
|
||||
.output_ops(ge.OpMatcher("^.*/h$")
|
||||
.control_input_ops("^.*/c$"))(self.f.op))
|
||||
self.assertTrue(
|
||||
ge.matcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops(
|
||||
ge.matcher("^.*/h$").control_input_ops("^.*/c$")
|
||||
ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops(
|
||||
ge.OpMatcher("^.*/h$").control_input_ops("^.*/c$")
|
||||
.output_ops([]))(self.f.op))
|
||||
|
||||
|
||||
|
@ -40,30 +40,30 @@ class RerouteTest(test.TestCase):
|
||||
self.c2 = math_ops.add(self.a2, self.b2, name="c2")
|
||||
|
||||
def test_swap(self):
|
||||
ge.reroute.swap_ts([self.a0, self.b0], [self.a1, self.b1])
|
||||
self.assertTrue(ge.matcher("c0").input_ops("a1", "b1")(self.c0.op))
|
||||
self.assertTrue(ge.matcher("c1").input_ops("a0", "b0")(self.c1.op))
|
||||
ge.swap_ts([self.a0, self.b0], [self.a1, self.b1])
|
||||
self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
|
||||
self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
|
||||
|
||||
def test_multiswap(self):
|
||||
with self.graph.as_default():
|
||||
a3 = constant_op.constant(3.0, shape=[2], name="a3")
|
||||
ge.reroute.swap(
|
||||
ge.sgv(a3.op).remap_outputs([0, 0]), ge.sgv(self.a0.op, self.a1.op))
|
||||
self.assertTrue(ge.matcher("c0").input_ops("a3", "b0")(self.c0.op))
|
||||
self.assertTrue(ge.matcher("c1").input_ops("a3", "b1")(self.c1.op))
|
||||
ge.swap_ios(ge.sgv(a3.op).remap_outputs([0, 0]),
|
||||
ge.sgv(self.a0.op, self.a1.op))
|
||||
self.assertTrue(ge.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op))
|
||||
self.assertTrue(ge.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op))
|
||||
|
||||
def test_reroute(self):
|
||||
ge.reroute.reroute_a2b_ts([self.a0, self.b0], [self.a1, self.b1])
|
||||
self.assertTrue(ge.matcher("c0").input_ops("a0", "b0")(self.c0.op))
|
||||
self.assertTrue(ge.matcher("c1").input_ops("a0", "b0")(self.c1.op))
|
||||
ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1])
|
||||
self.assertTrue(ge.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
|
||||
self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
|
||||
|
||||
ge.reroute.reroute_b2a_ts([self.a0, self.b0], [self.a1, self.b1])
|
||||
self.assertTrue(ge.matcher("c0").input_ops("a1", "b1")(self.c0.op))
|
||||
self.assertTrue(ge.matcher("c1").input_ops("a1", "b1")(self.c1.op))
|
||||
ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0])
|
||||
self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
|
||||
self.assertTrue(ge.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
|
||||
|
||||
def test_compatibility(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ge.reroute.reroute_a2b_ts([self.a0, self.b0], [self.a2, self.b2])
|
||||
ge.reroute_ts([self.a0, self.b0], [self.a2, self.b2])
|
||||
|
||||
def test_reroute_can_modify(self):
|
||||
graph = ops.Graph()
|
||||
@ -82,11 +82,11 @@ class RerouteTest(test.TestCase):
|
||||
sgv0 = ge.sgv(a.op, b.op, c.op)
|
||||
sgv1 = ge.sgv(e.op, f.op)
|
||||
|
||||
ge.reroute.swap_outputs(sgv0, sgv1)
|
||||
ge.swap_outputs(sgv0, sgv1)
|
||||
self.assertTrue(
|
||||
ge.matcher("g").input_ops("a", ge.matcher("c")
|
||||
.input_ops("a", "b"))(g.op))
|
||||
self.assertTrue(ge.matcher("d").input_ops("e", "f")(d.op))
|
||||
ge.OpMatcher("g").input_ops("a", ge.OpMatcher("c").input_ops("a", "b"))(
|
||||
g.op))
|
||||
self.assertTrue(ge.OpMatcher("d").input_ops("e", "f")(d.op))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -44,76 +44,71 @@ class SelectTest(test.TestCase):
|
||||
self.h = math_ops.add(self.f, self.g, name="h")
|
||||
|
||||
def test_regex(self):
|
||||
"""Test for ge.select.can_be_regex and ge.select.make_regex."""
|
||||
self.assertTrue(ge.select.can_be_regex("foo"))
|
||||
self.assertTrue(ge.select.can_be_regex(re.compile("foo")))
|
||||
"""Test for ge.can_be_regex and ge.make_regex."""
|
||||
self.assertTrue(ge.can_be_regex("foo"))
|
||||
self.assertTrue(ge.can_be_regex(re.compile("foo")))
|
||||
regex = re.compile("foo")
|
||||
self.assertIs(ge.select.make_regex(regex), regex)
|
||||
self.assertIs(ge.make_regex(regex), regex)
|
||||
|
||||
def test_get_input_output_ts(self):
|
||||
"""Test for ge.select._get_input_ts abd ge.select._get_output_ts."""
|
||||
"""Test for ge._get_input_ts abd ge._get_output_ts."""
|
||||
self.assertEqual(len(ge.select._get_input_ts(self.graph)), 6)
|
||||
self.assertEqual(len(ge.select._get_output_ts(self.graph)), 8)
|
||||
|
||||
def test_get_filter(self):
|
||||
"""Test for various filtering operations on ts ops."""
|
||||
# TODO(fkp): parameterise
|
||||
self.assertEqual(len(ge.select.filter_ops(self.graph, True)), 8)
|
||||
self.assertEqual(len(ge.filter_ops(self.graph, True)), 8)
|
||||
self.assertEqual(
|
||||
len(
|
||||
ge.select.filter_ops(self.graph,
|
||||
lambda op: op.node_def.op == "Const")), 3)
|
||||
len(ge.filter_ops(self.graph, lambda op: op.node_def.op == "Const")), 3)
|
||||
self.assertEqual(
|
||||
len(
|
||||
ge.select.filter_ops(self.graph,
|
||||
lambda op: op.node_def.op == "Add")), 5)
|
||||
len(ge.filter_ops(self.graph, lambda op: op.node_def.op == "Add")), 5)
|
||||
self.assertEqual(
|
||||
len(ge.select.filter_ops_from_regex(self.graph, r"^.*\b[abc]$")), 3)
|
||||
len(ge.filter_ops_from_regex(self.graph, r"^.*\b[abc]$")), 3)
|
||||
|
||||
self.assertEqual(len(ge.select.filter_ts(self.graph, True)), 8)
|
||||
self.assertEqual(len(ge.filter_ts(self.graph, True)), 8)
|
||||
self.assertEqual(
|
||||
len(ge.select.filter_ts_from_regex(self.graph, r"^.*/[fgh]:\d$")), 3)
|
||||
len(ge.filter_ts_from_regex(self.graph, r"^.*/[fgh]:\d$")), 3)
|
||||
|
||||
self.assertEqual(len(ge.select.get_name_scope_ops(self.graph, "foo/")), 7)
|
||||
self.assertEqual(
|
||||
len(ge.select.get_name_scope_ops(self.graph, "foo/bar")), 4)
|
||||
self.assertEqual(len(ge.get_name_scope_ops(self.graph, "foo/")), 7)
|
||||
self.assertEqual(len(ge.get_name_scope_ops(self.graph, "foo/bar")), 4)
|
||||
|
||||
def test_get_ops_ios(self):
|
||||
"""Test for ge.select.get_ops_ios."""
|
||||
"""Test for ge.get_ops_ios."""
|
||||
control_outputs = ge.util.ControlOutputs(self.graph)
|
||||
self.assertEqual(
|
||||
len(ge.select.get_ops_ios(
|
||||
len(ge.get_ops_ios(
|
||||
self.h.op, control_ios=control_outputs)), 3)
|
||||
self.assertEqual(len(ge.select.get_ops_ios(self.h.op)), 2)
|
||||
self.assertEqual(len(ge.get_ops_ios(self.h.op)), 2)
|
||||
self.assertEqual(
|
||||
len(ge.select.get_ops_ios(
|
||||
len(ge.get_ops_ios(
|
||||
self.c.op, control_ios=control_outputs)), 6)
|
||||
self.assertEqual(len(ge.select.get_ops_ios(self.c.op)), 5)
|
||||
self.assertEqual(len(ge.get_ops_ios(self.c.op)), 5)
|
||||
|
||||
def test_compute_boundary_ts_0(self):
|
||||
"""Test for ge.select.compute_boundary_ts."""
|
||||
input_ts, output_ts, inside_ts = ge.select.compute_boundary_ts(self.g.op)
|
||||
"""Test for ge.compute_boundary_ts."""
|
||||
input_ts, output_ts, inside_ts = ge.compute_boundary_ts(self.g.op)
|
||||
self.assertEqual(list(input_ts), [self.c, self.a])
|
||||
self.assertEqual(list(output_ts), [self.g])
|
||||
self.assertEqual(list(inside_ts), [])
|
||||
|
||||
def test_compute_boundary_ts_1(self):
|
||||
"""Test for ge.select.compute_boundary_ts."""
|
||||
input_ts, output_ts, inside_ts = ge.select.compute_boundary_ts(
|
||||
"""Test for ge.compute_boundary_ts."""
|
||||
input_ts, output_ts, inside_ts = ge.compute_boundary_ts(
|
||||
[self.g.op, self.h.op])
|
||||
self.assertEqual(list(input_ts), [self.c, self.a, self.f])
|
||||
self.assertEqual(list(output_ts), [self.h])
|
||||
self.assertEqual(list(inside_ts), [self.g])
|
||||
|
||||
def test_compute_boundary_ts_2(self):
|
||||
"""Test for ge.select.compute_boundary_ts."""
|
||||
"""Test for ge.compute_boundary_ts."""
|
||||
graph = ops_lib.Graph()
|
||||
with graph.as_default():
|
||||
a = constant_op.constant(1, name="a")
|
||||
b = constant_op.constant(1, name="b")
|
||||
c = math_ops.add(a, b, name="c")
|
||||
_ = a + c
|
||||
input_ts, output_ts, inside_ts = ge.select.compute_boundary_ts([a.op, c.op])
|
||||
input_ts, output_ts, inside_ts = ge.compute_boundary_ts([a.op, c.op])
|
||||
self.assertEqual(list(input_ts), [b])
|
||||
self.assertEqual(list(output_ts), [a, c])
|
||||
self.assertEqual(list(inside_ts), [a])
|
||||
@ -121,7 +116,7 @@ class SelectTest(test.TestCase):
|
||||
def test_get_within_boundary_ops_0(self):
|
||||
"""Test for test_get_within_boundary_ops."""
|
||||
control_outputs = ge.util.ControlOutputs(self.graph)
|
||||
ops = ge.select.get_within_boundary_ops(
|
||||
ops = ge.get_within_boundary_ops(
|
||||
ops=self.graph,
|
||||
seed_ops=self.f.op,
|
||||
boundary_ops=[self.c.op, self.h.op],
|
||||
@ -130,19 +125,19 @@ class SelectTest(test.TestCase):
|
||||
self.assertEqual(len(ops), 3)
|
||||
|
||||
def test_get_within_boundary_ops_1(self):
|
||||
"""Test for ge.select.test_get_within_boundary_ops."""
|
||||
ops = ge.select.get_within_boundary_ops(
|
||||
"""Test for ge.test_get_within_boundary_ops."""
|
||||
ops = ge.get_within_boundary_ops(
|
||||
ops=self.graph, seed_ops=self.h.op, boundary_ops=[self.f.op, self.g.op])
|
||||
self.assertEqual(len(ops), 3)
|
||||
|
||||
def test_get_walks_intersection(self):
|
||||
"""Test for ge.select.get_walks_intersection_ops."""
|
||||
ops = ge.select.get_walks_intersection_ops([self.c.op], [self.g.op])
|
||||
"""Test for ge.get_walks_intersection_ops."""
|
||||
ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op])
|
||||
self.assertEqual(len(ops), 2)
|
||||
|
||||
def test_get_walks_union(self):
|
||||
"""Test for ge.select.get_walks_union_ops."""
|
||||
ops = ge.select.get_walks_union_ops([self.f.op], [self.g.op])
|
||||
"""Test for ge.get_walks_union_ops."""
|
||||
ops = ge.get_walks_union_ops([self.f.op], [self.g.op])
|
||||
self.assertEqual(len(ops), 6)
|
||||
|
||||
def test_select_ops(self):
|
||||
@ -151,7 +146,7 @@ class SelectTest(test.TestCase):
|
||||
(("^foo/bar/",), 4),
|
||||
(("^foo/bar/", "a"), 5),)
|
||||
for param, length in parameters:
|
||||
ops = ge.select.select_ops(*param, graph=self.graph)
|
||||
ops = ge.select_ops(*param, graph=self.graph)
|
||||
self.assertEqual(len(ops), length)
|
||||
|
||||
def test_select_ts(self):
|
||||
@ -159,7 +154,7 @@ class SelectTest(test.TestCase):
|
||||
(".*:0", 8),
|
||||
(r".*/bar/\w+:0", 4),)
|
||||
for regex, length in parameters:
|
||||
ts = ge.select.select_ts(regex, graph=self.graph)
|
||||
ts = ge.select_ts(regex, graph=self.graph)
|
||||
self.assertEqual(len(ts), length)
|
||||
|
||||
def test_select_ops_and_ts(self):
|
||||
@ -167,7 +162,7 @@ class SelectTest(test.TestCase):
|
||||
(("^foo/.*",), 7, 0),
|
||||
(("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),)
|
||||
for param, l0, l1 in parameters:
|
||||
ops, ts = ge.select.select_ops_and_ts(*param, graph=self.graph)
|
||||
ops, ts = ge.select_ops_and_ts(*param, graph=self.graph)
|
||||
self.assertEqual(len(ops), l0)
|
||||
self.assertEqual(len(ts), l1)
|
||||
|
||||
|
@ -75,7 +75,7 @@ class TransformTest(test.TestCase):
|
||||
_ = math_ops.add(a, b)
|
||||
sgv = ge.make_view([assert_op, eq.op, a.op, b.op])
|
||||
copier = ge.Transformer()
|
||||
copied_sgv, info = copier(sgv, sgv.graph, "", "")
|
||||
_, info = copier(sgv, sgv.graph, "", "")
|
||||
new_assert_op = info.transformed(assert_op)
|
||||
self.assertIsNotNone(new_assert_op)
|
||||
|
||||
@ -84,63 +84,31 @@ class TransformTest(test.TestCase):
|
||||
|
||||
def my_transform_op_handler(info, op):
|
||||
add_noise = op.name.startswith("Add")
|
||||
op_ = ge.transform.copy_op_handler(info, op)
|
||||
if add_noise:
|
||||
# add some noise to op
|
||||
with info.graph_.as_default():
|
||||
t_ = math_ops.add(constant_op.constant(
|
||||
1.0, shape=[10], name="Noise"),
|
||||
op_.outputs[0],
|
||||
name="AddNoise")
|
||||
# return the "noisy" op
|
||||
return t_.op
|
||||
else:
|
||||
return op_
|
||||
op_, op_outputs_ = ge.transform.copy_op_handler(info, op)
|
||||
if not add_noise:
|
||||
return op_, op_outputs_
|
||||
# add some noise to op
|
||||
with info.graph_.as_default():
|
||||
t_ = math_ops.add(
|
||||
constant_op.constant(1.0, shape=[10], name="Noise"),
|
||||
op_.outputs[0],
|
||||
name="AddNoise")
|
||||
# return the "noisy" op
|
||||
return op_, [t_]
|
||||
|
||||
transformer.transform_op_handler = my_transform_op_handler
|
||||
|
||||
graph = ops.Graph()
|
||||
transformer(self.graph, graph, "", "")
|
||||
matcher0 = ge.matcher("AddNoise").input_ops(
|
||||
"Noise", ge.matcher("Add").input_ops("Const", "Input"))
|
||||
matcher1 = ge.matcher("AddNoise_1").input_ops(
|
||||
"Noise_1", ge.matcher("Add_1").input_ops("Const_1", matcher0))
|
||||
matcher2 = ge.matcher("AddNoise_2").input_ops(
|
||||
"Noise_2", ge.matcher("Add_2").input_ops("Const_2", matcher1))
|
||||
matcher0 = ge.OpMatcher("AddNoise").input_ops(
|
||||
"Noise", ge.OpMatcher("Add").input_ops("Const", "Input"))
|
||||
matcher1 = ge.OpMatcher("AddNoise_1").input_ops(
|
||||
"Noise_1", ge.OpMatcher("Add_1").input_ops("Const_1", matcher0))
|
||||
matcher2 = ge.OpMatcher("AddNoise_2").input_ops(
|
||||
"Noise_2", ge.OpMatcher("Add_2").input_ops("Const_2", matcher1))
|
||||
top = ge.select_ops("^AddNoise_2$", graph=graph)[0]
|
||||
self.assertTrue(matcher2(top))
|
||||
|
||||
def test_transform_in_place(self):
|
||||
transformer = ge.Transformer()
|
||||
|
||||
def my_transform_op_handler_in_place(info, op):
|
||||
add_noise = op.name.startswith("Add")
|
||||
op = ge.transform.transform_op_in_place(
|
||||
info, op, detach_outputs=add_noise)
|
||||
if add_noise:
|
||||
# add some noise to op
|
||||
with info.graph_.as_default():
|
||||
t = math_ops.add(constant_op.constant(
|
||||
1.0, shape=[10], name="Noise"),
|
||||
op.outputs[0],
|
||||
name="AddNoise")
|
||||
# return the "noisy" op
|
||||
return t.op
|
||||
else:
|
||||
return op
|
||||
|
||||
transformer.transform_op_handler = my_transform_op_handler_in_place
|
||||
|
||||
transformer(self.graph, self.graph, "", "")
|
||||
matcher0 = ge.matcher("AddNoise").input_ops(
|
||||
"Noise", ge.matcher("Add").input_ops("Const", "Input"))
|
||||
matcher1 = ge.matcher("AddNoise_1").input_ops(
|
||||
"Noise_1", ge.matcher("Add_1").input_ops("Const_1", matcher0))
|
||||
matcher2 = ge.matcher("AddNoise_2").input_ops(
|
||||
"Noise_2", ge.matcher("Add_2").input_ops("Const_2", matcher1))
|
||||
top = ge.select_ops("^AddNoise_2$", graph=self.graph)[0]
|
||||
self.assertTrue(matcher2(top))
|
||||
|
||||
def test_copy_with_input_replacements(self):
|
||||
with self.graph.as_default():
|
||||
ten = constant_op.constant(10.0, shape=[10], name="Input")
|
||||
|
@ -21,18 +21,17 @@ from __future__ import print_function
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
from six import iteritems
|
||||
from six import iterkeys
|
||||
from six import string_types
|
||||
from six import StringIO
|
||||
|
||||
from tensorflow.contrib.graph_editor import edit
|
||||
from tensorflow.contrib.graph_editor import reroute
|
||||
from tensorflow.contrib.graph_editor import select
|
||||
from tensorflow.contrib.graph_editor import subgraph
|
||||
from tensorflow.contrib.graph_editor import util
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
__all__ = [
|
||||
"replace_t_with_placeholder_handler",
|
||||
@ -40,8 +39,8 @@ __all__ = [
|
||||
"assign_renamed_collections_handler",
|
||||
"transform_op_if_inside_handler",
|
||||
"copy_op_handler",
|
||||
"transform_op_in_place",
|
||||
"Transformer",
|
||||
"TransformerInfo",
|
||||
"copy",
|
||||
"copy_with_input_replacements",
|
||||
"graph_replace",
|
||||
@ -55,7 +54,7 @@ def replace_t_with_placeholder_handler(info, t):
|
||||
placeholder.
|
||||
|
||||
Args:
|
||||
info: Transform._Info instance.
|
||||
info: Transform._TmpInfo instance.
|
||||
t: tensor whose input must be transformed into a place holder.
|
||||
Returns:
|
||||
The tensor generated by the newly created place holder.
|
||||
@ -73,7 +72,7 @@ def keep_t_if_possible_handler(info, t):
|
||||
This handler is typically used to transform a hidden input tensors.
|
||||
|
||||
Args:
|
||||
info: Transform._Info instance.
|
||||
info: Transform._TmpInfo instance.
|
||||
t: tensor whose input must be transformed into a place holder.
|
||||
Returns:
|
||||
The tensor generated by the newly created place holder.
|
||||
@ -87,17 +86,24 @@ def keep_t_if_possible_handler(info, t):
|
||||
def assign_renamed_collections_handler(info, elem, elem_):
|
||||
"""Add the transformed elem to the (renamed) collections of elem.
|
||||
|
||||
A collection is renamed only if is not a known key, as described in
|
||||
`tf.GraphKeys`.
|
||||
|
||||
Args:
|
||||
info: Transform._Info instance.
|
||||
info: Transform._TmpInfo instance.
|
||||
elem: the original element (`tf.Tensor` or `tf.Operation`)
|
||||
elem_: the transformed element
|
||||
"""
|
||||
# TODO(fkp): handle known special cases
|
||||
known_collection_names = util.get_predefined_collection_names()
|
||||
for name, collection in iteritems(info.collections):
|
||||
if elem not in collection:
|
||||
continue
|
||||
collection_name_ = info.transformer.new_name(name)
|
||||
info.graph_.add_to_collection(collection_name_, elem_)
|
||||
|
||||
if name in known_collection_names:
|
||||
transformed_name = name
|
||||
else:
|
||||
transformed_name = info.new_name(name)
|
||||
info.graph_.add_to_collection(transformed_name, elem_)
|
||||
|
||||
|
||||
def transform_op_if_inside_handler(info, op, keep_if_possible=True):
|
||||
@ -107,18 +113,15 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True):
|
||||
if they are inside the subgraph, otherwise they are just ignored.
|
||||
|
||||
Args:
|
||||
info: Transform._Info instance.
|
||||
info: Transform._TmpInfo instance.
|
||||
op: the optional op to transform (or ignore).
|
||||
keep_if_possible: re-attach to the original op if possible, that is,
|
||||
if the source graph and the destination graph are the same.
|
||||
Returns:
|
||||
The transformed op or None.
|
||||
"""
|
||||
if op is None:
|
||||
return None
|
||||
if op in info.sgv.ops:
|
||||
return info.transformer._transform_op( # pylint: disable=protected-access
|
||||
op)
|
||||
return info.transformed_ops[op]
|
||||
else:
|
||||
if keep_if_possible and info.graph is info.graph_:
|
||||
return op
|
||||
@ -130,31 +133,19 @@ def copy_op_handler(info, op, copy_shape=True):
|
||||
"""Copy a `tf.Operation`.
|
||||
|
||||
Args:
|
||||
info: Transform._Info instance.
|
||||
info: Transform._TmpInfo instance.
|
||||
op: the `tf.Operation` to be copied.
|
||||
copy_shape: also copy the shape of the tensor
|
||||
Returns:
|
||||
A copy of op.
|
||||
A `(op, op_outputs)` tuple containgin the transformed op and its outputs.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
|
||||
# Transform control inputs:
|
||||
control_inputs_ = [info.transformer.transform_control_input_handler(info, ci)
|
||||
for ci in op.control_inputs]
|
||||
control_inputs_ = [ci for ci in control_inputs_ if ci is not None]
|
||||
|
||||
# Transform it if any:
|
||||
original_op_ = info.transformer.transform_original_op_handler(info,
|
||||
op._original_op)
|
||||
|
||||
# Transform inputs:
|
||||
inputs_ = [info.transformer._transform_t(t) for t in op.inputs]
|
||||
|
||||
# Clone the node def:
|
||||
node_def_ = deepcopy(op._node_def)
|
||||
|
||||
# Transform name:
|
||||
name_ = info.transformer.new_name(op.name)
|
||||
name_ = info.new_name(op.name)
|
||||
name_ = info.graph_.unique_name(name_)
|
||||
node_def_.name = name_
|
||||
|
||||
@ -167,45 +158,196 @@ def copy_op_handler(info, op, copy_shape=True):
|
||||
op_def_ = deepcopy(op._op_def)
|
||||
|
||||
# Initialize a new Operation instance
|
||||
op_ = tf_ops.Operation(node_def_, info.graph_, inputs_, output_types_,
|
||||
control_inputs_, input_types_, original_op_, op_def_)
|
||||
op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_,
|
||||
[], input_types_, None, op_def_)
|
||||
|
||||
# copy the shape over
|
||||
if copy_shape:
|
||||
for t, t_ in zip(op.outputs, op_.outputs):
|
||||
t_.set_shape(t.get_shape())
|
||||
|
||||
# Finalize original op.
|
||||
if op._original_op:
|
||||
original_op = info.transform_original_op_handler(info, op._original_op)
|
||||
if original_op is None:
|
||||
logging.info("Could not find original op of: %s", op_.name)
|
||||
else:
|
||||
op_._original_op = original_op
|
||||
|
||||
# Add op to the graph
|
||||
info.graph_._add_op(op_)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
return op_
|
||||
return op_, op_.outputs
|
||||
|
||||
|
||||
def transform_op_in_place(info, op, detach_outputs=False):
|
||||
"""Transform a op in-place - experimental!
|
||||
class TransformerInfo(object):
|
||||
""""Contains information about the result of a transform operation."""
|
||||
|
||||
Transform an operation in place. It reconnects the inputs if they have been
|
||||
modified. if detach_outputs is True, the outputs of op are also detached.
|
||||
def __init__(self, info):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
info: Transform._Info instance.
|
||||
op: the op to transform in place.
|
||||
detach_outputs: if True, the outputs of op are detached, ready for the user
|
||||
to add more operation.
|
||||
Returns:
|
||||
The transformed op.
|
||||
Args:
|
||||
info: an instance of Transformer._TmpInfo containing various internal
|
||||
information about the transform operation.
|
||||
"""
|
||||
self._graph = info.graph
|
||||
self._scope = info.scope
|
||||
self._graph_ = info.graph_
|
||||
self._scope_ = info.scope_
|
||||
self._transformed_ops = info.transformed_ops
|
||||
self._transformed_ts = info.transformed_ts
|
||||
|
||||
def _get_transformed_map(self, top):
|
||||
"""Return the correct container depending on the type of `top`."""
|
||||
if isinstance(top, tf_ops.Operation):
|
||||
return self._transformed_ops
|
||||
elif isinstance(top, tf_ops.Tensor):
|
||||
return self._transformed_ts
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected a tf.Tensor or a tf.Operation, got a {}".format(
|
||||
type(top)))
|
||||
|
||||
def _transformed_elem(self, original_top, missing_fn=None):
|
||||
"""Return the transformed op/tensor corresponding to the original one.
|
||||
|
||||
Args:
|
||||
original_top: the original tensor/operation.
|
||||
missing_fn: function handling the case where the counterpart
|
||||
cannot be found. By default, None is returned.
|
||||
Returns:
|
||||
the transformed tensor/operation (or None if no match is found).
|
||||
"""
|
||||
transformed_map = self._get_transformed_map(original_top)
|
||||
if isinstance(original_top, string_types):
|
||||
for original, transformed in iteritems(transformed_map):
|
||||
if original.name == original_top:
|
||||
return transformed
|
||||
return None if missing_fn is None else missing_fn(original_top)
|
||||
else:
|
||||
if original_top not in transformed_map:
|
||||
return None if missing_fn is None else missing_fn(original_top)
|
||||
return transformed_map[original_top]
|
||||
|
||||
def _original_elem(self, transformed_top, missing_fn=None):
|
||||
"""Return the original op/tensor corresponding to the transformed one.
|
||||
|
||||
Args:
|
||||
transformed_top: the transformed tensor/operation.
|
||||
missing_fn: function handling the case where the counterpart
|
||||
cannot be found. By default, None is returned.
|
||||
Returns:
|
||||
the original tensor/operation (or None if no match is found).
|
||||
"""
|
||||
transformed_map = self._get_transformed_map(transformed_top)
|
||||
if isinstance(transformed_top, string_types):
|
||||
finder = lambda transformed: transformed.name == transformed_top
|
||||
else:
|
||||
finder = lambda transformed: transformed == transformed_top
|
||||
for original, transformed in iteritems(transformed_map):
|
||||
if finder(transformed):
|
||||
return original
|
||||
return None if missing_fn is None else missing_fn(transformed_top)
|
||||
|
||||
def transformed(self, original, missing_fn=None):
|
||||
"""Return the transformed op/tensor corresponding to the original one.
|
||||
|
||||
Note that the output of this function mimics the hierarchy
|
||||
of its input argument `original`.
|
||||
Given an iterable, it returns a list. Given an operation or a tensor,
|
||||
it will return an operation or a tensor.
|
||||
|
||||
Args:
|
||||
original: the original tensor/operation.
|
||||
missing_fn: function handling the case where the counterpart
|
||||
cannot be found. By default, None is returned.
|
||||
Returns:
|
||||
the transformed tensor/operation (or None if no match is found).
|
||||
"""
|
||||
transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn)
|
||||
return util.transform_tree(original, transformed_elem)
|
||||
|
||||
def original(self, transformed, missing_fn=None):
|
||||
"""Return the original op/tensor corresponding to the transformed one.
|
||||
|
||||
Note that the output of this function mimics the hierarchy
|
||||
of its input argument `transformed`.
|
||||
Given an iterable, it returns a list. Given an operation or a tensor,
|
||||
it will return an operation or a tensor.
|
||||
|
||||
Args:
|
||||
transformed: the transformed tensor/operation.
|
||||
missing_fn: function handling the case where the counterpart
|
||||
cannot be found. By default, None is returned.
|
||||
Returns:
|
||||
the original tensor/operation (or None if no match is found).
|
||||
"""
|
||||
original_elem = partial(self._original_elem, missing_fn=missing_fn)
|
||||
return util.transform_tree(transformed, original_elem)
|
||||
|
||||
def __str__(self):
|
||||
res = StringIO()
|
||||
print("Transform result info:", file=res)
|
||||
if self._graph == self._graph_:
|
||||
in_place_str = "" if self._scope_ else " IN-PLACE"
|
||||
print(" Within graph[{}]{}".format(
|
||||
id(self._graph), in_place_str), file=res)
|
||||
else:
|
||||
print(" graph[{}] => graph[{}]".format(
|
||||
id(self._graph), id(self._graph_)), file=res)
|
||||
if self._scope:
|
||||
print(" Relative to source scope: {}".format(self._scope), file=res)
|
||||
if self._scope_:
|
||||
print(" Scope destination: {}".format(self._scope_), file=res)
|
||||
print("Operations mapping:", file=res)
|
||||
for op, op_ in iteritems(self._transformed_ops):
|
||||
print(" {} => {}".format(op.name, op_.name), file=res)
|
||||
return res.getvalue()
|
||||
|
||||
|
||||
class _TmpInfo(object):
|
||||
"""Transformer temporary data.
|
||||
|
||||
An instance of this class holds all the information relevant to a call
|
||||
to a transformer instance (that is, a call to __call__). An instance
|
||||
is created for the life-time of the __call__ function and is passed as
|
||||
argument to the handlers.
|
||||
"""
|
||||
# recursive call to the inputs:
|
||||
inputs = [info.transformer._transform_t(t) # pylint: disable=protected-access
|
||||
for t in op.inputs]
|
||||
# re-connect to the inputs if they have changed:
|
||||
if inputs != list(op.inputs):
|
||||
reroute.reroute_a2b_ts(inputs, op.inputs)
|
||||
# detach op from its consumer first ?
|
||||
if detach_outputs:
|
||||
edit.detach_outputs(op)
|
||||
return op
|
||||
|
||||
def __init__(self, sgv, dst_graph, dst_scope, src_scope):
|
||||
self.sgv = sgv
|
||||
self.sgv_inputs_set = frozenset(sgv.inputs)
|
||||
self.ops = frozenset(sgv.ops)
|
||||
self.control_outputs = util.ControlOutputs(sgv.graph)
|
||||
self.graph = sgv.graph
|
||||
self.scope = src_scope
|
||||
self.graph_ = dst_graph
|
||||
self.scope_ = dst_scope
|
||||
self.transformed_ops = {}
|
||||
self.transformed_ts = {}
|
||||
self.collections = dict((key, self.graph.get_collection(key))
|
||||
for key in self.graph.get_all_collection_keys())
|
||||
self.cyclic_ops = []
|
||||
self.transform_original_op_handler = transform_op_if_inside_handler
|
||||
|
||||
def new_name(self, name):
|
||||
"""Compute a destination name from a source name.
|
||||
|
||||
Args:
|
||||
name: the name to be "transformed".
|
||||
Returns:
|
||||
The transformed name.
|
||||
Raises:
|
||||
ValueError: if the source scope is used (that is, not an empty string)
|
||||
and the source name does not belong to the source scope.
|
||||
"""
|
||||
scope = self.scope
|
||||
if not name.startswith(scope):
|
||||
raise ValueError("{} does not belong to source scope: {}.".format(
|
||||
name, scope))
|
||||
rel_name = name[len(scope):]
|
||||
name_ = self.scope_ + rel_name
|
||||
return name_
|
||||
|
||||
|
||||
class Transformer(object):
|
||||
@ -216,155 +358,6 @@ class Transformer(object):
|
||||
the handlers.
|
||||
"""
|
||||
|
||||
class _Info(object):
|
||||
"""Transformer temporary data.
|
||||
|
||||
An instance of this class holds all the information relevant to a call
|
||||
to a transformer instance (that is, a call to __call__). An instance
|
||||
is created for the life-time of the __call__ function and is passed as
|
||||
argument to the handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, transformer, sgv, dst_graph, dst_scope, src_scope):
|
||||
self.transformer = transformer
|
||||
self.sgv = sgv
|
||||
self.sgv_inputs_set = frozenset(sgv.inputs)
|
||||
self.ops = frozenset(sgv.ops)
|
||||
self.control_outputs = util.ControlOutputs(sgv.graph)
|
||||
self.graph = sgv.graph
|
||||
self.scope = src_scope
|
||||
self.graph_ = dst_graph
|
||||
self.scope_ = dst_scope
|
||||
self.transformed_ops = {}
|
||||
self.transformed_ts = {}
|
||||
self.collections = dict((key, self.graph.get_collection(key))
|
||||
for key in self.graph.get_all_collection_keys())
|
||||
|
||||
|
||||
class ResultInfo(object):
|
||||
""""Contains information about the result of a transform operation."""
|
||||
|
||||
def __init__(self, info):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
info: an instance of Transformer._Info containing various internal
|
||||
information about the transform operation.
|
||||
"""
|
||||
self._graph = info.graph
|
||||
self._scope = info.scope
|
||||
self._graph_ = info.graph_
|
||||
self._scope_ = info.scope_
|
||||
self._transformed_ops = info.transformed_ops
|
||||
self._transformed_ts = info.transformed_ts
|
||||
|
||||
def _get_transformed_map(self, top):
|
||||
"""Return the correct container depending on the type of `top`."""
|
||||
if isinstance(top, tf_ops.Operation):
|
||||
return self._transformed_ops
|
||||
elif isinstance(top, tf_ops.Tensor):
|
||||
return self._transformed_ts
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected a tf.Tensor or a tf.Operation, got a {}".format(
|
||||
type(top)))
|
||||
|
||||
def _transformed_elem(self, original_top, missing_fn=None):
|
||||
"""Return the transformed op/tensor corresponding to the original one.
|
||||
|
||||
Args:
|
||||
original_top: the original tensor/operation.
|
||||
missing_fn: function handling the case where the counterpart
|
||||
cannot be found. By default, None is returned.
|
||||
Returns:
|
||||
the transformed tensor/operation (or None if no match is found).
|
||||
"""
|
||||
transformed_map = self._get_transformed_map(original_top)
|
||||
if isinstance(original_top, string_types):
|
||||
for original, transformed in iteritems(transformed_map):
|
||||
if original.name == original_top:
|
||||
return transformed
|
||||
return None if missing_fn is None else missing_fn(original_top)
|
||||
else:
|
||||
if original_top not in transformed_map:
|
||||
return None if missing_fn is None else missing_fn(original_top)
|
||||
return transformed_map[original_top]
|
||||
|
||||
def _original_elem(self, transformed_top, missing_fn=None):
|
||||
"""Return the original op/tensor corresponding to the transformed one.
|
||||
|
||||
Args:
|
||||
transformed_top: the transformed tensor/operation.
|
||||
missing_fn: function handling the case where the counterpart
|
||||
cannot be found. By default, None is returned.
|
||||
Returns:
|
||||
the original tensor/operation (or None if no match is found).
|
||||
"""
|
||||
transformed_map = self._get_transformed_map(transformed_top)
|
||||
if isinstance(transformed_top, string_types):
|
||||
finder = lambda transformed: transformed.name == transformed_top
|
||||
else:
|
||||
finder = lambda transformed: transformed == transformed_top
|
||||
for original, transformed in iteritems(transformed_map):
|
||||
if finder(transformed):
|
||||
return original
|
||||
return None if missing_fn is None else missing_fn(transformed_top)
|
||||
|
||||
def transformed(self, original, missing_fn=None):
|
||||
"""Return the transformed op/tensor corresponding to the original one.
|
||||
|
||||
Note that the output of this function mimics the hierarchy
|
||||
of its input argument `original`.
|
||||
Given an iterable, it returns a list. Given an operation or a tensor,
|
||||
it will return an operation or a tensor.
|
||||
|
||||
Args:
|
||||
original: the original tensor/operation.
|
||||
missing_fn: function handling the case where the counterpart
|
||||
cannot be found. By default, None is returned.
|
||||
Returns:
|
||||
the transformed tensor/operation (or None if no match is found).
|
||||
"""
|
||||
transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn)
|
||||
return util.transform_tree(original, transformed_elem)
|
||||
|
||||
def original(self, transformed, missing_fn=None):
|
||||
"""Return the original op/tensor corresponding to the transformed one.
|
||||
|
||||
Note that the output of this function mimics the hierarchy
|
||||
of its input argument `transformed`.
|
||||
Given an iterable, it returns a list. Given an operation or a tensor,
|
||||
it will return an operation or a tensor.
|
||||
|
||||
Args:
|
||||
transformed: the transformed tensor/operation.
|
||||
missing_fn: function handling the case where the counterpart
|
||||
cannot be found. By default, None is returned.
|
||||
Returns:
|
||||
the original tensor/operation (or None if no match is found).
|
||||
"""
|
||||
original_elem = partial(self._original_elem, missing_fn=missing_fn)
|
||||
return util.transform_tree(transformed, original_elem)
|
||||
|
||||
def __str__(self):
|
||||
res = StringIO()
|
||||
print("Transform result info:", file=res)
|
||||
if self._graph == self._graph_:
|
||||
in_place_str = "" if self._scope_ else " IN-PLACE"
|
||||
print(" Within graph[{}]{}".format(
|
||||
id(self._graph), in_place_str), file=res)
|
||||
else:
|
||||
print(" graph[{}] => graph[{}]".format(
|
||||
id(self._graph), id(self._graph_)), file=res)
|
||||
if self._scope:
|
||||
print(" Relative to source scope: {}".format(self._scope), file=res)
|
||||
if self._scope_:
|
||||
print(" Scope destination: {}".format(self._scope_), file=res)
|
||||
print("Operations mapping:", file=res)
|
||||
for op, op_ in iteritems(self._transformed_ops):
|
||||
print(" {} => {}".format(op.name, op_.name), file=res)
|
||||
return res.getvalue()
|
||||
|
||||
def __init__(self):
|
||||
"""Transformer constructor.
|
||||
|
||||
@ -395,9 +388,6 @@ class Transformer(object):
|
||||
self.transform_external_hidden_input_handler = keep_t_if_possible_handler
|
||||
self.transform_original_op_handler = transform_op_if_inside_handler
|
||||
|
||||
# temporary per-call variable
|
||||
self._info = None
|
||||
|
||||
def __call__(self,
|
||||
sgv,
|
||||
dst_graph,
|
||||
@ -420,7 +410,7 @@ class Transformer(object):
|
||||
Returns:
|
||||
A tuple `(sgv, info)` where:
|
||||
`sgv` is the transformed subgraph view;
|
||||
`info` is an instance of Transformer.ResultInfo containing
|
||||
`info` is an instance of TransformerInfo containing
|
||||
information about the transform, including mapping between
|
||||
original and transformed tensors and operations.
|
||||
Raises:
|
||||
@ -438,38 +428,68 @@ class Transformer(object):
|
||||
dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1]))
|
||||
|
||||
# Create temporary info used during this transform call
|
||||
self._info = Transformer._Info(self, sgv, dst_graph, dst_scope, src_scope)
|
||||
info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope)
|
||||
info.transform_original_op_handler = self.transform_original_op_handler
|
||||
|
||||
# Transform the graph starting from the output tensors.
|
||||
for output_t in self._info.sgv.outputs:
|
||||
self._transform_t(output_t)
|
||||
self._copy_ops(info)
|
||||
self._connect_ops(info)
|
||||
|
||||
# Some ops might have been missed by the previous walk, namely, the roots
|
||||
# without any outputs. So the walk is now finalized from those roots.
|
||||
remaining_ops = [op for op in self._info.sgv.ops
|
||||
if op not in self._info.transformed_ops]
|
||||
remaining_roots = [op for op in remaining_ops if not op.outputs]
|
||||
for op in remaining_roots:
|
||||
self._transform_op(op)
|
||||
|
||||
sgv_ = self._transform_sgv(sgv)
|
||||
|
||||
res_info = Transformer.ResultInfo(self._info)
|
||||
self._info = None
|
||||
# Compute information about the transformation
|
||||
res_info = TransformerInfo(info)
|
||||
sgv_ = self._transform_sgv(info, sgv)
|
||||
return sgv_, res_info
|
||||
|
||||
def _transform_sgv(self, sgv):
|
||||
def _copy_ops(self, info):
|
||||
"""Copy ops without connecting them."""
|
||||
for op in info.sgv.ops:
|
||||
logging.info("Copying op: %s", op.name)
|
||||
# TODO(fkp): return a subgraph?
|
||||
op_, op_outputs_ = self.transform_op_handler(info, op)
|
||||
if op is op_:
|
||||
raise ValueError("In-place tranformation not allowed.")
|
||||
|
||||
# Process op.
|
||||
info.transformed_ops[op] = op_
|
||||
self.assign_collections_handler(info, op, op_)
|
||||
|
||||
# Process output tensors.
|
||||
for op_output, op_output_ in zip(op.outputs, op_outputs_):
|
||||
info.transformed_ts[op_output] = op_output_
|
||||
self.assign_collections_handler(info, op_output, op_output_)
|
||||
|
||||
def _connect_ops(self, info):
|
||||
"""Connect the previously copied ops."""
|
||||
for op in info.sgv.ops:
|
||||
logging.info("Finalizing op: %s", op.name)
|
||||
op_ = info.transformed_ops[op]
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if op_.inputs:
|
||||
raise ValueError("The newly transformed op should not have "
|
||||
"any inputs yet: {}".format(op_.name))
|
||||
inputs_ = [self._transformed_t(info, t) for t in op.inputs]
|
||||
for t in inputs_:
|
||||
op_._add_input(t)
|
||||
|
||||
# Finalize control inputs:
|
||||
control_inputs_ = [self.transform_control_input_handler(info, ci)
|
||||
for ci in op.control_inputs]
|
||||
control_inputs_ = [ci for ci in control_inputs_ if ci is not None]
|
||||
reroute.add_control_inputs(op_, control_inputs_)
|
||||
|
||||
def _transform_sgv(self, info, sgv):
|
||||
"""Transform a subgraph view.
|
||||
|
||||
For convenience, a transform operation returns a subgraph view of the
|
||||
transformed graph.
|
||||
|
||||
Args:
|
||||
info: Temporary information for this transorfm call.
|
||||
sgv: the subgraph to be transformed.
|
||||
Returns:
|
||||
The transformed subgraph.
|
||||
"""
|
||||
ops_ = [op_ for _, op_ in iteritems(self._info.transformed_ops)]
|
||||
ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)]
|
||||
sgv_ = subgraph.SubGraphView(ops_)
|
||||
sgv_inputs_ = sgv_.inputs
|
||||
sgv_outputs_ = sgv_.outputs
|
||||
@ -477,9 +497,9 @@ class Transformer(object):
|
||||
# re-order inputs
|
||||
input_map_ = []
|
||||
for input_t in sgv.inputs:
|
||||
if input_t not in self._info.transformed_ts:
|
||||
if input_t not in info.transformed_ts:
|
||||
continue
|
||||
input_t_ = self._info.transformed_ts[input_t]
|
||||
input_t_ = info.transformed_ts[input_t]
|
||||
if input_t_ not in sgv_inputs_:
|
||||
continue
|
||||
input_t_index_ = sgv_.input_index(input_t_)
|
||||
@ -488,9 +508,9 @@ class Transformer(object):
|
||||
# re-order outputs
|
||||
output_map_ = []
|
||||
for output_t in sgv.outputs:
|
||||
if output_t not in self._info.transformed_ts:
|
||||
if output_t not in info.transformed_ts:
|
||||
continue
|
||||
output_t_ = self._info.transformed_ts[output_t]
|
||||
output_t_ = info.transformed_ts[output_t]
|
||||
if output_t_ not in sgv_outputs_:
|
||||
continue
|
||||
output_t_index_ = sgv_.output_index(output_t_)
|
||||
@ -498,90 +518,19 @@ class Transformer(object):
|
||||
|
||||
return sgv_.remap(input_map_, output_map_)
|
||||
|
||||
def _transform_t(self, t):
|
||||
"""Transform a tf.Tensor.
|
||||
|
||||
Args:
|
||||
t: the tensor to be transformed.
|
||||
Returns:
|
||||
The transformed tensor.
|
||||
"""
|
||||
if t in self._info.transformed_ts:
|
||||
return self._info.transformed_ts[t]
|
||||
|
||||
op, op_index = t.op, t.value_index
|
||||
|
||||
# If op is not in the subgraph:
|
||||
if op not in self._info.ops:
|
||||
# t_ is an input of the subgraph
|
||||
if t in self._info.sgv_inputs_set:
|
||||
t_ = self.transform_external_input_handler(self._info, t)
|
||||
# t_ is a hidden input of the subgraph
|
||||
def _transformed_t(self, info, t):
|
||||
"""Return tre transformed tensor of `t`."""
|
||||
if t not in info.transformed_ts:
|
||||
# If op is not in the subgraph.
|
||||
if t in info.sgv_inputs_set:
|
||||
# t is an input of the subgraph.
|
||||
return self.transform_external_input_handler(info, t)
|
||||
else:
|
||||
t_ = self.transform_external_hidden_input_handler(self._info, t)
|
||||
# If op is in the subgraph, just transform it:
|
||||
# t is a hidden input of the subgraph.
|
||||
return self.transform_external_hidden_input_handler(info, t)
|
||||
else:
|
||||
op_ = self._transform_op(op)
|
||||
t_ = op_.outputs[op_index]
|
||||
|
||||
# assign to collection
|
||||
if t is not t_:
|
||||
self.assign_collections_handler(self._info, t, t_)
|
||||
|
||||
self._info.transformed_ts[t] = t_
|
||||
return t_
|
||||
|
||||
def _transform_op(self, op):
|
||||
"""Transform a tf.Operation.
|
||||
|
||||
Args:
|
||||
op: the operation to be transformed.
|
||||
Returns:
|
||||
The transformed operation.
|
||||
"""
|
||||
if op in self._info.transformed_ops:
|
||||
return self._info.transformed_ops[op]
|
||||
|
||||
op_ = self.transform_op_handler(self._info, op)
|
||||
|
||||
# Add to all the active control dependencies
|
||||
# pylint: disable=protected-access
|
||||
self._info.graph_._record_op_seen_by_control_dependencies(op_)
|
||||
|
||||
# All to all the active devices
|
||||
for device_function in reversed(self._info.graph_._device_function_stack):
|
||||
if device_function is None:
|
||||
break
|
||||
op_._set_device(device_function(op_))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# TODO(fkp): Establish clear policy about what context managers are allowed.
|
||||
|
||||
# assign to collection
|
||||
if op is not op_:
|
||||
self.assign_collections_handler(self._info, op, op_)
|
||||
|
||||
self._info.transformed_ops[op] = op_
|
||||
return op_
|
||||
|
||||
def new_name(self, name):
|
||||
"""Compute a destination name from a source name.
|
||||
|
||||
Args:
|
||||
name: the name to be "transformed".
|
||||
Returns:
|
||||
The transformed name.
|
||||
Raises:
|
||||
ValueError: if the source scope is used (that is, not an empty string)
|
||||
and the source name does not belong to the source scope.
|
||||
"""
|
||||
scope = self._info.scope
|
||||
if not name.startswith(scope):
|
||||
raise ValueError("{} does not belong to source scope: {}.".format(name,
|
||||
scope))
|
||||
rel_name = name[len(scope):]
|
||||
name_ = self._info.scope_ + rel_name
|
||||
return name_
|
||||
# If op is in the subgraph, just return its transformed.
|
||||
return info.transformed_ts[t]
|
||||
|
||||
|
||||
def copy(sgv, dst_graph=None, dst_scope="", src_scope="",
|
||||
@ -600,7 +549,7 @@ def copy(sgv, dst_graph=None, dst_scope="", src_scope="",
|
||||
Returns:
|
||||
A tuple `(sgv, info)` where:
|
||||
`sgv` is the transformed subgraph view;
|
||||
`info` is an instance of Transformer.ResultInfo containing
|
||||
`info` is an instance of TransformerInfo containing
|
||||
information about the transform, including mapping between
|
||||
original and transformed tensors and operations.
|
||||
Raises:
|
||||
@ -642,7 +591,7 @@ def copy_with_input_replacements(sgv, replacement_ts,
|
||||
Returns:
|
||||
A tuple `(sgv, info)` where:
|
||||
`sgv` is the transformed subgraph view;
|
||||
`info` is an instance of Transformer.ResultInfo containing
|
||||
`info` is an instance of TransformerInfo containing
|
||||
information about the transform, including mapping between
|
||||
original and transformed tensors and operations.
|
||||
Raises:
|
||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
from six import iteritems
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
from tensorflow.python.ops import array_ops as tf_array_ops
|
||||
@ -465,3 +466,75 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
|
||||
"""
|
||||
return tf_array_ops.placeholder(
|
||||
dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
|
||||
|
||||
|
||||
_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$")
|
||||
|
||||
|
||||
def get_predefined_collection_names():
|
||||
"""Return all the predefined collection names."""
|
||||
return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys)
|
||||
if not _INTERNAL_VARIABLE_RE.match(key)]
|
||||
|
||||
|
||||
def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""):
|
||||
"""Find corresponding op/tensor in a different graph.
|
||||
|
||||
Args:
|
||||
target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph.
|
||||
dst_graph: The graph in which the corresponding graph element must be found.
|
||||
dst_scope: A scope which is prepended to the name to look for.
|
||||
src_scope: A scope which is removed from the original of `target` name.
|
||||
|
||||
Returns:
|
||||
The corresponding tf.Tensor` or a `tf.Operation`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `src_name` does not start with `src_scope`.
|
||||
TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation`
|
||||
KeyError: If the corresponding graph element cannot be found.
|
||||
"""
|
||||
src_name = target.name
|
||||
if src_scope:
|
||||
src_scope = scope_finalize(src_scope)
|
||||
if not src_name.startswidth(src_scope):
|
||||
raise ValueError("{} does not start with {}".format(src_name, src_scope))
|
||||
src_name = src_name[len(src_scope):]
|
||||
|
||||
dst_name = src_name
|
||||
if dst_scope:
|
||||
dst_scope = scope_finalize(dst_scope)
|
||||
dst_name = dst_scope + dst_name
|
||||
|
||||
if isinstance(target, tf_ops.Tensor):
|
||||
return dst_graph.get_tensor_by_name(dst_name)
|
||||
if isinstance(target, tf_ops.Operation):
|
||||
return dst_graph.get_operation_by_name(dst_name)
|
||||
raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target))
|
||||
|
||||
|
||||
def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""):
|
||||
"""Find corresponding ops/tensors in a different graph.
|
||||
|
||||
`targets` is a Python tree, that is, a nested structure of iterable
|
||||
(list, tupple, dictionary) whose leaves are instances of
|
||||
`tf.Tensor` or `tf.Operation`
|
||||
|
||||
Args:
|
||||
targets: A Python tree containing `tf.Tensor` or `tf.Operation`
|
||||
belonging to the original graph.
|
||||
dst_graph: The graph in which the corresponding graph element must be found.
|
||||
dst_scope: A scope which is prepended to the name to look for.
|
||||
src_scope: A scope which is removed from the original of `top` name.
|
||||
|
||||
Returns:
|
||||
A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `src_name` does not start with `src_scope`.
|
||||
TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation`
|
||||
KeyError: If the corresponding graph element cannot be found.
|
||||
"""
|
||||
def func(top):
|
||||
return find_corresponding_elem(top, dst_graph, dst_scope, src_scope)
|
||||
return transform_tree(targets, func)
|
||||
|
@ -21,8 +21,8 @@ transforms (including rotation) are supported.
|
||||
|
||||
## Image `Ops`
|
||||
|
||||
@@ rotate
|
||||
@@ transform
|
||||
@@rotate
|
||||
@@transform
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -31,3 +31,8 @@ from __future__ import print_function
|
||||
# pylint: disable=line-too-long
|
||||
from tensorflow.contrib.image.python.ops.image_ops import rotate
|
||||
from tensorflow.contrib.image.python.ops.image_ops import transform
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
|
||||
remove_undocumented(__name__)
|
||||
|
@ -52,7 +52,7 @@ tf_kernel_library(
|
||||
|
||||
py_library(
|
||||
name = "input_pipeline_py",
|
||||
srcs = glob(["python/ops/*.py"]),
|
||||
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
|
||||
data = [":python/ops/_input_pipeline_ops.so"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
|
@ -15,11 +15,12 @@
|
||||
"""Ops and modules related to input_pipeline.
|
||||
|
||||
@@obtain_next
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.input_pipeline.python.ops.input_pipeline_ops import obtain_next
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
@ -59,6 +59,7 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.integrate.python.ops.odes import *
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
||||
remove_undocumented(__name__)
|
||||
|
@ -23,17 +23,27 @@ common machine learning algorithms.
|
||||
@@avg_pool2d
|
||||
@@batch_norm
|
||||
@@convolution2d
|
||||
@@conv2d_in_plane
|
||||
@@convolution2d_in_plane
|
||||
@@conv2d_transpose
|
||||
@@convolution2d_transpose
|
||||
@@dropout
|
||||
@@flatten
|
||||
@@fully_connected
|
||||
@@layer_norm
|
||||
@@linear
|
||||
@@max_pool2d
|
||||
@@one_hot_encoding
|
||||
@@relu
|
||||
@@relu6
|
||||
@@repeat
|
||||
@@safe_embedding_lookup_sparse
|
||||
@@separable_conv2d
|
||||
@@separable_convolution2d
|
||||
@@softmax
|
||||
@@stack
|
||||
@@unit_norm
|
||||
@@embed_sequence
|
||||
|
||||
Aliases for fully_connected which set a default activation function are
|
||||
available: `relu`, `relu6` and `linear`.
|
||||
@ -95,6 +105,7 @@ Feature columns provide a mechanism to map data to a model.
|
||||
@@input_from_feature_columns
|
||||
@@joint_weighted_sum_from_feature_columns
|
||||
@@make_place_holder_tensors_for_base_features
|
||||
@@multi_class_target
|
||||
@@one_hot_column
|
||||
@@parse_feature_columns_from_examples
|
||||
@@parse_feature_columns_from_sequence_examples
|
||||
@ -105,6 +116,8 @@ Feature columns provide a mechanism to map data to a model.
|
||||
@@sparse_column_with_keys
|
||||
@@weighted_sparse_column
|
||||
@@weighted_sum_from_feature_columns
|
||||
@@infer_real_valued_columns
|
||||
@@sequence_input_from_feature_columns
|
||||
|
||||
"""
|
||||
|
||||
@ -112,16 +125,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.layers.python.layers import *
|
||||
from tensorflow.contrib.layers.python.ops import sparse_ops
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
# Note: `stack` operation is available, just excluded from the document above
|
||||
# due to collision with tf.stack.
|
||||
_allowed_symbols = ['bias_add',
|
||||
'conv2d',
|
||||
'feature_column',
|
||||
'legacy_fully_connected',
|
||||
'legacy_linear',
|
||||
'legacy_relu',
|
||||
'OPTIMIZER_CLS_NAMES',
|
||||
'regression_target',
|
||||
'SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY',
|
||||
'summaries']
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -122,11 +122,11 @@ import math
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.contrib.layers.python.ops import bucketization_op
|
||||
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
||||
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
|
||||
from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -587,7 +587,7 @@ class _SparseColumnKeys(_SparseColumn):
|
||||
"""Handles sparse column to id conversion."""
|
||||
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
|
||||
|
||||
table = contrib_lookup_ops.string_to_index_table_from_tensor(
|
||||
table = lookup.string_to_index_table_from_tensor(
|
||||
mapping=list(self.lookup_config.keys),
|
||||
default_value=self.lookup_config.default_value,
|
||||
name="lookup")
|
||||
@ -662,7 +662,7 @@ class _SparseColumnVocabulary(_SparseColumn):
|
||||
else:
|
||||
sparse_string_tensor = st
|
||||
|
||||
table = contrib_lookup_ops.string_to_index_table_from_file(
|
||||
table = lookup.string_to_index_table_from_file(
|
||||
vocabulary_file=self.lookup_config.vocabulary_file,
|
||||
num_oov_buckets=self.lookup_config.num_oov_buckets,
|
||||
vocab_size=self.lookup_config.vocab_size,
|
||||
|
@ -23,15 +23,18 @@ import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
# TODO: #6568 Remove this hack that makes dlopen() not crash.
|
||||
# pylint: disable=g-bad-todo
|
||||
# TODO(#6568): Remove this hack that makes dlopen() not crash.
|
||||
# pylint: enable=g-bad-todo
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
|
||||
import ctypes
|
||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.layers.python.layers import feature_column as fc
|
||||
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
||||
import tensorflow.contrib.layers.python.layers.feature_column as fc
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
|
@ -24,13 +24,24 @@ Train and evaluate TensorFlow models.
|
||||
@@Estimator
|
||||
@@Trainable
|
||||
@@Evaluable
|
||||
@@KMeansClustering
|
||||
@@ModeKeys
|
||||
@@ModelFnOps
|
||||
@@MetricSpec
|
||||
@@PredictionKey
|
||||
@@DNNClassifier
|
||||
@@DNNRegressor
|
||||
@@DNNLinearCombinedRegressor
|
||||
@@DNNLinearCombinedClassifier
|
||||
@@LinearClassifier
|
||||
@@LinearRegressor
|
||||
@@LogisticRegressor
|
||||
|
||||
## Distributed training utilities
|
||||
@@Experiment
|
||||
@@ExportStrategy
|
||||
@@TaskType
|
||||
|
||||
## Graph actions
|
||||
|
||||
Perform various training, evaluation, and inference actions on a graph.
|
||||
@ -58,6 +69,10 @@ Queue and read batched input data.
|
||||
@@read_batch_features
|
||||
@@read_batch_record_features
|
||||
|
||||
Export utilities
|
||||
|
||||
@@build_parsing_serving_input_fn
|
||||
@@ProblemType
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -67,7 +82,11 @@ from __future__ import print_function
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.learn.python.learn import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
__all__.append('datasets')
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ['datasets', 'head', 'io', 'models',
|
||||
'monitors', 'NotFittedError', 'ops', 'preprocessing',
|
||||
'utils', 'graph_actions']
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -19,8 +19,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.learn.python.learn import basic_session_run_hooks
|
||||
from tensorflow.contrib.learn.python.learn import datasets
|
||||
@ -46,4 +44,5 @@ from tensorflow.contrib.learn.python.learn.learn_io import *
|
||||
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
|
||||
from tensorflow.contrib.learn.python.learn.monitors import NanLossDuringTrainingError
|
||||
from tensorflow.contrib.learn.python.learn.trainable import Trainable
|
||||
from tensorflow.contrib.learn.python.learn.utils import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -12,10 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Some common SessionRunHook classes.
|
||||
|
||||
@@
|
||||
"""
|
||||
"""Some common SessionRunHook classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -306,6 +306,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
||||
from tensorflow.contrib.learn.python.learn.estimators.constants import ProblemType
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedClassifier
|
||||
|
@ -25,7 +25,10 @@ import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
# TODO: #6568 Remove this hack that makes dlopen() not crash.
|
||||
# pylint: disable=g-bad-todo
|
||||
# TODO(#6568): Remove this hack that makes dlopen() not crash.
|
||||
# pylint: enable=g-bad-todo
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags'):
|
||||
import ctypes
|
||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||
@ -35,6 +38,7 @@ import six
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
|
||||
from tensorflow.contrib.layers.python.layers import optimizers
|
||||
@ -48,7 +52,6 @@ from tensorflow.contrib.learn.python.learn.estimators import linear
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
|
||||
from tensorflow.contrib.lookup import lookup_ops
|
||||
from tensorflow.contrib.metrics.python.ops import metric_ops
|
||||
from tensorflow.contrib.testing.python.framework import util_test
|
||||
from tensorflow.python.client import session as session_lib
|
||||
@ -221,8 +224,8 @@ def _build_estimator_for_export_tests(tmpdir):
|
||||
vocab_file = gfile.GFile(vocab_file_name, mode='w')
|
||||
vocab_file.write(VOCAB_FILE_CONTENT)
|
||||
vocab_file.close()
|
||||
hashtable = lookup_ops.HashTable(
|
||||
lookup_ops.TextFileStringTableInitializer(vocab_file_name), 'x')
|
||||
hashtable = lookup.HashTable(
|
||||
lookup.TextFileStringTableInitializer(vocab_file_name), 'x')
|
||||
features['bogus_lookup'] = hashtable.lookup(
|
||||
math_ops.to_int64(features['feature']))
|
||||
|
||||
@ -878,8 +881,8 @@ class ReplicaDeviceSetterTest(test.TestCase):
|
||||
|
||||
with ops.device(estimator._get_replica_device_setter(config)):
|
||||
default_val = constant_op.constant([-1, -1], dtypes.int64)
|
||||
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
input_string = constant_op.constant(['brain', 'salad', 'tank'])
|
||||
output = table.lookup(input_string)
|
||||
self.assertDeviceEqual('/job:ps/task:0', table._table_ref.device)
|
||||
@ -889,8 +892,8 @@ class ReplicaDeviceSetterTest(test.TestCase):
|
||||
with ops.device(
|
||||
estimator._get_replica_device_setter(run_config.RunConfig())):
|
||||
default_val = constant_op.constant([-1, -1], dtypes.int64)
|
||||
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
input_string = constant_op.constant(['brain', 'salad', 'tank'])
|
||||
output = table.lookup(input_string)
|
||||
self.assertDeviceEqual('', table._table_ref.device)
|
||||
|
@ -97,7 +97,7 @@ def numpy_input_fn(x,
|
||||
shape_dict_of_x = {k: x[k].shape for k in x.keys()}
|
||||
shape_of_y = None if y is None else y.shape
|
||||
raise ValueError('Length of tensors in x and y is mismatched. All '
|
||||
'elementson x and y must have the same length.\n'
|
||||
'elements in x and y must have the same length.\n'
|
||||
'Shapes in x: {}\n'
|
||||
'Shape for y: {}\n'.format(shape_dict_of_x, shape_of_y))
|
||||
|
||||
|
@ -27,7 +27,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow.contrib.learn.python.learn.dataframe.queues.feeding_functions as ff
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_functions as ff
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
|
@ -27,7 +27,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow.contrib.learn.python.learn.dataframe.queues.feeding_functions as ff
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_functions as ff
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
@ -24,7 +24,8 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
|
||||
import ctypes
|
||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||
|
||||
import tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source as rs
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.transforms import reader_source as rs
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
@ -20,3 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.utils.export import export_estimator
|
||||
from tensorflow.contrib.learn.python.learn.utils.input_fn_utils import build_default_serving_input_fn
|
||||
from tensorflow.contrib.learn.python.learn.utils.input_fn_utils import build_parsing_serving_input_fn
|
||||
from tensorflow.contrib.learn.python.learn.utils.saved_model_export_utils import make_export_strategy
|
||||
|
||||
|
@ -54,3 +54,6 @@ from tensorflow.contrib.linalg.python.ops.linear_operator_matrix import *
|
||||
from tensorflow.contrib.linalg.python.ops.linear_operator_tril import *
|
||||
|
||||
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
@ -22,7 +22,7 @@ import abc
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.framework import tensor_util as contrib_tensor_util
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
|
@ -17,6 +17,8 @@
|
||||
## This package provides optimizers to train linear models.
|
||||
|
||||
@@SdcaModel
|
||||
@@SparseFeatureColumn
|
||||
@@SDCAOptimizer
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -25,3 +27,6 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
|
||||
from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn
|
||||
from tensorflow.contrib.linear_optimizer.python.sdca_optimizer import SDCAOptimizer
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from six.moves import range
|
||||
|
||||
from tensorflow.contrib.lookup import lookup_ops
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class ShardedMutableDenseHashTable(lookup_ops.LookupInterface):
|
||||
class ShardedMutableDenseHashTable(lookup.LookupInterface):
|
||||
"""A sharded version of MutableDenseHashTable.
|
||||
|
||||
It is designed to be interface compatible with LookupInterface and
|
||||
@ -41,7 +41,7 @@ class ShardedMutableDenseHashTable(lookup_ops.LookupInterface):
|
||||
internally. The shard is computed via the modulo operation on the key.
|
||||
"""
|
||||
|
||||
# TODO(andreasst): consider moving this to lookup_ops
|
||||
# TODO(andreasst): consider moving this to lookup module
|
||||
|
||||
def __init__(self,
|
||||
key_dtype,
|
||||
@ -56,7 +56,7 @@ class ShardedMutableDenseHashTable(lookup_ops.LookupInterface):
|
||||
table_shards = []
|
||||
for i in range(num_shards):
|
||||
table_shards.append(
|
||||
lookup_ops.MutableDenseHashTable(
|
||||
lookup.MutableDenseHashTable(
|
||||
key_dtype=key_dtype,
|
||||
value_dtype=value_dtype,
|
||||
default_value=default_value,
|
||||
|
@ -25,6 +25,7 @@
|
||||
@@IdTableWithHashBuckets
|
||||
@@HashTable
|
||||
@@MutableHashTable
|
||||
@@MutableDenseHashTable
|
||||
@@TableInitializerBase
|
||||
@@KeyValueTensorInitializer
|
||||
@@TextFileIndex
|
||||
@ -32,6 +33,9 @@
|
||||
@@TextFileIdTableInitializer
|
||||
@@TextFileStringTableInitializer
|
||||
|
||||
@@HasherSpec
|
||||
@@StrongHashSpec
|
||||
@@FastHashSpec
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -41,3 +45,6 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.lookup.lookup_ops import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -19,8 +19,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.losses.python import losses
|
||||
from tensorflow.contrib.losses.python.losses import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__, doc_string_modules=[losses])
|
||||
|
@ -109,6 +109,7 @@ labels and predictions tensors and results in a weighted average of the metric.
|
||||
@@streaming_mean_iou
|
||||
@@streaming_mean_relative_error
|
||||
@@streaming_mean_squared_error
|
||||
@@streaming_mean_tensor
|
||||
@@streaming_root_mean_squared_error
|
||||
@@streaming_covariance
|
||||
@@streaming_pearson_correlation
|
||||
@ -137,6 +138,8 @@ labels and predictions tensors and results in a weighted average of the metric.
|
||||
@@aggregate_metrics
|
||||
@@aggregate_metric_map
|
||||
|
||||
@@confusion_matrix
|
||||
|
||||
## Set `Ops`
|
||||
|
||||
@@set_difference
|
||||
@ -193,7 +196,7 @@ from tensorflow.contrib.metrics.python.ops.set_ops import set_difference
|
||||
from tensorflow.contrib.metrics.python.ops.set_ops import set_intersection
|
||||
from tensorflow.contrib.metrics.python.ops.set_ops import set_size
|
||||
from tensorflow.contrib.metrics.python.ops.set_ops import set_union
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
@ -23,7 +23,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework import tensor_util
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
@ -23,7 +23,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework import deprecated
|
||||
from tensorflow.contrib.framework import tensor_util
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.metrics.python.ops import set_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -4593,7 +4593,7 @@ class StreamingConcatTest(test.TestCase):
|
||||
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
|
||||
|
||||
def testNextArraySize(self):
|
||||
next_array_size = metrics.python.ops.metric_ops._next_array_size
|
||||
next_array_size = metric_ops._next_array_size # pylint: disable=protected-access
|
||||
with self.test_session():
|
||||
self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2)
|
||||
self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4)
|
||||
|
@ -12,7 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Module for deprecated ops in tf.nn."""
|
||||
"""Module for deprecated ops in tf.nn.
|
||||
|
||||
@@deprecated_flipped_softmax_cross_entropy_with_logits
|
||||
@@deprecated_flipped_sparse_softmax_cross_entropy_with_logits
|
||||
@@deprecated_flipped_sigmoid_cross_entropy_with_logits
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -21,3 +26,6 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.nn.python.ops.cross_entropy import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
@ -23,3 +23,12 @@ from tensorflow.contrib.opt.python.training.external_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
|
||||
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ['ExternalOptimizerInterface',
|
||||
'MovingAverageOptimizer',
|
||||
'ScipyOptimizerInterface',
|
||||
'VariableClippingOptimizer']
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -24,6 +24,7 @@
|
||||
@@BasicLSTMCell
|
||||
@@GRUCell
|
||||
@@LSTMCell
|
||||
@@LayerNormBasicLSTMCell
|
||||
|
||||
## Classes storing split `RNNCell` state
|
||||
|
||||
@ -32,6 +33,7 @@
|
||||
## RNN Cell wrappers (RNNCells that wrap other RNNCells)
|
||||
|
||||
@@MultiRNNCell
|
||||
@@LSTMBlockWrapper
|
||||
@@DropoutWrapper
|
||||
@@EmbeddingWrapper
|
||||
@@InputProjectionWrapper
|
||||
@ -86,10 +88,13 @@ from tensorflow.contrib.rnn.python.ops.core_rnn_cell import MultiRNNCell
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import, line-too-long
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
from tensorflow.contrib.rnn.python.ops.fused_rnn_cell import *
|
||||
from tensorflow.contrib.rnn.python.ops.gru_ops import *
|
||||
from tensorflow.contrib.rnn.python.ops.lstm_ops import *
|
||||
from tensorflow.contrib.rnn.python.ops.rnn import *
|
||||
from tensorflow.contrib.rnn.python.ops.rnn_cell import *
|
||||
# pylint: enable=unused-import,wildcard-import,line-too-long
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__, ['core_rnn_cell'])
|
||||
|
@ -19,13 +19,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long
|
||||
from tensorflow.contrib.seq2seq.python.ops.attention_decoder_fn import attention_decoder_fn_inference
|
||||
from tensorflow.contrib.seq2seq.python.ops.attention_decoder_fn import attention_decoder_fn_train
|
||||
from tensorflow.contrib.seq2seq.python.ops.attention_decoder_fn import prepare_attention
|
||||
from tensorflow.contrib.seq2seq.python.ops.decoder_fn import *
|
||||
from tensorflow.contrib.seq2seq.python.ops.loss import *
|
||||
from tensorflow.contrib.seq2seq.python.ops.seq2seq import *
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
# pylint: enable=unused-import,widcard-import,line-too-long
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ["attention_decoder_fn_inference",
|
||||
"attention_decoder_fn_train",
|
||||
"dynamic_rnn_decoder",
|
||||
"prepare_attention",
|
||||
"sequence_loss",
|
||||
"simple_decoder_fn_train",
|
||||
"simple_decoder_fn_inference"]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -25,3 +25,10 @@ from __future__ import print_function
|
||||
from tensorflow.python.pywrap_tensorflow import DeleteStatSummarizer
|
||||
from tensorflow.python.pywrap_tensorflow import NewStatSummarizer
|
||||
from tensorflow.python.pywrap_tensorflow import StatSummarizer
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = ['DeleteStatSummarizer', 'NewStatSummarizer',
|
||||
'StatSummarizer']
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
@ -19,4 +19,3 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
@ -35,7 +35,6 @@ from tensorflow.python.framework.meta_graph import stripped_op_list_for_graph
|
||||
from tensorflow.python.framework.tensor_util import constant_value
|
||||
from tensorflow.python.framework.tensor_util import make_tensor_proto
|
||||
from tensorflow.python.framework.tensor_util import MakeNdarray as make_ndarray
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
# pylint: disable=unused_import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
||||
|
@ -12,7 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for loading op libraries."""
|
||||
"""Utilities for loading op libraries.
|
||||
|
||||
@@load_op_library
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
@ -291,10 +291,43 @@ class Library(Document):
|
||||
def _generate_signature_for_function(self, func):
|
||||
"""Given a function, returns a string representing its args."""
|
||||
args_list = []
|
||||
argspec = inspect.getargspec(func)
|
||||
if isinstance(func, functools.partial):
|
||||
argspec = inspect.getargspec(func.func)
|
||||
# Remove the args from the original function that have been used up.
|
||||
first_default_arg = (
|
||||
len(argspec.args or []) - len(argspec.defaults or []))
|
||||
partial_args = len(func.args)
|
||||
if argspec.args:
|
||||
argspec_args = list(argspec.args[partial_args:])
|
||||
else:
|
||||
argspec_args = []
|
||||
if argspec.defaults:
|
||||
argspec_defaults = list(argspec.defaults[
|
||||
max(0, partial_args-first_default_arg):])
|
||||
else:
|
||||
argspec_defaults = []
|
||||
first_default_arg = max(0, first_default_arg - partial_args)
|
||||
for kwarg in func.keywords:
|
||||
if kwarg in argspec_args:
|
||||
i = argspec_args.index(kwarg)
|
||||
argspec_args.pop(i)
|
||||
if i >= first_default_arg:
|
||||
argspec_defaults.pop(i-first_default_arg)
|
||||
else:
|
||||
first_default_arg -= 1
|
||||
argspec_varargs = None
|
||||
argspec_keywords = None
|
||||
|
||||
else:
|
||||
argspec = inspect.getargspec(func)
|
||||
argspec_args = argspec.args
|
||||
argspec_defaults = argspec.defaults
|
||||
argspec_varargs = argspec.varargs
|
||||
argspec_keywords = argspec.keywords
|
||||
|
||||
first_arg_with_default = (
|
||||
len(argspec.args or []) - len(argspec.defaults or []))
|
||||
for arg in argspec.args[:first_arg_with_default]:
|
||||
len(argspec_args or []) - len(argspec_defaults or []))
|
||||
for arg in argspec_args[:first_arg_with_default]:
|
||||
if arg == "self":
|
||||
# Python documentation typically skips `self` when printing method
|
||||
# signatures.
|
||||
@ -306,16 +339,16 @@ class Library(Document):
|
||||
# TODO(aselle): This workaround is brittle on TestCase.__call__
|
||||
# so we need to wrap this in a try/catch
|
||||
# We should do something better.
|
||||
if argspec.varargs == "args" and argspec.keywords == "kwds":
|
||||
if argspec_varargs == "args" and argspec_keywords == "kwds":
|
||||
try:
|
||||
original_func = func.__closure__[0].cell_contents
|
||||
return self._generate_signature_for_function(original_func)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
if argspec.defaults:
|
||||
if argspec_defaults:
|
||||
for arg, default in zip(
|
||||
argspec.args[first_arg_with_default:], argspec.defaults):
|
||||
argspec_args[first_arg_with_default:], argspec_defaults):
|
||||
if callable(default):
|
||||
if hasattr(default, "__name__"):
|
||||
args_list.append("%s=%s" % (arg, default.__name__))
|
||||
@ -326,10 +359,10 @@ class Library(Document):
|
||||
args_list.append("%s=%s()" % (arg, default.__class__.__name__))
|
||||
else:
|
||||
args_list.append("%s=%r" % (arg, default))
|
||||
if argspec.varargs:
|
||||
args_list.append("*" + argspec.varargs)
|
||||
if argspec.keywords:
|
||||
args_list.append("**" + argspec.keywords)
|
||||
if argspec_varargs:
|
||||
args_list.append("*" + argspec_varargs)
|
||||
if argspec_keywords:
|
||||
args_list.append("**" + argspec_keywords)
|
||||
return "(" + ", ".join(args_list) + ")"
|
||||
|
||||
def _remove_docstring_indent(self, docstring):
|
||||
|
@ -261,11 +261,18 @@ _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange",
|
||||
# TODO(wicke): Remove contrib.layers.relu* after shortnames are
|
||||
# disabled. These conflict with tf.nn.relu*
|
||||
EXCLUDE = frozenset(["tf.contrib.learn.monitors.NanLossDuringTrainingError",
|
||||
"tf.contrib.layers.dropout",
|
||||
"tf.contrib.layers.bias_add",
|
||||
"tf.contrib.layers.conv2d",
|
||||
"tf.contrib.layers.conv2d_transpose",
|
||||
"tf.contrib.layers.separable_conv2d",
|
||||
"tf.contrib.layers.softmax",
|
||||
"tf.contrib.layers.relu", "tf.contrib.layers.relu6",
|
||||
"tf.contrib.framework.assert_global_step",
|
||||
"tf.contrib.framework.get_global_step",
|
||||
"tf.contrib.learn.NanLossDuringTrainingError",
|
||||
"tf.contrib.layers.stack",
|
||||
"tf.contrib.layers.ProblemType",
|
||||
"tf.confusion_matrix"])
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user