Fix context setting in pyct.templates.replace for subscript expressions
PiperOrigin-RevId: 241952938
This commit is contained in:
parent
87ea41d023
commit
cae849ba7f
@ -147,6 +147,7 @@ py_test(
|
||||
deps = [
|
||||
":pyct",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"@gast_archive//:gast",
|
||||
],
|
||||
)
|
||||
|
@ -87,8 +87,9 @@ class ContextAdjuster(gast.NodeTransformer):
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
self._apply_override(node)
|
||||
self._ctx_override = gast.Load
|
||||
node.value = self.visit(node.value)
|
||||
self._ctx_override = None
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node):
|
||||
@ -214,9 +215,10 @@ class ReplaceTransformer(gast.NodeTransformer):
|
||||
|
||||
def _convert_to_ast(n):
|
||||
"""Converts from a known data type to AST."""
|
||||
# Note: When generating AST nodes from strings/QNs in isolation, ctx is
|
||||
# unknown. ctx must be filled in according to the template being used.
|
||||
# See ReplaceTransformer.visit_Name.
|
||||
if isinstance(n, str):
|
||||
# Note: the node will receive the ctx value from the template, see
|
||||
# ReplaceTransformer.visit_Name.
|
||||
return gast.Name(id=n, ctx=None, annotation=None)
|
||||
if isinstance(n, qual_names.QN):
|
||||
return n.ast()
|
||||
|
@ -20,15 +20,53 @@ from __future__ import print_function
|
||||
|
||||
import imp
|
||||
|
||||
from absl.testing import parameterized
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import compiler
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names as qn
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TemplatesTest(test.TestCase):
|
||||
class _CtxClearer(gast.NodeTransformer):
|
||||
|
||||
def visit(self, node):
|
||||
super(_CtxClearer, self).visit(node)
|
||||
if hasattr(node, 'ctx'):
|
||||
node.ctx = None
|
||||
return node
|
||||
|
||||
|
||||
def _parse_with_unset_ctx(expr_source):
|
||||
ast_node = parser.parse_expression(expr_source)
|
||||
_CtxClearer().visit(ast_node)
|
||||
return ast_node
|
||||
|
||||
|
||||
class _CtxChecker(gast.NodeTransformer):
|
||||
|
||||
def __init__(self, test_instance, expected_ctx):
|
||||
self.at_top_level = True
|
||||
self.test_instance = test_instance
|
||||
self.expected_ctx = expected_ctx
|
||||
|
||||
def visit(self, node):
|
||||
if hasattr(node, 'ctx'):
|
||||
self.test_instance.assertIsInstance(node.ctx, self.expected_ctx)
|
||||
if self.at_top_level:
|
||||
self.at_top_level = False
|
||||
self.expected_ctx = gast.Load
|
||||
return super(_CtxChecker, self).visit(node)
|
||||
|
||||
|
||||
class TemplatesTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def assertExpectedCtxSet(self, node, ctx):
|
||||
"""Assert that node has ctx=ctx at top and ctx=gast.Load everywhere else."""
|
||||
checker = _CtxChecker(self, ctx)
|
||||
checker.visit(node)
|
||||
|
||||
def test_replace_tuple(self):
|
||||
template = """
|
||||
@ -39,7 +77,7 @@ class TemplatesTest(test.TestCase):
|
||||
node = templates.replace(template, b=('a', 'c'))[0]
|
||||
result, _, _ = compiler.ast_to_object(node)
|
||||
|
||||
self.assertEquals((2, 3), result.test_fn(2, 3))
|
||||
self.assertEqual((2, 3), result.test_fn(2, 3))
|
||||
|
||||
def test_replace_variable(self):
|
||||
template = """
|
||||
@ -51,7 +89,7 @@ class TemplatesTest(test.TestCase):
|
||||
|
||||
node = templates.replace(template, a='b')[0]
|
||||
result, _, _ = compiler.ast_to_object(node)
|
||||
self.assertEquals(7, result.test_fn(2))
|
||||
self.assertEqual(7, result.test_fn(2))
|
||||
|
||||
def test_replace_function_name(self):
|
||||
template = """
|
||||
@ -63,7 +101,7 @@ class TemplatesTest(test.TestCase):
|
||||
|
||||
node = templates.replace(template, fname='test_fn')[0]
|
||||
result, _, _ = compiler.ast_to_object(node)
|
||||
self.assertEquals(7, result.test_fn(2))
|
||||
self.assertEqual(7, result.test_fn(2))
|
||||
|
||||
def test_replace_code_block(self):
|
||||
template = """
|
||||
@ -80,7 +118,7 @@ class TemplatesTest(test.TestCase):
|
||||
], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
|
||||
] * 2)[0]
|
||||
result, _, _ = compiler.ast_to_object(node)
|
||||
self.assertEquals(3, result.test_fn(1))
|
||||
self.assertEqual(3, result.test_fn(1))
|
||||
|
||||
def test_replace_attribute(self):
|
||||
template = """
|
||||
@ -92,7 +130,7 @@ class TemplatesTest(test.TestCase):
|
||||
result, _, _ = compiler.ast_to_object(node)
|
||||
mod = imp.new_module('test')
|
||||
mod.b = 3
|
||||
self.assertEquals(3, result.test_fn(mod))
|
||||
self.assertEqual(3, result.test_fn(mod))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
templates.replace(template, foo=1)
|
||||
@ -180,7 +218,7 @@ class TemplatesTest(test.TestCase):
|
||||
source = parser.parse_expression('f(d=3, f=5)')
|
||||
node = templates.replace(template, kws=source.keywords)[0]
|
||||
result, _, _ = compiler.ast_to_object(node)
|
||||
self.assertEquals(9, result.test_fn())
|
||||
self.assertEqual(9, result.test_fn())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
templates.replace(template, kws=[])
|
||||
@ -200,7 +238,7 @@ class TemplatesTest(test.TestCase):
|
||||
source = parser.parse_expression('f()(b)')
|
||||
node = templates.replace(template, foo=source)[0]
|
||||
result, _, _ = compiler.ast_to_object(node)
|
||||
self.assertEquals(15, result.test_fn())
|
||||
self.assertEqual(15, result.test_fn())
|
||||
|
||||
def test_replace_name_with_dict(self):
|
||||
template = """
|
||||
@ -211,7 +249,7 @@ class TemplatesTest(test.TestCase):
|
||||
source = parser.parse_expression('{\'bar\': 3}')
|
||||
node = templates.replace(template, foo=source)[0]
|
||||
result, _, _ = compiler.ast_to_object(node)
|
||||
self.assertEquals(3, result.test_fn())
|
||||
self.assertEqual(3, result.test_fn())
|
||||
|
||||
def test_replace_as_expression(self):
|
||||
template = """
|
||||
@ -258,6 +296,31 @@ class TemplatesTest(test.TestCase):
|
||||
self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param)
|
||||
self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
|
||||
|
||||
def test_replace_name_with_subscript(self):
|
||||
template = """
|
||||
foo = bar
|
||||
"""
|
||||
replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key'))
|
||||
|
||||
node = templates.replace(template, foo=replacement)[0].targets[0]
|
||||
self.assertIsInstance(node.ctx, gast.Store)
|
||||
self.assertIsInstance(node.value.ctx, gast.Load)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
('mixed_attr_subscript', 'a.b["c"]'),
|
||||
('mixed_subscript_attr', 'a[b.c]'),
|
||||
('nested_subscript', 'a[b[c]]'),
|
||||
('repeated_subscript', 'a[b][c]'),
|
||||
])
|
||||
def test_replace_name_mixed_attr_subscript(self, expression_source):
|
||||
template = 'foo = bar'
|
||||
replacement = _parse_with_unset_ctx(expression_source)
|
||||
|
||||
target_node = templates.replace(template, foo=replacement)[0].targets[0]
|
||||
self.assertExpectedCtxSet(target_node, gast.Store)
|
||||
|
||||
value_node = templates.replace(template, bar=replacement)[0].value
|
||||
self.assertExpectedCtxSet(value_node, gast.Load)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user