Set the origin info more accurately on certain lines generated during control flow translation. This ensures the stack trace is correctly translated when errors occur in the generated code for the respective statements.
PiperOrigin-RevId: 326319490 Change-Id: I1e421b1718eab537fddcc175d43e7565e00a0405
This commit is contained in:
parent
cdc891452a
commit
9ec80ae8bd
@ -24,6 +24,7 @@ from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
@ -243,7 +244,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
(symbol_names,),
|
||||
nouts)
|
||||
"""
|
||||
return templates.replace(
|
||||
new_nodes = templates.replace(
|
||||
template,
|
||||
body=node.body,
|
||||
body_name=self.ctx.namer.new_symbol('if_body', reserved),
|
||||
@ -257,6 +258,8 @@ class ControlFlowTransformer(converter.Base):
|
||||
symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars),
|
||||
test=node.test,
|
||||
undefined_assigns=undefined_assigns)
|
||||
origin_info.copy_origin(node, new_nodes[-1])
|
||||
return new_nodes
|
||||
|
||||
def visit_While(self, node):
|
||||
node = self.generic_visit(node)
|
||||
@ -292,7 +295,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
(symbol_names,),
|
||||
opts)
|
||||
"""
|
||||
return templates.replace(
|
||||
new_nodes = templates.replace(
|
||||
template,
|
||||
body=node.body,
|
||||
body_name=self.ctx.namer.new_symbol('loop_body', reserved),
|
||||
@ -305,6 +308,8 @@ class ControlFlowTransformer(converter.Base):
|
||||
test=node.test,
|
||||
test_name=self.ctx.namer.new_symbol('loop_test', reserved),
|
||||
undefined_assigns=undefined_assigns)
|
||||
origin_info.copy_origin(node, new_nodes[-1])
|
||||
return new_nodes
|
||||
|
||||
def visit_For(self, node):
|
||||
node = self.generic_visit(node)
|
||||
@ -356,6 +361,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
"""
|
||||
iterate_expansion = templates.replace(
|
||||
template, iterate_arg_name=iterate_arg_name, iterates=node.target)
|
||||
origin_info.copy_origin(node, iterate_expansion)
|
||||
|
||||
template = """
|
||||
state_functions
|
||||
@ -374,7 +380,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
(symbol_names,),
|
||||
opts)
|
||||
"""
|
||||
return templates.replace(
|
||||
new_nodes = templates.replace(
|
||||
template,
|
||||
body=node.body,
|
||||
body_name=self.ctx.namer.new_symbol('loop_body', reserved),
|
||||
@ -390,6 +396,8 @@ class ControlFlowTransformer(converter.Base):
|
||||
state_getter_name=state_getter_name,
|
||||
state_setter_name=state_setter_name,
|
||||
undefined_assigns=undefined_assigns)
|
||||
origin_info.copy_origin(node, new_nodes[-1])
|
||||
return new_nodes
|
||||
|
||||
|
||||
class AnnotatedDef(reaching_definitions.Definition):
|
||||
|
@ -279,3 +279,15 @@ def resolve_entity(node, source, entity):
|
||||
col_offset = len(definition_line) - len(definition_line.lstrip())
|
||||
|
||||
resolve(node, source, filepath, lineno, col_offset)
|
||||
|
||||
|
||||
def copy_origin(from_node, to_node):
|
||||
"""Copies the origin info from a node to another, recursively."""
|
||||
origin = anno.Basic.ORIGIN.of(from_node, default=None)
|
||||
if origin is None:
|
||||
return
|
||||
if not isinstance(to_node, (list, tuple)):
|
||||
to_node = (to_node,)
|
||||
for node in to_node:
|
||||
for n in gast.walk(node):
|
||||
anno.setanno(n, anno.Basic.ORIGIN, origin)
|
||||
|
Loading…
x
Reference in New Issue
Block a user