Add support for list literals in template replacement values.

PiperOrigin-RevId: 212337233
This commit is contained in:
Dan Moldovan 2018-09-10 14:40:21 -07:00 committed by TensorFlower Gardener
parent b828f89263
commit 6d3af1df20
2 changed files with 39 additions and 3 deletions

View File

@ -113,7 +113,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._check_inner_children_have_context(node.value)
self._check_has_context(node)
elif isinstance(node, gast.Tuple):
elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._check_inner_children_have_context(e)
self._check_has_context(node)
@ -142,7 +142,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._set_inner_child_context(node.value, gast.Load())
node.ctx = ctx
elif isinstance(node, gast.Tuple):
elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._set_inner_child_context(e, ctx)
node.ctx = ctx
@ -191,7 +191,7 @@ class ReplaceTransformer(gast.NodeTransformer):
# Preserve the target context.
for n in new_nodes:
if isinstance(n, gast.Tuple):
if isinstance(n, (gast.Tuple, gast.List)):
for e in n.elts:
self._set_inner_child_context(e, node.ctx)
if isinstance(n, gast.Attribute):

View File

@ -110,6 +110,42 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
def test_replace_list_context(self):
template = """
def test_fn(foo):
foo = 0
"""
node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0]
self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
def test_replace_tuple_context(self):
template = """
def test_fn(foo):
foo = 0
"""
node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
def test_replace_complex_context(self):
template = """
def test_fn(foo):
foo = 0
"""
node = templates.replace(
template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
function_call_arg = node.body[0].targets[0].value.args[0]
self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
def test_replace_call_keyword(self):
template = """
def test_fn():