Exclude frames corresponding to function calls found in the main autograph module (api.py) from the stack trace attached to errors that occur in converted code.

### Example

Given the code:

```
d = {}

@tf.function
def f(d):
  d.pop('a')

f(d)
```

Error message before:

```
KeyError: in converted code:

    <file>:<line> f  *
        d.pop('a')
    <internal TF file>:<line> converted_call
        return py_builtins.overload_of(f)(*args)

    KeyError: 'a'
```

Error message after:

```
KeyError: in converted code:

    <file>:<line> f  **
        d.pop('a')

    KeyError: 'a'
```

PiperOrigin-RevId: 285178695
Change-Id: Ic6c3d4c57b5478dd560b17c7708f5e502f62a41b
This commit is contained in:
Dan Moldovan 2019-12-12 06:22:48 -08:00 committed by TensorFlower Gardener
parent b963d6a436
commit 61f2288b4e
5 changed files with 116 additions and 38 deletions
tensorflow/python

View File

@ -38,6 +38,10 @@ Among the distinctive features of the re-raised exception:
the `@tf.function`
* the references corresponding to converted code are marked with an
asterisk (`*`)
* the references corresponding to code which AutoGraph reached, but decided not
to convert, are marked with a double asterisk (`**`)
* the references corresponding to code that AutoGraph didn't reach at all have
no marking
For example, the code below triggers an exception in the Python runtime, at
graph construction time:
@ -62,7 +66,7 @@ TypeError: in converted code:
<ipython-input-9-002fa22f79df>:8 f *
tf.constant(1) + tf.constant(1.0)
tensorflow/python/ops/math_ops.py:900 binary_op_wrapper
tensorflow/python/ops/math_ops.py:900 binary_op_wrapper **
return func(x, y, name=name)
... more TensorFlow internal frames ...

View File

