Use TF constants for the break/continue control variables, to ensure control dependencies get created correctly. This renders break cond continue incompatible with Python inputs, but that's an extremely very unlikely use case.

PiperOrigin-RevId: 206738877
This commit is contained in:
Dan Moldovan 2018-07-31 04:28:28 -07:00 committed by TensorFlower Gardener
parent 3a1df26a25
commit 3bec2640dc
4 changed files with 47 additions and 33 deletions

View File

@ -42,7 +42,7 @@ class BreakTransformer(converter.Base):
var_name = self.state[_Break].control_var_name
# TODO(mdan): This will fail when expanded inside a top-level else block.
template = """
var_name = True
var_name = tf.constant(True)
continue
"""
return templates.replace(template, var_name=var_name)
@ -85,7 +85,7 @@ class BreakTransformer(converter.Base):
guarded_orelse = self._guard_if_present(node.orelse, break_var)
template = """
var_name = False
var_name = tf.constant(False)
while test and not var_name:
body
else:
@ -122,7 +122,7 @@ class BreakTransformer(converter.Base):
# the control variable is marked as used.
# TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
template = """
var_name = False
var_name = tf.constant(False)
for target in iter_:
(var_name,)
body

View File

@ -20,13 +20,16 @@ from __future__ import print_function
from tensorflow.contrib.autograph.converters import break_statements
from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.eager import context as tfe_ctx
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class BreakCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn, break_statements, {}) as result:
with self.converted(test_fn, break_statements, {},
constant_op.constant) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_while_loop(self):
@ -40,9 +43,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 4)
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 4)
def test_for_loop(self):
@ -55,7 +59,8 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
with self.converted(test_fn, break_statements, {}) as result:
with self.converted(test_fn, break_statements, {},
constant_op.constant) as result:
# The break is incompletely canonicalized. The loop will not interrupt,
# but the section following the break will be skipped.
self.assertEqual([3], result.test_fn([5, 4]))
@ -77,9 +82,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 11)
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 11)
def test_nested_loops(self):
@ -99,10 +105,11 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 5)
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 5)
def test_loop_orelse(self):
@ -120,9 +127,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
if __name__ == '__main__':

View File

@ -37,7 +37,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
def visit_Continue(self, node):
self.set_local(CONTINUE_USED, True)
template = """
var_name = True
var_name = tf.constant(True)
"""
return templates.replace(
template, var_name=self.get_local(CONTROL_VAR_NAME))
@ -92,7 +92,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
if self.get_local(CONTINUE_USED, False):
template = """
var_name = False
var_name = tf.constant(False)
"""
control_var_init = templates.replace(template, var_name=continue_var)
nodes = control_var_init + nodes

View File

@ -20,13 +20,16 @@ from __future__ import print_function
from tensorflow.contrib.autograph.converters import continue_statements
from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.eager import context as tfe_ctx
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class ContinueCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn, continue_statements, {}) as result:
with self.converted(test_fn, continue_statements, {},
constant_op.constant) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_basic(self):
@ -40,10 +43,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
def test_for_loop(self):
@ -56,10 +60,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [1])
self.assertTransformedEquivalent(test_fn, [2])
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [1])
self.assertTransformedEquivalent(test_fn, [2])
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
def test_nested(self):
@ -78,10 +83,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
with tfe_ctx.eager_mode():
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
if __name__ == '__main__':