Avoid hardcoding the control variable to Tensor in break canonicalization.
PiperOrigin-RevId: 219185092
This commit is contained in:
parent
2967a9a9e7
commit
4f09245c1f
@ -42,7 +42,7 @@ class BreakTransformer(converter.Base):
|
||||
var_name = self.state[_Break].control_var_name
|
||||
# TODO(mdan): This will fail when expanded inside a top-level else block.
|
||||
template = """
|
||||
var_name = tf.constant(True)
|
||||
var_name = True
|
||||
continue
|
||||
"""
|
||||
return templates.replace(template, var_name=var_name)
|
||||
@ -85,7 +85,7 @@ class BreakTransformer(converter.Base):
|
||||
guarded_orelse = self._guard_if_present(node.orelse, break_var)
|
||||
|
||||
template = """
|
||||
var_name = tf.constant(False)
|
||||
var_name = False
|
||||
while test and not var_name:
|
||||
body
|
||||
else:
|
||||
@ -122,7 +122,7 @@ class BreakTransformer(converter.Base):
|
||||
# the control variable is marked as used.
|
||||
# TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
|
||||
template = """
|
||||
var_name = tf.constant(False)
|
||||
var_name = False
|
||||
for target in iter_:
|
||||
(var_name,)
|
||||
body
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.converters import break_statements
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.eager import context as tfe_ctx
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -43,10 +42,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v
|
||||
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 1)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 1)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
|
||||
def test_for_loop(self):
|
||||
|
||||
@ -82,10 +80,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v, u, w
|
||||
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 11)
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 11)
|
||||
|
||||
def test_nested_loops(self):
|
||||
|
||||
@ -105,11 +102,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v, u
|
||||
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 2)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 5)
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 2)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 5)
|
||||
|
||||
def test_loop_orelse(self):
|
||||
|
||||
@ -127,10 +123,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v, u
|
||||
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 2)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 2)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user