@ -312,23 +312,24 @@ def do_not_convert(func=None):
return autograph_artifact(wrapper)
def _attach_metadata(e, f, converted):
def _attach_metadata(e, f):
"""Augments an error with the metadata necessary for rewrite."""
if hasattr(e, 'ag_pass_through'):
return
metadata = getattr(e, 'ag_error_metadata', None)
source_map = f.ag_source_map if converted else {}
source_map = f.ag_source_map
if metadata is None:
logging.log(
1, 'Caught error in %s (converted=%s)', f, converted, exc_info=True)
logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
message = '{}: {}'.format(e.__class__.__name__, e)
else:
message = None
cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map)
e.ag_error_metadata = _ErrorMetadata(
cause_tb, metadata, message, source_map, __file__)
def _call_unconverted(f, args, kwargs, options, update_cache=True):
@ -339,14 +340,10 @@ def _call_unconverted(f, args, kwargs, options, update_cache=True):
if inspect_utils.istfmethodtarget(f):
return f.__self__.call(args, kwargs)
try:
if kwargs is not None:
return f(*args, **kwargs)
else:
return f(*args)
except Exception as e: # pylint:disable=broad-except
_attach_metadata(e, f, False)
raise
if kwargs is not None:
return f(*args, **kwargs)
else:
return f(*args)
def _is_known_loaded_type(f, module_name, entity_name):
@ -584,7 +581,7 @@ def converted_call(f,
else:
result = converted_f(*effective_args)
except Exception as e:
_attach_metadata(e, converted_f, True)
_attach_metadata(e, converted_f)
raise
return result

View File

@ -26,11 +26,13 @@ from tensorflow.python.autograph.pyct import origin_info
class FrameInfo(
collections.namedtuple(
'FrameInfo',
('filename', 'lineno', 'function_name', 'code', 'converted'))):
pass
('filename', 'lineno', 'function_name', 'code', 'is_converted',
'is_whitelisted'))):
__slots__ = ()
def _stack_trace_inside_mapped_code(tb, source_map):
def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
"""Summarizes inner traceback frames up to the call to a given function.
This functions locates the innermost (i.e. most recent) frame that corresponds
@ -67,10 +69,14 @@ def _stack_trace_inside_mapped_code(tb, source_map):
raise ...
Args:
tb: List[Tuple], the traceback corresponding to an error; typically,
the output of traceback.extract_tb.
tb: traceback.FrameSummary, The traceback corresponding to an error.
Typically, the output of traceback.Summary.extract(capture_locals=True).
source_map: Dict[LineLocation, OriginInfo], a source map as created by
origin_info.create_source_map.
converter_filename: str, the file path of the converted module. Call frames
corresponding to this module are elided and their preceding frames are
marked as whitelisted. Note that frames enclosing converted code are
dropped using a different mechanism.
Returns:
List[FrameInfo]
@ -81,21 +87,37 @@ def _stack_trace_inside_mapped_code(tb, source_map):
loc = origin_info.LineLocation(filename=filename, lineno=line_number)
if loc in source_map:
origin = source_map[loc]
origin_frame_info = FrameInfo(
fi = FrameInfo(
filename=origin.loc.filename,
lineno=origin.loc.lineno,
function_name=origin.function_name,
code=origin.source_code_line,
converted=True)
result_frames.append(origin_frame_info)
is_converted=True,
is_whitelisted=False)
result_frames.append(fi)
break
if filename == converter_filename:
if result_frames:
prev = result_frames[-1]
assert not prev.is_converted # See the if above.
fi = FrameInfo(
filename=prev.filename,
lineno=prev.lineno,
function_name=prev.function_name,
code=prev.code,
is_converted=False,
is_whitelisted=True)
result_frames[-1] = fi
continue
fi = FrameInfo(
filename=filename,
lineno=line_number,
function_name=function_name,
code=text,
converted=False)
is_converted=False,
is_whitelisted=False)
result_frames.append(fi)
return tuple(result_frames)
@ -136,8 +158,12 @@ class ErrorMetadataBase(object):
code from which the generated code originated.
"""
def __init__(self, callsite_tb, cause_metadata, cause_message, source_map):
translated_stack = _stack_trace_inside_mapped_code(callsite_tb, source_map)
__slots__ = ('translated_stack', 'cause_message')
def __init__(self, callsite_tb, cause_metadata, cause_message, source_map,
converter_filename):
translated_stack = _stack_trace_inside_mapped_code(
callsite_tb, source_map, converter_filename)
if cause_metadata is None:
self.translated_stack = translated_stack
@ -156,12 +182,15 @@ class ErrorMetadataBase(object):
lines.append('')
for frame_info in reversed(self.translated_stack):
lines.append(' {}:{} {}{}'.format(
frame_info.filename,
frame_info.lineno,
frame_info.function_name,
' *' if frame_info.converted else '',
))
formatted_line = ' {}:{} {}'.format(frame_info.filename,
frame_info.lineno,
frame_info.function_name)
if frame_info.is_converted:
formatted_line += ' *'
elif frame_info.is_whitelisted:
formatted_line += ' **'
lines.append(formatted_line)
if frame_info.code is None:
code_snippet = '<source unavailable>'
else:

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import re
from tensorflow.python.autograph.pyct import error_utils
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.platform import test
@ -35,7 +36,8 @@ class ErrorMetadataBaseTest(test.TestCase):
callsite_tb=(),
cause_metadata=None,
cause_message='test message',
source_map={})
source_map={},
converter_filename=None)
exc = em.create_exception(CustomError())
self.assertIsInstance(exc, CustomError)
self.assertIn('test message', str(exc))
@ -51,11 +53,12 @@ class ErrorMetadataBaseTest(test.TestCase):
callsite_tb=(),
cause_metadata=None,
cause_message='test message',
source_map={})
source_map={},
converter_filename=None)
exc = em.create_exception(CustomError())
self.assertIsNone(exc)
def test_get_message_when_frame_info_code_is_none(self):
def test_get_message_no_code(self):
callsite_tb = [
('/path/one.py', 11, 'test_fn_1', None),
('/path/two.py', 171, 'test_fn_2', 'test code'),
@ -65,11 +68,57 @@ class ErrorMetadataBaseTest(test.TestCase):
callsite_tb=callsite_tb,
cause_metadata=None,
cause_message=cause_message,
source_map={})
source_map={},
converter_filename=None)
self.assertRegex(
em.get_message(),
re.compile('test_fn_1.*test_fn_2.*Test message', re.DOTALL))
def test_get_message_converted_code(self):
callsite_tb = [
('/path/one.py', 11, 'test_fn_1', 'test code 1'),
('/path/two.py', 171, 'test_fn_2', 'test code 2'),
('/path/three.py', 171, 'test_fn_3', 'test code 3'),
]
cause_message = 'Test message'
em = error_utils.ErrorMetadataBase(
callsite_tb=callsite_tb,
cause_metadata=None,
cause_message=cause_message,
source_map={
origin_info.LineLocation(filename='/path/two.py', lineno=171):
origin_info.OriginInfo(
loc=origin_info.LineLocation(
filename='/path/other_two.py', lineno=13),
function_name='converted_fn',
source_code_line='converted test code',
comment=None)
},
converter_filename=None)
result = em.get_message()
self.assertRegex(
result,
re.compile(r'converted_fn \*.*test_fn_3.*Test message', re.DOTALL))
self.assertNotRegex(result, re.compile('test_fn_1'))
def test_get_message_call_overload(self):
callsite_tb = [
('/path/one.py', 11, 'test_fn_1', 'test code 1'),
('/path/two.py', 0, 'test_fn_2', 'test code 2'),
('/path/three.py', 0, 'test_fn_3', 'test code 3'),
]
cause_message = 'Test message'
em = error_utils.ErrorMetadataBase(
callsite_tb=callsite_tb,
cause_metadata=None,
cause_message=cause_message,
source_map={},
converter_filename='/path/two.py')
self.assertRegex(
em.get_message(),
re.compile(r'test_fn_1.*test_fn_3 \*\*.*Test message', re.DOTALL))
if __name__ == '__main__':
test.main()

View File

@ -694,8 +694,7 @@ class SymbolicSupportTest(test.TestCase):
if hasattr(e, 'ag_error_metadata'):
self.assertIn('easily_identifiable_name', str(e))
# See ErrorMetadataBase in autograph/pyct/errors.py
# Topmost frame corresponds to `call` itself.
function_name = e.ag_error_metadata.translated_stack[-2].function_name
function_name = e.ag_error_metadata.translated_stack[-1].function_name
else:
tb = traceback.extract_tb(sys.exc_info()[2])
last_entry = tb[-1]