diff --git a/tensorflow/python/autograph/g3doc/reference/error_handling.md b/tensorflow/python/autograph/g3doc/reference/error_handling.md index ce3a64f8f28..6b1808404aa 100644 --- a/tensorflow/python/autograph/g3doc/reference/error_handling.md +++ b/tensorflow/python/autograph/g3doc/reference/error_handling.md @@ -38,6 +38,10 @@ Among the distinctive features of the re-raised exception: the `@tf.function` * the references corresponding to converted code are marked with an asterisk (`*`) + * the references corresponding to code which AutoGraph reached, but decided not + to convert, are marked with a double asterisk (`**`) + * the references corresponding to code that AutoGraph didn't reach at all have + no marking For example, the code below triggers an exception in the Python runtime, at graph construction time: @@ -62,7 +66,7 @@ TypeError: in converted code: :8 f * tf.constant(1) + tf.constant(1.0) - tensorflow/python/ops/math_ops.py:900 binary_op_wrapper + tensorflow/python/ops/math_ops.py:900 binary_op_wrapper ** return func(x, y, name=name) ... more TensorFlow internal frames ... diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 613add8a881..dbcdf4333c6 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -312,23 +312,24 @@ def do_not_convert(func=None): return autograph_artifact(wrapper) -def _attach_metadata(e, f, converted): +def _attach_metadata(e, f): """Augments an error with the metadata necessary for rewrite.""" if hasattr(e, 'ag_pass_through'): return metadata = getattr(e, 'ag_error_metadata', None) - source_map = f.ag_source_map if converted else {} + source_map = f.ag_source_map if metadata is None: - logging.log( - 1, 'Caught error in %s (converted=%s)', f, converted, exc_info=True) + logging.log(1, 'Caught error in user callable %s', f, exc_info=True) message = '{}: {}'.format(e.__class__.__name__, e) else: message = None cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:] - e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map) + + e.ag_error_metadata = _ErrorMetadata( + cause_tb, metadata, message, source_map, __file__) def _call_unconverted(f, args, kwargs, options, update_cache=True): @@ -339,14 +340,10 @@ def _call_unconverted(f, args, kwargs, options, update_cache=True): if inspect_utils.istfmethodtarget(f): return f.__self__.call(args, kwargs) - try: - if kwargs is not None: - return f(*args, **kwargs) - else: - return f(*args) - except Exception as e: # pylint:disable=broad-except - _attach_metadata(e, f, False) - raise + if kwargs is not None: + return f(*args, **kwargs) + else: + return f(*args) def _is_known_loaded_type(f, module_name, entity_name): @@ -584,7 +581,7 @@ def converted_call(f, else: result = converted_f(*effective_args) except Exception as e: - _attach_metadata(e, converted_f, True) + _attach_metadata(e, converted_f) raise return result diff --git a/tensorflow/python/autograph/pyct/error_utils.py b/tensorflow/python/autograph/pyct/error_utils.py index 150f1bb0584..3f7ace067fe 100644 --- a/tensorflow/python/autograph/pyct/error_utils.py +++ b/tensorflow/python/autograph/pyct/error_utils.py @@ -26,11 +26,13 @@ from tensorflow.python.autograph.pyct import origin_info class FrameInfo( collections.namedtuple( 'FrameInfo', - ('filename', 'lineno', 'function_name', 'code', 'converted'))): - pass + ('filename', 'lineno', 'function_name', 'code', 'is_converted', + 'is_whitelisted'))): + + __slots__ = () -def _stack_trace_inside_mapped_code(tb, source_map): +def _stack_trace_inside_mapped_code(tb, source_map, converter_filename): """Summarizes inner traceback frames up to the call to a given function. This functions locates the innermost (i.e. most recent) frame that corresponds @@ -67,10 +69,14 @@ def _stack_trace_inside_mapped_code(tb, source_map): raise ... Args: - tb: List[Tuple], the traceback corresponding to an error; typically, - the output of traceback.extract_tb. + tb: traceback.FrameSummary, The traceback corresponding to an error. + Typically, the output of traceback.Summary.extract(capture_locals=True). source_map: Dict[LineLocation, OriginInfo], a source map as created by origin_info.create_source_map. + converter_filename: str, the file path of the converted module. Call frames + corresponding to this module are elided and their preceding frames are + marked as whitelisted. Note that frames enclosing converted code are + dropped using a different mechanism. Returns: List[FrameInfo] @@ -81,21 +87,37 @@ def _stack_trace_inside_mapped_code(tb, source_map): loc = origin_info.LineLocation(filename=filename, lineno=line_number) if loc in source_map: origin = source_map[loc] - origin_frame_info = FrameInfo( + fi = FrameInfo( filename=origin.loc.filename, lineno=origin.loc.lineno, function_name=origin.function_name, code=origin.source_code_line, - converted=True) - result_frames.append(origin_frame_info) + is_converted=True, + is_whitelisted=False) + result_frames.append(fi) break + if filename == converter_filename: + if result_frames: + prev = result_frames[-1] + assert not prev.is_converted # See the if above. + fi = FrameInfo( + filename=prev.filename, + lineno=prev.lineno, + function_name=prev.function_name, + code=prev.code, + is_converted=False, + is_whitelisted=True) + result_frames[-1] = fi + continue + fi = FrameInfo( filename=filename, lineno=line_number, function_name=function_name, code=text, - converted=False) + is_converted=False, + is_whitelisted=False) result_frames.append(fi) return tuple(result_frames) @@ -136,8 +158,12 @@ class ErrorMetadataBase(object): code from which the generated code originated. """ - def __init__(self, callsite_tb, cause_metadata, cause_message, source_map): - translated_stack = _stack_trace_inside_mapped_code(callsite_tb, source_map) + __slots__ = ('translated_stack', 'cause_message') + + def __init__(self, callsite_tb, cause_metadata, cause_message, source_map, + converter_filename): + translated_stack = _stack_trace_inside_mapped_code( + callsite_tb, source_map, converter_filename) if cause_metadata is None: self.translated_stack = translated_stack @@ -156,12 +182,15 @@ class ErrorMetadataBase(object): lines.append('') for frame_info in reversed(self.translated_stack): - lines.append(' {}:{} {}{}'.format( - frame_info.filename, - frame_info.lineno, - frame_info.function_name, - ' *' if frame_info.converted else '', - )) + formatted_line = ' {}:{} {}'.format(frame_info.filename, + frame_info.lineno, + frame_info.function_name) + if frame_info.is_converted: + formatted_line += ' *' + elif frame_info.is_whitelisted: + formatted_line += ' **' + lines.append(formatted_line) + if frame_info.code is None: code_snippet = '' else: diff --git a/tensorflow/python/autograph/pyct/error_utils_test.py b/tensorflow/python/autograph/pyct/error_utils_test.py index 9fdbc55579e..601a2c59796 100644 --- a/tensorflow/python/autograph/pyct/error_utils_test.py +++ b/tensorflow/python/autograph/pyct/error_utils_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import re from tensorflow.python.autograph.pyct import error_utils +from tensorflow.python.autograph.pyct import origin_info from tensorflow.python.platform import test @@ -35,7 +36,8 @@ class ErrorMetadataBaseTest(test.TestCase): callsite_tb=(), cause_metadata=None, cause_message='test message', - source_map={}) + source_map={}, + converter_filename=None) exc = em.create_exception(CustomError()) self.assertIsInstance(exc, CustomError) self.assertIn('test message', str(exc)) @@ -51,11 +53,12 @@ class ErrorMetadataBaseTest(test.TestCase): callsite_tb=(), cause_metadata=None, cause_message='test message', - source_map={}) + source_map={}, + converter_filename=None) exc = em.create_exception(CustomError()) self.assertIsNone(exc) - def test_get_message_when_frame_info_code_is_none(self): + def test_get_message_no_code(self): callsite_tb = [ ('/path/one.py', 11, 'test_fn_1', None), ('/path/two.py', 171, 'test_fn_2', 'test code'), @@ -65,11 +68,57 @@ class ErrorMetadataBaseTest(test.TestCase): callsite_tb=callsite_tb, cause_metadata=None, cause_message=cause_message, - source_map={}) + source_map={}, + converter_filename=None) self.assertRegex( em.get_message(), re.compile('test_fn_1.*test_fn_2.*Test message', re.DOTALL)) + def test_get_message_converted_code(self): + callsite_tb = [ + ('/path/one.py', 11, 'test_fn_1', 'test code 1'), + ('/path/two.py', 171, 'test_fn_2', 'test code 2'), + ('/path/three.py', 171, 'test_fn_3', 'test code 3'), + ] + cause_message = 'Test message' + em = error_utils.ErrorMetadataBase( + callsite_tb=callsite_tb, + cause_metadata=None, + cause_message=cause_message, + source_map={ + origin_info.LineLocation(filename='/path/two.py', lineno=171): + origin_info.OriginInfo( + loc=origin_info.LineLocation( + filename='/path/other_two.py', lineno=13), + function_name='converted_fn', + source_code_line='converted test code', + comment=None) + }, + converter_filename=None) + result = em.get_message() + self.assertRegex( + result, + re.compile(r'converted_fn \*.*test_fn_3.*Test message', re.DOTALL)) + self.assertNotRegex(result, re.compile('test_fn_1')) + + def test_get_message_call_overload(self): + + callsite_tb = [ + ('/path/one.py', 11, 'test_fn_1', 'test code 1'), + ('/path/two.py', 0, 'test_fn_2', 'test code 2'), + ('/path/three.py', 0, 'test_fn_3', 'test code 3'), + ] + cause_message = 'Test message' + em = error_utils.ErrorMetadataBase( + callsite_tb=callsite_tb, + cause_metadata=None, + cause_message=cause_message, + source_map={}, + converter_filename='/path/two.py') + self.assertRegex( + em.get_message(), + re.compile(r'test_fn_1.*test_fn_3 \*\*.*Test message', re.DOTALL)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index d0ff510ebc3..02433be946d 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -694,8 +694,7 @@ class SymbolicSupportTest(test.TestCase): if hasattr(e, 'ag_error_metadata'): self.assertIn('easily_identifiable_name', str(e)) # See ErrorMetadataBase in autograph/pyct/errors.py - # Topmost frame corresponds to `call` itself. - function_name = e.ag_error_metadata.translated_stack[-2].function_name + function_name = e.ag_error_metadata.translated_stack[-1].function_name else: tb = traceback.extract_tb(sys.exc_info()[2]) last_entry = tb[-1]