From 9ec80ae8bd8e3b1753cda19973977578c7525dbf Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Wed, 12 Aug 2020 14:58:17 -0700 Subject: [PATCH] Set the origin info more accurately on certain lines generated during control flow translation. This ensures the stack trace is correctly translated when errors occur in the generated code for the respective statements. PiperOrigin-RevId: 326319490 Change-Id: I1e421b1718eab537fddcc175d43e7565e00a0405 --- .../python/autograph/converters/control_flow.py | 14 +++++++++++--- tensorflow/python/autograph/pyct/origin_info.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index b54770cbd28..c3fc879ded5 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -24,6 +24,7 @@ from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import cfg +from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import templates @@ -243,7 +244,7 @@ class ControlFlowTransformer(converter.Base): (symbol_names,), nouts) """ - return templates.replace( + new_nodes = templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('if_body', reserved), @@ -257,6 +258,8 @@ class ControlFlowTransformer(converter.Base): symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars), test=node.test, undefined_assigns=undefined_assigns) + origin_info.copy_origin(node, new_nodes[-1]) + return new_nodes def visit_While(self, node): node = self.generic_visit(node) @@ -292,7 +295,7 @@ class ControlFlowTransformer(converter.Base): (symbol_names,), opts) """ - return templates.replace( + new_nodes = templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved), @@ -305,6 +308,8 @@ class ControlFlowTransformer(converter.Base): test=node.test, test_name=self.ctx.namer.new_symbol('loop_test', reserved), undefined_assigns=undefined_assigns) + origin_info.copy_origin(node, new_nodes[-1]) + return new_nodes def visit_For(self, node): node = self.generic_visit(node) @@ -356,6 +361,7 @@ class ControlFlowTransformer(converter.Base): """ iterate_expansion = templates.replace( template, iterate_arg_name=iterate_arg_name, iterates=node.target) + origin_info.copy_origin(node, iterate_expansion) template = """ state_functions @@ -374,7 +380,7 @@ class ControlFlowTransformer(converter.Base): (symbol_names,), opts) """ - return templates.replace( + new_nodes = templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved), @@ -390,6 +396,8 @@ class ControlFlowTransformer(converter.Base): state_getter_name=state_getter_name, state_setter_name=state_setter_name, undefined_assigns=undefined_assigns) + origin_info.copy_origin(node, new_nodes[-1]) + return new_nodes class AnnotatedDef(reaching_definitions.Definition): diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py index ba25d96d2d6..cd909e16364 100644 --- a/tensorflow/python/autograph/pyct/origin_info.py +++ b/tensorflow/python/autograph/pyct/origin_info.py @@ -279,3 +279,15 @@ def resolve_entity(node, source, entity): col_offset = len(definition_line) - len(definition_line.lstrip()) resolve(node, source, filepath, lineno, col_offset) + + +def copy_origin(from_node, to_node): + """Copies the origin info from a node to another, recursively.""" + origin = anno.Basic.ORIGIN.of(from_node, default=None) + if origin is None: + return + if not isinstance(to_node, (list, tuple)): + to_node = (to_node,) + for node in to_node: + for n in gast.walk(node): + anno.setanno(n, anno.Basic.ORIGIN, origin)