Add supplemental verification to abort recursive conversion when it has been disabled. This covers code that has been already converted and is called through a do_not_convert wrapper. The code that has already been converted (e.g. as a function local to a previously-converted function) will remain converted, but any unconverted functions that that code calls will no longer be recursively converted.
Example: ``` def h(): ... @tf.function def f(): def g(): return h() return tf.autograph.do_not_convert(g)() ``` Before this change, `h()` was being transitively converted. After this change, `h()` will no longer be converted. Mark py_function's targets with do_not_convert. PiperOrigin-RevId: 281280243 Change-Id: I49f6c92350815a1e8eb38a51d919fe26055a0bd4
This commit is contained in:
parent
67ed6a56ca
commit
f2839145cf
@ -63,5 +63,15 @@ class ControlStatusCtx(object):
|
||||
_control_ctx().pop()
|
||||
|
||||
|
||||
class NullCtx(object):
|
||||
"""Helper substitute for contextlib.nullcontext."""
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, unused_type, unused_value, unused_traceback):
|
||||
pass
|
||||
|
||||
|
||||
def _default_control_status_ctx():
|
||||
return ControlStatusCtx(status=Status.UNSPECIFIED)
|
||||
|
@ -179,17 +179,28 @@ def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
|
||||
decorators, f = tf_decorator.unwrap(f)
|
||||
|
||||
# TODO(mdan): Grab features from context.
|
||||
# Note: we pass the original context through to convert to properly handle the
|
||||
# following scenario, which can be used insite TF implementations:
|
||||
#
|
||||
# ctx = ag_ctx.control_status_ctx()
|
||||
# @function(autograph=False) # Low-level graph code
|
||||
# def inner_fn():
|
||||
# # The context is disabled here, but should be enabled in user user_fn
|
||||
# tf_convert(user_fn, ctx=ctx)
|
||||
if ctx.status == ag_ctx.Status.ENABLED:
|
||||
wrapper = convert(recursive=True, user_requested=user_requested)(f)
|
||||
wrapper_factory = convert(
|
||||
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
|
||||
elif ctx.status == ag_ctx.Status.DISABLED:
|
||||
wrapper = do_not_convert(f)
|
||||
wrapper_factory = do_not_convert
|
||||
elif ctx.status == ag_ctx.Status.UNSPECIFIED:
|
||||
if convert_by_default:
|
||||
wrapper = convert(recursive=True, user_requested=user_requested)(f)
|
||||
wrapper_factory = convert(
|
||||
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
|
||||
else:
|
||||
wrapper = call_with_unspecified_conversion_status(f)
|
||||
wrapper_factory = call_with_unspecified_conversion_status
|
||||
else:
|
||||
raise ValueError(ctx.status)
|
||||
assert False, 'This switch contains all possible cases!'
|
||||
wrapper = wrapper_factory(f)
|
||||
|
||||
if decorators:
|
||||
wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
|
||||
@ -199,7 +210,10 @@ def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
|
||||
|
||||
|
||||
# TODO(mdan): Make private.
|
||||
def convert(recursive=False, optional_features=None, user_requested=True):
|
||||
def convert(recursive=False,
|
||||
optional_features=None,
|
||||
user_requested=True,
|
||||
conversion_ctx=ag_ctx.NullCtx()):
|
||||
"""Decorator that compiles a function to use TensorFlow ops.
|
||||
|
||||
The decorator is dynamic - it recompiles the target whenever the decorated
|
||||
@ -213,8 +227,10 @@ def convert(recursive=False, optional_features=None, user_requested=True):
|
||||
optional_features: converted.Feature, allows toggling optional or
|
||||
experimental features. When set to None, only the core features are
|
||||
enabled.
|
||||
user_requested: bool, whether to ignore the conversion whitelist. See
|
||||
ConversionOptions.user_requested.
|
||||
user_requested: bool, whether this is a function that the user explicitly
|
||||
asked to be converted. See ConversionOptions.user_requested.
|
||||
conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
|
||||
which `f` is used.
|
||||
|
||||
Returns:
|
||||
Callable, a decorator that converts the given function into an equivalent
|
||||
@ -231,7 +247,8 @@ def convert(recursive=False, optional_features=None, user_requested=True):
|
||||
user_requested=user_requested,
|
||||
optional_features=optional_features)
|
||||
try:
|
||||
return converted_call(f, args, kwargs, options=options)
|
||||
with conversion_ctx:
|
||||
return converted_call(f, args, kwargs, options=options)
|
||||
except Exception as e: # pylint:disable=broad-except
|
||||
if hasattr(e, 'ag_error_metadata'):
|
||||
raise e.ag_error_metadata.to_exception(e)
|
||||
@ -368,7 +385,11 @@ def _errors_are_normally_possible(entity, error):
|
||||
return False
|
||||
|
||||
|
||||
def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
|
||||
def converted_call(f,
|
||||
args,
|
||||
kwargs,
|
||||
caller_fn_scope=None,
|
||||
options=None):
|
||||
"""Compiles a function call inline.
|
||||
|
||||
For internal use only.
|
||||
@ -405,6 +426,10 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
|
||||
if conversion.check_cached_unconverted(f, options):
|
||||
return _call_unconverted(f, args, kwargs, options, False)
|
||||
|
||||
if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
|
||||
logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f)
|
||||
return _call_unconverted(f, args, kwargs, options, False)
|
||||
|
||||
if inspect_utils.isbuiltin(f):
|
||||
if f is eval:
|
||||
return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
|
||||
|
@ -722,6 +722,35 @@ class ApiTest(test.TestCase):
|
||||
|
||||
self.assertNoMemoryLeaks(test_fn)
|
||||
|
||||
def test_converted_call_no_caching_on_abort(self):
|
||||
|
||||
def test_fn(needs_autograph):
|
||||
if needs_autograph:
|
||||
if constant_op.constant(True):
|
||||
x = constant_op.constant(1)
|
||||
else:
|
||||
x = constant_op.constant(2)
|
||||
else:
|
||||
x = 3
|
||||
return x
|
||||
|
||||
def call_in_disabled_context():
|
||||
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
|
||||
return api.converted_call(
|
||||
test_fn, (False,), None, options=DEFAULT_RECURSIVE)
|
||||
|
||||
def call_in_default_context():
|
||||
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED):
|
||||
return api.converted_call(
|
||||
test_fn, (True,), None, options=DEFAULT_RECURSIVE)
|
||||
|
||||
# Note: this is an invariant, not a test (see above).
|
||||
assert call_in_disabled_context() == 3
|
||||
|
||||
# If api.convert placed test_fn in the unconverted cache, this second
|
||||
# invocation would fail.
|
||||
self.assertEqual(self.evaluate(call_in_default_context()), 1)
|
||||
|
||||
def test_context_tracking_direct_calls(self):
|
||||
|
||||
@api.do_not_convert()
|
||||
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
|
||||
# Used by py_util.cc to get tracebacks.
|
||||
@ -39,9 +40,15 @@ from tensorflow.python.ops import gen_script_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import lazy_loader
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
autograph = lazy_loader.LazyLoader(
|
||||
"autograph", globals(),
|
||||
"tensorflow.python.autograph.impl.api")
|
||||
|
||||
|
||||
# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs);
|
||||
# used for differentiation.
|
||||
tape_cache = {}
|
||||
@ -275,6 +282,23 @@ def _internal_py_func(func,
|
||||
raise ValueError("Expected func to be callable, got func of type {}".format(
|
||||
type(func)))
|
||||
|
||||
original_func = func
|
||||
func = autograph.do_not_convert(func)
|
||||
|
||||
# Tying the registered function's lifetime with the current default graph is
|
||||
# not reliable. For example, Estimator-based binaries may switch graphs in
|
||||
# between model training end evaluation, via saved_model. Those binaries work
|
||||
# because the original function is global, and break once the registered
|
||||
# function is an anonymous lambda, like the one produced by do_not_convert.
|
||||
# To avoid breaking those cases, we attach the wrapper to the original
|
||||
# function so that their lifetime is connected.
|
||||
# TODO(b/144286616): Remove this.
|
||||
if inspect.isfunction(original_func):
|
||||
# Note: this check is needed because original_func may be a descriptor
|
||||
# (https://docs.python.org/3/howto/descriptor.html)
|
||||
# and we can't attach attributes to those.
|
||||
original_func.ag_dnc_wrapper__ = func
|
||||
|
||||
is_list_or_tuple = False
|
||||
if isinstance(Tout, (list, tuple)):
|
||||
is_list_or_tuple = True
|
||||
|
Loading…
Reference in New Issue
Block a user