diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index 67ea42aa051..9f4b9f39791 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -147,6 +147,7 @@ py_test( deps = [ ":pyct", "//tensorflow/python:client_testlib", + "@absl_py//absl/testing:parameterized", "@gast_archive//:gast", ], ) diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py index 66393e929e0..3e705dd48a1 100644 --- a/tensorflow/python/autograph/pyct/templates.py +++ b/tensorflow/python/autograph/pyct/templates.py @@ -87,8 +87,9 @@ class ContextAdjuster(gast.NodeTransformer): return self.generic_visit(node) def visit_Subscript(self, node): + self._apply_override(node) + self._ctx_override = gast.Load node.value = self.visit(node.value) - self._ctx_override = None return self.generic_visit(node) def visit_comprehension(self, node): @@ -214,9 +215,10 @@ class ReplaceTransformer(gast.NodeTransformer): def _convert_to_ast(n): """Converts from a known data type to AST.""" + # Note: When generating AST nodes from strings/QNs in isolation, ctx is + # unknown. ctx must be filled in according to the template being used. + # See ReplaceTransformer.visit_Name. if isinstance(n, str): - # Note: the node will receive the ctx value from the template, see - # ReplaceTransformer.visit_Name. return gast.Name(id=n, ctx=None, annotation=None) if isinstance(n, qual_names.QN): return n.ast() diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py index 352067de485..5ed10d9c937 100644 --- a/tensorflow/python/autograph/pyct/templates_test.py +++ b/tensorflow/python/autograph/pyct/templates_test.py @@ -20,15 +20,53 @@ from __future__ import print_function import imp +from absl.testing import parameterized import gast from tensorflow.python.autograph.pyct import compiler from tensorflow.python.autograph.pyct import parser +from tensorflow.python.autograph.pyct import qual_names as qn from tensorflow.python.autograph.pyct import templates from tensorflow.python.platform import test -class TemplatesTest(test.TestCase): +class _CtxClearer(gast.NodeTransformer): + + def visit(self, node): + super(_CtxClearer, self).visit(node) + if hasattr(node, 'ctx'): + node.ctx = None + return node + + +def _parse_with_unset_ctx(expr_source): + ast_node = parser.parse_expression(expr_source) + _CtxClearer().visit(ast_node) + return ast_node + + +class _CtxChecker(gast.NodeTransformer): + + def __init__(self, test_instance, expected_ctx): + self.at_top_level = True + self.test_instance = test_instance + self.expected_ctx = expected_ctx + + def visit(self, node): + if hasattr(node, 'ctx'): + self.test_instance.assertIsInstance(node.ctx, self.expected_ctx) + if self.at_top_level: + self.at_top_level = False + self.expected_ctx = gast.Load + return super(_CtxChecker, self).visit(node) + + +class TemplatesTest(test.TestCase, parameterized.TestCase): + + def assertExpectedCtxSet(self, node, ctx): + """Assert that node has ctx=ctx at top and ctx=gast.Load everywhere else.""" + checker = _CtxChecker(self, ctx) + checker.visit(node) def test_replace_tuple(self): template = """ @@ -39,7 +77,7 @@ class TemplatesTest(test.TestCase): node = templates.replace(template, b=('a', 'c'))[0] result, _, _ = compiler.ast_to_object(node) - self.assertEquals((2, 3), result.test_fn(2, 3)) + self.assertEqual((2, 3), result.test_fn(2, 3)) def test_replace_variable(self): template = """ @@ -51,7 +89,7 @@ class TemplatesTest(test.TestCase): node = templates.replace(template, a='b')[0] result, _, _ = compiler.ast_to_object(node) - self.assertEquals(7, result.test_fn(2)) + self.assertEqual(7, result.test_fn(2)) def test_replace_function_name(self): template = """ @@ -63,7 +101,7 @@ class TemplatesTest(test.TestCase): node = templates.replace(template, fname='test_fn')[0] result, _, _ = compiler.ast_to_object(node) - self.assertEquals(7, result.test_fn(2)) + self.assertEqual(7, result.test_fn(2)) def test_replace_code_block(self): template = """ @@ -80,7 +118,7 @@ class TemplatesTest(test.TestCase): ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), ] * 2)[0] result, _, _ = compiler.ast_to_object(node) - self.assertEquals(3, result.test_fn(1)) + self.assertEqual(3, result.test_fn(1)) def test_replace_attribute(self): template = """ @@ -92,7 +130,7 @@ class TemplatesTest(test.TestCase): result, _, _ = compiler.ast_to_object(node) mod = imp.new_module('test') mod.b = 3 - self.assertEquals(3, result.test_fn(mod)) + self.assertEqual(3, result.test_fn(mod)) with self.assertRaises(ValueError): templates.replace(template, foo=1) @@ -180,7 +218,7 @@ class TemplatesTest(test.TestCase): source = parser.parse_expression('f(d=3, f=5)') node = templates.replace(template, kws=source.keywords)[0] result, _, _ = compiler.ast_to_object(node) - self.assertEquals(9, result.test_fn()) + self.assertEqual(9, result.test_fn()) with self.assertRaises(ValueError): templates.replace(template, kws=[]) @@ -200,7 +238,7 @@ class TemplatesTest(test.TestCase): source = parser.parse_expression('f()(b)') node = templates.replace(template, foo=source)[0] result, _, _ = compiler.ast_to_object(node) - self.assertEquals(15, result.test_fn()) + self.assertEqual(15, result.test_fn()) def test_replace_name_with_dict(self): template = """ @@ -211,7 +249,7 @@ class TemplatesTest(test.TestCase): source = parser.parse_expression('{\'bar\': 3}') node = templates.replace(template, foo=source)[0] result, _, _ = compiler.ast_to_object(node) - self.assertEquals(3, result.test_fn()) + self.assertEqual(3, result.test_fn()) def test_replace_as_expression(self): template = """ @@ -258,6 +296,31 @@ class TemplatesTest(test.TestCase): self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param) self.assertIsInstance(lambda_arg.body.ctx, gast.Load) + def test_replace_name_with_subscript(self): + template = """ + foo = bar + """ + replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key')) + + node = templates.replace(template, foo=replacement)[0].targets[0] + self.assertIsInstance(node.ctx, gast.Store) + self.assertIsInstance(node.value.ctx, gast.Load) + + @parameterized.named_parameters([ + ('mixed_attr_subscript', 'a.b["c"]'), + ('mixed_subscript_attr', 'a[b.c]'), + ('nested_subscript', 'a[b[c]]'), + ('repeated_subscript', 'a[b][c]'), + ]) + def test_replace_name_mixed_attr_subscript(self, expression_source): + template = 'foo = bar' + replacement = _parse_with_unset_ctx(expression_source) + + target_node = templates.replace(template, foo=replacement)[0].targets[0] + self.assertExpectedCtxSet(target_node, gast.Store) + + value_node = templates.replace(template, bar=replacement)[0].value + self.assertExpectedCtxSet(value_node, gast.Load) if __name__ == '__main__': test.main()