From 59ad9529bab849bfa10ab9ccfdda935be4d68e42 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 28 Jan 2017 03:55:07 -0800 Subject: [PATCH 1/6] Remove access to individual sub-modules: only use the symbols imported at the root of the library. Renamed some functions in the reroute module: removed a2b/b2a distinction and added _ios to swap and reroute. Change: 145878770 --- tensorflow/contrib/graph_editor/__init__.py | 112 ++---------------- tensorflow/contrib/graph_editor/edit.py | 4 +- tensorflow/contrib/graph_editor/reroute.py | 66 ++--------- tensorflow/contrib/graph_editor/select.py | 2 + .../contrib/graph_editor/tests/edit_test.py | 9 +- .../contrib/graph_editor/tests/match_test.py | 16 +-- .../graph_editor/tests/reroute_test.py | 36 +++--- .../contrib/graph_editor/tests/select_test.py | 73 ++++++------ .../graph_editor/tests/transform_test.py | 12 +- 9 files changed, 92 insertions(+), 238 deletions(-) diff --git a/tensorflow/contrib/graph_editor/__init__.py b/tensorflow/contrib/graph_editor/__init__.py index c59aa2520c5..04a4cbb8198 100644 --- a/tensorflow/contrib/graph_editor/__init__.py +++ b/tensorflow/contrib/graph_editor/__init__.py @@ -108,111 +108,12 @@ which to operate must always be given explicitly. This is the reason why *connect* or *bypass*. * transform: the Transformer class, which enables transforming (or simply copying) a subgraph into another one. - -## Module: util - -@@make_list_of_op -@@get_tensors -@@make_list_of_t -@@get_generating_ops -@@get_consuming_ops -@@ControlOutputs -@@placeholder_name -@@make_placeholder_from_tensor -@@make_placeholder_from_dtype_and_shape - -## Module: select - -@@filter_ts -@@filter_ts_from_regex -@@filter_ops -@@filter_ops_from_regex -@@get_name_scope_ops -@@check_cios -@@get_ops_ios -@@compute_boundary_ts -@@get_within_boundary_ops -@@get_forward_walk_ops -@@get_backward_walk_ops -@@get_walks_intersection_ops -@@get_walks_union_ops -@@select_ops -@@select_ts -@@select_ops_and_ts - -## Module: subgraph - -@@SubGraphView -@@make_view -@@make_view_from_scope - -## Module: reroute - -@@swap_ts -@@reroute_a2b_ts -@@reroute_b2a_ts -@@swap_inputs -@@reroute_a2b_inputs -@@reroute_b2a_inputs -@@swap_outputs -@@reroute_a2b_outputs -@@reroute_b2a_outputs -@@swap -@@reroute_a2b -@@reroute_b2a -@@remove_control_inputs -@@add_control_inputs - -## Module: edit - -@@detach_control_inputs -@@detach_control_outputs -@@detach_inputs -@@detach_outputs -@@detach -@@connect -@@bypass - -## Module: transform - -@@replace_t_with_placeholder_handler -@@keep_t_if_possible_handler -@@assign_renamed_collections_handler -@@transform_op_if_inside_handler -@@copy_op_handler -@@transform_op_in_place -@@Transformer -@@copy -@@copy_with_input_replacements -@@graph_replace - -## Module: match - -@@op_type -@@OpMatcher - -## Useful aliases - -@@ph -@@sgv -@@sgv_scope -@@ts -@@ops -@@matcher """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.graph_editor import edit -from tensorflow.contrib.graph_editor import match -from tensorflow.contrib.graph_editor import reroute -from tensorflow.contrib.graph_editor import select -from tensorflow.contrib.graph_editor import subgraph -from tensorflow.contrib.graph_editor import transform -from tensorflow.contrib.graph_editor import util - # pylint: disable=wildcard-import from tensorflow.contrib.graph_editor.edit import * from tensorflow.contrib.graph_editor.match import * @@ -224,9 +125,10 @@ from tensorflow.contrib.graph_editor.util import * # pylint: enable=wildcard-import # some useful aliases -ph = util.make_placeholder_from_dtype_and_shape -sgv = subgraph.make_view -sgv_scope = subgraph.make_view_from_scope -ts = select.select_ts -ops = select.select_ops -matcher = match.OpMatcher +# pylint: disable=g-bad-import-order +from tensorflow.contrib.graph_editor import subgraph as _subgraph +from tensorflow.contrib.graph_editor import util as _util +# pylint: enable=g-bad-import-order +ph = _util.make_placeholder_from_dtype_and_shape +sgv = _subgraph.make_view +sgv_scope = _subgraph.make_view_from_scope diff --git a/tensorflow/contrib/graph_editor/edit.py b/tensorflow/contrib/graph_editor/edit.py index 0ac79b6dd3b..3037eef4894 100644 --- a/tensorflow/contrib/graph_editor/edit.py +++ b/tensorflow/contrib/graph_editor/edit.py @@ -194,7 +194,7 @@ def connect(sgv0, sgv1, disconnect_first=False): if disconnect_first: detach_outputs(sgv0) sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs) - reroute.reroute_a2b_inputs(sgv0_outputs, sgv1) + reroute.reroute_inputs(sgv0_outputs, sgv1) return sgv0, sgv1 @@ -217,5 +217,5 @@ def bypass(sgv): sgv = subgraph.make_view(sgv) sgv_inputs = list(sgv.inputs) sgv, detached_inputs = detach_inputs(sgv) - reroute.reroute_a2b_ts(sgv_inputs, sgv.outputs) + reroute.reroute_ts(sgv_inputs, sgv.outputs) return sgv, detached_inputs diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 57284351132..4c5f281badd 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -24,17 +24,13 @@ from tensorflow.python.framework import ops as tf_ops __all__ = [ "swap_ts", - "reroute_a2b_ts", - "reroute_b2a_ts", + "reroute_ts", "swap_inputs", - "reroute_a2b_inputs", - "reroute_b2a_inputs", + "reroute_inputs", "swap_outputs", - "reroute_a2b_outputs", - "reroute_b2a_outputs", - "swap", - "reroute_a2b", - "reroute_b2a", + "reroute_outputs", + "swap_ios", + "reroute_ios", "remove_control_inputs", "add_control_inputs", ] @@ -232,7 +228,7 @@ def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify) -def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None): +def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, replace the end of t1 by the end of t0. B0 B1 B0 B1 @@ -258,33 +254,6 @@ def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None): return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify) -def reroute_b2a_ts(ts0, ts1, can_modify=None, cannot_modify=None): - r"""For each tensor's pair, replace the end of t0 by the end of t1. - - B0 B1 B0 B1 - | | => \| - A0 A1 A0 A1 - - The end of the tensors in ts0 are left dangling. - - Args: - ts0: an object convertible to a list of `tf.Tensor`. - ts1: an object convertible to a list of `tf.Tensor`. - can_modify: iterable of operations which can be modified. Any operation - outside within_ops will be left untouched by this function. - cannot_modify: iterable of operations which cannot be modified. - Any operation within cannot_modify will be left untouched by this - function. - Returns: - The number of individual modifications made by the function. - Raises: - TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. - TypeError: if can_modify or cannot_modify is not None and cannot be - converted to a list of tf.Operation. - """ - return _reroute_ts(ts0, ts1, _RerouteMode.b2a, can_modify, cannot_modify) - - def _reroute_sgv_remap(sgv0, sgv1, mode): """Remap in place the inputs of two subgraph views to mimic the reroute. @@ -425,46 +394,31 @@ def swap_inputs(sgv0, sgv1): return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.swap) -def reroute_a2b_inputs(sgv0, sgv1): +def reroute_inputs(sgv0, sgv1): """Re-route all the inputs of sgv0 to sgv1 (see reroute_inputs).""" return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b) -def reroute_b2a_inputs(sgv0, sgv1): - """Re-route all the inputs of sgv1 to sgv0 (see reroute_inputs).""" - return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.b2a) - - def swap_outputs(sgv0, sgv1): """Swap all the outputs of sgv0 and sgv1 (see _reroute_outputs).""" return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap) -def reroute_a2b_outputs(sgv0, sgv1): +def reroute_outputs(sgv0, sgv1): """Re-route all the outputs of sgv0 to sgv1 (see _reroute_outputs).""" return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b) -def reroute_b2a_outputs(sgv0, sgv1): - """Re-route all the outputs of sgv1 to sgv0 (see _reroute_outputs).""" - return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.b2a) - - -def swap(sgv0, sgv1): +def swap_ios(sgv0, sgv1): """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute).""" return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap) -def reroute_a2b(sgv0, sgv1): +def reroute_ios(sgv0, sgv1): """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute).""" return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b) -def reroute_b2a(sgv0, sgv1): - """Re-route the inputs and outputs of sgv1 to sgv0 (see _reroute).""" - return _reroute_sgv(sgv0, sgv1, _RerouteMode.b2a) - - def remove_control_inputs(op, cops): """Remove the control inputs cops from co. diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py index 1125b90a9e3..2401a3a1578 100644 --- a/tensorflow/contrib/graph_editor/select.py +++ b/tensorflow/contrib/graph_editor/select.py @@ -28,6 +28,8 @@ from tensorflow.contrib.graph_editor import util from tensorflow.python.framework import ops as tf_ops __all__ = [ + "can_be_regex", + "make_regex", "filter_ts", "filter_ts_from_regex", "filter_ops", diff --git a/tensorflow/contrib/graph_editor/tests/edit_test.py b/tensorflow/contrib/graph_editor/tests/edit_test.py index 371e6cdc8bc..8adaf84b42b 100644 --- a/tensorflow/contrib/graph_editor/tests/edit_test.py +++ b/tensorflow/contrib/graph_editor/tests/edit_test.py @@ -50,11 +50,11 @@ class EditTest(test.TestCase): def test_detach(self): """Test for ge.detach.""" sgv = ge.sgv(self.c.op, self.a.op) - control_outputs = ge.util.ControlOutputs(self.graph) + control_outputs = ge.ControlOutputs(self.graph) ge.detach(sgv, control_ios=control_outputs) # make sure the detached graph is as expected. self.assertTrue( - ge.matcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op)) + ge.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op)) def test_connect(self): """Test for ge.connect.""" @@ -66,13 +66,14 @@ class EditTest(test.TestCase): sgv = ge.sgv(x.op, y.op, z.op) ge.connect(sgv, ge.sgv(self.e.op).remap_inputs([0])) self.assertTrue( - ge.matcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op)) + ge.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op)) def test_bypass(self): """Test for ge.bypass.""" ge.bypass(ge.sgv(self.f.op).remap_inputs([0])) self.assertTrue( - ge.matcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")(self.h.op)) + ge.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")( + self.h.op)) if __name__ == "__main__": diff --git a/tensorflow/contrib/graph_editor/tests/match_test.py b/tensorflow/contrib/graph_editor/tests/match_test.py index 6eb71ea510e..bcb8f3f0e3b 100644 --- a/tensorflow/contrib/graph_editor/tests/match_test.py +++ b/tensorflow/contrib/graph_editor/tests/match_test.py @@ -42,20 +42,20 @@ class MatchTest(test.TestCase): self.h = math_ops.add(self.f, self.g, name="h") def test_simple_match(self): - self.assertTrue(ge.matcher("^.*/f$")(self.f.op)) + self.assertTrue(ge.OpMatcher("^.*/f$")(self.f.op)) self.assertTrue( - ge.matcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op)) - self.assertTrue(ge.matcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op)) + ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op)) + self.assertTrue(ge.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op)) self.assertTrue( - ge.matcher("^.*/f$").input_ops( + ge.OpMatcher("^.*/f$").input_ops( ge.match.op_type("Add"), ge.match.op_type("Const"))(self.f.op)) self.assertTrue( - ge.matcher("^.*/f$").input_ops("^.*/c$", "^.*/d$") - .output_ops(ge.matcher("^.*/h$") + ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$") + .output_ops(ge.OpMatcher("^.*/h$") .control_input_ops("^.*/c$"))(self.f.op)) self.assertTrue( - ge.matcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops( - ge.matcher("^.*/h$").control_input_ops("^.*/c$") + ge.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops( + ge.OpMatcher("^.*/h$").control_input_ops("^.*/c$") .output_ops([]))(self.f.op)) diff --git a/tensorflow/contrib/graph_editor/tests/reroute_test.py b/tensorflow/contrib/graph_editor/tests/reroute_test.py index e116f93d600..d663c8839da 100644 --- a/tensorflow/contrib/graph_editor/tests/reroute_test.py +++ b/tensorflow/contrib/graph_editor/tests/reroute_test.py @@ -40,30 +40,30 @@ class RerouteTest(test.TestCase): self.c2 = math_ops.add(self.a2, self.b2, name="c2") def test_swap(self): - ge.reroute.swap_ts([self.a0, self.b0], [self.a1, self.b1]) - self.assertTrue(ge.matcher("c0").input_ops("a1", "b1")(self.c0.op)) - self.assertTrue(ge.matcher("c1").input_ops("a0", "b0")(self.c1.op)) + ge.swap_ts([self.a0, self.b0], [self.a1, self.b1]) + self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) + self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) def test_multiswap(self): with self.graph.as_default(): a3 = constant_op.constant(3.0, shape=[2], name="a3") - ge.reroute.swap( - ge.sgv(a3.op).remap_outputs([0, 0]), ge.sgv(self.a0.op, self.a1.op)) - self.assertTrue(ge.matcher("c0").input_ops("a3", "b0")(self.c0.op)) - self.assertTrue(ge.matcher("c1").input_ops("a3", "b1")(self.c1.op)) + ge.swap_ios(ge.sgv(a3.op).remap_outputs([0, 0]), + ge.sgv(self.a0.op, self.a1.op)) + self.assertTrue(ge.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op)) + self.assertTrue(ge.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op)) def test_reroute(self): - ge.reroute.reroute_a2b_ts([self.a0, self.b0], [self.a1, self.b1]) - self.assertTrue(ge.matcher("c0").input_ops("a0", "b0")(self.c0.op)) - self.assertTrue(ge.matcher("c1").input_ops("a0", "b0")(self.c1.op)) + ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1]) + self.assertTrue(ge.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) + self.assertTrue(ge.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) - ge.reroute.reroute_b2a_ts([self.a0, self.b0], [self.a1, self.b1]) - self.assertTrue(ge.matcher("c0").input_ops("a1", "b1")(self.c0.op)) - self.assertTrue(ge.matcher("c1").input_ops("a1", "b1")(self.c1.op)) + ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0]) + self.assertTrue(ge.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) + self.assertTrue(ge.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op)) def test_compatibility(self): with self.assertRaises(ValueError): - ge.reroute.reroute_a2b_ts([self.a0, self.b0], [self.a2, self.b2]) + ge.reroute_ts([self.a0, self.b0], [self.a2, self.b2]) def test_reroute_can_modify(self): graph = ops.Graph() @@ -82,11 +82,11 @@ class RerouteTest(test.TestCase): sgv0 = ge.sgv(a.op, b.op, c.op) sgv1 = ge.sgv(e.op, f.op) - ge.reroute.swap_outputs(sgv0, sgv1) + ge.swap_outputs(sgv0, sgv1) self.assertTrue( - ge.matcher("g").input_ops("a", ge.matcher("c") - .input_ops("a", "b"))(g.op)) - self.assertTrue(ge.matcher("d").input_ops("e", "f")(d.op)) + ge.OpMatcher("g").input_ops("a", ge.OpMatcher("c").input_ops("a", "b"))( + g.op)) + self.assertTrue(ge.OpMatcher("d").input_ops("e", "f")(d.op)) if __name__ == "__main__": diff --git a/tensorflow/contrib/graph_editor/tests/select_test.py b/tensorflow/contrib/graph_editor/tests/select_test.py index 68f38e31391..82f999637d0 100644 --- a/tensorflow/contrib/graph_editor/tests/select_test.py +++ b/tensorflow/contrib/graph_editor/tests/select_test.py @@ -44,76 +44,71 @@ class SelectTest(test.TestCase): self.h = math_ops.add(self.f, self.g, name="h") def test_regex(self): - """Test for ge.select.can_be_regex and ge.select.make_regex.""" - self.assertTrue(ge.select.can_be_regex("foo")) - self.assertTrue(ge.select.can_be_regex(re.compile("foo"))) + """Test for ge.can_be_regex and ge.make_regex.""" + self.assertTrue(ge.can_be_regex("foo")) + self.assertTrue(ge.can_be_regex(re.compile("foo"))) regex = re.compile("foo") - self.assertIs(ge.select.make_regex(regex), regex) + self.assertIs(ge.make_regex(regex), regex) def test_get_input_output_ts(self): - """Test for ge.select._get_input_ts abd ge.select._get_output_ts.""" + """Test for ge._get_input_ts abd ge._get_output_ts.""" self.assertEqual(len(ge.select._get_input_ts(self.graph)), 6) self.assertEqual(len(ge.select._get_output_ts(self.graph)), 8) def test_get_filter(self): """Test for various filtering operations on ts ops.""" # TODO(fkp): parameterise - self.assertEqual(len(ge.select.filter_ops(self.graph, True)), 8) + self.assertEqual(len(ge.filter_ops(self.graph, True)), 8) self.assertEqual( - len( - ge.select.filter_ops(self.graph, - lambda op: op.node_def.op == "Const")), 3) + len(ge.filter_ops(self.graph, lambda op: op.node_def.op == "Const")), 3) self.assertEqual( - len( - ge.select.filter_ops(self.graph, - lambda op: op.node_def.op == "Add")), 5) + len(ge.filter_ops(self.graph, lambda op: op.node_def.op == "Add")), 5) self.assertEqual( - len(ge.select.filter_ops_from_regex(self.graph, r"^.*\b[abc]$")), 3) + len(ge.filter_ops_from_regex(self.graph, r"^.*\b[abc]$")), 3) - self.assertEqual(len(ge.select.filter_ts(self.graph, True)), 8) + self.assertEqual(len(ge.filter_ts(self.graph, True)), 8) self.assertEqual( - len(ge.select.filter_ts_from_regex(self.graph, r"^.*/[fgh]:\d$")), 3) + len(ge.filter_ts_from_regex(self.graph, r"^.*/[fgh]:\d$")), 3) - self.assertEqual(len(ge.select.get_name_scope_ops(self.graph, "foo/")), 7) - self.assertEqual( - len(ge.select.get_name_scope_ops(self.graph, "foo/bar")), 4) + self.assertEqual(len(ge.get_name_scope_ops(self.graph, "foo/")), 7) + self.assertEqual(len(ge.get_name_scope_ops(self.graph, "foo/bar")), 4) def test_get_ops_ios(self): - """Test for ge.select.get_ops_ios.""" + """Test for ge.get_ops_ios.""" control_outputs = ge.util.ControlOutputs(self.graph) self.assertEqual( - len(ge.select.get_ops_ios( + len(ge.get_ops_ios( self.h.op, control_ios=control_outputs)), 3) - self.assertEqual(len(ge.select.get_ops_ios(self.h.op)), 2) + self.assertEqual(len(ge.get_ops_ios(self.h.op)), 2) self.assertEqual( - len(ge.select.get_ops_ios( + len(ge.get_ops_ios( self.c.op, control_ios=control_outputs)), 6) - self.assertEqual(len(ge.select.get_ops_ios(self.c.op)), 5) + self.assertEqual(len(ge.get_ops_ios(self.c.op)), 5) def test_compute_boundary_ts_0(self): - """Test for ge.select.compute_boundary_ts.""" - input_ts, output_ts, inside_ts = ge.select.compute_boundary_ts(self.g.op) + """Test for ge.compute_boundary_ts.""" + input_ts, output_ts, inside_ts = ge.compute_boundary_ts(self.g.op) self.assertEqual(list(input_ts), [self.c, self.a]) self.assertEqual(list(output_ts), [self.g]) self.assertEqual(list(inside_ts), []) def test_compute_boundary_ts_1(self): - """Test for ge.select.compute_boundary_ts.""" - input_ts, output_ts, inside_ts = ge.select.compute_boundary_ts( + """Test for ge.compute_boundary_ts.""" + input_ts, output_ts, inside_ts = ge.compute_boundary_ts( [self.g.op, self.h.op]) self.assertEqual(list(input_ts), [self.c, self.a, self.f]) self.assertEqual(list(output_ts), [self.h]) self.assertEqual(list(inside_ts), [self.g]) def test_compute_boundary_ts_2(self): - """Test for ge.select.compute_boundary_ts.""" + """Test for ge.compute_boundary_ts.""" graph = ops_lib.Graph() with graph.as_default(): a = constant_op.constant(1, name="a") b = constant_op.constant(1, name="b") c = math_ops.add(a, b, name="c") _ = a + c - input_ts, output_ts, inside_ts = ge.select.compute_boundary_ts([a.op, c.op]) + input_ts, output_ts, inside_ts = ge.compute_boundary_ts([a.op, c.op]) self.assertEqual(list(input_ts), [b]) self.assertEqual(list(output_ts), [a, c]) self.assertEqual(list(inside_ts), [a]) @@ -121,7 +116,7 @@ class SelectTest(test.TestCase): def test_get_within_boundary_ops_0(self): """Test for test_get_within_boundary_ops.""" control_outputs = ge.util.ControlOutputs(self.graph) - ops = ge.select.get_within_boundary_ops( + ops = ge.get_within_boundary_ops( ops=self.graph, seed_ops=self.f.op, boundary_ops=[self.c.op, self.h.op], @@ -130,19 +125,19 @@ class SelectTest(test.TestCase): self.assertEqual(len(ops), 3) def test_get_within_boundary_ops_1(self): - """Test for ge.select.test_get_within_boundary_ops.""" - ops = ge.select.get_within_boundary_ops( + """Test for ge.test_get_within_boundary_ops.""" + ops = ge.get_within_boundary_ops( ops=self.graph, seed_ops=self.h.op, boundary_ops=[self.f.op, self.g.op]) self.assertEqual(len(ops), 3) def test_get_walks_intersection(self): - """Test for ge.select.get_walks_intersection_ops.""" - ops = ge.select.get_walks_intersection_ops([self.c.op], [self.g.op]) + """Test for ge.get_walks_intersection_ops.""" + ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op]) self.assertEqual(len(ops), 2) def test_get_walks_union(self): - """Test for ge.select.get_walks_union_ops.""" - ops = ge.select.get_walks_union_ops([self.f.op], [self.g.op]) + """Test for ge.get_walks_union_ops.""" + ops = ge.get_walks_union_ops([self.f.op], [self.g.op]) self.assertEqual(len(ops), 6) def test_select_ops(self): @@ -151,7 +146,7 @@ class SelectTest(test.TestCase): (("^foo/bar/",), 4), (("^foo/bar/", "a"), 5),) for param, length in parameters: - ops = ge.select.select_ops(*param, graph=self.graph) + ops = ge.select_ops(*param, graph=self.graph) self.assertEqual(len(ops), length) def test_select_ts(self): @@ -159,7 +154,7 @@ class SelectTest(test.TestCase): (".*:0", 8), (r".*/bar/\w+:0", 4),) for regex, length in parameters: - ts = ge.select.select_ts(regex, graph=self.graph) + ts = ge.select_ts(regex, graph=self.graph) self.assertEqual(len(ts), length) def test_select_ops_and_ts(self): @@ -167,7 +162,7 @@ class SelectTest(test.TestCase): (("^foo/.*",), 7, 0), (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),) for param, l0, l1 in parameters: - ops, ts = ge.select.select_ops_and_ts(*param, graph=self.graph) + ops, ts = ge.select_ops_and_ts(*param, graph=self.graph) self.assertEqual(len(ops), l0) self.assertEqual(len(ts), l1) diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index 2c80c04ce61..88fea8bbace 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -101,12 +101,12 @@ class TransformTest(test.TestCase): graph = ops.Graph() transformer(self.graph, graph, "", "") - matcher0 = ge.matcher("AddNoise").input_ops( - "Noise", ge.matcher("Add").input_ops("Const", "Input")) - matcher1 = ge.matcher("AddNoise_1").input_ops( - "Noise_1", ge.matcher("Add_1").input_ops("Const_1", matcher0)) - matcher2 = ge.matcher("AddNoise_2").input_ops( - "Noise_2", ge.matcher("Add_2").input_ops("Const_2", matcher1)) + matcher0 = ge.OpMatcher("AddNoise").input_ops( + "Noise", ge.OpMatcher("Add").input_ops("Const", "Input")) + matcher1 = ge.OpMatcher("AddNoise_1").input_ops( + "Noise_1", ge.OpMatcher("Add_1").input_ops("Const_1", matcher0)) + matcher2 = ge.OpMatcher("AddNoise_2").input_ops( + "Noise_2", ge.OpMatcher("Add_2").input_ops("Const_2", matcher1)) top = ge.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top)) From daa030ed68f19d98633d8ef6e45bb15419242a3b Mon Sep 17 00:00:00 2001 From: Martin Wicke Date: Sun, 29 Jan 2017 17:50:23 -0800 Subject: [PATCH 2/6] Seal contrib interfaces (as much a feasible). If you were using a symbol which is now hidden, it should be added to the _allowed_symbols list in the appropriate __init__.py file. Change: 145943844 --- tensorflow/contrib/__init__.py | 4 + tensorflow/contrib/bayesflow/__init__.py | 10 + tensorflow/contrib/copy_graph/__init__.py | 5 + tensorflow/contrib/crf/__init__.py | 4 + tensorflow/contrib/cudnn_rnn/__init__.py | 3 + tensorflow/contrib/deprecated/__init__.py | 7 + tensorflow/contrib/distributions/__init__.py | 9 + tensorflow/contrib/graph_editor/__init__.py | 4 + tensorflow/contrib/graph_editor/select.py | 3 +- tensorflow/contrib/image/__init__.py | 9 +- tensorflow/contrib/input_pipeline/BUILD | 2 +- tensorflow/contrib/input_pipeline/__init__.py | 5 +- tensorflow/contrib/integrate/__init__.py | 5 +- tensorflow/contrib/layers/__init__.py | 32 +- .../layers/python/layers/feature_column.py | 6 +- .../python/layers/feature_column_test.py | 7 +- tensorflow/contrib/learn/__init__.py | 25 +- .../contrib/learn/python/learn/__init__.py | 2 - .../python/learn/basic_session_run_hooks.py | 5 +- .../learn/python/learn/estimators/__init__.py | 1 + .../python/learn/estimators/estimator_test.py | 19 +- .../learn/python/learn/learn_io/numpy_io.py | 2 +- .../tests/dataframe/feeding_functions_test.py | 2 +- .../dataframe/feeding_queue_runner_test.py | 2 +- .../tests/dataframe/reader_source_test.py | 3 +- tensorflow/contrib/linalg/__init__.py | 3 + .../python/ops/linear_operator_test_util.py | 2 +- .../contrib/linear_optimizer/__init__.py | 5 + .../ops/sharded_mutable_dense_hashtable.py | 8 +- tensorflow/contrib/lookup/__init__.py | 7 + tensorflow/contrib/lookup/lookup_ops_test.py | 464 +++++++++--------- tensorflow/contrib/losses/__init__.py | 6 +- tensorflow/contrib/metrics/__init__.py | 7 +- .../metrics/python/ops/histogram_ops.py | 2 +- .../contrib/metrics/python/ops/metric_ops.py | 2 +- .../metrics/python/ops/metric_ops_test.py | 2 +- tensorflow/contrib/nn/__init__.py | 10 +- tensorflow/contrib/opt/__init__.py | 9 + tensorflow/contrib/rnn/__init__.py | 7 +- tensorflow/contrib/seq2seq/__init__.py | 18 +- .../contrib/stat_summarizer/__init__.py | 7 + tensorflow/contrib/tfprof/__init__.py | 1 - tensorflow/contrib/util/__init__.py | 7 +- tensorflow/contrib/util/loader.py | 5 +- tensorflow/python/framework/docs.py | 53 +- .../python/framework/gen_docs_combined.py | 7 + 46 files changed, 500 insertions(+), 308 deletions(-) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index b0022423109..9404b7a1463 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -58,3 +58,7 @@ from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.ndlstm import python as ndlstm from tensorflow.contrib.specs import python as specs + +del absolute_import +del division +del print_function diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 53dac356750..baa5748eb62 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -30,3 +30,13 @@ from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor from tensorflow.contrib.bayesflow.python.ops import stochastic_variables from tensorflow.contrib.bayesflow.python.ops import variational_inference # pylint: enable=unused-import,line-too-long + +from tensorflow.python.util.all_util import remove_undocumented + + +_allowed_symbols = ['entropy', 'monte_carlo', + 'special_math', 'stochastic_gradient_estimators', + 'stochastic_graph', 'stochastic_tensor', + 'stochastic_variables', 'variational_inference'] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/copy_graph/__init__.py b/tensorflow/contrib/copy_graph/__init__.py index 4e9f34ca8c1..96dc0d7df2d 100644 --- a/tensorflow/contrib/copy_graph/__init__.py +++ b/tensorflow/contrib/copy_graph/__init__.py @@ -20,4 +20,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.copy_graph.python.util import copy_elements from tensorflow.contrib.copy_graph.python.util.copy_elements import * + +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(__name__, doc_string_modules=[copy_elements]) diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index 195e8cd7171..7f7818c845d 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -37,3 +37,7 @@ from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score from tensorflow.contrib.crf.python.ops.crf import crf_unary_score from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell from tensorflow.contrib.crf.python.ops.crf import viterbi_decode + +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index b7ac5e7146f..7a8224fa5eb 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -23,3 +23,6 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanh from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import RNNParamsSaveable + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/deprecated/__init__.py b/tensorflow/contrib/deprecated/__init__.py index 2c94882cd75..befb8e6198b 100644 --- a/tensorflow/contrib/deprecated/__init__.py +++ b/tensorflow/contrib/deprecated/__init__.py @@ -95,3 +95,10 @@ from tensorflow.python.ops.logging_ops import merge_all_summaries from tensorflow.python.ops.logging_ops import merge_summary from tensorflow.python.ops.logging_ops import scalar_summary # pylint: enable=unused-import,line-too-long + +from tensorflow.python.util.all_util import remove_undocumented +_allowed_symbols = ['audio_summary', 'histogram_summary', + 'image_summary', 'merge_all_summaries', + 'merge_summary', 'scalar_summary'] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 01896c52440..f822f723eb0 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -134,3 +134,12 @@ from tensorflow.contrib.distributions.python.ops.uniform import * from tensorflow.contrib.distributions.python.ops.wishart import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['bijector', + 'ConditionalDistribution', + 'ConditionalTransformedDistribution', + 'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED'] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/graph_editor/__init__.py b/tensorflow/contrib/graph_editor/__init__.py index 04a4cbb8198..47905cc9927 100644 --- a/tensorflow/contrib/graph_editor/__init__.py +++ b/tensorflow/contrib/graph_editor/__init__.py @@ -132,3 +132,7 @@ from tensorflow.contrib.graph_editor import util as _util ph = _util.make_placeholder_from_dtype_and_shape sgv = _subgraph.make_view sgv_scope = _subgraph.make_view_from_scope + +del absolute_import +del division +del print_function diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py index 2401a3a1578..706c4091189 100644 --- a/tensorflow/contrib/graph_editor/select.py +++ b/tensorflow/contrib/graph_editor/select.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Various ways of selecting operations and tensors in a graph. -""" +"""Various ways of selecting operations and tensors in a graph.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index 03ca71e3f56..9d3ea508683 100644 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -21,8 +21,8 @@ transforms (including rotation) are supported. ## Image `Ops` -@@ rotate -@@ transform +@@rotate +@@transform """ from __future__ import absolute_import from __future__ import division @@ -31,3 +31,8 @@ from __future__ import print_function # pylint: disable=line-too-long from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform + +from tensorflow.python.util.all_util import remove_undocumented + + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index 8eb8201f08f..a6bca863899 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -52,7 +52,7 @@ tf_kernel_library( py_library( name = "input_pipeline_py", - srcs = glob(["python/ops/*.py"]), + srcs = glob(["python/ops/*.py"]) + ["__init__.py"], data = [":python/ops/_input_pipeline_ops.so"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/input_pipeline/__init__.py b/tensorflow/contrib/input_pipeline/__init__.py index d1219883c95..02fd4135bfa 100644 --- a/tensorflow/contrib/input_pipeline/__init__.py +++ b/tensorflow/contrib/input_pipeline/__init__.py @@ -15,11 +15,12 @@ """Ops and modules related to input_pipeline. @@obtain_next - """ - from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.input_pipeline.python.ops.input_pipeline_ops import obtain_next + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/integrate/__init__.py b/tensorflow/contrib/integrate/__init__.py index 953dc6c55ae..c951efd3d9f 100644 --- a/tensorflow/contrib/integrate/__init__.py +++ b/tensorflow/contrib/integrate/__init__.py @@ -59,6 +59,7 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.integrate.python.ops.odes import * -from tensorflow.python.util.all_util import make_all +from tensorflow.python.util.all_util import remove_undocumented -__all__ = make_all(__name__) + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index b7832be73fc..c563b29de90 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -23,17 +23,27 @@ common machine learning algorithms. @@avg_pool2d @@batch_norm @@convolution2d +@@conv2d_in_plane @@convolution2d_in_plane +@@conv2d_transpose @@convolution2d_transpose +@@dropout @@flatten @@fully_connected @@layer_norm +@@linear @@max_pool2d @@one_hot_encoding +@@relu +@@relu6 @@repeat @@safe_embedding_lookup_sparse +@@separable_conv2d @@separable_convolution2d +@@softmax +@@stack @@unit_norm +@@embed_sequence Aliases for fully_connected which set a default activation function are available: `relu`, `relu6` and `linear`. @@ -95,6 +105,7 @@ Feature columns provide a mechanism to map data to a model. @@input_from_feature_columns @@joint_weighted_sum_from_feature_columns @@make_place_holder_tensors_for_base_features +@@multi_class_target @@one_hot_column @@parse_feature_columns_from_examples @@parse_feature_columns_from_sequence_examples @@ -105,6 +116,8 @@ Feature columns provide a mechanism to map data to a model. @@sparse_column_with_keys @@weighted_sparse_column @@weighted_sum_from_feature_columns +@@infer_real_valued_columns +@@sequence_input_from_feature_columns """ @@ -112,16 +125,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.layers.python.layers import * -from tensorflow.contrib.layers.python.ops import sparse_ops -from tensorflow.python.util.all_util import make_all # pylint: enable=unused-import,wildcard-import +from tensorflow.python.util.all_util import remove_undocumented -# Note: `stack` operation is available, just excluded from the document above -# due to collision with tf.stack. +_allowed_symbols = ['bias_add', + 'conv2d', + 'feature_column', + 'legacy_fully_connected', + 'legacy_linear', + 'legacy_relu', + 'OPTIMIZER_CLS_NAMES', + 'regression_target', + 'SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY', + 'summaries'] -__all__ = make_all(__name__) +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index ba6e70bfa2c..0db53f9af98 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -122,11 +122,11 @@ import math import six +from tensorflow.contrib import lookup from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.ops import bucketization_op from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops -from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor as sparse_tensor_py from tensorflow.python.ops import array_ops @@ -587,7 +587,7 @@ class _SparseColumnKeys(_SparseColumn): """Handles sparse column to id conversion.""" input_tensor = self._get_input_sparse_tensor(columns_to_tensors) - table = contrib_lookup_ops.string_to_index_table_from_tensor( + table = lookup.string_to_index_table_from_tensor( mapping=list(self.lookup_config.keys), default_value=self.lookup_config.default_value, name="lookup") @@ -662,7 +662,7 @@ class _SparseColumnVocabulary(_SparseColumn): else: sparse_string_tensor = st - table = contrib_lookup_ops.string_to_index_table_from_file( + table = lookup.string_to_index_table_from_file( vocabulary_file=self.lookup_config.vocabulary_file, num_oov_buckets=self.lookup_config.num_oov_buckets, vocab_size=self.lookup_config.vocab_size, diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index d166069bd6e..a3b2c98c807 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -23,15 +23,18 @@ import os import sys import tempfile -# TODO: #6568 Remove this hack that makes dlopen() not crash. +# pylint: disable=g-bad-todo +# TODO(#6568): Remove this hack that makes dlopen() not crash. +# pylint: enable=g-bad-todo +# pylint: disable=g-import-not-at-top if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import ctypes sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) import numpy as np +from tensorflow.contrib.layers.python.layers import feature_column as fc from tensorflow.contrib.layers.python.layers import feature_column_ops -import tensorflow.contrib.layers.python.layers.feature_column as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py index 811b89e8845..2cc38fbbec7 100644 --- a/tensorflow/contrib/learn/__init__.py +++ b/tensorflow/contrib/learn/__init__.py @@ -24,13 +24,24 @@ Train and evaluate TensorFlow models. @@Estimator @@Trainable @@Evaluable +@@KMeansClustering @@ModeKeys +@@ModelFnOps +@@MetricSpec +@@PredictionKey @@DNNClassifier @@DNNRegressor +@@DNNLinearCombinedRegressor +@@DNNLinearCombinedClassifier @@LinearClassifier @@LinearRegressor @@LogisticRegressor +## Distributed training utilities +@@Experiment +@@ExportStrategy +@@TaskType + ## Graph actions Perform various training, evaluation, and inference actions on a graph. @@ -58,6 +69,10 @@ Queue and read batched input data. @@read_batch_features @@read_batch_record_features +Export utilities + +@@build_parsing_serving_input_fn +@@ProblemType """ from __future__ import absolute_import @@ -67,7 +82,11 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.learn.python.learn import * # pylint: enable=wildcard-import -from tensorflow.python.util.all_util import make_all -__all__ = make_all(__name__) -__all__.append('datasets') +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['datasets', 'head', 'io', 'models', + 'monitors', 'NotFittedError', 'ops', 'preprocessing', + 'utils', 'graph_actions'] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/learn/python/learn/__init__.py b/tensorflow/contrib/learn/python/learn/__init__.py index d7b9aaffd4c..3c6d05dc0dc 100644 --- a/tensorflow/contrib/learn/python/learn/__init__.py +++ b/tensorflow/contrib/learn/python/learn/__init__.py @@ -19,8 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - # pylint: disable=wildcard-import from tensorflow.contrib.learn.python.learn import basic_session_run_hooks from tensorflow.contrib.learn.python.learn import datasets diff --git a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py index d239201efe2..2284ec46e97 100644 --- a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py +++ b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Some common SessionRunHook classes. - -@@ -""" +"""Some common SessionRunHook classes.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index 04be0ac7fa0..2c1c0e6dd5b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -306,6 +306,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError +from tensorflow.contrib.learn.python.learn.estimators.constants import ProblemType from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedClassifier diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 349072976c3..c4e49fa4078 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -25,7 +25,10 @@ import os import sys import tempfile -# TODO: #6568 Remove this hack that makes dlopen() not crash. +# pylint: disable=g-bad-todo +# TODO(#6568): Remove this hack that makes dlopen() not crash. +# pylint: enable=g-bad-todo +# pylint: disable=g-import-not-at-top if hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags'): import ctypes sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) @@ -35,6 +38,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import learn +from tensorflow.contrib import lookup from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.layers.python.layers import optimizers @@ -48,7 +52,6 @@ from tensorflow.contrib.learn.python.learn.estimators import linear from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.utils import input_fn_utils -from tensorflow.contrib.lookup import lookup_ops from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.contrib.testing.python.framework import util_test from tensorflow.python.client import session as session_lib @@ -221,8 +224,8 @@ def _build_estimator_for_export_tests(tmpdir): vocab_file = gfile.GFile(vocab_file_name, mode='w') vocab_file.write(VOCAB_FILE_CONTENT) vocab_file.close() - hashtable = lookup_ops.HashTable( - lookup_ops.TextFileStringTableInitializer(vocab_file_name), 'x') + hashtable = lookup.HashTable( + lookup.TextFileStringTableInitializer(vocab_file_name), 'x') features['bogus_lookup'] = hashtable.lookup( math_ops.to_int64(features['feature'])) @@ -878,8 +881,8 @@ class ReplicaDeviceSetterTest(test.TestCase): with ops.device(estimator._get_replica_device_setter(config)): default_val = constant_op.constant([-1, -1], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertDeviceEqual('/job:ps/task:0', table._table_ref.device) @@ -889,8 +892,8 @@ class ReplicaDeviceSetterTest(test.TestCase): with ops.device( estimator._get_replica_device_setter(run_config.RunConfig())): default_val = constant_op.constant([-1, -1], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertDeviceEqual('', table._table_ref.device) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py index 0b27e2ae0f9..69610018390 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py @@ -97,7 +97,7 @@ def numpy_input_fn(x, shape_dict_of_x = {k: x[k].shape for k in x.keys()} shape_of_y = None if y is None else y.shape raise ValueError('Length of tensors in x and y is mismatched. All ' - 'elementson x and y must have the same length.\n' + 'elements in x and y must have the same length.\n' 'Shapes in x: {}\n' 'Shape for y: {}\n'.format(shape_dict_of_x, shape_of_y)) diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_functions_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_functions_test.py index a58a185f5cf..2b6a300193b 100644 --- a/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_functions_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_functions_test.py @@ -27,7 +27,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import numpy as np -import tensorflow.contrib.learn.python.learn.dataframe.queues.feeding_functions as ff +from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_functions as ff from tensorflow.python.platform import test # pylint: disable=g-import-not-at-top diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_queue_runner_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_queue_runner_test.py index a4c19147b6d..125a3e13d52 100644 --- a/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_queue_runner_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/feeding_queue_runner_test.py @@ -27,7 +27,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import numpy as np -import tensorflow.contrib.learn.python.learn.dataframe.queues.feeding_functions as ff +from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_functions as ff from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/reader_source_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/reader_source_test.py index 94eae51a99e..74f6bfd5c69 100644 --- a/tensorflow/contrib/learn/python/learn/tests/dataframe/reader_source_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/reader_source_test.py @@ -24,7 +24,8 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import ctypes sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) -import tensorflow.contrib.learn.python.learn.dataframe.transforms.reader_source as rs +# pylint: disable=g-import-not-at-top +from tensorflow.contrib.learn.python.learn.dataframe.transforms import reader_source as rs from tensorflow.python.ops import io_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index f4e1c6d7197..3fe0c5f761b 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -54,3 +54,6 @@ from tensorflow.contrib.linalg.python.ops.linear_operator_matrix import * from tensorflow.contrib.linalg.python.ops.linear_operator_tril import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py index 5de9bb5d775..18dec3dd9ef 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py @@ -22,7 +22,7 @@ import abc import numpy as np import six -from tensorflow.contrib.framework import tensor_util as contrib_tensor_util +from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed diff --git a/tensorflow/contrib/linear_optimizer/__init__.py b/tensorflow/contrib/linear_optimizer/__init__.py index 83bd8b5fcf0..d447487b4a5 100644 --- a/tensorflow/contrib/linear_optimizer/__init__.py +++ b/tensorflow/contrib/linear_optimizer/__init__.py @@ -17,6 +17,8 @@ ## This package provides optimizers to train linear models. @@SdcaModel +@@SparseFeatureColumn +@@SDCAOptimizer """ from __future__ import absolute_import from __future__ import division @@ -25,3 +27,6 @@ from __future__ import print_function from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn from tensorflow.contrib.linear_optimizer.python.sdca_optimizer import SDCAOptimizer + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index 494dfb6c990..7e214905b13 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -20,7 +20,7 @@ from __future__ import print_function from six.moves import range -from tensorflow.contrib.lookup import lookup_ops +from tensorflow.contrib import lookup from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -30,7 +30,7 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops -class ShardedMutableDenseHashTable(lookup_ops.LookupInterface): +class ShardedMutableDenseHashTable(lookup.LookupInterface): """A sharded version of MutableDenseHashTable. It is designed to be interface compatible with LookupInterface and @@ -41,7 +41,7 @@ class ShardedMutableDenseHashTable(lookup_ops.LookupInterface): internally. The shard is computed via the modulo operation on the key. """ - # TODO(andreasst): consider moving this to lookup_ops + # TODO(andreasst): consider moving this to lookup module def __init__(self, key_dtype, @@ -56,7 +56,7 @@ class ShardedMutableDenseHashTable(lookup_ops.LookupInterface): table_shards = [] for i in range(num_shards): table_shards.append( - lookup_ops.MutableDenseHashTable( + lookup.MutableDenseHashTable( key_dtype=key_dtype, value_dtype=value_dtype, default_value=default_value, diff --git a/tensorflow/contrib/lookup/__init__.py b/tensorflow/contrib/lookup/__init__.py index fadf42000fe..e743832e807 100644 --- a/tensorflow/contrib/lookup/__init__.py +++ b/tensorflow/contrib/lookup/__init__.py @@ -25,6 +25,7 @@ @@IdTableWithHashBuckets @@HashTable @@MutableHashTable +@@MutableDenseHashTable @@TableInitializerBase @@KeyValueTensorInitializer @@TextFileIndex @@ -32,6 +33,9 @@ @@TextFileIdTableInitializer @@TextFileStringTableInitializer +@@HasherSpec +@@StrongHashSpec +@@FastHashSpec """ from __future__ import absolute_import @@ -41,3 +45,6 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.lookup.lookup_ops import * # pylint: enable=unused-import,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 15c318b6ef7..b46db38770b 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tf.contrib.lookup.lookup_ops.""" +"""Tests for tf.contrib.lookup.lookup.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -22,7 +22,7 @@ import tempfile import numpy as np import six -from tensorflow.contrib.lookup import lookup_ops +from tensorflow.contrib import lookup from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -45,8 +45,8 @@ class HashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) table.init.run() self.assertAllEqual(3, table.size().eval()) @@ -63,8 +63,8 @@ class HashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) table.init.run() self.assertAllEqual(3, table.size().eval()) @@ -81,8 +81,8 @@ class HashTableOpTest(test.TestCase): default_val = -1 keys = ["brain", "salad", "surgery"] values = [0, 1, 2] - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer( + table = lookup.HashTable( + lookup.KeyValueTensorInitializer( keys, values, value_dtype=dtypes.int64), default_val) table.init.run() @@ -100,8 +100,8 @@ class HashTableOpTest(test.TestCase): default_val = -1 keys = np.array(["brain", "salad", "surgery"], dtype=np.str) values = np.array([0, 1, 2], dtype=np.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) table.init.run() self.assertAllEqual(3, table.size().eval()) @@ -118,12 +118,12 @@ class HashTableOpTest(test.TestCase): keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table1 = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table2 = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - table3 = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) + table1 = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) + table2 = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) + table3 = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) data_flow_ops.tables_initializer().run() self.assertAllEqual(3, table1.size().eval()) @@ -145,8 +145,8 @@ class HashTableOpTest(test.TestCase): default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) table.init.run() input_string = constant_op.constant(["brain", "salad", "tank"]) @@ -160,8 +160,8 @@ class HashTableOpTest(test.TestCase): default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) table.init.run() sp_indices = [[0, 0], [0, 1], [1, 0]] @@ -183,8 +183,8 @@ class HashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) table.init.run() input_string = constant_op.constant([1, 2, 3], dtypes.int64) @@ -192,22 +192,22 @@ class HashTableOpTest(test.TestCase): table.lookup(input_string) with self.assertRaises(TypeError): - lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), "UNK") + lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), "UNK") def testDTypes(self): with self.test_session(): default_val = -1 with self.assertRaises(TypeError): - lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(["a"], [1], [dtypes.string], - dtypes.int64), default_val) + lookup.HashTable( + lookup.KeyValueTensorInitializer(["a"], [1], [dtypes.string], + dtypes.int64), default_val) def testNotInitialized(self): with self.test_session(): default_val = -1 - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer( + table = lookup.HashTable( + lookup.KeyValueTensorInitializer( ["a"], [1], value_dtype=dtypes.int64), default_val) @@ -222,8 +222,8 @@ class HashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) table.init.run() with self.assertRaisesOpError("Table already initialized"): @@ -236,8 +236,8 @@ class HashTableOpTest(test.TestCase): values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) with self.assertRaises(ValueError): - lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) + lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) def testMultipleSessions(self): # Start a server @@ -252,8 +252,8 @@ class HashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val, name="t1") @@ -276,8 +276,8 @@ class MutableHashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() @@ -311,7 +311,7 @@ class MutableHashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["b", "c", "d"], dtypes.string) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableHashTable( + table = lookup.MutableHashTable( dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) save = saver.Saver() @@ -333,7 +333,7 @@ class MutableHashTableOpTest(test.TestCase): v0 = variables.Variable(-1.0, name="v0") v1 = variables.Variable(-1.0, name="v1") default_val = -1 - table = lookup_ops.MutableHashTable( + table = lookup.MutableHashTable( dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) table.insert( constant_op.constant(["a", "c"], dtypes.string), @@ -365,7 +365,7 @@ class MutableHashTableOpTest(test.TestCase): session1 = session.Session(server.target) session2 = session.Session(server.target) - table = lookup_ops.MutableHashTable( + table = lookup.MutableHashTable( dtypes.int64, dtypes.string, "-", name="t1") # Populate the table in the first session @@ -392,8 +392,8 @@ class MutableHashTableOpTest(test.TestCase): default_val = constant_op.constant([-1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() @@ -420,8 +420,8 @@ class MutableHashTableOpTest(test.TestCase): default_val = constant_op.constant([-1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) self.assertAllEqual(0, table1.size().eval()) table1.insert(keys, values).run() self.assertAllEqual(3, table1.size().eval()) @@ -436,8 +436,8 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual(6, exported_values.eval().size) # Populate a second table from the exported data - table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) self.assertAllEqual(0, table2.size().eval()) table2.insert(exported_keys, exported_values).run() self.assertAllEqual(3, table2.size().eval()) @@ -450,8 +450,8 @@ class MutableHashTableOpTest(test.TestCase): with self.test_session(): default_val = constant_op.constant([-1, -1], dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) # Shape [6] instead of [3, 2] values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64) @@ -481,8 +481,8 @@ class MutableHashTableOpTest(test.TestCase): def testMutableHashTableInvalidDefaultValue(self): with self.test_session(): default_val = constant_op.constant([[-1, -1]], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) with self.assertRaisesOpError("Default value must be a vector"): self.assertAllEqual(0, table.size().eval()) @@ -491,8 +491,8 @@ class MutableHashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() @@ -509,8 +509,8 @@ class MutableHashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) @@ -528,8 +528,8 @@ class MutableHashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) table.insert(keys, values).run() self.assertAllEqual(4, table.size().eval()) @@ -546,8 +546,8 @@ class MutableHashTableOpTest(test.TestCase): keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) @@ -567,12 +567,12 @@ class MutableHashTableOpTest(test.TestCase): keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table3 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + table3 = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) table1.insert(keys, values).run() table2.insert(keys, values).run() table3.insert(keys, values).run() @@ -596,8 +596,8 @@ class MutableHashTableOpTest(test.TestCase): default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) table.insert(keys, values).run() self.assertAllEqual(3, table.size().eval()) @@ -613,8 +613,8 @@ class MutableHashTableOpTest(test.TestCase): default_val = -1 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, + default_val) # insert with keys of the wrong type with self.assertRaises(TypeError): @@ -636,15 +636,15 @@ class MutableHashTableOpTest(test.TestCase): # default value of the wrong type with self.assertRaises(TypeError): - lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, "UNK") + lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK") def testMutableHashTableStringFloat(self): with self.test_session(): default_val = -1.5 keys = constant_op.constant(["brain", "salad", "surgery"]) values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.float32, + default_val) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() @@ -661,8 +661,8 @@ class MutableHashTableOpTest(test.TestCase): default_val = "n/a" keys = constant_op.constant([0, 1, 2], dtypes.int64) values = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.string, - default_val) + table = lookup.MutableHashTable(dtypes.int64, dtypes.string, + default_val) self.assertAllEqual(0, table.size().eval()) table.insert(keys, values).run() @@ -681,7 +681,7 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.test_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) self.assertAllEqual(0, table.size().eval()) @@ -699,7 +699,7 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.test_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) table.insert(keys, values).run() @@ -716,7 +716,7 @@ class MutableDenseHashTableOpTest(test.TestCase): keys = constant_op.constant(["a", "b", "c"], dtypes.string) values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32) default_value = constant_op.constant(-1.5, dtypes.float32) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.string, dtypes.float32, default_value=default_value, @@ -739,7 +739,7 @@ class MutableDenseHashTableOpTest(test.TestCase): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0.0, 1.1, 2.2], float_dtype) default_value = constant_op.constant(-1.5, float_dtype) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, float_dtype, default_value=default_value, empty_key=0) self.assertAllEqual(0, table.size().eval()) @@ -759,7 +759,7 @@ class MutableDenseHashTableOpTest(test.TestCase): values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], dtypes.int64) default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, @@ -791,7 +791,7 @@ class MutableDenseHashTableOpTest(test.TestCase): values = constant_op.constant([10, 11, 12], dtypes.int64) empty_key = constant_op.constant([0, 3], dtypes.int64) default_value = constant_op.constant(-1, dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, @@ -820,7 +820,7 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.test_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, @@ -848,7 +848,7 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.test_session(): keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([1, 2, 3], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, @@ -885,7 +885,7 @@ class MutableDenseHashTableOpTest(test.TestCase): empty_key = 0 keys = constant_op.constant([11, 12, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, @@ -906,7 +906,7 @@ class MutableDenseHashTableOpTest(test.TestCase): self.assertEqual(save_path, val) with self.test_session(graph=ops.Graph()) as sess: - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, @@ -941,7 +941,7 @@ class MutableDenseHashTableOpTest(test.TestCase): default_value = constant_op.constant([-1, -2], dtypes.int64) keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64) values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, @@ -964,7 +964,7 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.test_session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) default_value = constant_op.constant([-1, -2], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, @@ -1001,7 +1001,7 @@ class MutableDenseHashTableOpTest(test.TestCase): default_value = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant([[11, 12], [11, 14], [13, 14]], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, @@ -1024,7 +1024,7 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.test_session(graph=ops.Graph()) as sess: empty_key = constant_op.constant([11, 13], dtypes.int64) default_value = constant_op.constant(-1, dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=default_value, @@ -1057,7 +1057,7 @@ class MutableDenseHashTableOpTest(test.TestCase): # The values are chosen to make sure collisions occur when using GCC STL keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, @@ -1080,7 +1080,7 @@ class MutableDenseHashTableOpTest(test.TestCase): with self.test_session(): keys = constant_op.constant([11, 0, 13], dtypes.int64) values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, empty_key=12) self.assertAllEqual(0, table.size().eval()) @@ -1096,7 +1096,7 @@ class MutableDenseHashTableOpTest(test.TestCase): def testErrors(self): with self.test_session(): - table = lookup_ops.MutableDenseHashTable( + table = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, empty_key=0) # Inserting the empty key returns an error @@ -1121,7 +1121,7 @@ class MutableDenseHashTableOpTest(test.TestCase): "Expected key shape"): table.insert(keys, values).run() - table2 = lookup_ops.MutableDenseHashTable( + table2 = lookup.MutableDenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, @@ -1143,7 +1143,7 @@ class StringToIndexTableFromFile(test.TestCase): def test_string_to_index_table_from_file(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") with self.test_session(): - table = lookup_ops.string_to_index_table_from_file( + table = lookup.string_to_index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1155,7 +1155,7 @@ class StringToIndexTableFromFile(test.TestCase): default_value = -42 vocabulary_file = self._createVocabFile("f2i_vocab2.txt") with self.test_session(): - table = lookup_ops.string_to_index_table_from_file( + table = lookup.string_to_index_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1166,7 +1166,7 @@ class StringToIndexTableFromFile(test.TestCase): def test_string_to_index_table_from_file_with_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab3.txt") with self.test_session(): - table = lookup_ops.string_to_index_table_from_file( + table = lookup.string_to_index_table_from_file( vocabulary_file=vocabulary_file, num_oov_buckets=1000) ids = table.lookup( constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) @@ -1184,13 +1184,13 @@ class StringToIndexTableFromFile(test.TestCase): def test_string_to_index_table_from_file_with_only_oov_buckets(self): self.assertRaises( ValueError, - lookup_ops.string_to_index_table_from_file, + lookup.string_to_index_table_from_file, vocabulary_file=None) def test_string_to_index_table_from_file_with_vocab_size_too_small(self): vocabulary_file = self._createVocabFile("f2i_vocab5.txt") with self.test_session(): - table = lookup_ops.string_to_index_table_from_file( + table = lookup.string_to_index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=2) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1202,7 +1202,7 @@ class StringToIndexTableFromFile(test.TestCase): def test_string_to_index_table_from_file_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") with self.test_session(): - table = lookup_ops.string_to_index_table_from_file( + table = lookup.string_to_index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Invalid vocab_size", table.init.run) @@ -1212,12 +1212,12 @@ class StringToIndexTableFromFile(test.TestCase): self.assertRaises( ValueError, - lookup_ops.string_to_index_table_from_file, + lookup.string_to_index_table_from_file, vocabulary_file=vocabulary_file, vocab_size=0) with self.test_session(): - table = lookup_ops.string_to_index_table_from_file( + table = lookup.string_to_index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1230,17 +1230,17 @@ class StringToIndexTableFromFile(test.TestCase): vocabulary_file = self._createVocabFile("invalid_hasher.txt") with self.test_session(): with self.assertRaises(TypeError): - lookup_ops.string_to_index_table_from_file( + lookup.string_to_index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3, num_oov_buckets=1, hasher_spec=1) - table = lookup_ops.string_to_index_table_from_file( + table = lookup.string_to_index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3, num_oov_buckets=1, - hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) + hasher_spec=lookup.HasherSpec("my-awesome-hash", None)) self.assertRaises(ValueError, table.lookup, constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1250,7 +1250,7 @@ class StringToIndexTableFromTensor(test.TestCase): def test_string_to_index_table_from_tensor_with_tensor_init(self): with self.test_session(): - table = lookup_ops.string_to_index_table_from_tensor( + table = lookup.string_to_index_table_from_tensor( mapping=["brain", "salad", "surgery"], num_oov_buckets=1) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1261,7 +1261,7 @@ class StringToIndexTableFromTensor(test.TestCase): def test_string_to_index_table_from_tensor_with_default_value(self): default_value = -42 with self.test_session(): - table = lookup_ops.string_to_index_table_from_tensor( + table = lookup.string_to_index_table_from_tensor( mapping=["brain", "salad", "surgery"], default_value=default_value) ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1272,21 +1272,21 @@ class StringToIndexTableFromTensor(test.TestCase): def test_string_to_index_table_from_tensor_with_only_oov_buckets(self): with self.test_session(): with self.assertRaises(ValueError): - lookup_ops.string_to_index_table_from_tensor( + lookup.string_to_index_table_from_tensor( mapping=None, num_oov_buckets=1) def test_string_to_index_table_from_tensor_with_invalid_hashers(self): with self.test_session(): with self.assertRaises(TypeError): - lookup_ops.string_to_index_table_from_tensor( + lookup.string_to_index_table_from_tensor( mapping=["brain", "salad", "surgery"], num_oov_buckets=1, hasher_spec=1) - table = lookup_ops.string_to_index_table_from_tensor( + table = lookup.string_to_index_table_from_tensor( mapping=["brain", "salad", "surgery"], num_oov_buckets=1, - hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) + hasher_spec=lookup.HasherSpec("my-awesome-hash", None)) self.assertRaises(ValueError, table.lookup, constant_op.constant(["salad", "surgery", "tarkus"])) @@ -1298,7 +1298,7 @@ class StringToIndexTest(test.TestCase): with self.test_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) feats = constant_op.constant(["salad", "surgery", "tarkus"]) - indices = lookup_ops.string_to_index(feats, mapping=mapping_strings) + indices = lookup.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, indices.eval) data_flow_ops.tables_initializer().run() @@ -1309,7 +1309,7 @@ class StringToIndexTest(test.TestCase): with self.test_session(): mapping_strings = constant_op.constant(["hello", "hello"]) feats = constant_op.constant(["hello", "hola"]) - indices = lookup_ops.string_to_index(feats, mapping=mapping_strings) + _ = lookup.string_to_index(feats, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, data_flow_ops.tables_initializer().run) @@ -1319,7 +1319,7 @@ class StringToIndexTest(test.TestCase): with self.test_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) feats = constant_op.constant(["salad", "surgery", "tarkus"]) - indices = lookup_ops.string_to_index( + indices = lookup.string_to_index( feats, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, indices.eval) @@ -1338,7 +1338,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table(self): vocabulary_file = self._createVocabFile("i2f_vocab1.txt") with self.test_session(): - table = lookup_ops.index_to_string_table_from_file( + table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file) features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) @@ -1350,7 +1350,7 @@ class IndexToStringTableFromFileTest(test.TestCase): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") with self.test_session(): - table = lookup_ops.index_to_string_table_from_file( + table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, default_value=default_value) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) self.assertRaises(errors_impl.OpError, features.eval) @@ -1362,7 +1362,7 @@ class IndexToStringTableFromFileTest(test.TestCase): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") with self.test_session(): - table = lookup_ops.index_to_string_table_from_file( + table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=2, default_value=default_value) @@ -1375,7 +1375,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") with self.test_session(): - table = lookup_ops.index_to_string_table_from_file( + table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=4) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -1387,7 +1387,7 @@ class IndexToStringTableFromFileTest(test.TestCase): def test_index_to_string_table_with_vocab_size(self): vocabulary_file = self._createVocabFile("f2i_vocab7.txt") with self.test_session(): - table = lookup_ops.index_to_string_table_from_file( + table = lookup.index_to_string_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3) features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) @@ -1401,7 +1401,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): def test_index_to_string_table_from_tensor(self): with self.test_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup_ops.index_to_string_table_from_tensor( + table = lookup.index_to_string_table_from_tensor( mapping=mapping_strings) indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) @@ -1415,7 +1415,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): def test_duplicate_entries(self): with self.test_session(): mapping_strings = constant_op.constant(["hello", "hello"]) - table = lookup_ops.index_to_string_table_from_tensor( + table = lookup.index_to_string_table_from_tensor( mapping=mapping_strings) indices = constant_op.constant([0, 1, 4], dtypes.int64) features = table.lookup(indices) @@ -1426,7 +1426,7 @@ class IndexToStringTableFromTensorTest(test.TestCase): default_value = b"NONE" with self.test_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup_ops.index_to_string_table_from_tensor( + table = lookup.index_to_string_table_from_tensor( mapping=mapping_strings, default_value=default_value) indices = constant_op.constant([1, 2, 4], dtypes.int64) features = table.lookup(indices) @@ -1443,7 +1443,7 @@ class IndexToStringTest(test.TestCase): with self.test_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) - feats = lookup_ops.index_to_string(indices, mapping=mapping_strings) + feats = lookup.index_to_string(indices, mapping=mapping_strings) self.assertRaises(errors_impl.OpError, feats.eval) data_flow_ops.tables_initializer().run() @@ -1455,7 +1455,7 @@ class IndexToStringTest(test.TestCase): with self.test_session(): mapping_strings = constant_op.constant(["hello", "hello"]) indices = constant_op.constant([0, 1, 4], dtypes.int64) - feats = lookup_ops.index_to_string(indices, mapping=mapping_strings) + feats = lookup.index_to_string(indices, mapping=mapping_strings) data_flow_ops.tables_initializer().run() self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval()) @@ -1467,7 +1467,7 @@ class IndexToStringTest(test.TestCase): with self.test_session(): mapping_strings = constant_op.constant(["brain", "salad", "surgery"]) indices = constant_op.constant([1, 2, 4], dtypes.int64) - feats = lookup_ops.index_to_string( + feats = lookup.index_to_string( indices, mapping=mapping_strings, default_value=default_value) self.assertRaises(errors_impl.OpError, feats.eval) @@ -1488,11 +1488,11 @@ class InitializeTableFromFileOpTest(test.TestCase): with self.test_session(): default_value = -1 - table = lookup_ops.HashTable( - lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string, - lookup_ops.TextFileIndex.WHOLE_LINE, - dtypes.int64, - lookup_ops.TextFileIndex.LINE_NUMBER), + table = lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.string, + lookup.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup.TextFileIndex.LINE_NUMBER), default_value) table.init.run() @@ -1507,11 +1507,11 @@ class InitializeTableFromFileOpTest(test.TestCase): with self.test_session(): default_value = "UNK" - key_index = lookup_ops.TextFileIndex.LINE_NUMBER - value_index = lookup_ops.TextFileIndex.WHOLE_LINE - table = lookup_ops.HashTable( - lookup_ops.TextFileInitializer(vocabulary_file, dtypes.int64, - key_index, dtypes.string, value_index), + key_index = lookup.TextFileIndex.LINE_NUMBER + value_index = lookup.TextFileIndex.WHOLE_LINE + table = lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.int64, + key_index, dtypes.string, value_index), default_value) table.init.run() @@ -1531,9 +1531,9 @@ class InitializeTableFromFileOpTest(test.TestCase): key_index = 1 value_index = 2 - table = lookup_ops.HashTable( - lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string, - key_index, dtypes.int64, value_index), + table = lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.string, + key_index, dtypes.int64, value_index), default_value) table.init.run() @@ -1552,9 +1552,9 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value = -1 key_index = 2 value_index = 1 - table = lookup_ops.HashTable( - lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string, - key_index, dtypes.int64, value_index), + table = lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.string, + key_index, dtypes.int64, value_index), default_value) with self.assertRaisesOpError("is not a valid"): table.init.run() @@ -1564,24 +1564,24 @@ class InitializeTableFromFileOpTest(test.TestCase): with self.test_session(): default_value = "UNK" - key_index = lookup_ops.TextFileIndex.WHOLE_LINE - value_index = lookup_ops.TextFileIndex.LINE_NUMBER + key_index = lookup.TextFileIndex.WHOLE_LINE + value_index = lookup.TextFileIndex.LINE_NUMBER with self.assertRaises(ValueError): - lookup_ops.HashTable( - lookup_ops.TextFileInitializer(vocabulary_file, dtypes.int64, - key_index, dtypes.string, - value_index), default_value) + lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.int64, + key_index, dtypes.string, + value_index), default_value) def testInvalidIndex(self): vocabulary_file = self._createVocabFile("one_column_4.txt") with self.test_session(): default_value = -1 key_index = 1 # second column of the line - value_index = lookup_ops.TextFileIndex.LINE_NUMBER - table = lookup_ops.HashTable( - lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string, - key_index, dtypes.int64, value_index), + value_index = lookup.TextFileIndex.LINE_NUMBER + table = lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.string, + key_index, dtypes.int64, value_index), default_value) with self.assertRaisesOpError("Invalid number of columns"): @@ -1593,25 +1593,25 @@ class InitializeTableFromFileOpTest(test.TestCase): with self.test_session() as sess: shared_name = "shared-one-columm" default_value = -1 - table1 = lookup_ops.HashTable( - lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string, - lookup_ops.TextFileIndex.WHOLE_LINE, - dtypes.int64, - lookup_ops.TextFileIndex.LINE_NUMBER), + table1 = lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.string, + lookup.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup.TextFileIndex.LINE_NUMBER), default_value, shared_name=shared_name) - table2 = lookup_ops.HashTable( - lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string, - lookup_ops.TextFileIndex.WHOLE_LINE, - dtypes.int64, - lookup_ops.TextFileIndex.LINE_NUMBER), + table2 = lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.string, + lookup.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup.TextFileIndex.LINE_NUMBER), default_value, shared_name=shared_name) - table3 = lookup_ops.HashTable( - lookup_ops.TextFileInitializer(vocabulary_file, dtypes.string, - lookup_ops.TextFileIndex.WHOLE_LINE, - dtypes.int64, - lookup_ops.TextFileIndex.LINE_NUMBER), + table3 = lookup.HashTable( + lookup.TextFileInitializer(vocabulary_file, dtypes.string, + lookup.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup.TextFileIndex.LINE_NUMBER), default_value, shared_name=shared_name) @@ -1632,10 +1632,10 @@ class InitializeTableFromFileOpTest(test.TestCase): with self.test_session(): default_value = -1 with self.assertRaises(ValueError): - lookup_ops.HashTable( - lookup_ops.TextFileInitializer( - "", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, - dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), + lookup.HashTable( + lookup.TextFileInitializer( + "", dtypes.string, lookup.TextFileIndex.WHOLE_LINE, + dtypes.int64, lookup.TextFileIndex.LINE_NUMBER), default_value) def testInitializeWithVocabSize(self): @@ -1643,13 +1643,13 @@ class InitializeTableFromFileOpTest(test.TestCase): default_value = -1 vocab_size = 3 vocabulary_file1 = self._createVocabFile("one_column6.txt") - table1 = lookup_ops.HashTable( - lookup_ops.TextFileInitializer( + table1 = lookup.HashTable( + lookup.TextFileInitializer( vocabulary_file1, dtypes.string, - lookup_ops.TextFileIndex.WHOLE_LINE, + lookup.TextFileIndex.WHOLE_LINE, dtypes.int64, - lookup_ops.TextFileIndex.LINE_NUMBER, + lookup.TextFileIndex.LINE_NUMBER, vocab_size=vocab_size), default_value) @@ -1659,13 +1659,13 @@ class InitializeTableFromFileOpTest(test.TestCase): vocabulary_file2 = self._createVocabFile("one_column7.txt") vocab_size = 5 - table2 = lookup_ops.HashTable( - lookup_ops.TextFileInitializer( + table2 = lookup.HashTable( + lookup.TextFileInitializer( vocabulary_file2, dtypes.string, - lookup_ops.TextFileIndex.WHOLE_LINE, + lookup.TextFileIndex.WHOLE_LINE, dtypes.int64, - lookup_ops.TextFileIndex.LINE_NUMBER, + lookup.TextFileIndex.LINE_NUMBER, vocab_size=vocab_size), default_value) with self.assertRaisesOpError("Invalid vocab_size"): @@ -1673,13 +1673,13 @@ class InitializeTableFromFileOpTest(test.TestCase): vocab_size = 1 vocabulary_file3 = self._createVocabFile("one_column3.txt") - table3 = lookup_ops.HashTable( - lookup_ops.TextFileInitializer( + table3 = lookup.HashTable( + lookup.TextFileInitializer( vocabulary_file3, dtypes.string, - lookup_ops.TextFileIndex.WHOLE_LINE, + lookup.TextFileIndex.WHOLE_LINE, dtypes.int64, - lookup_ops.TextFileIndex.LINE_NUMBER, + lookup.TextFileIndex.LINE_NUMBER, vocab_size=vocab_size), default_value) @@ -1692,11 +1692,11 @@ class InitializeTableFromFileOpTest(test.TestCase): with self.test_session(): default_value = -1 - table = lookup_ops.HashTable( - lookup_ops.TextFileInitializer("old_file.txt", dtypes.string, - lookup_ops.TextFileIndex.WHOLE_LINE, - dtypes.int64, - lookup_ops.TextFileIndex.LINE_NUMBER), + table = lookup.HashTable( + lookup.TextFileInitializer("old_file.txt", dtypes.string, + lookup.TextFileIndex.WHOLE_LINE, + dtypes.int64, + lookup.TextFileIndex.LINE_NUMBER), default_value) # Initialize with non existing file (old_file.txt) should fail. @@ -1723,19 +1723,19 @@ class InitializeTableFromFileOpTest(test.TestCase): # Invalid data type other_type = constant_op.constant(1) with self.assertRaises(ValueError): - lookup_ops.HashTable( - lookup_ops.TextFileInitializer( - other_type, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, - dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), + lookup.HashTable( + lookup.TextFileInitializer( + other_type, dtypes.string, lookup.TextFileIndex.WHOLE_LINE, + dtypes.int64, lookup.TextFileIndex.LINE_NUMBER), default_value) # Non-scalar filename filenames = constant_op.constant([vocabulary_file, vocabulary_file]) with self.assertRaises(ValueError): - lookup_ops.HashTable( - lookup_ops.TextFileInitializer( - filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, - dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), + lookup.HashTable( + lookup.TextFileInitializer( + filenames, dtypes.string, lookup.TextFileIndex.WHOLE_LINE, + dtypes.int64, lookup.TextFileIndex.LINE_NUMBER), default_value) def testIdToStringTable(self): @@ -1743,8 +1743,8 @@ class InitializeTableFromFileOpTest(test.TestCase): with self.test_session(): default_value = "UNK" vocab_size = 3 - table = lookup_ops.HashTable( - lookup_ops.TextFileStringTableInitializer( + table = lookup.HashTable( + lookup.TextFileStringTableInitializer( vocab_file, vocab_size=vocab_size), default_value) @@ -1761,8 +1761,8 @@ class InitializeTableFromFileOpTest(test.TestCase): with self.test_session(): default_value = -1 vocab_size = 3 - table = lookup_ops.HashTable( - lookup_ops.TextFileIdTableInitializer( + table = lookup.HashTable( + lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value) table.init.run() @@ -1788,9 +1788,9 @@ class IdTableWithHashBucketsTest(test.TestCase): default_value = -1 vocab_size = 3 oov_buckets = 1 - table = lookup_ops.IdTableWithHashBuckets( - lookup_ops.HashTable( - lookup_ops.TextFileIdTableInitializer( + table = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value), oov_buckets) @@ -1809,7 +1809,7 @@ class IdTableWithHashBucketsTest(test.TestCase): # Set a table that only uses hash buckets, for each input value returns # an id calculated by fingerprint("input") mod oov_buckets. - table = lookup_ops.IdTableWithHashBuckets(None, oov_buckets) + table = lookup.IdTableWithHashBuckets(None, oov_buckets) table.init.run() input_string = constant_op.constant(["brain", "salad", "surgery"]) @@ -1831,20 +1831,20 @@ class IdTableWithHashBucketsTest(test.TestCase): vocab_size = 3 oov_buckets = 3 - vocab_table = lookup_ops.HashTable( - lookup_ops.TextFileIdTableInitializer( + vocab_table = lookup.HashTable( + lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value) - table1 = lookup_ops.IdTableWithHashBuckets( + table1 = lookup.IdTableWithHashBuckets( vocab_table, oov_buckets, - hasher_spec=lookup_ops.FastHashSpec, + hasher_spec=lookup.FastHashSpec, name="table1") - table2 = lookup_ops.IdTableWithHashBuckets( + table2 = lookup.IdTableWithHashBuckets( vocab_table, oov_buckets, - hasher_spec=lookup_ops.StrongHashSpec((1, 2)), + hasher_spec=lookup.StrongHashSpec((1, 2)), name="table2") data_flow_ops.tables_initializer().run() @@ -1872,9 +1872,9 @@ class IdTableWithHashBucketsTest(test.TestCase): default_value = -1 vocab_size = 3 oov_buckets = 1 - table1 = lookup_ops.IdTableWithHashBuckets( - lookup_ops.HashTable( - lookup_ops.TextFileIdTableInitializer( + table1 = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value, shared_name=shared_name), @@ -1897,9 +1897,9 @@ class IdTableWithHashBucketsTest(test.TestCase): # Underlying lookup table already initialized in previous session. # No need to call table2.init.run() - table2 = lookup_ops.IdTableWithHashBuckets( - lookup_ops.HashTable( - lookup_ops.TextFileIdTableInitializer( + table2 = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value, shared_name=shared_name), @@ -1918,17 +1918,17 @@ class IdTableWithHashBucketsTest(test.TestCase): default_value1 = -1 vocab_size = 3 oov_buckets = 0 - table1 = lookup_ops.IdTableWithHashBuckets( - lookup_ops.HashTable( - lookup_ops.TextFileIdTableInitializer( + table1 = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value1), oov_buckets) default_value2 = -2 - table2 = lookup_ops.IdTableWithHashBuckets( - lookup_ops.HashTable( - lookup_ops.TextFileIdTableInitializer( + table2 = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value2), oov_buckets) @@ -1959,9 +1959,9 @@ class IdTableWithHashBucketsTest(test.TestCase): dtypes.string), constant_op.constant(input_shape, dtypes.int64)) - table = lookup_ops.IdTableWithHashBuckets( - lookup_ops.HashTable( - lookup_ops.TextFileIdTableInitializer( + table = lookup.IdTableWithHashBuckets( + lookup.HashTable( + lookup.TextFileIdTableInitializer( vocab_file, vocab_size=3), -1), 1) @@ -1984,19 +1984,19 @@ class IdTableWithHashBucketsTest(test.TestCase): default_value = -1 vocab_size = 3 oov_buckets = 1 - lookup_table = lookup_ops.HashTable( - lookup_ops.TextFileIdTableInitializer( + lookup_table = lookup.HashTable( + lookup.TextFileIdTableInitializer( vocab_file, vocab_size=vocab_size), default_value) with self.assertRaises(TypeError): - lookup_ops.IdTableWithHashBuckets( + lookup.IdTableWithHashBuckets( lookup_table, oov_buckets, hasher_spec=1) - table = lookup_ops.IdTableWithHashBuckets( + table = lookup.IdTableWithHashBuckets( lookup_table, oov_buckets, - hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) + hasher_spec=lookup.HasherSpec("my-awesome-hash", None)) input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) @@ -2004,22 +2004,22 @@ class IdTableWithHashBucketsTest(test.TestCase): table.lookup(input_string) with self.assertRaises(ValueError): - table = lookup_ops.IdTableWithHashBuckets( + table = lookup.IdTableWithHashBuckets( lookup_table, oov_buckets, - hasher_spec=lookup_ops.StrongHashSpec([])) + hasher_spec=lookup.StrongHashSpec([])) with self.assertRaises(ValueError): - table = lookup_ops.IdTableWithHashBuckets( + table = lookup.IdTableWithHashBuckets( lookup_table, oov_buckets, - hasher_spec=lookup_ops.StrongHashSpec([1, 2, 3])) + hasher_spec=lookup.StrongHashSpec([1, 2, 3])) with self.assertRaises(TypeError): - table = lookup_ops.IdTableWithHashBuckets( + table = lookup.IdTableWithHashBuckets( lookup_table, oov_buckets, - hasher_spec=lookup_ops.StrongHashSpec([None, 2])) + hasher_spec=lookup.StrongHashSpec([None, 2])) if __name__ == "__main__": diff --git a/tensorflow/contrib/losses/__init__.py b/tensorflow/contrib/losses/__init__.py index 14a4d531529..a405e11c22b 100644 --- a/tensorflow/contrib/losses/__init__.py +++ b/tensorflow/contrib/losses/__init__.py @@ -19,8 +19,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - # pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.losses.python import losses from tensorflow.contrib.losses.python.losses import * # pylint: enable=unused-import,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__, doc_string_modules=[losses]) diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index aaa1b62d5f7..b5ad8fb8b5d 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -109,6 +109,7 @@ labels and predictions tensors and results in a weighted average of the metric. @@streaming_mean_iou @@streaming_mean_relative_error @@streaming_mean_squared_error +@@streaming_mean_tensor @@streaming_root_mean_squared_error @@streaming_covariance @@streaming_pearson_correlation @@ -137,6 +138,8 @@ labels and predictions tensors and results in a weighted average of the metric. @@aggregate_metrics @@aggregate_metric_map +@@confusion_matrix + ## Set `Ops` @@set_difference @@ -193,7 +196,7 @@ from tensorflow.contrib.metrics.python.ops.set_ops import set_difference from tensorflow.contrib.metrics.python.ops.set_ops import set_intersection from tensorflow.contrib.metrics.python.ops.set_ops import set_size from tensorflow.contrib.metrics.python.ops.set_ops import set_union -from tensorflow.python.util.all_util import make_all # pylint: enable=unused-import,line-too-long -__all__ = make_all(__name__) +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/metrics/python/ops/histogram_ops.py b/tensorflow/contrib/metrics/python/ops/histogram_ops.py index 68d9bb5b7a4..d3d74d28a3a 100644 --- a/tensorflow/contrib/metrics/python/ops/histogram_ops.py +++ b/tensorflow/contrib/metrics/python/ops/histogram_ops.py @@ -23,7 +23,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.framework import tensor_util +from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 7ac337732a7..0e07c1f47ac 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -23,7 +23,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework import tensor_util +from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.metrics.python.ops import set_ops from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 4fb244e3d44..af6b365a2a8 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -4593,7 +4593,7 @@ class StreamingConcatTest(test.TestCase): self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) def testNextArraySize(self): - next_array_size = metrics.python.ops.metric_ops._next_array_size + next_array_size = metric_ops._next_array_size # pylint: disable=protected-access with self.test_session(): self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2) self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4) diff --git a/tensorflow/contrib/nn/__init__.py b/tensorflow/contrib/nn/__init__.py index c2fe913b595..73757a6696e 100644 --- a/tensorflow/contrib/nn/__init__.py +++ b/tensorflow/contrib/nn/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Module for deprecated ops in tf.nn.""" +"""Module for deprecated ops in tf.nn. + +@@deprecated_flipped_softmax_cross_entropy_with_logits +@@deprecated_flipped_sparse_softmax_cross_entropy_with_logits +@@deprecated_flipped_sigmoid_cross_entropy_with_logits +""" from __future__ import absolute_import from __future__ import division @@ -21,3 +26,6 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.nn.python.ops.cross_entropy import * # pylint: enable=unused-import,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index ec54c9b3c98..8ef90095965 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -23,3 +23,12 @@ from tensorflow.contrib.opt.python.training.external_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * # pylint: enable=wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['ExternalOptimizerInterface', + 'MovingAverageOptimizer', + 'ScipyOptimizerInterface', + 'VariableClippingOptimizer'] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 0966fc0cf0d..56c394c028d 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -24,6 +24,7 @@ @@BasicLSTMCell @@GRUCell @@LSTMCell +@@LayerNormBasicLSTMCell ## Classes storing split `RNNCell` state @@ -32,6 +33,7 @@ ## RNN Cell wrappers (RNNCells that wrap other RNNCells) @@MultiRNNCell +@@LSTMBlockWrapper @@DropoutWrapper @@EmbeddingWrapper @@InputProjectionWrapper @@ -86,10 +88,13 @@ from tensorflow.contrib.rnn.python.ops.core_rnn_cell import MultiRNNCell from tensorflow.contrib.rnn.python.ops.core_rnn_cell import OutputProjectionWrapper from tensorflow.contrib.rnn.python.ops.core_rnn_cell import RNNCell -# pylint: disable=unused-import,wildcard-import, line-too-long +# pylint: disable=unused-import,wildcard-import,line-too-long from tensorflow.contrib.rnn.python.ops.fused_rnn_cell import * from tensorflow.contrib.rnn.python.ops.gru_ops import * from tensorflow.contrib.rnn.python.ops.lstm_ops import * from tensorflow.contrib.rnn.python.ops.rnn import * from tensorflow.contrib.rnn.python.ops.rnn_cell import * # pylint: enable=unused-import,wildcard-import,line-too-long + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__, ['core_rnn_cell']) diff --git a/tensorflow/contrib/seq2seq/__init__.py b/tensorflow/contrib/seq2seq/__init__.py index a7e272984b9..f8b35a1cbb3 100644 --- a/tensorflow/contrib/seq2seq/__init__.py +++ b/tensorflow/contrib/seq2seq/__init__.py @@ -19,13 +19,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - -# pylint: disable=unused-import,line-too-long +# pylint: disable=unused-import,wildcard-import,line-too-long from tensorflow.contrib.seq2seq.python.ops.attention_decoder_fn import attention_decoder_fn_inference from tensorflow.contrib.seq2seq.python.ops.attention_decoder_fn import attention_decoder_fn_train from tensorflow.contrib.seq2seq.python.ops.attention_decoder_fn import prepare_attention from tensorflow.contrib.seq2seq.python.ops.decoder_fn import * from tensorflow.contrib.seq2seq.python.ops.loss import * from tensorflow.contrib.seq2seq.python.ops.seq2seq import * -# pylint: enable=unused-import,line-too-long +# pylint: enable=unused-import,widcard-import,line-too-long + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["attention_decoder_fn_inference", + "attention_decoder_fn_train", + "dynamic_rnn_decoder", + "prepare_attention", + "sequence_loss", + "simple_decoder_fn_train", + "simple_decoder_fn_inference"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/stat_summarizer/__init__.py b/tensorflow/contrib/stat_summarizer/__init__.py index 32feb7edb97..53d5548863a 100644 --- a/tensorflow/contrib/stat_summarizer/__init__.py +++ b/tensorflow/contrib/stat_summarizer/__init__.py @@ -25,3 +25,10 @@ from __future__ import print_function from tensorflow.python.pywrap_tensorflow import DeleteStatSummarizer from tensorflow.python.pywrap_tensorflow import NewStatSummarizer from tensorflow.python.pywrap_tensorflow import StatSummarizer + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['DeleteStatSummarizer', 'NewStatSummarizer', + 'StatSummarizer'] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/tfprof/__init__.py b/tensorflow/contrib/tfprof/__init__.py index 129dad2726c..f3952f6cb5c 100644 --- a/tensorflow/contrib/tfprof/__init__.py +++ b/tensorflow/contrib/tfprof/__init__.py @@ -19,4 +19,3 @@ from __future__ import print_function from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer from tensorflow.contrib.tfprof.python.tools.tfprof import tfprof_logger -from tensorflow.python.util.all_util import make_all diff --git a/tensorflow/contrib/util/__init__.py b/tensorflow/contrib/util/__init__.py index cdaafef1f7d..45efdc20c80 100644 --- a/tensorflow/contrib/util/__init__.py +++ b/tensorflow/contrib/util/__init__.py @@ -35,7 +35,6 @@ from tensorflow.python.framework.meta_graph import stripped_op_list_for_graph from tensorflow.python.framework.tensor_util import constant_value from tensorflow.python.framework.tensor_util import make_tensor_proto from tensorflow.python.framework.tensor_util import MakeNdarray as make_ndarray -from tensorflow.python.util.all_util import make_all - - -__all__ = make_all(__name__) +# pylint: disable=unused_import +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/util/loader.py b/tensorflow/contrib/util/loader.py index 6f690f414a4..95657217a00 100644 --- a/tensorflow/contrib/util/loader.py +++ b/tensorflow/contrib/util/loader.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities for loading op libraries.""" +"""Utilities for loading op libraries. + +@@load_op_library +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py index 94e658d8c09..4ae0046117b 100644 --- a/tensorflow/python/framework/docs.py +++ b/tensorflow/python/framework/docs.py @@ -291,10 +291,43 @@ class Library(Document): def _generate_signature_for_function(self, func): """Given a function, returns a string representing its args.""" args_list = [] - argspec = inspect.getargspec(func) + if isinstance(func, functools.partial): + argspec = inspect.getargspec(func.func) + # Remove the args from the original function that have been used up. + first_default_arg = ( + len(argspec.args or []) - len(argspec.defaults or [])) + partial_args = len(func.args) + if argspec.args: + argspec_args = list(argspec.args[partial_args:]) + else: + argspec_args = [] + if argspec.defaults: + argspec_defaults = list(argspec.defaults[ + max(0, partial_args-first_default_arg):]) + else: + argspec_defaults = [] + first_default_arg = max(0, first_default_arg - partial_args) + for kwarg in func.keywords: + if kwarg in argspec_args: + i = argspec_args.index(kwarg) + argspec_args.pop(i) + if i >= first_default_arg: + argspec_defaults.pop(i-first_default_arg) + else: + first_default_arg -= 1 + argspec_varargs = None + argspec_keywords = None + + else: + argspec = inspect.getargspec(func) + argspec_args = argspec.args + argspec_defaults = argspec.defaults + argspec_varargs = argspec.varargs + argspec_keywords = argspec.keywords + first_arg_with_default = ( - len(argspec.args or []) - len(argspec.defaults or [])) - for arg in argspec.args[:first_arg_with_default]: + len(argspec_args or []) - len(argspec_defaults or [])) + for arg in argspec_args[:first_arg_with_default]: if arg == "self": # Python documentation typically skips `self` when printing method # signatures. @@ -306,16 +339,16 @@ class Library(Document): # TODO(aselle): This workaround is brittle on TestCase.__call__ # so we need to wrap this in a try/catch # We should do something better. - if argspec.varargs == "args" and argspec.keywords == "kwds": + if argspec_varargs == "args" and argspec_keywords == "kwds": try: original_func = func.__closure__[0].cell_contents return self._generate_signature_for_function(original_func) except TypeError: pass - if argspec.defaults: + if argspec_defaults: for arg, default in zip( - argspec.args[first_arg_with_default:], argspec.defaults): + argspec_args[first_arg_with_default:], argspec_defaults): if callable(default): if hasattr(default, "__name__"): args_list.append("%s=%s" % (arg, default.__name__)) @@ -326,10 +359,10 @@ class Library(Document): args_list.append("%s=%s()" % (arg, default.__class__.__name__)) else: args_list.append("%s=%r" % (arg, default)) - if argspec.varargs: - args_list.append("*" + argspec.varargs) - if argspec.keywords: - args_list.append("**" + argspec.keywords) + if argspec_varargs: + args_list.append("*" + argspec_varargs) + if argspec_keywords: + args_list.append("**" + argspec_keywords) return "(" + ", ".join(args_list) + ")" def _remove_docstring_indent(self, docstring): diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index 7c387e1da44..fb0a4dda3d2 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -261,11 +261,18 @@ _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange", # TODO(wicke): Remove contrib.layers.relu* after shortnames are # disabled. These conflict with tf.nn.relu* EXCLUDE = frozenset(["tf.contrib.learn.monitors.NanLossDuringTrainingError", + "tf.contrib.layers.dropout", + "tf.contrib.layers.bias_add", + "tf.contrib.layers.conv2d", + "tf.contrib.layers.conv2d_transpose", + "tf.contrib.layers.separable_conv2d", + "tf.contrib.layers.softmax", "tf.contrib.layers.relu", "tf.contrib.layers.relu6", "tf.contrib.framework.assert_global_step", "tf.contrib.framework.get_global_step", "tf.contrib.learn.NanLossDuringTrainingError", "tf.contrib.layers.stack", + "tf.contrib.layers.ProblemType", "tf.confusion_matrix"]) From b2215c82e03f6518dc8d0c605ce940842a465764 Mon Sep 17 00:00:00 2001 From: Martin Wicke Date: Tue, 31 Jan 2017 15:26:57 -0800 Subject: [PATCH 3/6] Seal tf.contrib.framework. Change: 146174108 --- tensorflow/contrib/framework/__init__.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 6aea18a8ac0..0765ee33ea2 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -16,6 +16,7 @@ """Framework utilities. @@assert_same_float_dtype +@@assert_scalar @@assert_scalar_int @@convert_to_tensor_or_sparse_tensor @@get_graph_from_inputs @@ -24,6 +25,7 @@ @@is_strictly_increasing @@is_tensor @@reduce_sum_n +@@remove_squeezable_dimensions @@with_shape @@with_same_shape @@ -47,6 +49,7 @@ @@assign_from_values @@assign_from_values_fn @@create_global_step +@@filter_variables @@get_global_step @@get_or_create_global_step @@get_local_variables @@ -74,12 +77,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.framework.python.framework import * from tensorflow.contrib.framework.python.ops import * -from tensorflow.python.util.all_util import make_all # pylint: enable=unused-import,wildcard-import -__all__ = make_all(__name__) +from tensorflow.python.util.all_util import remove_undocumented + + +remove_undocumented(__name__) From 4e20b473c35994e507a4675fd1c5a2350559e0d2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 18 Jan 2017 05:53:24 -0800 Subject: [PATCH 4/6] Fix transform for cyclic graph. Improve collection name handling. Added helper to retrieve corresponding tensor/op. Change: 144824657 --- tensorflow/contrib/graph_editor/transform.py | 37 ++++++++-- tensorflow/contrib/graph_editor/util.py | 73 ++++++++++++++++++++ 2 files changed, 105 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 6fb347c8349..832698b8a07 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -26,13 +26,13 @@ from six import iteritems from six import iterkeys from six import string_types from six import StringIO - from tensorflow.contrib.graph_editor import edit from tensorflow.contrib.graph_editor import reroute from tensorflow.contrib.graph_editor import select from tensorflow.contrib.graph_editor import subgraph from tensorflow.contrib.graph_editor import util from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.platform import tf_logging as logging __all__ = [ "replace_t_with_placeholder_handler", @@ -87,17 +87,24 @@ def keep_t_if_possible_handler(info, t): def assign_renamed_collections_handler(info, elem, elem_): """Add the transformed elem to the (renamed) collections of elem. + A collection is renamed only if is not a known key, as described in + `tf.GraphKeys`. + Args: info: Transform._Info instance. elem: the original element (`tf.Tensor` or `tf.Operation`) elem_: the transformed element """ - # TODO(fkp): handle known special cases + known_collection_names = util.get_predefined_collection_names() for name, collection in iteritems(info.collections): if elem not in collection: continue - collection_name_ = info.transformer.new_name(name) - info.graph_.add_to_collection(collection_name_, elem_) + + if name in known_collection_names: + transformed_name = name + else: + transformed_name = info.transformer.new_name(name) + info.graph_.add_to_collection(transformed_name, elem_) def transform_op_if_inside_handler(info, op, keep_if_possible=True): @@ -150,6 +157,11 @@ def copy_op_handler(info, op, copy_shape=True): # Transform inputs: inputs_ = [info.transformer._transform_t(t) for t in op.inputs] + # Leave inputs empty if a graph cycle was found. + if None in inputs_: + info.cyclic_ops.append(op) + inputs_ = [] + # Clone the node def: node_def_ = deepcopy(op._node_def) @@ -239,7 +251,7 @@ class Transformer(object): self.transformed_ts = {} self.collections = dict((key, self.graph.get_collection(key)) for key in self.graph.get_all_collection_keys()) - + self.cyclic_ops = [] class ResultInfo(object): """"Contains information about the result of a transform operation.""" @@ -452,6 +464,17 @@ class Transformer(object): for op in remaining_roots: self._transform_op(op) + # Finalize cyclic ops: + for op in self._info.cyclic_ops: + logging.debug("Finalizing cyclic op: %s", op.name) + op_ = self._info.transformed_ops[op] + inputs_ = [self._info.transformed_ts[t] for t in op.inputs] + if None in inputs_: + raise ValueError("Could not find all the inputs of cyclic op: {}" + .format(op_.name)) + for input_id, t_ in enumerate(inputs_): + op_._update_input(input_id, t_) # pylint: disable=protected-access + sgv_ = self._transform_sgv(sgv) res_info = Transformer.ResultInfo(self._info) @@ -506,9 +529,13 @@ class Transformer(object): Returns: The transformed tensor. """ + logging.debug("Transforming tensor: %s", t.name) if t in self._info.transformed_ts: return self._info.transformed_ts[t] + # Mark as None to detect cycle. + self._info.transformed_ts[t] = None + op, op_index = t.op, t.value_index # If op is not in the subgraph: diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 11ee2435c9c..d8824f67923 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +import re from six import iteritems from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops as tf_array_ops @@ -465,3 +466,75 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): """ return tf_array_ops.placeholder( dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) + + +_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") + + +def get_predefined_collection_names(): + """Return all the predefined collection names.""" + return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys) + if not _INTERNAL_VARIABLE_RE.match(key)] + + +def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""): + """Find corresponding op/tensor in a different graph. + + Args: + target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph. + dst_graph: The graph in which the corresponding graph element must be found. + dst_scope: A scope which is prepended to the name to look for. + src_scope: A scope which is removed from the original of `target` name. + + Returns: + The corresponding tf.Tensor` or a `tf.Operation`. + + Raises: + ValueError: if `src_name` does not start with `src_scope`. + TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation` + KeyError: If the corresponding graph element cannot be found. + """ + src_name = target.name + if src_scope: + src_scope = scope_finalize(src_scope) + if not src_name.startswidth(src_scope): + raise ValueError("{} does not start with {}".format(src_name, src_scope)) + src_name = src_name[len(src_scope):] + + dst_name = src_name + if dst_scope: + dst_scope = scope_finalize(dst_scope) + dst_name = dst_scope + dst_name + + if isinstance(target, tf_ops.Tensor): + return dst_graph.get_tensor_by_name(dst_name) + if isinstance(target, tf_ops.Operation): + return dst_graph.get_operation_by_name(dst_name) + raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target)) + + +def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""): + """Find corresponding ops/tensors in a different graph. + + `targets` is a Python tree, that is, a nested structure of iterable + (list, tupple, dictionary) whose leaves are instances of + `tf.Tensor` or `tf.Operation` + + Args: + targets: A Python tree containing `tf.Tensor` or `tf.Operation` + belonging to the original graph. + dst_graph: The graph in which the corresponding graph element must be found. + dst_scope: A scope which is prepended to the name to look for. + src_scope: A scope which is removed from the original of `top` name. + + Returns: + A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`. + + Raises: + ValueError: if `src_name` does not start with `src_scope`. + TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation` + KeyError: If the corresponding graph element cannot be found. + """ + def func(top): + return find_corresponding_elem(top, dst_graph, dst_scope, src_scope) + return transform_tree(targets, func) From f3979c09fb96be09738890bc38286085c6f52394 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 26 Jan 2017 01:56:20 -0800 Subject: [PATCH 5/6] Fix transform for cyclic graph (second try). Deprecate in-place transform. Change: 145649225 --- .../graph_editor/tests/transform_test.py | 56 +- tensorflow/contrib/graph_editor/transform.py | 584 ++++++++---------- 2 files changed, 265 insertions(+), 375 deletions(-) diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index 88fea8bbace..33f1217412c 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -75,7 +75,7 @@ class TransformTest(test.TestCase): _ = math_ops.add(a, b) sgv = ge.make_view([assert_op, eq.op, a.op, b.op]) copier = ge.Transformer() - copied_sgv, info = copier(sgv, sgv.graph, "", "") + _, info = copier(sgv, sgv.graph, "", "") new_assert_op = info.transformed(assert_op) self.assertIsNotNone(new_assert_op) @@ -84,18 +84,17 @@ class TransformTest(test.TestCase): def my_transform_op_handler(info, op): add_noise = op.name.startswith("Add") - op_ = ge.transform.copy_op_handler(info, op) - if add_noise: - # add some noise to op - with info.graph_.as_default(): - t_ = math_ops.add(constant_op.constant( - 1.0, shape=[10], name="Noise"), - op_.outputs[0], - name="AddNoise") - # return the "noisy" op - return t_.op - else: - return op_ + op_, op_outputs_ = ge.transform.copy_op_handler(info, op) + if not add_noise: + return op_, op_outputs_ + # add some noise to op + with info.graph_.as_default(): + t_ = math_ops.add( + constant_op.constant(1.0, shape=[10], name="Noise"), + op_.outputs[0], + name="AddNoise") + # return the "noisy" op + return op_, [t_] transformer.transform_op_handler = my_transform_op_handler @@ -110,37 +109,6 @@ class TransformTest(test.TestCase): top = ge.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top)) - def test_transform_in_place(self): - transformer = ge.Transformer() - - def my_transform_op_handler_in_place(info, op): - add_noise = op.name.startswith("Add") - op = ge.transform.transform_op_in_place( - info, op, detach_outputs=add_noise) - if add_noise: - # add some noise to op - with info.graph_.as_default(): - t = math_ops.add(constant_op.constant( - 1.0, shape=[10], name="Noise"), - op.outputs[0], - name="AddNoise") - # return the "noisy" op - return t.op - else: - return op - - transformer.transform_op_handler = my_transform_op_handler_in_place - - transformer(self.graph, self.graph, "", "") - matcher0 = ge.matcher("AddNoise").input_ops( - "Noise", ge.matcher("Add").input_ops("Const", "Input")) - matcher1 = ge.matcher("AddNoise_1").input_ops( - "Noise_1", ge.matcher("Add_1").input_ops("Const_1", matcher0)) - matcher2 = ge.matcher("AddNoise_2").input_ops( - "Noise_2", ge.matcher("Add_2").input_ops("Const_2", matcher1)) - top = ge.select_ops("^AddNoise_2$", graph=self.graph)[0] - self.assertTrue(matcher2(top)) - def test_copy_with_input_replacements(self): with self.graph.as_default(): ten = constant_op.constant(10.0, shape=[10], name="Input") diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 832698b8a07..4d272e108f3 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -21,12 +21,10 @@ from __future__ import print_function from copy import deepcopy from functools import partial - from six import iteritems from six import iterkeys from six import string_types from six import StringIO -from tensorflow.contrib.graph_editor import edit from tensorflow.contrib.graph_editor import reroute from tensorflow.contrib.graph_editor import select from tensorflow.contrib.graph_editor import subgraph @@ -34,14 +32,15 @@ from tensorflow.contrib.graph_editor import util from tensorflow.python.framework import ops as tf_ops from tensorflow.python.platform import tf_logging as logging + __all__ = [ "replace_t_with_placeholder_handler", "keep_t_if_possible_handler", "assign_renamed_collections_handler", "transform_op_if_inside_handler", "copy_op_handler", - "transform_op_in_place", "Transformer", + "TransformerInfo", "copy", "copy_with_input_replacements", "graph_replace", @@ -55,7 +54,7 @@ def replace_t_with_placeholder_handler(info, t): placeholder. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. t: tensor whose input must be transformed into a place holder. Returns: The tensor generated by the newly created place holder. @@ -73,7 +72,7 @@ def keep_t_if_possible_handler(info, t): This handler is typically used to transform a hidden input tensors. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. t: tensor whose input must be transformed into a place holder. Returns: The tensor generated by the newly created place holder. @@ -91,7 +90,7 @@ def assign_renamed_collections_handler(info, elem, elem_): `tf.GraphKeys`. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. elem: the original element (`tf.Tensor` or `tf.Operation`) elem_: the transformed element """ @@ -103,7 +102,7 @@ def assign_renamed_collections_handler(info, elem, elem_): if name in known_collection_names: transformed_name = name else: - transformed_name = info.transformer.new_name(name) + transformed_name = info.new_name(name) info.graph_.add_to_collection(transformed_name, elem_) @@ -114,18 +113,15 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): if they are inside the subgraph, otherwise they are just ignored. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. op: the optional op to transform (or ignore). keep_if_possible: re-attach to the original op if possible, that is, if the source graph and the destination graph are the same. Returns: The transformed op or None. """ - if op is None: - return None if op in info.sgv.ops: - return info.transformer._transform_op( # pylint: disable=protected-access - op) + return info.transformed_ops[op] else: if keep_if_possible and info.graph is info.graph_: return op @@ -137,36 +133,19 @@ def copy_op_handler(info, op, copy_shape=True): """Copy a `tf.Operation`. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. op: the `tf.Operation` to be copied. copy_shape: also copy the shape of the tensor Returns: - A copy of op. + A `(op, op_outputs)` tuple containgin the transformed op and its outputs. """ # pylint: disable=protected-access - # Transform control inputs: - control_inputs_ = [info.transformer.transform_control_input_handler(info, ci) - for ci in op.control_inputs] - control_inputs_ = [ci for ci in control_inputs_ if ci is not None] - - # Transform it if any: - original_op_ = info.transformer.transform_original_op_handler(info, - op._original_op) - - # Transform inputs: - inputs_ = [info.transformer._transform_t(t) for t in op.inputs] - - # Leave inputs empty if a graph cycle was found. - if None in inputs_: - info.cyclic_ops.append(op) - inputs_ = [] - # Clone the node def: node_def_ = deepcopy(op._node_def) # Transform name: - name_ = info.transformer.new_name(op.name) + name_ = info.new_name(op.name) name_ = info.graph_.unique_name(name_) node_def_.name = name_ @@ -179,45 +158,196 @@ def copy_op_handler(info, op, copy_shape=True): op_def_ = deepcopy(op._op_def) # Initialize a new Operation instance - op_ = tf_ops.Operation(node_def_, info.graph_, inputs_, output_types_, - control_inputs_, input_types_, original_op_, op_def_) + op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_, + [], input_types_, None, op_def_) # copy the shape over if copy_shape: for t, t_ in zip(op.outputs, op_.outputs): t_.set_shape(t.get_shape()) + # Finalize original op. + if op._original_op: + original_op = info.transform_original_op_handler(info, op._original_op) + if original_op is None: + logging.info("Could not find original op of: %s", op_.name) + else: + op_._original_op = original_op + # Add op to the graph info.graph_._add_op(op_) - # pylint: enable=protected-access - return op_ + return op_, op_.outputs -def transform_op_in_place(info, op, detach_outputs=False): - """Transform a op in-place - experimental! +class TransformerInfo(object): + """"Contains information about the result of a transform operation.""" - Transform an operation in place. It reconnects the inputs if they have been - modified. if detach_outputs is True, the outputs of op are also detached. + def __init__(self, info): + """Constructor. - Args: - info: Transform._Info instance. - op: the op to transform in place. - detach_outputs: if True, the outputs of op are detached, ready for the user - to add more operation. - Returns: - The transformed op. + Args: + info: an instance of Transformer._TmpInfo containing various internal + information about the transform operation. + """ + self._graph = info.graph + self._scope = info.scope + self._graph_ = info.graph_ + self._scope_ = info.scope_ + self._transformed_ops = info.transformed_ops + self._transformed_ts = info.transformed_ts + + def _get_transformed_map(self, top): + """Return the correct container depending on the type of `top`.""" + if isinstance(top, tf_ops.Operation): + return self._transformed_ops + elif isinstance(top, tf_ops.Tensor): + return self._transformed_ts + else: + raise TypeError( + "Expected a tf.Tensor or a tf.Operation, got a {}".format( + type(top))) + + def _transformed_elem(self, original_top, missing_fn=None): + """Return the transformed op/tensor corresponding to the original one. + + Args: + original_top: the original tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. + Returns: + the transformed tensor/operation (or None if no match is found). + """ + transformed_map = self._get_transformed_map(original_top) + if isinstance(original_top, string_types): + for original, transformed in iteritems(transformed_map): + if original.name == original_top: + return transformed + return None if missing_fn is None else missing_fn(original_top) + else: + if original_top not in transformed_map: + return None if missing_fn is None else missing_fn(original_top) + return transformed_map[original_top] + + def _original_elem(self, transformed_top, missing_fn=None): + """Return the original op/tensor corresponding to the transformed one. + + Args: + transformed_top: the transformed tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. + Returns: + the original tensor/operation (or None if no match is found). + """ + transformed_map = self._get_transformed_map(transformed_top) + if isinstance(transformed_top, string_types): + finder = lambda transformed: transformed.name == transformed_top + else: + finder = lambda transformed: transformed == transformed_top + for original, transformed in iteritems(transformed_map): + if finder(transformed): + return original + return None if missing_fn is None else missing_fn(transformed_top) + + def transformed(self, original, missing_fn=None): + """Return the transformed op/tensor corresponding to the original one. + + Note that the output of this function mimics the hierarchy + of its input argument `original`. + Given an iterable, it returns a list. Given an operation or a tensor, + it will return an operation or a tensor. + + Args: + original: the original tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. + Returns: + the transformed tensor/operation (or None if no match is found). + """ + transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn) + return util.transform_tree(original, transformed_elem) + + def original(self, transformed, missing_fn=None): + """Return the original op/tensor corresponding to the transformed one. + + Note that the output of this function mimics the hierarchy + of its input argument `transformed`. + Given an iterable, it returns a list. Given an operation or a tensor, + it will return an operation or a tensor. + + Args: + transformed: the transformed tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. + Returns: + the original tensor/operation (or None if no match is found). + """ + original_elem = partial(self._original_elem, missing_fn=missing_fn) + return util.transform_tree(transformed, original_elem) + + def __str__(self): + res = StringIO() + print("Transform result info:", file=res) + if self._graph == self._graph_: + in_place_str = "" if self._scope_ else " IN-PLACE" + print(" Within graph[{}]{}".format( + id(self._graph), in_place_str), file=res) + else: + print(" graph[{}] => graph[{}]".format( + id(self._graph), id(self._graph_)), file=res) + if self._scope: + print(" Relative to source scope: {}".format(self._scope), file=res) + if self._scope_: + print(" Scope destination: {}".format(self._scope_), file=res) + print("Operations mapping:", file=res) + for op, op_ in iteritems(self._transformed_ops): + print(" {} => {}".format(op.name, op_.name), file=res) + return res.getvalue() + + +class _TmpInfo(object): + """Transformer temporary data. + + An instance of this class holds all the information relevant to a call + to a transformer instance (that is, a call to __call__). An instance + is created for the life-time of the __call__ function and is passed as + argument to the handlers. """ - # recursive call to the inputs: - inputs = [info.transformer._transform_t(t) # pylint: disable=protected-access - for t in op.inputs] - # re-connect to the inputs if they have changed: - if inputs != list(op.inputs): - reroute.reroute_a2b_ts(inputs, op.inputs) - # detach op from its consumer first ? - if detach_outputs: - edit.detach_outputs(op) - return op + + def __init__(self, sgv, dst_graph, dst_scope, src_scope): + self.sgv = sgv + self.sgv_inputs_set = frozenset(sgv.inputs) + self.ops = frozenset(sgv.ops) + self.control_outputs = util.ControlOutputs(sgv.graph) + self.graph = sgv.graph + self.scope = src_scope + self.graph_ = dst_graph + self.scope_ = dst_scope + self.transformed_ops = {} + self.transformed_ts = {} + self.collections = dict((key, self.graph.get_collection(key)) + for key in self.graph.get_all_collection_keys()) + self.cyclic_ops = [] + self.transform_original_op_handler = transform_op_if_inside_handler + + def new_name(self, name): + """Compute a destination name from a source name. + + Args: + name: the name to be "transformed". + Returns: + The transformed name. + Raises: + ValueError: if the source scope is used (that is, not an empty string) + and the source name does not belong to the source scope. + """ + scope = self.scope + if not name.startswith(scope): + raise ValueError("{} does not belong to source scope: {}.".format( + name, scope)) + rel_name = name[len(scope):] + name_ = self.scope_ + rel_name + return name_ class Transformer(object): @@ -228,155 +358,6 @@ class Transformer(object): the handlers. """ - class _Info(object): - """Transformer temporary data. - - An instance of this class holds all the information relevant to a call - to a transformer instance (that is, a call to __call__). An instance - is created for the life-time of the __call__ function and is passed as - argument to the handlers. - """ - - def __init__(self, transformer, sgv, dst_graph, dst_scope, src_scope): - self.transformer = transformer - self.sgv = sgv - self.sgv_inputs_set = frozenset(sgv.inputs) - self.ops = frozenset(sgv.ops) - self.control_outputs = util.ControlOutputs(sgv.graph) - self.graph = sgv.graph - self.scope = src_scope - self.graph_ = dst_graph - self.scope_ = dst_scope - self.transformed_ops = {} - self.transformed_ts = {} - self.collections = dict((key, self.graph.get_collection(key)) - for key in self.graph.get_all_collection_keys()) - self.cyclic_ops = [] - - class ResultInfo(object): - """"Contains information about the result of a transform operation.""" - - def __init__(self, info): - """Constructor. - - Args: - info: an instance of Transformer._Info containing various internal - information about the transform operation. - """ - self._graph = info.graph - self._scope = info.scope - self._graph_ = info.graph_ - self._scope_ = info.scope_ - self._transformed_ops = info.transformed_ops - self._transformed_ts = info.transformed_ts - - def _get_transformed_map(self, top): - """Return the correct container depending on the type of `top`.""" - if isinstance(top, tf_ops.Operation): - return self._transformed_ops - elif isinstance(top, tf_ops.Tensor): - return self._transformed_ts - else: - raise TypeError( - "Expected a tf.Tensor or a tf.Operation, got a {}".format( - type(top))) - - def _transformed_elem(self, original_top, missing_fn=None): - """Return the transformed op/tensor corresponding to the original one. - - Args: - original_top: the original tensor/operation. - missing_fn: function handling the case where the counterpart - cannot be found. By default, None is returned. - Returns: - the transformed tensor/operation (or None if no match is found). - """ - transformed_map = self._get_transformed_map(original_top) - if isinstance(original_top, string_types): - for original, transformed in iteritems(transformed_map): - if original.name == original_top: - return transformed - return None if missing_fn is None else missing_fn(original_top) - else: - if original_top not in transformed_map: - return None if missing_fn is None else missing_fn(original_top) - return transformed_map[original_top] - - def _original_elem(self, transformed_top, missing_fn=None): - """Return the original op/tensor corresponding to the transformed one. - - Args: - transformed_top: the transformed tensor/operation. - missing_fn: function handling the case where the counterpart - cannot be found. By default, None is returned. - Returns: - the original tensor/operation (or None if no match is found). - """ - transformed_map = self._get_transformed_map(transformed_top) - if isinstance(transformed_top, string_types): - finder = lambda transformed: transformed.name == transformed_top - else: - finder = lambda transformed: transformed == transformed_top - for original, transformed in iteritems(transformed_map): - if finder(transformed): - return original - return None if missing_fn is None else missing_fn(transformed_top) - - def transformed(self, original, missing_fn=None): - """Return the transformed op/tensor corresponding to the original one. - - Note that the output of this function mimics the hierarchy - of its input argument `original`. - Given an iterable, it returns a list. Given an operation or a tensor, - it will return an operation or a tensor. - - Args: - original: the original tensor/operation. - missing_fn: function handling the case where the counterpart - cannot be found. By default, None is returned. - Returns: - the transformed tensor/operation (or None if no match is found). - """ - transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn) - return util.transform_tree(original, transformed_elem) - - def original(self, transformed, missing_fn=None): - """Return the original op/tensor corresponding to the transformed one. - - Note that the output of this function mimics the hierarchy - of its input argument `transformed`. - Given an iterable, it returns a list. Given an operation or a tensor, - it will return an operation or a tensor. - - Args: - transformed: the transformed tensor/operation. - missing_fn: function handling the case where the counterpart - cannot be found. By default, None is returned. - Returns: - the original tensor/operation (or None if no match is found). - """ - original_elem = partial(self._original_elem, missing_fn=missing_fn) - return util.transform_tree(transformed, original_elem) - - def __str__(self): - res = StringIO() - print("Transform result info:", file=res) - if self._graph == self._graph_: - in_place_str = "" if self._scope_ else " IN-PLACE" - print(" Within graph[{}]{}".format( - id(self._graph), in_place_str), file=res) - else: - print(" graph[{}] => graph[{}]".format( - id(self._graph), id(self._graph_)), file=res) - if self._scope: - print(" Relative to source scope: {}".format(self._scope), file=res) - if self._scope_: - print(" Scope destination: {}".format(self._scope_), file=res) - print("Operations mapping:", file=res) - for op, op_ in iteritems(self._transformed_ops): - print(" {} => {}".format(op.name, op_.name), file=res) - return res.getvalue() - def __init__(self): """Transformer constructor. @@ -407,9 +388,6 @@ class Transformer(object): self.transform_external_hidden_input_handler = keep_t_if_possible_handler self.transform_original_op_handler = transform_op_if_inside_handler - # temporary per-call variable - self._info = None - def __call__(self, sgv, dst_graph, @@ -432,7 +410,7 @@ class Transformer(object): Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; - `info` is an instance of Transformer.ResultInfo containing + `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: @@ -450,49 +428,68 @@ class Transformer(object): dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) # Create temporary info used during this transform call - self._info = Transformer._Info(self, sgv, dst_graph, dst_scope, src_scope) + info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) + info.transform_original_op_handler = self.transform_original_op_handler - # Transform the graph starting from the output tensors. - for output_t in self._info.sgv.outputs: - self._transform_t(output_t) + self._copy_ops(info) + self._connect_ops(info) - # Some ops might have been missed by the previous walk, namely, the roots - # without any outputs. So the walk is now finalized from those roots. - remaining_ops = [op for op in self._info.sgv.ops - if op not in self._info.transformed_ops] - remaining_roots = [op for op in remaining_ops if not op.outputs] - for op in remaining_roots: - self._transform_op(op) - - # Finalize cyclic ops: - for op in self._info.cyclic_ops: - logging.debug("Finalizing cyclic op: %s", op.name) - op_ = self._info.transformed_ops[op] - inputs_ = [self._info.transformed_ts[t] for t in op.inputs] - if None in inputs_: - raise ValueError("Could not find all the inputs of cyclic op: {}" - .format(op_.name)) - for input_id, t_ in enumerate(inputs_): - op_._update_input(input_id, t_) # pylint: disable=protected-access - - sgv_ = self._transform_sgv(sgv) - - res_info = Transformer.ResultInfo(self._info) - self._info = None + # Compute information about the transformation + res_info = TransformerInfo(info) + sgv_ = self._transform_sgv(info, sgv) return sgv_, res_info - def _transform_sgv(self, sgv): + def _copy_ops(self, info): + """Copy ops without connecting them.""" + for op in info.sgv.ops: + logging.info("Copying op: %s", op.name) + # TODO(fkp): return a subgraph? + op_, op_outputs_ = self.transform_op_handler(info, op) + if op is op_: + raise ValueError("In-place tranformation not allowed.") + + # Process op. + info.transformed_ops[op] = op_ + self.assign_collections_handler(info, op, op_) + + # Process output tensors. + for op_output, op_output_ in zip(op.outputs, op_outputs_): + info.transformed_ts[op_output] = op_output_ + self.assign_collections_handler(info, op_output, op_output_) + + def _connect_ops(self, info): + """Connect the previously copied ops.""" + for op in info.sgv.ops: + logging.info("Finalizing op: %s", op.name) + op_ = info.transformed_ops[op] + + # pylint: disable=protected-access + if op_.inputs: + raise ValueError("The newly transformed op should not have " + "any inputs yet: {}".format(op_.name)) + inputs_ = [self._transformed_t(info, t) for t in op.inputs] + for t in inputs_: + op_._add_input(t) + + # Finalize control inputs: + control_inputs_ = [self.transform_control_input_handler(info, ci) + for ci in op.control_inputs] + control_inputs_ = [ci for ci in control_inputs_ if ci is not None] + reroute.add_control_inputs(op_, control_inputs_) + + def _transform_sgv(self, info, sgv): """Transform a subgraph view. For convenience, a transform operation returns a subgraph view of the transformed graph. Args: + info: Temporary information for this transorfm call. sgv: the subgraph to be transformed. Returns: The transformed subgraph. """ - ops_ = [op_ for _, op_ in iteritems(self._info.transformed_ops)] + ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)] sgv_ = subgraph.SubGraphView(ops_) sgv_inputs_ = sgv_.inputs sgv_outputs_ = sgv_.outputs @@ -500,9 +497,9 @@ class Transformer(object): # re-order inputs input_map_ = [] for input_t in sgv.inputs: - if input_t not in self._info.transformed_ts: + if input_t not in info.transformed_ts: continue - input_t_ = self._info.transformed_ts[input_t] + input_t_ = info.transformed_ts[input_t] if input_t_ not in sgv_inputs_: continue input_t_index_ = sgv_.input_index(input_t_) @@ -511,9 +508,9 @@ class Transformer(object): # re-order outputs output_map_ = [] for output_t in sgv.outputs: - if output_t not in self._info.transformed_ts: + if output_t not in info.transformed_ts: continue - output_t_ = self._info.transformed_ts[output_t] + output_t_ = info.transformed_ts[output_t] if output_t_ not in sgv_outputs_: continue output_t_index_ = sgv_.output_index(output_t_) @@ -521,94 +518,19 @@ class Transformer(object): return sgv_.remap(input_map_, output_map_) - def _transform_t(self, t): - """Transform a tf.Tensor. - - Args: - t: the tensor to be transformed. - Returns: - The transformed tensor. - """ - logging.debug("Transforming tensor: %s", t.name) - if t in self._info.transformed_ts: - return self._info.transformed_ts[t] - - # Mark as None to detect cycle. - self._info.transformed_ts[t] = None - - op, op_index = t.op, t.value_index - - # If op is not in the subgraph: - if op not in self._info.ops: - # t_ is an input of the subgraph - if t in self._info.sgv_inputs_set: - t_ = self.transform_external_input_handler(self._info, t) - # t_ is a hidden input of the subgraph + def _transformed_t(self, info, t): + """Return tre transformed tensor of `t`.""" + if t not in info.transformed_ts: + # If op is not in the subgraph. + if t in info.sgv_inputs_set: + # t is an input of the subgraph. + return self.transform_external_input_handler(info, t) else: - t_ = self.transform_external_hidden_input_handler(self._info, t) - # If op is in the subgraph, just transform it: + # t is a hidden input of the subgraph. + return self.transform_external_hidden_input_handler(info, t) else: - op_ = self._transform_op(op) - t_ = op_.outputs[op_index] - - # assign to collection - if t is not t_: - self.assign_collections_handler(self._info, t, t_) - - self._info.transformed_ts[t] = t_ - return t_ - - def _transform_op(self, op): - """Transform a tf.Operation. - - Args: - op: the operation to be transformed. - Returns: - The transformed operation. - """ - if op in self._info.transformed_ops: - return self._info.transformed_ops[op] - - op_ = self.transform_op_handler(self._info, op) - - # Add to all the active control dependencies - # pylint: disable=protected-access - self._info.graph_._record_op_seen_by_control_dependencies(op_) - - # All to all the active devices - for device_function in reversed(self._info.graph_._device_function_stack): - if device_function is None: - break - op_._set_device(device_function(op_)) - # pylint: enable=protected-access - - # TODO(fkp): Establish clear policy about what context managers are allowed. - - # assign to collection - if op is not op_: - self.assign_collections_handler(self._info, op, op_) - - self._info.transformed_ops[op] = op_ - return op_ - - def new_name(self, name): - """Compute a destination name from a source name. - - Args: - name: the name to be "transformed". - Returns: - The transformed name. - Raises: - ValueError: if the source scope is used (that is, not an empty string) - and the source name does not belong to the source scope. - """ - scope = self._info.scope - if not name.startswith(scope): - raise ValueError("{} does not belong to source scope: {}.".format(name, - scope)) - rel_name = name[len(scope):] - name_ = self._info.scope_ + rel_name - return name_ + # If op is in the subgraph, just return its transformed. + return info.transformed_ts[t] def copy(sgv, dst_graph=None, dst_scope="", src_scope="", @@ -627,7 +549,7 @@ def copy(sgv, dst_graph=None, dst_scope="", src_scope="", Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; - `info` is an instance of Transformer.ResultInfo containing + `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: @@ -669,7 +591,7 @@ def copy_with_input_replacements(sgv, replacement_ts, Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; - `info` is an instance of Transformer.ResultInfo containing + `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: From 810a4760882646017b1a418e85d13ac62114ab55 Mon Sep 17 00:00:00 2001 From: David Soergel Date: Wed, 18 Jan 2017 22:33:05 -0800 Subject: [PATCH 6/6] Expose export-related utility functions under tf.contrib.learn Change: 144924507 --- tensorflow/contrib/learn/python/learn/__init__.py | 1 + tensorflow/contrib/learn/python/learn/utils/__init__.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/tensorflow/contrib/learn/python/learn/__init__.py b/tensorflow/contrib/learn/python/learn/__init__.py index 3c6d05dc0dc..269dd2a6613 100644 --- a/tensorflow/contrib/learn/python/learn/__init__.py +++ b/tensorflow/contrib/learn/python/learn/__init__.py @@ -44,4 +44,5 @@ from tensorflow.contrib.learn.python.learn.learn_io import * from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.learn.python.learn.monitors import NanLossDuringTrainingError from tensorflow.contrib.learn.python.learn.trainable import Trainable +from tensorflow.contrib.learn.python.learn.utils import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/learn/python/learn/utils/__init__.py b/tensorflow/contrib/learn/python/learn/utils/__init__.py index f313699c148..74236da9790 100644 --- a/tensorflow/contrib/learn/python/learn/utils/__init__.py +++ b/tensorflow/contrib/learn/python/learn/utils/__init__.py @@ -20,3 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.learn.python.learn.utils.export import export_estimator +from tensorflow.contrib.learn.python.learn.utils.input_fn_utils import build_default_serving_input_fn +from tensorflow.contrib.learn.python.learn.utils.input_fn_utils import build_parsing_serving_input_fn +from tensorflow.contrib.learn.python.learn.utils.saved_model_export_utils import make_export_strategy +