Preserve the UNSPECIFIED conversion status when not converting by default. Fix the conversion decorator tot only apply functools.wraps if the target is a function.
PiperOrigin-RevId: 254323896
This commit is contained in:
parent
9fe47db864
commit
40b85798db
@ -159,14 +159,18 @@ def tf_convert(f, ctx, convert_by_default=True, force_conversion=False):
|
||||
f_wrapper = f
|
||||
decorators, f = tf_decorator.unwrap(f)
|
||||
|
||||
apply_autograph = ((ctx.status == ag_ctx.Status.ENABLED) or
|
||||
(convert_by_default and
|
||||
ctx.status == ag_ctx.Status.UNSPECIFIED))
|
||||
if apply_autograph:
|
||||
# TODO(mdan): Grab features from context.
|
||||
if ctx.status == ag_ctx.Status.ENABLED:
|
||||
wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
|
||||
elif ctx.status == ag_ctx.Status.DISABLED:
|
||||
wrapper = do_not_convert(f)
|
||||
elif ctx.status == ag_ctx.Status.UNSPECIFIED:
|
||||
if convert_by_default:
|
||||
wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
|
||||
else:
|
||||
wrapper = do_not_convert(f)
|
||||
wrapper = call_with_unspecified_conversion_status(f)
|
||||
else:
|
||||
raise ValueError(ctx.status)
|
||||
|
||||
if decorators:
|
||||
wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
|
||||
@ -246,6 +250,19 @@ class RunMode(Enum):
|
||||
PY_FUNC = 2
|
||||
|
||||
|
||||
def call_with_unspecified_conversion_status(func):
|
||||
"""Decorator that resets the conversion context to the unspecified status."""
|
||||
def wrapper(*args, **kwargs):
|
||||
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if inspect.isfunction(func) or inspect.ismethod(func):
|
||||
wrapper = functools.update_wrapper(wrapper, func)
|
||||
|
||||
setattr(wrapper, '__ag_compiled', True)
|
||||
return wrapper
|
||||
|
||||
|
||||
def do_not_convert_internal(f):
|
||||
"""Decorator that marks internal functions which do not need conversion."""
|
||||
setattr(f, '__ag_compiled', True)
|
||||
@ -279,12 +296,10 @@ def do_not_convert(func=None, run_as=RunMode.GRAPH, return_dtypes=None):
|
||||
run_as=run_as,
|
||||
return_dtypes=return_dtypes)
|
||||
|
||||
@functools.wraps(func)
|
||||
def graph_wrapper(*args, **kwargs):
|
||||
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@functools.wraps(func)
|
||||
def py_func_wrapper(*args, **kwargs):
|
||||
if kwargs:
|
||||
raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
|
||||
@ -299,6 +314,9 @@ def do_not_convert(func=None, run_as=RunMode.GRAPH, return_dtypes=None):
|
||||
else:
|
||||
raise ValueError('unknown value for run_as: %s' % run_as)
|
||||
|
||||
if inspect.isfunction(func) or inspect.ismethod(func):
|
||||
wrapper = functools.update_wrapper(wrapper, func)
|
||||
|
||||
setattr(wrapper, '__ag_compiled', True)
|
||||
return wrapper
|
||||
|
||||
|
@ -212,6 +212,16 @@ class ApiTest(test.TestCase):
|
||||
self.assertEqual((),
|
||||
tuple(function_utils.fn_args(tc.test_method_whitelisted)))
|
||||
|
||||
def test_do_not_convert_callable_object(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def __call__(self):
|
||||
return 1
|
||||
|
||||
tc = TestClass()
|
||||
self.assertEqual(1, api.do_not_convert(tc)())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_convert_call_site_decorator(self):
|
||||
|
||||
@ -729,6 +739,12 @@ class ApiTest(test.TestCase):
|
||||
self.assertEqual(
|
||||
ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)
|
||||
|
||||
@api.call_with_unspecified_conversion_status
|
||||
def unspecified_fn():
|
||||
self.assertEqual(
|
||||
ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)
|
||||
unspecified_fn()
|
||||
|
||||
def test_to_graph_basic(self):
|
||||
|
||||
def test_fn(x, s):
|
||||
@ -888,6 +904,23 @@ class ApiTest(test.TestCase):
|
||||
# The code in `f` is only valid with AutoGraph.
|
||||
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
|
||||
|
||||
def test_tf_convert_unspecified_not_converted_by_default(self):
|
||||
|
||||
def f():
|
||||
self.assertEqual(
|
||||
ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)
|
||||
if tf.reduce_sum([1, 2]) > 0:
|
||||
return -1
|
||||
return 1
|
||||
|
||||
@def_function.function
|
||||
def test_fn(ctx):
|
||||
return api.tf_convert(f, ctx, convert_by_default=False)()
|
||||
|
||||
with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'):
|
||||
# The code in `f` is only valid with AutoGraph.
|
||||
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
|
||||
|
||||
def test_tf_convert_whitelisted_method(self):
|
||||
|
||||
model = sequential.Sequential([
|
||||
|
Loading…
Reference in New Issue
Block a user