Add supplemental verification to abort recursive conversion when it has been disabled. This covers code that has been already converted and is called through a do_not_convert wrapper. The code that has already been converted (e.g. as a function local to a previously-converted function) will remain converted, but any unconverted functions that that code calls will no longer be recursively converted.

Example:

```
def h():
  ...

@tf.function
def f():
  def g():
    return h()
  return tf.autograph.do_not_convert(g)()
```

Before this change, `h()` was being transitively converted. After this change, `h()` will no longer be converted.

Mark py_function's targets with do_not_convert.

PiperOrigin-RevId: 281280243
Change-Id: I49f6c92350815a1e8eb38a51d919fe26055a0bd4
This commit is contained in:
Dan Moldovan 2019-11-19 06:12:33 -08:00 committed by TensorFlower Gardener
parent 67ed6a56ca
commit f2839145cf
4 changed files with 98 additions and 10 deletions

View File

@ -63,5 +63,15 @@ class ControlStatusCtx(object):
_control_ctx().pop()
class NullCtx(object):
"""Helper substitute for contextlib.nullcontext."""
def __enter__(self):
pass
def __exit__(self, unused_type, unused_value, unused_traceback):
pass
def _default_control_status_ctx():
return ControlStatusCtx(status=Status.UNSPECIFIED)

View File

@ -179,17 +179,28 @@ def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
decorators, f = tf_decorator.unwrap(f)
# TODO(mdan): Grab features from context.
# Note: we pass the original context through to convert to properly handle the
# following scenario, which can be used insite TF implementations:
#
# ctx = ag_ctx.control_status_ctx()
# @function(autograph=False) # Low-level graph code
# def inner_fn():
# # The context is disabled here, but should be enabled in user user_fn
# tf_convert(user_fn, ctx=ctx)
if ctx.status == ag_ctx.Status.ENABLED:
wrapper = convert(recursive=True, user_requested=user_requested)(f)
wrapper_factory = convert(
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
elif ctx.status == ag_ctx.Status.DISABLED:
wrapper = do_not_convert(f)
wrapper_factory = do_not_convert
elif ctx.status == ag_ctx.Status.UNSPECIFIED:
if convert_by_default:
wrapper = convert(recursive=True, user_requested=user_requested)(f)
wrapper_factory = convert(
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
else:
wrapper = call_with_unspecified_conversion_status(f)
wrapper_factory = call_with_unspecified_conversion_status
else:
raise ValueError(ctx.status)
assert False, 'This switch contains all possible cases!'
wrapper = wrapper_factory(f)
if decorators:
wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
@ -199,7 +210,10 @@ def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
# TODO(mdan): Make private.
def convert(recursive=False, optional_features=None, user_requested=True):
def convert(recursive=False,
optional_features=None,
user_requested=True,
conversion_ctx=ag_ctx.NullCtx()):
"""Decorator that compiles a function to use TensorFlow ops.
The decorator is dynamic - it recompiles the target whenever the decorated
@ -213,8 +227,10 @@ def convert(recursive=False, optional_features=None, user_requested=True):
optional_features: converted.Feature, allows toggling optional or
experimental features. When set to None, only the core features are
enabled.
user_requested: bool, whether to ignore the conversion whitelist. See
ConversionOptions.user_requested.
user_requested: bool, whether this is a function that the user explicitly
asked to be converted. See ConversionOptions.user_requested.
conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
which `f` is used.
Returns:
Callable, a decorator that converts the given function into an equivalent
@ -231,7 +247,8 @@ def convert(recursive=False, optional_features=None, user_requested=True):
user_requested=user_requested,
optional_features=optional_features)
try:
return converted_call(f, args, kwargs, options=options)
with conversion_ctx:
return converted_call(f, args, kwargs, options=options)
except Exception as e: # pylint:disable=broad-except
if hasattr(e, 'ag_error_metadata'):
raise e.ag_error_metadata.to_exception(e)
@ -368,7 +385,11 @@ def _errors_are_normally_possible(entity, error):
return False
def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
def converted_call(f,
args,
kwargs,
caller_fn_scope=None,
options=None):
"""Compiles a function call inline.
For internal use only.
@ -405,6 +426,10 @@ def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
if conversion.check_cached_unconverted(f, options):
return _call_unconverted(f, args, kwargs, options, False)
if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f)
return _call_unconverted(f, args, kwargs, options, False)
if inspect_utils.isbuiltin(f):
if f is eval:
return py_builtins.eval_in_original_context(f, args, caller_fn_scope)

View File

@ -722,6 +722,35 @@ class ApiTest(test.TestCase):
self.assertNoMemoryLeaks(test_fn)
def test_converted_call_no_caching_on_abort(self):
def test_fn(needs_autograph):
if needs_autograph:
if constant_op.constant(True):
x = constant_op.constant(1)
else:
x = constant_op.constant(2)
else:
x = 3
return x
def call_in_disabled_context():
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
return api.converted_call(
test_fn, (False,), None, options=DEFAULT_RECURSIVE)
def call_in_default_context():
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED):
return api.converted_call(
test_fn, (True,), None, options=DEFAULT_RECURSIVE)
# Note: this is an invariant, not a test (see above).
assert call_in_disabled_context() == 3
# If api.convert placed test_fn in the unconverted cache, this second
# invocation would fail.
self.assertEqual(self.evaluate(call_in_default_context()), 1)
def test_context_tracking_direct_calls(self):
@api.do_not_convert()

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import threading
# Used by py_util.cc to get tracebacks.
@ -39,9 +40,15 @@ from tensorflow.python.ops import gen_script_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
autograph = lazy_loader.LazyLoader(
"autograph", globals(),
"tensorflow.python.autograph.impl.api")
# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs);
# used for differentiation.
tape_cache = {}
@ -275,6 +282,23 @@ def _internal_py_func(func,
raise ValueError("Expected func to be callable, got func of type {}".format(
type(func)))
original_func = func
func = autograph.do_not_convert(func)
# Tying the registered function's lifetime with the current default graph is
# not reliable. For example, Estimator-based binaries may switch graphs in
# between model training end evaluation, via saved_model. Those binaries work
# because the original function is global, and break once the registered
# function is an anonymous lambda, like the one produced by do_not_convert.
# To avoid breaking those cases, we attach the wrapper to the original
# function so that their lifetime is connected.
# TODO(b/144286616): Remove this.
if inspect.isfunction(original_func):
# Note: this check is needed because original_func may be a descriptor
# (https://docs.python.org/3/howto/descriptor.html)
# and we can't attach attributes to those.
original_func.ag_dnc_wrapper__ = func
is_list_or_tuple = False
if isinstance(Tout, (list, tuple)):
is_list_or_tuple = True