Avoid force-converting the target function for tf_decorator. Unlike @tf.function, this is not a user annotation so there isn't any expectation that the target function should be converted no matter what.

PiperOrigin-RevId: 253713786
This commit is contained in:
Dan Moldovan 2019-06-17 20:11:31 -07:00 committed by TensorFlower Gardener
parent 6982da07f3
commit f9dd464df8
3 changed files with 20 additions and 5 deletions

View File

@ -130,7 +130,7 @@ class StackTraceMapper(tf_stack.StackTraceMapper):
return origin.loc.filename, origin.loc.lineno, origin.function_name
def tf_convert(f, ctx, convert_by_default=True):
def tf_convert(f, ctx, convert_by_default=True, force_conversion=False):
"""Decorator that applies AutoGraph to a function.
Use in internal APIs.
@ -147,6 +147,8 @@ def tf_convert(f, ctx, convert_by_default=True):
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
convert_by_default: bool, whether to use AutoGraph when the context doesn't
specify.
force_conversion: bool, whether to ignore the conversion whitelist. See
ConversionOptions.force_conversion.
Returns:
Either `f or the converted version of `f`.
@ -162,7 +164,7 @@ def tf_convert(f, ctx, convert_by_default=True):
ctx.status == ag_ctx.Status.UNSPECIFIED))
if apply_autograph:
# TODO(mdan): Grab features from context.
wrapper = convert(recursive=True)(f)
wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
else:
wrapper = do_not_convert(f)
@ -174,7 +176,7 @@ def tf_convert(f, ctx, convert_by_default=True):
# TODO(mdan): Make private.
def convert(recursive=False, optional_features=None):
def convert(recursive=False, optional_features=None, force_conversion=True):
"""Decorator that compiles a function to use TensorFlow ops.
The decorator is dynamic - it recompiles the target whenever the decorated
@ -188,6 +190,8 @@ def convert(recursive=False, optional_features=None):
optional_features: converted.Feature, allows toggling optional or
experimental features. When set to None, only the core features are
enabled.
force_conversion: bool, whether to ignore the conversion whitelist. See
ConversionOptions.force_conversion.
Returns:
Callable, a decorator that converts the given function into an equivalent
@ -207,7 +211,7 @@ def convert(recursive=False, optional_features=None):
f, None,
converter.ConversionOptions(
recursive=recursive,
force_conversion=True,
force_conversion=force_conversion,
optional_features=optional_features,
), args, kwargs)
except Exception as e: # pylint:disable=broad-except

View File

@ -869,6 +869,16 @@ class ApiTest(test.TestCase):
# The code in `f` is only valid with AutoGraph.
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
def test_tf_convert_whitelisted_method(self):
model = sequential.Sequential([
core.Dense(2)
])
converted_call = api.tf_convert(
model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
_, converted_target = tf_decorator.unwrap(converted_call)
self.assertIs(converted_target.__func__, model.call.__func__)
def test_tf_convert_wrapped(self):
def f():

View File

@ -564,7 +564,8 @@ class SymbolicSupportTest(test.TestCase):
if hasattr(e, 'ag_error_metadata'):
self.assertIn('easily_identifiable_name', str(e))
# See ErrorMetadataBase in autograph/pyct/errors.py
function_name = e.ag_error_metadata.translated_stack[-1].function_name
# Topmost frame corresponds to `call` itself.
function_name = e.ag_error_metadata.translated_stack[-2].function_name
else:
tb = traceback.extract_tb(sys.exc_info()[2])
last_entry = tb[-1]