Do not catch and warn for exceptions raised while calling a function without conversion. Fixes #36792.

PiperOrigin-RevId: 303366732
Change-Id: I579793f5625bf90e07ab8ea1cb97bd324bae9935
This commit is contained in:
Dan Moldovan 2020-03-27 11:09:26 -07:00 committed by TensorFlower Gardener
parent 8c65545d87
commit a0f6399164
2 changed files with 84 additions and 54 deletions

View File

@ -375,6 +375,44 @@ def _is_known_loaded_type(f, module_name, entity_name):
return False
def _fall_back_unconverted(f, args, kwargs, options, exc):
"""Falls back to calling the function unconverted, in case of error."""
# TODO(mdan): Consider adding an internal metric.
warning_template = (
'AutoGraph could not transform %s and will run it as-is.\n'
'%s'
'Cause: %s\n'
'To silence this warning, decorate the function with'
' @tf.autograph.experimental.do_not_convert')
if isinstance(exc, errors.UnsupportedLanguageElementError):
if not conversion.is_in_whitelist_cache(f, options):
logging.warn(warning_template, f, '', exc)
else:
file_bug_message = (
'Please report this to the TensorFlow team. When filing the bug, set'
' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
' attach the full output.\n')
logging.warn(warning_template, f, file_bug_message, exc)
return _call_unconverted(f, args, kwargs, options)
def _log_callargs(f, args, kwargs):
"""Logging helper."""
logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
if not six.PY2:
logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)
if kwargs is not None:
callargs = tf_inspect.getcallargs(f, *args, **kwargs)
else:
callargs = tf_inspect.getcallargs(f, *args)
formatted_callargs = '\n'.join(
' {}: {}'.format(k, v) for k, v in callargs.items())
logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
def converted_call(f,
args,
kwargs,
@ -498,9 +536,7 @@ def converted_call(f,
if not options.internal_convert_user_code:
return _call_unconverted(f, args, kwargs, options)
# TODO(mdan): Move this entire block inside to_graph.
try: # Begin of transformation error guards
try:
if inspect.ismethod(f) or inspect.isfunction(f):
target_entity = f
effective_args = args
@ -514,6 +550,8 @@ def converted_call(f,
elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
# Callable objects. Dunder methods have special lookup rules, see:
# https://docs.python.org/3/reference/datamodel.html#specialnames
# TODO(mdan): Recurse into converted_call to simplify other verifications.
# This should be handled in the same way as partials.
target_entity = f.__class__.__call__
effective_args = (f,) + args
@ -521,63 +559,34 @@ def converted_call(f,
target_entity = f
raise NotImplementedError('unknown callable type "%s"' % type(f))
if not inspect.isclass(target_entity):
if not hasattr(target_entity, '__code__'):
logging.log(2, 'Permanently whitelisted: %s: native binding',
target_entity)
return _call_unconverted(f, args, kwargs, options)
elif (hasattr(target_entity.__code__, 'co_filename') and
target_entity.__code__.co_filename == '<string>'):
# TODO(mdan): __globals__['txt'] might work in Py3.
logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)',
target_entity)
return _call_unconverted(f, args, kwargs, options)
program_ctx = converter.ProgramContext(
options=options, autograph_module=tf_inspect.getmodule(converted_call))
converted_f = conversion.convert(target_entity, program_ctx)
if logging.has_verbosity(2):
logging.log(2, 'Defaults of %s : %s', converted_f,
converted_f.__defaults__)
if not six.PY2:
logging.log(2, 'KW defaults of %s : %s',
converted_f, converted_f.__kwdefaults__)
if kwargs is not None:
callargs = tf_inspect.getcallargs(converted_f, *effective_args,
**kwargs)
else:
callargs = tf_inspect.getcallargs(converted_f, *effective_args)
formatted_callargs = '\n'.join(
' {}: {}'.format(k, v) for k, v in callargs.items())
logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)
except Exception as e: # pylint:disable=broad-except
logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
if is_autograph_strict_conversion_mode():
raise
return _fall_back_unconverted(f, args, kwargs, options, e)
warning_template = (
'AutoGraph could not transform %s and will run it as-is.\n'
'%s'
'Cause: %s\n'
'To silence this warning, decorate the function with'
' @tf.autograph.experimental.do_not_convert')
if isinstance(e, errors.UnsupportedLanguageElementError):
# Repeating the check made upon function entry because the state might
# have updated in the meantime.
if not conversion.is_in_whitelist_cache(f, options):
logging.warn(warning_template, target_entity, '', e)
else:
file_bug_message = (
'Please report this to the TensorFlow team. When filing the bug, set'
' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
' attach the full output.\n')
logging.warn(warning_template, target_entity, file_bug_message, e)
if not hasattr(target_entity, '__code__'):
logging.log(2, 'Permanently whitelisted: %s: native binding',
target_entity)
return _call_unconverted(f, args, kwargs, options)
elif (hasattr(target_entity.__code__, 'co_filename') and
target_entity.__code__.co_filename == '<string>'):
# TODO(mdan): __globals__['txt'] might work in Py3.
logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)',
target_entity)
return _call_unconverted(f, args, kwargs, options)
try:
program_ctx = converter.ProgramContext(
options=options, autograph_module=tf_inspect.getmodule(converted_call))
converted_f = conversion.convert(target_entity, program_ctx)
if logging.has_verbosity(2):
_log_callargs(converted_f, effective_args, kwargs)
except Exception as e: # pylint:disable=broad-except
logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
if is_autograph_strict_conversion_mode():
raise
return _fall_back_unconverted(f, args, kwargs, options, e)
with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
try:

View File

@ -753,6 +753,27 @@ class ApiTest(test.TestCase):
self.assertAllEqual(1, self.evaluate(x))
def test_converted_call_native_binding(self):
x = api.converted_call(np.power, (2, 2), None, options=DEFAULT_RECURSIVE)
self.assertAllEqual(x, 4)
def test_converted_call_native_binding_errorneous(self):
class FaultyBinding(object):
def __array__(self):
raise ValueError('fault')
bad_obj = FaultyBinding()
def fail_if_warning(*_):
self.fail('No warning should be issued')
with test.mock.patch.object(ag_logging, 'warn', fail_if_warning):
with self.assertRaisesRegex(ValueError, 'fault'):
api.converted_call(
np.power, (bad_obj, 2), None, options=DEFAULT_RECURSIVE)
def test_converted_call_through_tf_dataset(self):
def other_fn(x):