Fix context setting in pyct.templates.replace for subscript expressions

PiperOrigin-RevId: 241952938
This commit is contained in:
Brian Lee 2019-04-04 10:15:25 -07:00 committed by TensorFlower Gardener
parent 87ea41d023
commit cae849ba7f
3 changed files with 78 additions and 12 deletions
tensorflow/python/autograph/pyct

View File

@ -147,6 +147,7 @@ py_test(
deps = [ deps = [
":pyct", ":pyct",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
"@gast_archive//:gast", "@gast_archive//:gast",
], ],
) )

View File

@ -87,8 +87,9 @@ class ContextAdjuster(gast.NodeTransformer):
return self.generic_visit(node) return self.generic_visit(node)
def visit_Subscript(self, node): def visit_Subscript(self, node):
self._apply_override(node)
self._ctx_override = gast.Load
node.value = self.visit(node.value) node.value = self.visit(node.value)
self._ctx_override = None
return self.generic_visit(node) return self.generic_visit(node)
def visit_comprehension(self, node): def visit_comprehension(self, node):
@ -214,9 +215,10 @@ class ReplaceTransformer(gast.NodeTransformer):
def _convert_to_ast(n): def _convert_to_ast(n):
"""Converts from a known data type to AST.""" """Converts from a known data type to AST."""
# Note: When generating AST nodes from strings/QNs in isolation, ctx is
# unknown. ctx must be filled in according to the template being used.
# See ReplaceTransformer.visit_Name.
if isinstance(n, str): if isinstance(n, str):
# Note: the node will receive the ctx value from the template, see
# ReplaceTransformer.visit_Name.
return gast.Name(id=n, ctx=None, annotation=None) return gast.Name(id=n, ctx=None, annotation=None)
if isinstance(n, qual_names.QN): if isinstance(n, qual_names.QN):
return n.ast() return n.ast()

View File

