Rewrite del to treat undefinedness in a consistent manner.

PiperOrigin-RevId: 312947175
Change-Id: Ida4cb8c97ff280cb1011e33edc20a5c524fb8f8a
This commit is contained in:
Dan Moldovan 2020-05-23 20:28:54 -07:00 committed by TensorFlower Gardener
parent 144b3dc790
commit c76a8d14b1
2 changed files with 109 additions and 0 deletions

View File

@ -60,6 +60,31 @@ class VariableAccessTransformer(converter.Base):
node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
return node
def visit_Delete(self, node):
node = self.generic_visit(node)
rewrite_targets = []
for tgt in node.targets:
# Don't rewrite composites like `del a[0]`.
if isinstance(tgt, gast.Name):
rewrite_targets.append(tgt)
if not rewrite_targets:
return node
results = []
for tgt in rewrite_targets:
template = """
var_ = ag__.Undefined(var_name)
"""
results.extend(templates.replace(
template, var_=tgt, var_name=gast.Constant(tgt.id, kind=None)))
remaining_targets = [n for n in node.targets if n not in rewrite_targets]
if remaining_targets:
results.append(gast.Delete(targets=remaining_targets))
return results
def visit_AugAssign(self, node):
if isinstance(node.target, gast.Name):
template = """

View File

@ -51,6 +51,90 @@ class VariablesTest(converter_testing.TestCase):
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads
def test_del(self):
def test_fn(l):
del l
return l
with self.converted(test_fn, variables, {}) as result:
with self.assertRaisesRegex(
NameError, "'l' is used before assignment"):
result.test_fn(1)
def test_del_getitem_ignored(self):
def basic_slice(l):
del l[0]
return l
with self.converted(basic_slice, variables, {}) as result:
self.assertListEqual([2], result.basic_slice([1, 2]))
def range_slice(l):
del l[0:2]
return l
with self.converted(range_slice, variables, {}) as result:
self.assertListEqual([], result.range_slice([1, 2]))
def test_del_getattr_ignored(self):
def test_fn(l):
del l.a
return l
class TestClass(object):
def __init__(self):
self.a = 1
self.b = 2
with self.converted(test_fn, variables, {}) as result:
self.assertFalse(hasattr(result.test_fn(TestClass()), 'a'))
self.assertEqual(result.test_fn(TestClass()).b, 2)
def test_del_packing_ignored(self):
# Note: test for UnboundLocalError, not NameError because in this case we
# don't rewrite the del.
def list_(a, b):
del [a, b]
return a
with self.converted(list_, variables, {}) as result:
with self.assertRaises(UnboundLocalError):
result.list_(1, 2)
def nested(a, b, c):
del [a, (b, c)]
return c
with self.converted(nested, variables, {}) as result:
with self.assertRaises(UnboundLocalError):
result.nested(1, 2, 3)
def test_del_item_multiple_mixed(self):
def test_fn_failing(a, b, c):
del a, b, c[0]
a = 1
return a, b, c
with self.converted(test_fn_failing, variables, {}) as result:
with self.assertRaisesRegex(
NameError, "'b' is used before assignment"):
result.test_fn_failing(1, 2, [1, 2])
def test_fn_passing(a, b, c):
del a, b, c[0]
a = 1
b = 2
return c
with self.converted(test_fn_passing, variables, {}) as result:
self.assertListEqual([2], result.test_fn_passing(1, 2, [1, 2]))
def test_attribute(self):
class TestClass(object):