Fix context setting in pyct.templates.replace for subscript expressions

PiperOrigin-RevId: 241952938
This commit is contained in:
Brian Lee 2019-04-04 10:15:25 -07:00 committed by TensorFlower Gardener
parent 87ea41d023
commit cae849ba7f
3 changed files with 78 additions and 12 deletions

View File

@ -147,6 +147,7 @@ py_test(
deps = [
":pyct",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
"@gast_archive//:gast",
],
)

View File

@ -87,8 +87,9 @@ class ContextAdjuster(gast.NodeTransformer):
return self.generic_visit(node)
def visit_Subscript(self, node):
self._apply_override(node)
self._ctx_override = gast.Load
node.value = self.visit(node.value)
self._ctx_override = None
return self.generic_visit(node)
def visit_comprehension(self, node):
@ -214,9 +215,10 @@ class ReplaceTransformer(gast.NodeTransformer):
def _convert_to_ast(n):
"""Converts from a known data type to AST."""
# Note: When generating AST nodes from strings/QNs in isolation, ctx is
# unknown. ctx must be filled in according to the template being used.
# See ReplaceTransformer.visit_Name.
if isinstance(n, str):
# Note: the node will receive the ctx value from the template, see
# ReplaceTransformer.visit_Name.
return gast.Name(id=n, ctx=None, annotation=None)
if isinstance(n, qual_names.QN):
return n.ast()

View File

@ -20,15 +20,53 @@ from __future__ import print_function
import imp
from absl.testing import parameterized
import gast
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names as qn
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.platform import test
class TemplatesTest(test.TestCase):
class _CtxClearer(gast.NodeTransformer):
def visit(self, node):
super(_CtxClearer, self).visit(node)
if hasattr(node, 'ctx'):
node.ctx = None
return node
def _parse_with_unset_ctx(expr_source):
ast_node = parser.parse_expression(expr_source)
_CtxClearer().visit(ast_node)
return ast_node
class _CtxChecker(gast.NodeTransformer):
def __init__(self, test_instance, expected_ctx):
self.at_top_level = True
self.test_instance = test_instance
self.expected_ctx = expected_ctx
def visit(self, node):
if hasattr(node, 'ctx'):
self.test_instance.assertIsInstance(node.ctx, self.expected_ctx)
if self.at_top_level:
self.at_top_level = False
self.expected_ctx = gast.Load
return super(_CtxChecker, self).visit(node)
class TemplatesTest(test.TestCase, parameterized.TestCase):
def assertExpectedCtxSet(self, node, ctx):
"""Assert that node has ctx=ctx at top and ctx=gast.Load everywhere else."""
checker = _CtxChecker(self, ctx)
checker.visit(node)
def test_replace_tuple(self):
template = """
@ -39,7 +77,7 @@ class TemplatesTest(test.TestCase):
node = templates.replace(template, b=('a', 'c'))[0]
result, _, _ = compiler.ast_to_object(node)
self.assertEquals((2, 3), result.test_fn(2, 3))
self.assertEqual((2, 3), result.test_fn(2, 3))
def test_replace_variable(self):
template = """
@ -51,7 +89,7 @@ class TemplatesTest(test.TestCase):
node = templates.replace(template, a='b')[0]
result, _, _ = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
self.assertEqual(7, result.test_fn(2))
def test_replace_function_name(self):
template = """
@ -63,7 +101,7 @@ class TemplatesTest(test.TestCase):
node = templates.replace(template, fname='test_fn')[0]
result, _, _ = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
self.assertEqual(7, result.test_fn(2))
def test_replace_code_block(self):
template = """
@ -80,7 +118,7 @@ class TemplatesTest(test.TestCase):
], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
] * 2)[0]
result, _, _ = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn(1))
self.assertEqual(3, result.test_fn(1))
def test_replace_attribute(self):
template = """
@ -92,7 +130,7 @@ class TemplatesTest(test.TestCase):
result, _, _ = compiler.ast_to_object(node)
mod = imp.new_module('test')
mod.b = 3
self.assertEquals(3, result.test_fn(mod))
self.assertEqual(3, result.test_fn(mod))
with self.assertRaises(ValueError):
templates.replace(template, foo=1)
@ -180,7 +218,7 @@ class TemplatesTest(test.TestCase):
source = parser.parse_expression('f(d=3, f=5)')
node = templates.replace(template, kws=source.keywords)[0]
result, _, _ = compiler.ast_to_object(node)
self.assertEquals(9, result.test_fn())
self.assertEqual(9, result.test_fn())
with self.assertRaises(ValueError):
templates.replace(template, kws=[])
@ -200,7 +238,7 @@ class TemplatesTest(test.TestCase):
source = parser.parse_expression('f()(b)')
node = templates.replace(template, foo=source)[0]
result, _, _ = compiler.ast_to_object(node)
self.assertEquals(15, result.test_fn())
self.assertEqual(15, result.test_fn())
def test_replace_name_with_dict(self):
template = """
@ -211,7 +249,7 @@ class TemplatesTest(test.TestCase):
source = parser.parse_expression('{\'bar\': 3}')
node = templates.replace(template, foo=source)[0]
result, _, _ = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn())
self.assertEqual(3, result.test_fn())
def test_replace_as_expression(self):
template = """
@ -258,6 +296,31 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param)
self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
def test_replace_name_with_subscript(self):
template = """
foo = bar
"""
replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key'))
node = templates.replace(template, foo=replacement)[0].targets[0]
self.assertIsInstance(node.ctx, gast.Store)
self.assertIsInstance(node.value.ctx, gast.Load)
@parameterized.named_parameters([
('mixed_attr_subscript', 'a.b["c"]'),
('mixed_subscript_attr', 'a[b.c]'),
('nested_subscript', 'a[b[c]]'),
('repeated_subscript', 'a[b][c]'),
])
def test_replace_name_mixed_attr_subscript(self, expression_source):
template = 'foo = bar'
replacement = _parse_with_unset_ctx(expression_source)
target_node = templates.replace(template, foo=replacement)[0].targets[0]
self.assertExpectedCtxSet(target_node, gast.Store)
value_node = templates.replace(template, bar=replacement)[0].value
self.assertExpectedCtxSet(value_node, gast.Load)
if __name__ == '__main__':
test.main()