Use TF constants for the break/continue control variables, to ensure control dependencies get created correctly. This renders break cond continue incompatible with Python inputs, but that's an extremely very unlikely use case.
PiperOrigin-RevId: 206738877
This commit is contained in:
parent
3a1df26a25
commit
3bec2640dc
@ -42,7 +42,7 @@ class BreakTransformer(converter.Base):
|
||||
var_name = self.state[_Break].control_var_name
|
||||
# TODO(mdan): This will fail when expanded inside a top-level else block.
|
||||
template = """
|
||||
var_name = True
|
||||
var_name = tf.constant(True)
|
||||
continue
|
||||
"""
|
||||
return templates.replace(template, var_name=var_name)
|
||||
@ -85,7 +85,7 @@ class BreakTransformer(converter.Base):
|
||||
guarded_orelse = self._guard_if_present(node.orelse, break_var)
|
||||
|
||||
template = """
|
||||
var_name = False
|
||||
var_name = tf.constant(False)
|
||||
while test and not var_name:
|
||||
body
|
||||
else:
|
||||
@ -122,7 +122,7 @@ class BreakTransformer(converter.Base):
|
||||
# the control variable is marked as used.
|
||||
# TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
|
||||
template = """
|
||||
var_name = False
|
||||
var_name = tf.constant(False)
|
||||
for target in iter_:
|
||||
(var_name,)
|
||||
body
|
||||
|
@ -20,13 +20,16 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.autograph.converters import break_statements
|
||||
from tensorflow.contrib.autograph.core import converter_testing
|
||||
from tensorflow.python.eager import context as tfe_ctx
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
|
||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
||||
with self.converted(test_fn, break_statements, {}) as result:
|
||||
with self.converted(test_fn, break_statements, {},
|
||||
constant_op.constant) as result:
|
||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
||||
|
||||
def test_while_loop(self):
|
||||
@ -40,9 +43,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v
|
||||
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 1)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 1)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
|
||||
def test_for_loop(self):
|
||||
|
||||
@ -55,7 +59,8 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v
|
||||
|
||||
with self.converted(test_fn, break_statements, {}) as result:
|
||||
with self.converted(test_fn, break_statements, {},
|
||||
constant_op.constant) as result:
|
||||
# The break is incompletely canonicalized. The loop will not interrupt,
|
||||
# but the section following the break will be skipped.
|
||||
self.assertEqual([3], result.test_fn([5, 4]))
|
||||
@ -77,9 +82,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v, u, w
|
||||
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 11)
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 11)
|
||||
|
||||
def test_nested_loops(self):
|
||||
|
||||
@ -99,10 +105,11 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v, u
|
||||
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 2)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 5)
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 2)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 5)
|
||||
|
||||
def test_loop_orelse(self):
|
||||
|
||||
@ -120,9 +127,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v, u
|
||||
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 2)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 2)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -37,7 +37,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
|
||||
def visit_Continue(self, node):
|
||||
self.set_local(CONTINUE_USED, True)
|
||||
template = """
|
||||
var_name = True
|
||||
var_name = tf.constant(True)
|
||||
"""
|
||||
return templates.replace(
|
||||
template, var_name=self.get_local(CONTROL_VAR_NAME))
|
||||
@ -92,7 +92,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
|
||||
|
||||
if self.get_local(CONTINUE_USED, False):
|
||||
template = """
|
||||
var_name = False
|
||||
var_name = tf.constant(False)
|
||||
"""
|
||||
control_var_init = templates.replace(template, var_name=continue_var)
|
||||
nodes = control_var_init + nodes
|
||||
|
@ -20,13 +20,16 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.autograph.converters import continue_statements
|
||||
from tensorflow.contrib.autograph.core import converter_testing
|
||||
from tensorflow.python.eager import context as tfe_ctx
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ContinueCanonicalizationTest(converter_testing.TestCase):
|
||||
|
||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
||||
with self.converted(test_fn, continue_statements, {}) as result:
|
||||
with self.converted(test_fn, continue_statements, {},
|
||||
constant_op.constant) as result:
|
||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
||||
|
||||
def test_basic(self):
|
||||
@ -40,10 +43,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v
|
||||
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 1)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 1)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
|
||||
def test_for_loop(self):
|
||||
|
||||
@ -56,10 +60,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v
|
||||
|
||||
self.assertTransformedEquivalent(test_fn, [])
|
||||
self.assertTransformedEquivalent(test_fn, [1])
|
||||
self.assertTransformedEquivalent(test_fn, [2])
|
||||
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, [])
|
||||
self.assertTransformedEquivalent(test_fn, [1])
|
||||
self.assertTransformedEquivalent(test_fn, [2])
|
||||
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
|
||||
|
||||
def test_nested(self):
|
||||
|
||||
@ -78,10 +83,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
||||
v.append(x)
|
||||
return v, u, w
|
||||
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 1)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
with tfe_ctx.eager_mode():
|
||||
self.assertTransformedEquivalent(test_fn, 0)
|
||||
self.assertTransformedEquivalent(test_fn, 1)
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user