diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index b54770cbd28..c3fc879ded5 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -24,6 +24,7 @@ from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import cfg +from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import templates @@ -243,7 +244,7 @@ class ControlFlowTransformer(converter.Base): (symbol_names,), nouts) """ - return templates.replace( + new_nodes = templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('if_body', reserved), @@ -257,6 +258,8 @@ class ControlFlowTransformer(converter.Base): symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars), test=node.test, undefined_assigns=undefined_assigns) + origin_info.copy_origin(node, new_nodes[-1]) + return new_nodes def visit_While(self, node): node = self.generic_visit(node) @@ -292,7 +295,7 @@ class ControlFlowTransformer(converter.Base): (symbol_names,), opts) """ - return templates.replace( + new_nodes = templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved), @@ -305,6 +308,8 @@ class ControlFlowTransformer(converter.Base): test=node.test, test_name=self.ctx.namer.new_symbol('loop_test', reserved), undefined_assigns=undefined_assigns) + origin_info.copy_origin(node, new_nodes[-1]) + return new_nodes def visit_For(self, node): node = self.generic_visit(node) @@ -356,6 +361,7 @@ class ControlFlowTransformer(converter.Base): """ iterate_expansion = templates.replace( template, iterate_arg_name=iterate_arg_name, iterates=node.target) + origin_info.copy_origin(node, iterate_expansion) template = """ state_functions @@ -374,7 +380,7 @@ class ControlFlowTransformer(converter.Base): (symbol_names,), opts) """ - return templates.replace( + new_nodes = templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved), @@ -390,6 +396,8 @@ class ControlFlowTransformer(converter.Base): state_getter_name=state_getter_name, state_setter_name=state_setter_name, undefined_assigns=undefined_assigns) + origin_info.copy_origin(node, new_nodes[-1]) + return new_nodes class AnnotatedDef(reaching_definitions.Definition): diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py index ba25d96d2d6..cd909e16364 100644 --- a/tensorflow/python/autograph/pyct/origin_info.py +++ b/tensorflow/python/autograph/pyct/origin_info.py @@ -279,3 +279,15 @@ def resolve_entity(node, source, entity): col_offset = len(definition_line) - len(definition_line.lstrip()) resolve(node, source, filepath, lineno, col_offset) + + +def copy_origin(from_node, to_node): + """Copies the origin info from a node to another, recursively.""" + origin = anno.Basic.ORIGIN.of(from_node, default=None) + if origin is None: + return + if not isinstance(to_node, (list, tuple)): + to_node = (to_node,) + for node in to_node: + for n in gast.walk(node): + anno.setanno(n, anno.Basic.ORIGIN, origin)