Avoid force-converting the target function for tf_decorator. Unlike @tf.function, this is not a user annotation so there isn't any expectation that the target function should be converted no matter what.
PiperOrigin-RevId: 253713786
This commit is contained in:
parent
6982da07f3
commit
f9dd464df8
@ -130,7 +130,7 @@ class StackTraceMapper(tf_stack.StackTraceMapper):
|
||||
return origin.loc.filename, origin.loc.lineno, origin.function_name
|
||||
|
||||
|
||||
def tf_convert(f, ctx, convert_by_default=True):
|
||||
def tf_convert(f, ctx, convert_by_default=True, force_conversion=False):
|
||||
"""Decorator that applies AutoGraph to a function.
|
||||
|
||||
Use in internal APIs.
|
||||
@ -147,6 +147,8 @@ def tf_convert(f, ctx, convert_by_default=True):
|
||||
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
|
||||
convert_by_default: bool, whether to use AutoGraph when the context doesn't
|
||||
specify.
|
||||
force_conversion: bool, whether to ignore the conversion whitelist. See
|
||||
ConversionOptions.force_conversion.
|
||||
|
||||
Returns:
|
||||
Either `f or the converted version of `f`.
|
||||
@ -162,7 +164,7 @@ def tf_convert(f, ctx, convert_by_default=True):
|
||||
ctx.status == ag_ctx.Status.UNSPECIFIED))
|
||||
if apply_autograph:
|
||||
# TODO(mdan): Grab features from context.
|
||||
wrapper = convert(recursive=True)(f)
|
||||
wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
|
||||
else:
|
||||
wrapper = do_not_convert(f)
|
||||
|
||||
@ -174,7 +176,7 @@ def tf_convert(f, ctx, convert_by_default=True):
|
||||
|
||||
|
||||
# TODO(mdan): Make private.
|
||||
def convert(recursive=False, optional_features=None):
|
||||
def convert(recursive=False, optional_features=None, force_conversion=True):
|
||||
"""Decorator that compiles a function to use TensorFlow ops.
|
||||
|
||||
The decorator is dynamic - it recompiles the target whenever the decorated
|
||||
@ -188,6 +190,8 @@ def convert(recursive=False, optional_features=None):
|
||||
optional_features: converted.Feature, allows toggling optional or
|
||||
experimental features. When set to None, only the core features are
|
||||
enabled.
|
||||
force_conversion: bool, whether to ignore the conversion whitelist. See
|
||||
ConversionOptions.force_conversion.
|
||||
|
||||
Returns:
|
||||
Callable, a decorator that converts the given function into an equivalent
|
||||
@ -207,7 +211,7 @@ def convert(recursive=False, optional_features=None):
|
||||
f, None,
|
||||
converter.ConversionOptions(
|
||||
recursive=recursive,
|
||||
force_conversion=True,
|
||||
force_conversion=force_conversion,
|
||||
optional_features=optional_features,
|
||||
), args, kwargs)
|
||||
except Exception as e: # pylint:disable=broad-except
|
||||
|
@ -869,6 +869,16 @@ class ApiTest(test.TestCase):
|
||||
# The code in `f` is only valid with AutoGraph.
|
||||
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
|
||||
|
||||
def test_tf_convert_whitelisted_method(self):
|
||||
|
||||
model = sequential.Sequential([
|
||||
core.Dense(2)
|
||||
])
|
||||
converted_call = api.tf_convert(
|
||||
model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
|
||||
_, converted_target = tf_decorator.unwrap(converted_call)
|
||||
self.assertIs(converted_target.__func__, model.call.__func__)
|
||||
|
||||
def test_tf_convert_wrapped(self):
|
||||
|
||||
def f():
|
||||
|
@ -564,7 +564,8 @@ class SymbolicSupportTest(test.TestCase):
|
||||
if hasattr(e, 'ag_error_metadata'):
|
||||
self.assertIn('easily_identifiable_name', str(e))
|
||||
# See ErrorMetadataBase in autograph/pyct/errors.py
|
||||
function_name = e.ag_error_metadata.translated_stack[-1].function_name
|
||||
# Topmost frame corresponds to `call` itself.
|
||||
function_name = e.ag_error_metadata.translated_stack[-2].function_name
|
||||
else:
|
||||
tb = traceback.extract_tb(sys.exc_info()[2])
|
||||
last_entry = tb[-1]
|
||||
|
Loading…
x
Reference in New Issue
Block a user