Avoid hardcoding the control variable to Tensor in break canonicalization.

PiperOrigin-RevId: 219185092
This commit is contained in:
Dan Moldovan 2018-10-29 13:41:51 -07:00 committed by TensorFlower Gardener
parent 2967a9a9e7
commit 4f09245c1f
2 changed files with 16 additions and 21 deletions

View File

@ -42,7 +42,7 @@ class BreakTransformer(converter.Base):
var_name = self.state[_Break].control_var_name
# TODO(mdan): This will fail when expanded inside a top-level else block.
template = """
var_name = tf.constant(True)
var_name = True
continue
"""
return templates.replace(template, var_name=var_name)
@ -85,7 +85,7 @@ class BreakTransformer(converter.Base):
guarded_orelse = self._guard_if_present(node.orelse, break_var)
template = """
var_name = tf.constant(False)
var_name = False
while test and not var_name:
body
else:
@ -122,7 +122,7 @@ class BreakTransformer(converter.Base):
# the control variable is marked as used.
# TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
template = """
var_name = tf.constant(False)
var_name = False
for target in iter_:
(var_name,)
body

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.eager import context as tfe_ctx
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
@ -43,10 +42,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 4)
def test_for_loop(self):
@ -82,10 +80,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 11)
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 11)
def test_nested_loops(self):
@ -105,11 +102,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 5)
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 5)
def test_loop_orelse(self):
@ -127,10 +123,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
if __name__ == '__main__':