diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 180b0a54fbe..dc92f1a990f 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -130,7 +130,7 @@ class StackTraceMapper(tf_stack.StackTraceMapper): return origin.loc.filename, origin.loc.lineno, origin.function_name -def tf_convert(f, ctx, convert_by_default=True): +def tf_convert(f, ctx, convert_by_default=True, force_conversion=False): """Decorator that applies AutoGraph to a function. Use in internal APIs. @@ -147,6 +147,8 @@ def tf_convert(f, ctx, convert_by_default=True): ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. convert_by_default: bool, whether to use AutoGraph when the context doesn't specify. + force_conversion: bool, whether to ignore the conversion whitelist. See + ConversionOptions.force_conversion. Returns: Either `f or the converted version of `f`. @@ -162,7 +164,7 @@ def tf_convert(f, ctx, convert_by_default=True): ctx.status == ag_ctx.Status.UNSPECIFIED)) if apply_autograph: # TODO(mdan): Grab features from context. - wrapper = convert(recursive=True)(f) + wrapper = convert(recursive=True, force_conversion=force_conversion)(f) else: wrapper = do_not_convert(f) @@ -174,7 +176,7 @@ def tf_convert(f, ctx, convert_by_default=True): # TODO(mdan): Make private. -def convert(recursive=False, optional_features=None): +def convert(recursive=False, optional_features=None, force_conversion=True): """Decorator that compiles a function to use TensorFlow ops. The decorator is dynamic - it recompiles the target whenever the decorated @@ -188,6 +190,8 @@ def convert(recursive=False, optional_features=None): optional_features: converted.Feature, allows toggling optional or experimental features. When set to None, only the core features are enabled. + force_conversion: bool, whether to ignore the conversion whitelist. See + ConversionOptions.force_conversion. Returns: Callable, a decorator that converts the given function into an equivalent @@ -207,7 +211,7 @@ def convert(recursive=False, optional_features=None): f, None, converter.ConversionOptions( recursive=recursive, - force_conversion=True, + force_conversion=force_conversion, optional_features=optional_features, ), args, kwargs) except Exception as e: # pylint:disable=broad-except diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index 749e0e5c252..2c8feb3b27d 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -869,6 +869,16 @@ class ApiTest(test.TestCase): # The code in `f` is only valid with AutoGraph. test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED)) + def test_tf_convert_whitelisted_method(self): + + model = sequential.Sequential([ + core.Dense(2) + ]) + converted_call = api.tf_convert( + model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED)) + _, converted_target = tf_decorator.unwrap(converted_call) + self.assertIs(converted_target.__func__, model.call.__func__) + def test_tf_convert_wrapped(self): def f(): diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index f59e4c6a652..d1847a530e9 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -564,7 +564,8 @@ class SymbolicSupportTest(test.TestCase): if hasattr(e, 'ag_error_metadata'): self.assertIn('easily_identifiable_name', str(e)) # See ErrorMetadataBase in autograph/pyct/errors.py - function_name = e.ag_error_metadata.translated_stack[-1].function_name + # Topmost frame corresponds to `call` itself. + function_name = e.ag_error_metadata.translated_stack[-2].function_name else: tb = traceback.extract_tb(sys.exc_info()[2]) last_entry = tb[-1]