From 0aa35b554d16df0aaf0d513ae37b92d6f4b48e42 Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Mon, 23 Mar 2020 09:52:20 +0800 Subject: [PATCH 1/5] fix loop else in autograph and add test --- .../autograph/converters/break_statements.py | 30 ++++- .../converters/loop_integration_test.py | 124 ++++++++++++++++++ 2 files changed, 149 insertions(+), 5 deletions(-) create mode 100644 tensorflow/python/autograph/converters/loop_integration_test.py diff --git a/tensorflow/python/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py index 718c5bd3ca5..8be0e10c783 100644 --- a/tensorflow/python/autograph/converters/break_statements.py +++ b/tensorflow/python/autograph/converters/break_statements.py @@ -89,8 +89,7 @@ class BreakTransformer(converter.Base): var_name = False while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): body - else: - orelse + orelse """ node = templates.replace( template, @@ -101,6 +100,17 @@ class BreakTransformer(converter.Base): new_while_node = node[1] anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) + else: + template = """ + while test: + body + orelse + """ + node = templates.replace( + template, + test=node.test, + body=node.body, + orelse=node.orelse) return node @@ -131,8 +141,7 @@ class BreakTransformer(converter.Base): for target in iter_: (var_name,) body - else: - orelse + orelse """ node = templates.replace( template, @@ -145,7 +154,18 @@ class BreakTransformer(converter.Base): new_for_node = node[1] anno.setanno(new_for_node, 'extra_test', extra_test) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) - + else: + template = """ + for target in iter_: + body + orelse + """ + node = templates.replace( + template, + iter_=node.iter, + target=node.target, + body=node.body, + orelse=node.orelse) return node diff --git a/tensorflow/python/autograph/converters/loop_integration_test.py b/tensorflow/python/autograph/converters/loop_integration_test.py new file mode 100644 index 00000000000..977263e2826 --- /dev/null +++ b/tensorflow/python/autograph/converters/loop_integration_test.py @@ -0,0 +1,124 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration Tests for loop.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.converters import break_statements +from tensorflow.python.autograph.converters import continue_statements +from tensorflow.python.autograph.converters import control_flow +from tensorflow.python.autograph.core import converter_testing +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + + +class LoopIntegrationTest(converter_testing.TestCase): + + def assertTransformedEquivalent(self, test_fn, *inputs): + with self.converted(test_fn, [break_statements, + continue_statements, + control_flow], + {}, (constant_op.constant,)) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_while_loop(self): + + def test_fn(x): + v = [] + while x > 0: + x -= 1 + if x % 2 == 0: + break + v.append(x) + return v + + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 4) + + def test_for_loop(self): + + def test_fn(a): + v = [] + for x in a: + x -= 1 + if x % 2 == 0: + break + v.append(x) + return v + + with self.converted(test_fn, break_statements, {}, + (constant_op.constant,)) as result: + # The break is incompletely canonicalized. The loop will not interrupt, + # but the section following the break will be skipped. + self.assertEqual([3], result.test_fn([5, 4])) + + def test_while_loop_with_else(self): + def test_fn(x): + while x > 2: + x /= 2 + else: + x += 1 + return x + + self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(test_fn, 2) + + def test_while_loop_with_else_and_break(self): + def test_fn(cond1): + x = 8 + while x > 2: + x /= 2 + if cond1: + break + else: + x += 1 + return x + + self.assertTransformedEquivalent(test_fn, True) + self.assertTransformedEquivalent(test_fn, False) + + def test_for_loop_with_else(self): + def test_fn(l): + res = 0 + for x in l: + res += x + else: + res += 1 + return res + + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1, 2]) + + def test_for_loop_with_else_and_break(self): + def test_fn(flag): + l = [1, 2, 3] + res = 0 + for x in l: + res += x + if flag: + break + else: + res += 1 + return res + + self.assertTransformedEquivalent(test_fn, True) + self.assertTransformedEquivalent(test_fn, False) + + +if __name__ == '__main__': + test.main() From f2643a298573e30b96c4405e28494f123b9440fc Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Mon, 23 Mar 2020 10:41:06 +0800 Subject: [PATCH 2/5] correct copyright year --- tensorflow/python/autograph/converters/loop_integration_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/autograph/converters/loop_integration_test.py b/tensorflow/python/autograph/converters/loop_integration_test.py index 977263e2826..1798132449a 100644 --- a/tensorflow/python/autograph/converters/loop_integration_test.py +++ b/tensorflow/python/autograph/converters/loop_integration_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 9179a9df3ab3318a96a372a28e0523aa0b28174b Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Mon, 23 Mar 2020 22:01:59 +0800 Subject: [PATCH 3/5] add anno and delete unnecessary tests --- .../autograph/converters/break_statements.py | 113 ++++++++++-------- .../converters/loop_integration_test.py | 32 ----- 2 files changed, 62 insertions(+), 83 deletions(-) diff --git a/tensorflow/python/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py index 8be0e10c783..9474dbfde89 100644 --- a/tensorflow/python/autograph/converters/break_statements.py +++ b/tensorflow/python/autograph/converters/break_statements.py @@ -80,27 +80,7 @@ class BreakTransformer(converter.Base): # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) - if break_used: - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). - guarded_orelse = self._guard_if_present(node.orelse, break_var) - - template = """ - var_name = False - while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): - body - orelse - """ - node = templates.replace( - template, - var_name=break_var, - test=node.test, - body=node.body, - orelse=guarded_orelse) - - new_while_node = node[1] - anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) - else: + if not break_used: template = """ while test: body @@ -112,6 +92,31 @@ class BreakTransformer(converter.Base): body=node.body, orelse=node.orelse) + new_while_node = node[0] + anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) + + return node + + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + + template = """ + var_name = False + while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): + body + orelse + """ + node = templates.replace( + template, + var_name=break_var, + test=node.test, + body=node.body, + orelse=guarded_orelse) + + new_while_node = node[1] + anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) + return node def visit_For(self, node): @@ -125,36 +130,7 @@ class BreakTransformer(converter.Base): # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) - if break_used: - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). - guarded_orelse = self._guard_if_present(node.orelse, break_var) - extra_test = templates.replace_as_expression( - 'ag__.not_(var_name)', var_name=break_var) - - # The extra test is hidden in the AST, which will confuse the static - # analysis. To mitigate that, we insert a no-op statement that ensures - # the control variable is marked as used. - # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) - template = """ - var_name = False - for target in iter_: - (var_name,) - body - orelse - """ - node = templates.replace( - template, - var_name=break_var, - iter_=node.iter, - target=node.target, - body=node.body, - orelse=guarded_orelse) - - new_for_node = node[1] - anno.setanno(new_for_node, 'extra_test', extra_test) - anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) - else: + if not break_used: template = """ for target in iter_: body @@ -166,6 +142,41 @@ class BreakTransformer(converter.Base): target=node.target, body=node.body, orelse=node.orelse) + + new_for_node = node[0] + anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) + + return node + + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + extra_test = templates.replace_as_expression( + 'ag__.not_(var_name)', var_name=break_var) + + # The extra test is hidden in the AST, which will confuse the static + # analysis. To mitigate that, we insert a no-op statement that ensures + # the control variable is marked as used. + # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) + template = """ + var_name = False + for target in iter_: + (var_name,) + body + orelse + """ + node = templates.replace( + template, + var_name=break_var, + iter_=node.iter, + target=node.target, + body=node.body, + orelse=guarded_orelse) + + new_for_node = node[1] + anno.setanno(new_for_node, 'extra_test', extra_test) + anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) + return node diff --git a/tensorflow/python/autograph/converters/loop_integration_test.py b/tensorflow/python/autograph/converters/loop_integration_test.py index 1798132449a..8e3a8871a6f 100644 --- a/tensorflow/python/autograph/converters/loop_integration_test.py +++ b/tensorflow/python/autograph/converters/loop_integration_test.py @@ -35,38 +35,6 @@ class LoopIntegrationTest(converter_testing.TestCase): {}, (constant_op.constant,)) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) - def test_while_loop(self): - - def test_fn(x): - v = [] - while x > 0: - x -= 1 - if x % 2 == 0: - break - v.append(x) - return v - - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 4) - - def test_for_loop(self): - - def test_fn(a): - v = [] - for x in a: - x -= 1 - if x % 2 == 0: - break - v.append(x) - return v - - with self.converted(test_fn, break_statements, {}, - (constant_op.constant,)) as result: - # The break is incompletely canonicalized. The loop will not interrupt, - # but the section following the break will be skipped. - self.assertEqual([3], result.test_fn([5, 4])) - def test_while_loop_with_else(self): def test_fn(x): while x > 2: From 1a1b07304bd61c18b1f947a1463e62ac3fe5e201 Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Mon, 23 Mar 2020 22:58:01 +0800 Subject: [PATCH 4/5] add copyanno for safety --- tensorflow/python/autograph/converters/break_statements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py index 9474dbfde89..cce0f1c91ee 100644 --- a/tensorflow/python/autograph/converters/break_statements.py +++ b/tensorflow/python/autograph/converters/break_statements.py @@ -144,6 +144,7 @@ class BreakTransformer(converter.Base): orelse=node.orelse) new_for_node = node[0] + anno.copyanno(original_node, new_for_node, 'extra_test') anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node From beea13b22361f108d3c13256a27e38906bdb6a73 Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Tue, 31 Mar 2020 21:53:58 +0800 Subject: [PATCH 5/5] fix merge conflict --- tensorflow/python/autograph/converters/break_statements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py index cce0f1c91ee..635b1804593 100644 --- a/tensorflow/python/autograph/converters/break_statements.py +++ b/tensorflow/python/autograph/converters/break_statements.py @@ -144,7 +144,7 @@ class BreakTransformer(converter.Base): orelse=node.orelse) new_for_node = node[0] - anno.copyanno(original_node, new_for_node, 'extra_test') + anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node @@ -175,7 +175,7 @@ class BreakTransformer(converter.Base): orelse=guarded_orelse) new_for_node = node[1] - anno.setanno(new_for_node, 'extra_test', extra_test) + anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node