@ -20,15 +20,53 @@ from __future__ import print_function
import imp import imp
from absl.testing import parameterized
import gast import gast
from tensorflow.python.autograph.pyct import compiler from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names as qn
from tensorflow.python.autograph.pyct import templates from tensorflow.python.autograph.pyct import templates
from tensorflow.python.platform import test from tensorflow.python.platform import test
class TemplatesTest(test.TestCase): class _CtxClearer(gast.NodeTransformer):
def visit(self, node):
super(_CtxClearer, self).visit(node)
if hasattr(node, 'ctx'):
node.ctx = None
return node
def _parse_with_unset_ctx(expr_source):
ast_node = parser.parse_expression(expr_source)
_CtxClearer().visit(ast_node)
return ast_node
class _CtxChecker(gast.NodeTransformer):
def __init__(self, test_instance, expected_ctx):
self.at_top_level = True
self.test_instance = test_instance
self.expected_ctx = expected_ctx
def visit(self, node):
if hasattr(node, 'ctx'):
self.test_instance.assertIsInstance(node.ctx, self.expected_ctx)
if self.at_top_level:
self.at_top_level = False
self.expected_ctx = gast.Load
return super(_CtxChecker, self).visit(node)
class TemplatesTest(test.TestCase, parameterized.TestCase):
def assertExpectedCtxSet(self, node, ctx):
"""Assert that node has ctx=ctx at top and ctx=gast.Load everywhere else."""
checker = _CtxChecker(self, ctx)
checker.visit(node)
def test_replace_tuple(self): def test_replace_tuple(self):
template = """ template = """
@ -39,7 +77,7 @@ class TemplatesTest(test.TestCase):
node = templates.replace(template, b=('a', 'c'))[0] node = templates.replace(template, b=('a', 'c'))[0]
result, _, _ = compiler.ast_to_object(node) result, _, _ = compiler.ast_to_object(node)
self.assertEquals((2, 3), result.test_fn(2, 3)) self.assertEqual((2, 3), result.test_fn(2, 3))
def test_replace_variable(self): def test_replace_variable(self):
template = """ template = """
@ -51,7 +89,7 @@ class TemplatesTest(test.TestCase):
node = templates.replace(template, a='b')[0] node = templates.replace(template, a='b')[0]
result, _, _ = compiler.ast_to_object(node) result, _, _ = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2)) self.assertEqual(7, result.test_fn(2))
def test_replace_function_name(self): def test_replace_function_name(self):
template = """ template = """
@ -63,7 +101,7 @@ class TemplatesTest(test.TestCase):
node = templates.replace(template, fname='test_fn')[0] node = templates.replace(template, fname='test_fn')[0]
result, _, _ = compiler.ast_to_object(node) result, _, _ = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2)) self.assertEqual(7, result.test_fn(2))
def test_replace_code_block(self): def test_replace_code_block(self):
template = """ template = """
@ -80,7 +118,7 @@ class TemplatesTest(test.TestCase):
], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
] * 2)[0] ] * 2)[0]
result, _, _ = compiler.ast_to_object(node) result, _, _ = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn(1)) self.assertEqual(3, result.test_fn(1))
def test_replace_attribute(self): def test_replace_attribute(self):
template = """ template = """
@ -92,7 +130,7 @@ class TemplatesTest(test.TestCase):
result, _, _ = compiler.ast_to_object(node) result, _, _ = compiler.ast_to_object(node)
mod = imp.new_module('test') mod = imp.new_module('test')
mod.b = 3 mod.b = 3
self.assertEquals(3, result.test_fn(mod)) self.assertEqual(3, result.test_fn(mod))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
templates.replace(template, foo=1) templates.replace(template, foo=1)
@ -180,7 +218,7 @@ class TemplatesTest(test.TestCase):
source = parser.parse_expression('f(d=3, f=5)') source = parser.parse_expression('f(d=3, f=5)')
node = templates.replace(template, kws=source.keywords)[0] node = templates.replace(template, kws=source.keywords)[0]
result, _, _ = compiler.ast_to_object(node) result, _, _ = compiler.ast_to_object(node)
self.assertEquals(9, result.test_fn()) self.assertEqual(9, result.test_fn())
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
templates.replace(template, kws=[]) templates.replace(template, kws=[])
@ -200,7 +238,7 @@ class TemplatesTest(test.TestCase):
source = parser.parse_expression('f()(b)') source = parser.parse_expression('f()(b)')
node = templates.replace(template, foo=source)[0] node = templates.replace(template, foo=source)[0]
result, _, _ = compiler.ast_to_object(node) result, _, _ = compiler.ast_to_object(node)
self.assertEquals(15, result.test_fn()) self.assertEqual(15, result.test_fn())
def test_replace_name_with_dict(self): def test_replace_name_with_dict(self):
template = """ template = """
@ -211,7 +249,7 @@ class TemplatesTest(test.TestCase):
source = parser.parse_expression('{\'bar\': 3}') source = parser.parse_expression('{\'bar\': 3}')
node = templates.replace(template, foo=source)[0] node = templates.replace(template, foo=source)[0]
result, _, _ = compiler.ast_to_object(node) result, _, _ = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn()) self.assertEqual(3, result.test_fn())
def test_replace_as_expression(self): def test_replace_as_expression(self):
template = """ template = """
@ -258,6 +296,31 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param) self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param)
self.assertIsInstance(lambda_arg.body.ctx, gast.Load) self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
def test_replace_name_with_subscript(self):
template = """
foo = bar
"""
replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key'))
node = templates.replace(template, foo=replacement)[0].targets[0]
self.assertIsInstance(node.ctx, gast.Store)
self.assertIsInstance(node.value.ctx, gast.Load)
@parameterized.named_parameters([
('mixed_attr_subscript', 'a.b["c"]'),
('mixed_subscript_attr', 'a[b.c]'),
('nested_subscript', 'a[b[c]]'),
('repeated_subscript', 'a[b][c]'),
])
def test_replace_name_mixed_attr_subscript(self, expression_source):
template = 'foo = bar'
replacement = _parse_with_unset_ctx(expression_source)
target_node = templates.replace(template, foo=replacement)[0].targets[0]
self.assertExpectedCtxSet(target_node, gast.Store)
value_node = templates.replace(template, bar=replacement)[0].value
self.assertExpectedCtxSet(value_node, gast.Load)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()