Exclude frames corresponding to function calls found in the main autograph module (api.py) from the stack trace attached to errors that occur in converted code.
### Example Given the code: ``` d = {} @tf.function def f(d): d.pop('a') f(d) ``` Error message before: ``` KeyError: in converted code: <file>:<line> f * d.pop('a') <internal TF file>:<line> converted_call return py_builtins.overload_of(f)(*args) KeyError: 'a' ``` Error message after: ``` KeyError: in converted code: <file>:<line> f ** d.pop('a') KeyError: 'a' ``` PiperOrigin-RevId: 285178695 Change-Id: Ic6c3d4c57b5478dd560b17c7708f5e502f62a41b
This commit is contained in:
parent
b963d6a436
commit
61f2288b4e
tensorflow/python
autograph
keras/engine
@ -38,6 +38,10 @@ Among the distinctive features of the re-raised exception:
|
||||
the `@tf.function`
|
||||
* the references corresponding to converted code are marked with an
|
||||
asterisk (`*`)
|
||||
* the references corresponding to code which AutoGraph reached, but decided not
|
||||
to convert, are marked with a double asterisk (`**`)
|
||||
* the references corresponding to code that AutoGraph didn't reach at all have
|
||||
no marking
|
||||
|
||||
For example, the code below triggers an exception in the Python runtime, at
|
||||
graph construction time:
|
||||
@ -62,7 +66,7 @@ TypeError: in converted code:
|
||||
|
||||
<ipython-input-9-002fa22f79df>:8 f *
|
||||
tf.constant(1) + tf.constant(1.0)
|
||||
tensorflow/python/ops/math_ops.py:900 binary_op_wrapper
|
||||
tensorflow/python/ops/math_ops.py:900 binary_op_wrapper **
|
||||
return func(x, y, name=name)
|
||||
... more TensorFlow internal frames ...
|
||||
|
||||
|
@ -312,23 +312,24 @@ def do_not_convert(func=None):
|
||||
return autograph_artifact(wrapper)
|
||||
|
||||
|
||||
def _attach_metadata(e, f, converted):
|
||||
def _attach_metadata(e, f):
|
||||
"""Augments an error with the metadata necessary for rewrite."""
|
||||
if hasattr(e, 'ag_pass_through'):
|
||||
return
|
||||
|
||||
metadata = getattr(e, 'ag_error_metadata', None)
|
||||
source_map = f.ag_source_map if converted else {}
|
||||
source_map = f.ag_source_map
|
||||
|
||||
if metadata is None:
|
||||
logging.log(
|
||||
1, 'Caught error in %s (converted=%s)', f, converted, exc_info=True)
|
||||
logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
|
||||
message = '{}: {}'.format(e.__class__.__name__, e)
|
||||
else:
|
||||
message = None
|
||||
|
||||
cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
|
||||
e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map)
|
||||
|
||||
e.ag_error_metadata = _ErrorMetadata(
|
||||
cause_tb, metadata, message, source_map, __file__)
|
||||
|
||||
|
||||
def _call_unconverted(f, args, kwargs, options, update_cache=True):
|
||||
@ -339,14 +340,10 @@ def _call_unconverted(f, args, kwargs, options, update_cache=True):
|
||||
if inspect_utils.istfmethodtarget(f):
|
||||
return f.__self__.call(args, kwargs)
|
||||
|
||||
try:
|
||||
if kwargs is not None:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args)
|
||||
except Exception as e: # pylint:disable=broad-except
|
||||
_attach_metadata(e, f, False)
|
||||
raise
|
||||
if kwargs is not None:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args)
|
||||
|
||||
|
||||
def _is_known_loaded_type(f, module_name, entity_name):
|
||||
@ -584,7 +581,7 @@ def converted_call(f,
|
||||
else:
|
||||
result = converted_f(*effective_args)
|
||||
except Exception as e:
|
||||
_attach_metadata(e, converted_f, True)
|
||||
_attach_metadata(e, converted_f)
|
||||
raise
|
||||
|
||||
return result
|
||||
|
@ -26,11 +26,13 @@ from tensorflow.python.autograph.pyct import origin_info
|
||||
class FrameInfo(
|
||||
collections.namedtuple(
|
||||
'FrameInfo',
|
||||
('filename', 'lineno', 'function_name', 'code', 'converted'))):
|
||||
pass
|
||||
('filename', 'lineno', 'function_name', 'code', 'is_converted',
|
||||
'is_whitelisted'))):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
def _stack_trace_inside_mapped_code(tb, source_map):
|
||||
def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
|
||||
"""Summarizes inner traceback frames up to the call to a given function.
|
||||
|
||||
This functions locates the innermost (i.e. most recent) frame that corresponds
|
||||
@ -67,10 +69,14 @@ def _stack_trace_inside_mapped_code(tb, source_map):
|
||||
raise ...
|
||||
|
||||
Args:
|
||||
tb: List[Tuple], the traceback corresponding to an error; typically,
|
||||
the output of traceback.extract_tb.
|
||||
tb: traceback.FrameSummary, The traceback corresponding to an error.
|
||||
Typically, the output of traceback.Summary.extract(capture_locals=True).
|
||||
source_map: Dict[LineLocation, OriginInfo], a source map as created by
|
||||
origin_info.create_source_map.
|
||||
converter_filename: str, the file path of the converted module. Call frames
|
||||
corresponding to this module are elided and their preceding frames are
|
||||
marked as whitelisted. Note that frames enclosing converted code are
|
||||
dropped using a different mechanism.
|
||||
|
||||
Returns:
|
||||
List[FrameInfo]
|
||||
@ -81,21 +87,37 @@ def _stack_trace_inside_mapped_code(tb, source_map):
|
||||
loc = origin_info.LineLocation(filename=filename, lineno=line_number)
|
||||
if loc in source_map:
|
||||
origin = source_map[loc]
|
||||
origin_frame_info = FrameInfo(
|
||||
fi = FrameInfo(
|
||||
filename=origin.loc.filename,
|
||||
lineno=origin.loc.lineno,
|
||||
function_name=origin.function_name,
|
||||
code=origin.source_code_line,
|
||||
converted=True)
|
||||
result_frames.append(origin_frame_info)
|
||||
is_converted=True,
|
||||
is_whitelisted=False)
|
||||
result_frames.append(fi)
|
||||
break
|
||||
|
||||
if filename == converter_filename:
|
||||
if result_frames:
|
||||
prev = result_frames[-1]
|
||||
assert not prev.is_converted # See the if above.
|
||||
fi = FrameInfo(
|
||||
filename=prev.filename,
|
||||
lineno=prev.lineno,
|
||||
function_name=prev.function_name,
|
||||
code=prev.code,
|
||||
is_converted=False,
|
||||
is_whitelisted=True)
|
||||
result_frames[-1] = fi
|
||||
continue
|
||||
|
||||
fi = FrameInfo(
|
||||
filename=filename,
|
||||
lineno=line_number,
|
||||
function_name=function_name,
|
||||
code=text,
|
||||
converted=False)
|
||||
is_converted=False,
|
||||
is_whitelisted=False)
|
||||
result_frames.append(fi)
|
||||
|
||||
return tuple(result_frames)
|
||||
@ -136,8 +158,12 @@ class ErrorMetadataBase(object):
|
||||
code from which the generated code originated.
|
||||
"""
|
||||
|
||||
def __init__(self, callsite_tb, cause_metadata, cause_message, source_map):
|
||||
translated_stack = _stack_trace_inside_mapped_code(callsite_tb, source_map)
|
||||
__slots__ = ('translated_stack', 'cause_message')
|
||||
|
||||
def __init__(self, callsite_tb, cause_metadata, cause_message, source_map,
|
||||
converter_filename):
|
||||
translated_stack = _stack_trace_inside_mapped_code(
|
||||
callsite_tb, source_map, converter_filename)
|
||||
|
||||
if cause_metadata is None:
|
||||
self.translated_stack = translated_stack
|
||||
@ -156,12 +182,15 @@ class ErrorMetadataBase(object):
|
||||
lines.append('')
|
||||
|
||||
for frame_info in reversed(self.translated_stack):
|
||||
lines.append(' {}:{} {}{}'.format(
|
||||
frame_info.filename,
|
||||
frame_info.lineno,
|
||||
frame_info.function_name,
|
||||
' *' if frame_info.converted else '',
|
||||
))
|
||||
formatted_line = ' {}:{} {}'.format(frame_info.filename,
|
||||
frame_info.lineno,
|
||||
frame_info.function_name)
|
||||
if frame_info.is_converted:
|
||||
formatted_line += ' *'
|
||||
elif frame_info.is_whitelisted:
|
||||
formatted_line += ' **'
|
||||
lines.append(formatted_line)
|
||||
|
||||
if frame_info.code is None:
|
||||
code_snippet = '<source unavailable>'
|
||||
else:
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import re
|
||||
|
||||
from tensorflow.python.autograph.pyct import error_utils
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -35,7 +36,8 @@ class ErrorMetadataBaseTest(test.TestCase):
|
||||
callsite_tb=(),
|
||||
cause_metadata=None,
|
||||
cause_message='test message',
|
||||
source_map={})
|
||||
source_map={},
|
||||
converter_filename=None)
|
||||
exc = em.create_exception(CustomError())
|
||||
self.assertIsInstance(exc, CustomError)
|
||||
self.assertIn('test message', str(exc))
|
||||
@ -51,11 +53,12 @@ class ErrorMetadataBaseTest(test.TestCase):
|
||||
callsite_tb=(),
|
||||
cause_metadata=None,
|
||||
cause_message='test message',
|
||||
source_map={})
|
||||
source_map={},
|
||||
converter_filename=None)
|
||||
exc = em.create_exception(CustomError())
|
||||
self.assertIsNone(exc)
|
||||
|
||||
def test_get_message_when_frame_info_code_is_none(self):
|
||||
def test_get_message_no_code(self):
|
||||
callsite_tb = [
|
||||
('/path/one.py', 11, 'test_fn_1', None),
|
||||
('/path/two.py', 171, 'test_fn_2', 'test code'),
|
||||
@ -65,11 +68,57 @@ class ErrorMetadataBaseTest(test.TestCase):
|
||||
callsite_tb=callsite_tb,
|
||||
cause_metadata=None,
|
||||
cause_message=cause_message,
|
||||
source_map={})
|
||||
source_map={},
|
||||
converter_filename=None)
|
||||
self.assertRegex(
|
||||
em.get_message(),
|
||||
re.compile('test_fn_1.*test_fn_2.*Test message', re.DOTALL))
|
||||
|
||||
def test_get_message_converted_code(self):
|
||||
callsite_tb = [
|
||||
('/path/one.py', 11, 'test_fn_1', 'test code 1'),
|
||||
('/path/two.py', 171, 'test_fn_2', 'test code 2'),
|
||||
('/path/three.py', 171, 'test_fn_3', 'test code 3'),
|
||||
]
|
||||
cause_message = 'Test message'
|
||||
em = error_utils.ErrorMetadataBase(
|
||||
callsite_tb=callsite_tb,
|
||||
cause_metadata=None,
|
||||
cause_message=cause_message,
|
||||
source_map={
|
||||
origin_info.LineLocation(filename='/path/two.py', lineno=171):
|
||||
origin_info.OriginInfo(
|
||||
loc=origin_info.LineLocation(
|
||||
filename='/path/other_two.py', lineno=13),
|
||||
function_name='converted_fn',
|
||||
source_code_line='converted test code',
|
||||
comment=None)
|
||||
},
|
||||
converter_filename=None)
|
||||
result = em.get_message()
|
||||
self.assertRegex(
|
||||
result,
|
||||
re.compile(r'converted_fn \*.*test_fn_3.*Test message', re.DOTALL))
|
||||
self.assertNotRegex(result, re.compile('test_fn_1'))
|
||||
|
||||
def test_get_message_call_overload(self):
|
||||
|
||||
callsite_tb = [
|
||||
('/path/one.py', 11, 'test_fn_1', 'test code 1'),
|
||||
('/path/two.py', 0, 'test_fn_2', 'test code 2'),
|
||||
('/path/three.py', 0, 'test_fn_3', 'test code 3'),
|
||||
]
|
||||
cause_message = 'Test message'
|
||||
em = error_utils.ErrorMetadataBase(
|
||||
callsite_tb=callsite_tb,
|
||||
cause_metadata=None,
|
||||
cause_message=cause_message,
|
||||
source_map={},
|
||||
converter_filename='/path/two.py')
|
||||
self.assertRegex(
|
||||
em.get_message(),
|
||||
re.compile(r'test_fn_1.*test_fn_3 \*\*.*Test message', re.DOTALL))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -694,8 +694,7 @@ class SymbolicSupportTest(test.TestCase):
|
||||
if hasattr(e, 'ag_error_metadata'):
|
||||
self.assertIn('easily_identifiable_name', str(e))
|
||||
# See ErrorMetadataBase in autograph/pyct/errors.py
|
||||
# Topmost frame corresponds to `call` itself.
|
||||
function_name = e.ag_error_metadata.translated_stack[-2].function_name
|
||||
function_name = e.ag_error_metadata.translated_stack[-1].function_name
|
||||
else:
|
||||
tb = traceback.extract_tb(sys.exc_info()[2])
|
||||
last_entry = tb[-1]
|
||||
|
Loading…
Reference in New Issue
Block a user