diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 2a60750bdae..180779670d9 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -42,7 +42,7 @@ class BreakTransformer(converter.Base): var_name = self.state[_Break].control_var_name # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ - var_name = True + var_name = tf.constant(True) continue """ return templates.replace(template, var_name=var_name) @@ -85,7 +85,7 @@ class BreakTransformer(converter.Base): guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ - var_name = False + var_name = tf.constant(False) while test and not var_name: body else: @@ -122,7 +122,7 @@ class BreakTransformer(converter.Base): # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ - var_name = False + var_name = tf.constant(False) for target in iter_: (var_name,) body diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py index c26ca2946ce..fcae7d68c0f 100644 --- a/tensorflow/contrib/autograph/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -20,13 +20,16 @@ from __future__ import print_function from tensorflow.contrib.autograph.converters import break_statements from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class BreakCanonicalizationTest(converter_testing.TestCase): def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, break_statements, {}) as result: + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_while_loop(self): @@ -40,9 +43,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 4) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 4) def test_for_loop(self): @@ -55,7 +59,8 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - with self.converted(test_fn, break_statements, {}) as result: + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: # The break is incompletely canonicalized. The loop will not interrupt, # but the section following the break will be skipped. self.assertEqual([3], result.test_fn([5, 4])) @@ -77,9 +82,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 11) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 11) def test_nested_loops(self): @@ -99,10 +105,11 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 5) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 5) def test_loop_orelse(self): @@ -120,9 +127,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 3) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 958bde0a587..0476e97c15e 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -37,7 +37,7 @@ class ContinueCanonicalizationTransformer(converter.Base): def visit_Continue(self, node): self.set_local(CONTINUE_USED, True) template = """ - var_name = True + var_name = tf.constant(True) """ return templates.replace( template, var_name=self.get_local(CONTROL_VAR_NAME)) @@ -92,7 +92,7 @@ class ContinueCanonicalizationTransformer(converter.Base): if self.get_local(CONTINUE_USED, False): template = """ - var_name = False + var_name = tf.constant(False) """ control_var_init = templates.replace(template, var_name=continue_var) nodes = control_var_init + nodes diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py index 3a7c7d1486d..37c15211b4f 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -20,13 +20,16 @@ from __future__ import print_function from tensorflow.contrib.autograph.converters import continue_statements from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class ContinueCanonicalizationTest(converter_testing.TestCase): def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, continue_statements, {}) as result: + with self.converted(test_fn, continue_statements, {}, + constant_op.constant) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_basic(self): @@ -40,10 +43,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) def test_for_loop(self): @@ -56,10 +60,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [1]) - self.assertTransformedEquivalent(test_fn, [2]) - self.assertTransformedEquivalent(test_fn, [1, 2, 3]) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1]) + self.assertTransformedEquivalent(test_fn, [2]) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) def test_nested(self): @@ -78,10 +83,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) if __name__ == '__main__':