From c76a8d14b1465710618e3262ef7c84bc4677b152 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Sat, 23 May 2020 20:28:54 -0700 Subject: [PATCH] Rewrite `del` to treat undefinedness in a consistent manner. PiperOrigin-RevId: 312947175 Change-Id: Ida4cb8c97ff280cb1011e33edc20a5c524fb8f8a --- .../python/autograph/converters/variables.py | 25 ++++++ .../autograph/converters/variables_test.py | 84 +++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/tensorflow/python/autograph/converters/variables.py b/tensorflow/python/autograph/converters/variables.py index 3028a65a69b..9784f50ed56 100644 --- a/tensorflow/python/autograph/converters/variables.py +++ b/tensorflow/python/autograph/converters/variables.py @@ -60,6 +60,31 @@ class VariableAccessTransformer(converter.Base): node = templates.replace_as_expression('ag__.ld(var_)', var_=node) return node + def visit_Delete(self, node): + node = self.generic_visit(node) + + rewrite_targets = [] + for tgt in node.targets: + # Don't rewrite composites like `del a[0]`. + if isinstance(tgt, gast.Name): + rewrite_targets.append(tgt) + + if not rewrite_targets: + return node + + results = [] + for tgt in rewrite_targets: + template = """ + var_ = ag__.Undefined(var_name) + """ + results.extend(templates.replace( + template, var_=tgt, var_name=gast.Constant(tgt.id, kind=None))) + remaining_targets = [n for n in node.targets if n not in rewrite_targets] + if remaining_targets: + results.append(gast.Delete(targets=remaining_targets)) + + return results + def visit_AugAssign(self, node): if isinstance(node.target, gast.Name): template = """ diff --git a/tensorflow/python/autograph/converters/variables_test.py b/tensorflow/python/autograph/converters/variables_test.py index 556dafbaa8a..93a31e63de3 100644 --- a/tensorflow/python/autograph/converters/variables_test.py +++ b/tensorflow/python/autograph/converters/variables_test.py @@ -51,6 +51,90 @@ class VariablesTest(converter_testing.TestCase): with self.apply_add_one_conversion(test_fn) as result: self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads + def test_del(self): + + def test_fn(l): + del l + return l + + with self.converted(test_fn, variables, {}) as result: + with self.assertRaisesRegex( + NameError, "'l' is used before assignment"): + result.test_fn(1) + + def test_del_getitem_ignored(self): + + def basic_slice(l): + del l[0] + return l + + with self.converted(basic_slice, variables, {}) as result: + self.assertListEqual([2], result.basic_slice([1, 2])) + + def range_slice(l): + del l[0:2] + return l + + with self.converted(range_slice, variables, {}) as result: + self.assertListEqual([], result.range_slice([1, 2])) + + def test_del_getattr_ignored(self): + + def test_fn(l): + del l.a + return l + + class TestClass(object): + + def __init__(self): + self.a = 1 + self.b = 2 + + with self.converted(test_fn, variables, {}) as result: + self.assertFalse(hasattr(result.test_fn(TestClass()), 'a')) + self.assertEqual(result.test_fn(TestClass()).b, 2) + + def test_del_packing_ignored(self): + # Note: test for UnboundLocalError, not NameError because in this case we + # don't rewrite the del. + + def list_(a, b): + del [a, b] + return a + + with self.converted(list_, variables, {}) as result: + with self.assertRaises(UnboundLocalError): + result.list_(1, 2) + + def nested(a, b, c): + del [a, (b, c)] + return c + + with self.converted(nested, variables, {}) as result: + with self.assertRaises(UnboundLocalError): + result.nested(1, 2, 3) + + def test_del_item_multiple_mixed(self): + + def test_fn_failing(a, b, c): + del a, b, c[0] + a = 1 + return a, b, c + + with self.converted(test_fn_failing, variables, {}) as result: + with self.assertRaisesRegex( + NameError, "'b' is used before assignment"): + result.test_fn_failing(1, 2, [1, 2]) + + def test_fn_passing(a, b, c): + del a, b, c[0] + a = 1 + b = 2 + return c + + with self.converted(test_fn_passing, variables, {}) as result: + self.assertListEqual([2], result.test_fn_passing(1, 2, [1, 2])) + def test_attribute(self): class TestClass(object):