Avoid force-converting the target function for tf_decorator. Unlike @tf.function, this is not a user annotation so there isn't any expectation that the target function should be converted no matter what.
PiperOrigin-RevId: 253713786
This commit is contained in:
parent
6982da07f3
commit
f9dd464df8
@ -130,7 +130,7 @@ class StackTraceMapper(tf_stack.StackTraceMapper):
|
|||||||
return origin.loc.filename, origin.loc.lineno, origin.function_name
|
return origin.loc.filename, origin.loc.lineno, origin.function_name
|
||||||
|
|
||||||
|
|
||||||
def tf_convert(f, ctx, convert_by_default=True):
|
def tf_convert(f, ctx, convert_by_default=True, force_conversion=False):
|
||||||
"""Decorator that applies AutoGraph to a function.
|
"""Decorator that applies AutoGraph to a function.
|
||||||
|
|
||||||
Use in internal APIs.
|
Use in internal APIs.
|
||||||
@ -147,6 +147,8 @@ def tf_convert(f, ctx, convert_by_default=True):
|
|||||||
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
|
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
|
||||||
convert_by_default: bool, whether to use AutoGraph when the context doesn't
|
convert_by_default: bool, whether to use AutoGraph when the context doesn't
|
||||||
specify.
|
specify.
|
||||||
|
force_conversion: bool, whether to ignore the conversion whitelist. See
|
||||||
|
ConversionOptions.force_conversion.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Either `f or the converted version of `f`.
|
Either `f or the converted version of `f`.
|
||||||
@ -162,7 +164,7 @@ def tf_convert(f, ctx, convert_by_default=True):
|
|||||||
ctx.status == ag_ctx.Status.UNSPECIFIED))
|
ctx.status == ag_ctx.Status.UNSPECIFIED))
|
||||||
if apply_autograph:
|
if apply_autograph:
|
||||||
# TODO(mdan): Grab features from context.
|
# TODO(mdan): Grab features from context.
|
||||||
wrapper = convert(recursive=True)(f)
|
wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
|
||||||
else:
|
else:
|
||||||
wrapper = do_not_convert(f)
|
wrapper = do_not_convert(f)
|
||||||
|
|
||||||
@ -174,7 +176,7 @@ def tf_convert(f, ctx, convert_by_default=True):
|
|||||||
|
|
||||||
|
|
||||||
# TODO(mdan): Make private.
|
# TODO(mdan): Make private.
|
||||||
def convert(recursive=False, optional_features=None):
|
def convert(recursive=False, optional_features=None, force_conversion=True):
|
||||||
"""Decorator that compiles a function to use TensorFlow ops.
|
"""Decorator that compiles a function to use TensorFlow ops.
|
||||||
|
|
||||||
The decorator is dynamic - it recompiles the target whenever the decorated
|
The decorator is dynamic - it recompiles the target whenever the decorated
|
||||||
@ -188,6 +190,8 @@ def convert(recursive=False, optional_features=None):
|
|||||||
optional_features: converted.Feature, allows toggling optional or
|
optional_features: converted.Feature, allows toggling optional or
|
||||||
experimental features. When set to None, only the core features are
|
experimental features. When set to None, only the core features are
|
||||||
enabled.
|
enabled.
|
||||||
|
force_conversion: bool, whether to ignore the conversion whitelist. See
|
||||||
|
ConversionOptions.force_conversion.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Callable, a decorator that converts the given function into an equivalent
|
Callable, a decorator that converts the given function into an equivalent
|
||||||
@ -207,7 +211,7 @@ def convert(recursive=False, optional_features=None):
|
|||||||
f, None,
|
f, None,
|
||||||
converter.ConversionOptions(
|
converter.ConversionOptions(
|
||||||
recursive=recursive,
|
recursive=recursive,
|
||||||
force_conversion=True,
|
force_conversion=force_conversion,
|
||||||
optional_features=optional_features,
|
optional_features=optional_features,
|
||||||
), args, kwargs)
|
), args, kwargs)
|
||||||
except Exception as e: # pylint:disable=broad-except
|
except Exception as e: # pylint:disable=broad-except
|
||||||
|
@ -869,6 +869,16 @@ class ApiTest(test.TestCase):
|
|||||||
# The code in `f` is only valid with AutoGraph.
|
# The code in `f` is only valid with AutoGraph.
|
||||||
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
|
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
|
||||||
|
|
||||||
|
def test_tf_convert_whitelisted_method(self):
|
||||||
|
|
||||||
|
model = sequential.Sequential([
|
||||||
|
core.Dense(2)
|
||||||
|
])
|
||||||
|
converted_call = api.tf_convert(
|
||||||
|
model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
|
||||||
|
_, converted_target = tf_decorator.unwrap(converted_call)
|
||||||
|
self.assertIs(converted_target.__func__, model.call.__func__)
|
||||||
|
|
||||||
def test_tf_convert_wrapped(self):
|
def test_tf_convert_wrapped(self):
|
||||||
|
|
||||||
def f():
|
def f():
|
||||||
|
@ -564,7 +564,8 @@ class SymbolicSupportTest(test.TestCase):
|
|||||||
if hasattr(e, 'ag_error_metadata'):
|
if hasattr(e, 'ag_error_metadata'):
|
||||||
self.assertIn('easily_identifiable_name', str(e))
|
self.assertIn('easily_identifiable_name', str(e))
|
||||||
# See ErrorMetadataBase in autograph/pyct/errors.py
|
# See ErrorMetadataBase in autograph/pyct/errors.py
|
||||||
function_name = e.ag_error_metadata.translated_stack[-1].function_name
|
# Topmost frame corresponds to `call` itself.
|
||||||
|
function_name = e.ag_error_metadata.translated_stack[-2].function_name
|
||||||
else:
|
else:
|
||||||
tb = traceback.extract_tb(sys.exc_info()[2])
|
tb = traceback.extract_tb(sys.exc_info()[2])
|
||||||
last_entry = tb[-1]
|
last_entry = tb[-1]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user