diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 616a74e4f2a..146d4b6ec2c 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -375,6 +375,44 @@ def _is_known_loaded_type(f, module_name, entity_name): return False +def _fall_back_unconverted(f, args, kwargs, options, exc): + """Falls back to calling the function unconverted, in case of error.""" + # TODO(mdan): Consider adding an internal metric. + warning_template = ( + 'AutoGraph could not transform %s and will run it as-is.\n' + '%s' + 'Cause: %s\n' + 'To silence this warning, decorate the function with' + ' @tf.autograph.experimental.do_not_convert') + if isinstance(exc, errors.UnsupportedLanguageElementError): + if not conversion.is_in_whitelist_cache(f, options): + logging.warn(warning_template, f, '', exc) + else: + file_bug_message = ( + 'Please report this to the TensorFlow team. When filing the bug, set' + ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' + ' attach the full output.\n') + logging.warn(warning_template, f, file_bug_message, exc) + + return _call_unconverted(f, args, kwargs, options) + + +def _log_callargs(f, args, kwargs): + """Logging helper.""" + logging.log(2, 'Defaults of %s : %s', f, f.__defaults__) + if not six.PY2: + logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__) + + if kwargs is not None: + callargs = tf_inspect.getcallargs(f, *args, **kwargs) + else: + callargs = tf_inspect.getcallargs(f, *args) + + formatted_callargs = '\n'.join( + ' {}: {}'.format(k, v) for k, v in callargs.items()) + logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs) + + def converted_call(f, args, kwargs, @@ -498,9 +536,7 @@ def converted_call(f, if not options.internal_convert_user_code: return _call_unconverted(f, args, kwargs, options) - # TODO(mdan): Move this entire block inside to_graph. - try: # Begin of transformation error guards - + try: if inspect.ismethod(f) or inspect.isfunction(f): target_entity = f effective_args = args @@ -514,6 +550,8 @@ def converted_call(f, elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): # Callable objects. Dunder methods have special lookup rules, see: # https://docs.python.org/3/reference/datamodel.html#specialnames + # TODO(mdan): Recurse into converted_call to simplify other verifications. + # This should be handled in the same way as partials. target_entity = f.__class__.__call__ effective_args = (f,) + args @@ -521,63 +559,34 @@ def converted_call(f, target_entity = f raise NotImplementedError('unknown callable type "%s"' % type(f)) - if not inspect.isclass(target_entity): - if not hasattr(target_entity, '__code__'): - logging.log(2, 'Permanently whitelisted: %s: native binding', - target_entity) - return _call_unconverted(f, args, kwargs, options) - elif (hasattr(target_entity.__code__, 'co_filename') and - target_entity.__code__.co_filename == ''): - # TODO(mdan): __globals__['txt'] might work in Py3. - logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)', - target_entity) - return _call_unconverted(f, args, kwargs, options) - - program_ctx = converter.ProgramContext( - options=options, autograph_module=tf_inspect.getmodule(converted_call)) - converted_f = conversion.convert(target_entity, program_ctx) - - if logging.has_verbosity(2): - logging.log(2, 'Defaults of %s : %s', converted_f, - converted_f.__defaults__) - if not six.PY2: - logging.log(2, 'KW defaults of %s : %s', - converted_f, converted_f.__kwdefaults__) - - if kwargs is not None: - callargs = tf_inspect.getcallargs(converted_f, *effective_args, - **kwargs) - else: - callargs = tf_inspect.getcallargs(converted_f, *effective_args) - - formatted_callargs = '\n'.join( - ' {}: {}'.format(k, v) for k, v in callargs.items()) - logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs) - except Exception as e: # pylint:disable=broad-except logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) if is_autograph_strict_conversion_mode(): raise + return _fall_back_unconverted(f, args, kwargs, options, e) - warning_template = ( - 'AutoGraph could not transform %s and will run it as-is.\n' - '%s' - 'Cause: %s\n' - 'To silence this warning, decorate the function with' - ' @tf.autograph.experimental.do_not_convert') - if isinstance(e, errors.UnsupportedLanguageElementError): - # Repeating the check made upon function entry because the state might - # have updated in the meantime. - if not conversion.is_in_whitelist_cache(f, options): - logging.warn(warning_template, target_entity, '', e) - else: - file_bug_message = ( - 'Please report this to the TensorFlow team. When filing the bug, set' - ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' - ' attach the full output.\n') - logging.warn(warning_template, target_entity, file_bug_message, e) - + if not hasattr(target_entity, '__code__'): + logging.log(2, 'Permanently whitelisted: %s: native binding', + target_entity) return _call_unconverted(f, args, kwargs, options) + elif (hasattr(target_entity.__code__, 'co_filename') and + target_entity.__code__.co_filename == ''): + # TODO(mdan): __globals__['txt'] might work in Py3. + logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)', + target_entity) + return _call_unconverted(f, args, kwargs, options) + + try: + program_ctx = converter.ProgramContext( + options=options, autograph_module=tf_inspect.getmodule(converted_call)) + converted_f = conversion.convert(target_entity, program_ctx) + if logging.has_verbosity(2): + _log_callargs(converted_f, effective_args, kwargs) + except Exception as e: # pylint:disable=broad-except + logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) + if is_autograph_strict_conversion_mode(): + raise + return _fall_back_unconverted(f, args, kwargs, options, e) with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter(): try: diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index 4365edaaa8e..d8f73f20674 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -753,6 +753,27 @@ class ApiTest(test.TestCase): self.assertAllEqual(1, self.evaluate(x)) + def test_converted_call_native_binding(self): + x = api.converted_call(np.power, (2, 2), None, options=DEFAULT_RECURSIVE) + self.assertAllEqual(x, 4) + + def test_converted_call_native_binding_errorneous(self): + + class FaultyBinding(object): + + def __array__(self): + raise ValueError('fault') + + bad_obj = FaultyBinding() + + def fail_if_warning(*_): + self.fail('No warning should be issued') + + with test.mock.patch.object(ag_logging, 'warn', fail_if_warning): + with self.assertRaisesRegex(ValueError, 'fault'): + api.converted_call( + np.power, (bad_obj, 2), None, options=DEFAULT_RECURSIVE) + def test_converted_call_through_tf_dataset(self): def other_fn(x):