Avoid force-converting the target function for tf_decorator. Unlike @tf.function, this is not a user annotation so there isn't any expectation that the target function should be converted no matter what.

PiperOrigin-RevId: 253713786
This commit is contained in:
Dan Moldovan 2019-06-17 20:11:31 -07:00 committed by TensorFlower Gardener
parent 6982da07f3
commit f9dd464df8
3 changed files with 20 additions and 5 deletions

View File

@ -130,7 +130,7 @@ class StackTraceMapper(tf_stack.StackTraceMapper):
return origin.loc.filename, origin.loc.lineno, origin.function_name return origin.loc.filename, origin.loc.lineno, origin.function_name
def tf_convert(f, ctx, convert_by_default=True): def tf_convert(f, ctx, convert_by_default=True, force_conversion=False):
"""Decorator that applies AutoGraph to a function. """Decorator that applies AutoGraph to a function.
Use in internal APIs. Use in internal APIs.
@ -147,6 +147,8 @@ def tf_convert(f, ctx, convert_by_default=True):
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
convert_by_default: bool, whether to use AutoGraph when the context doesn't convert_by_default: bool, whether to use AutoGraph when the context doesn't
specify. specify.
force_conversion: bool, whether to ignore the conversion whitelist. See
ConversionOptions.force_conversion.
Returns: Returns:
Either `f or the converted version of `f`. Either `f or the converted version of `f`.
@ -162,7 +164,7 @@ def tf_convert(f, ctx, convert_by_default=True):
ctx.status == ag_ctx.Status.UNSPECIFIED)) ctx.status == ag_ctx.Status.UNSPECIFIED))
if apply_autograph: if apply_autograph:
# TODO(mdan): Grab features from context. # TODO(mdan): Grab features from context.
wrapper = convert(recursive=True)(f) wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
else: else:
wrapper = do_not_convert(f) wrapper = do_not_convert(f)
@ -174,7 +176,7 @@ def tf_convert(f, ctx, convert_by_default=True):
# TODO(mdan): Make private. # TODO(mdan): Make private.
def convert(recursive=False, optional_features=None): def convert(recursive=False, optional_features=None, force_conversion=True):
"""Decorator that compiles a function to use TensorFlow ops. """Decorator that compiles a function to use TensorFlow ops.
The decorator is dynamic - it recompiles the target whenever the decorated The decorator is dynamic - it recompiles the target whenever the decorated
@ -188,6 +190,8 @@ def convert(recursive=False, optional_features=None):
optional_features: converted.Feature, allows toggling optional or optional_features: converted.Feature, allows toggling optional or
experimental features. When set to None, only the core features are experimental features. When set to None, only the core features are
enabled. enabled.
force_conversion: bool, whether to ignore the conversion whitelist. See
ConversionOptions.force_conversion.
Returns: Returns:
Callable, a decorator that converts the given function into an equivalent Callable, a decorator that converts the given function into an equivalent
@ -207,7 +211,7 @@ def convert(recursive=False, optional_features=None):
f, None, f, None,
converter.ConversionOptions( converter.ConversionOptions(
recursive=recursive, recursive=recursive,
force_conversion=True, force_conversion=force_conversion,
optional_features=optional_features, optional_features=optional_features,
), args, kwargs) ), args, kwargs)
except Exception as e: # pylint:disable=broad-except except Exception as e: # pylint:disable=broad-except

View File

@ -869,6 +869,16 @@ class ApiTest(test.TestCase):
# The code in `f` is only valid with AutoGraph. # The code in `f` is only valid with AutoGraph.
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED)) test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
def test_tf_convert_whitelisted_method(self):
model = sequential.Sequential([
core.Dense(2)
])
converted_call = api.tf_convert(
model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
_, converted_target = tf_decorator.unwrap(converted_call)
self.assertIs(converted_target.__func__, model.call.__func__)
def test_tf_convert_wrapped(self): def test_tf_convert_wrapped(self):
def f(): def f():

View File

@ -564,7 +564,8 @@ class SymbolicSupportTest(test.TestCase):
if hasattr(e, 'ag_error_metadata'): if hasattr(e, 'ag_error_metadata'):
self.assertIn('easily_identifiable_name', str(e)) self.assertIn('easily_identifiable_name', str(e))
# See ErrorMetadataBase in autograph/pyct/errors.py # See ErrorMetadataBase in autograph/pyct/errors.py
function_name = e.ag_error_metadata.translated_stack[-1].function_name # Topmost frame corresponds to `call` itself.
function_name = e.ag_error_metadata.translated_stack[-2].function_name
else: else:
tb = traceback.extract_tb(sys.exc_info()[2]) tb = traceback.extract_tb(sys.exc_info()[2])
last_entry = tb[-1] last_entry = tb[-1]