Set the origin info more accurately on certain lines generated during control flow translation. This ensures the stack trace is correctly translated when errors occur in the generated code for the respective statements.

PiperOrigin-RevId: 326319490
Change-Id: I1e421b1718eab537fddcc175d43e7565e00a0405
This commit is contained in:
Dan Moldovan 2020-08-12 14:58:17 -07:00 committed by TensorFlower Gardener
parent cdc891452a
commit 9ec80ae8bd
2 changed files with 23 additions and 3 deletions

View File

@ -24,6 +24,7 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
@ -243,7 +244,7 @@ class ControlFlowTransformer(converter.Base):
(symbol_names,),
nouts)
"""
return templates.replace(
new_nodes = templates.replace(
template,
body=node.body,
body_name=self.ctx.namer.new_symbol('if_body', reserved),
@ -257,6 +258,8 @@ class ControlFlowTransformer(converter.Base):
symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars),
test=node.test,
undefined_assigns=undefined_assigns)
origin_info.copy_origin(node, new_nodes[-1])
return new_nodes
def visit_While(self, node):
node = self.generic_visit(node)
@ -292,7 +295,7 @@ class ControlFlowTransformer(converter.Base):
(symbol_names,),
opts)
"""
return templates.replace(
new_nodes = templates.replace(
template,
body=node.body,
body_name=self.ctx.namer.new_symbol('loop_body', reserved),
@ -305,6 +308,8 @@ class ControlFlowTransformer(converter.Base):
test=node.test,
test_name=self.ctx.namer.new_symbol('loop_test', reserved),
undefined_assigns=undefined_assigns)
origin_info.copy_origin(node, new_nodes[-1])
return new_nodes
def visit_For(self, node):
node = self.generic_visit(node)
@ -356,6 +361,7 @@ class ControlFlowTransformer(converter.Base):
"""
iterate_expansion = templates.replace(
template, iterate_arg_name=iterate_arg_name, iterates=node.target)
origin_info.copy_origin(node, iterate_expansion)
template = """
state_functions
@ -374,7 +380,7 @@ class ControlFlowTransformer(converter.Base):
(symbol_names,),
opts)
"""
return templates.replace(
new_nodes = templates.replace(
template,
body=node.body,
body_name=self.ctx.namer.new_symbol('loop_body', reserved),
@ -390,6 +396,8 @@ class ControlFlowTransformer(converter.Base):
state_getter_name=state_getter_name,
state_setter_name=state_setter_name,
undefined_assigns=undefined_assigns)
origin_info.copy_origin(node, new_nodes[-1])
return new_nodes
class AnnotatedDef(reaching_definitions.Definition):

View File

@ -279,3 +279,15 @@ def resolve_entity(node, source, entity):
col_offset = len(definition_line) - len(definition_line.lstrip())
resolve(node, source, filepath, lineno, col_offset)
def copy_origin(from_node, to_node):
"""Copies the origin info from a node to another, recursively."""
origin = anno.Basic.ORIGIN.of(from_node, default=None)
if origin is None:
return
if not isinstance(to_node, (list, tuple)):
to_node = (to_node,)
for node in to_node:
for n in gast.walk(node):
anno.setanno(n, anno.Basic.ORIGIN, origin)