Preserve the UNSPECIFIED conversion status when not converting by default. Fix the conversion decorator tot only apply functools.wraps if the target is a function.

PiperOrigin-RevId: 254323896
This commit is contained in:
Dan Moldovan 2019-06-20 20:15:29 -07:00 committed by TensorFlower Gardener
parent 9fe47db864
commit 40b85798db
2 changed files with 59 additions and 8 deletions

View File

@ -159,14 +159,18 @@ def tf_convert(f, ctx, convert_by_default=True, force_conversion=False):
f_wrapper = f f_wrapper = f
decorators, f = tf_decorator.unwrap(f) decorators, f = tf_decorator.unwrap(f)
apply_autograph = ((ctx.status == ag_ctx.Status.ENABLED) or
(convert_by_default and
ctx.status == ag_ctx.Status.UNSPECIFIED))
if apply_autograph:
# TODO(mdan): Grab features from context. # TODO(mdan): Grab features from context.
if ctx.status == ag_ctx.Status.ENABLED:
wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
elif ctx.status == ag_ctx.Status.DISABLED:
wrapper = do_not_convert(f)
elif ctx.status == ag_ctx.Status.UNSPECIFIED:
if convert_by_default:
wrapper = convert(recursive=True, force_conversion=force_conversion)(f) wrapper = convert(recursive=True, force_conversion=force_conversion)(f)
else: else:
wrapper = do_not_convert(f) wrapper = call_with_unspecified_conversion_status(f)
else:
raise ValueError(ctx.status)
if decorators: if decorators:
wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper) wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
@ -246,6 +250,19 @@ class RunMode(Enum):
PY_FUNC = 2 PY_FUNC = 2
def call_with_unspecified_conversion_status(func):
"""Decorator that resets the conversion context to the unspecified status."""
def wrapper(*args, **kwargs):
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
return func(*args, **kwargs)
if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)
setattr(wrapper, '__ag_compiled', True)
return wrapper
def do_not_convert_internal(f): def do_not_convert_internal(f):
"""Decorator that marks internal functions which do not need conversion.""" """Decorator that marks internal functions which do not need conversion."""
setattr(f, '__ag_compiled', True) setattr(f, '__ag_compiled', True)
@ -279,12 +296,10 @@ def do_not_convert(func=None, run_as=RunMode.GRAPH, return_dtypes=None):
run_as=run_as, run_as=run_as,
return_dtypes=return_dtypes) return_dtypes=return_dtypes)
@functools.wraps(func)
def graph_wrapper(*args, **kwargs): def graph_wrapper(*args, **kwargs):
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
return func(*args, **kwargs) return func(*args, **kwargs)
@functools.wraps(func)
def py_func_wrapper(*args, **kwargs): def py_func_wrapper(*args, **kwargs):
if kwargs: if kwargs:
raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
@ -299,6 +314,9 @@ def do_not_convert(func=None, run_as=RunMode.GRAPH, return_dtypes=None):
else: else:
raise ValueError('unknown value for run_as: %s' % run_as) raise ValueError('unknown value for run_as: %s' % run_as)
if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)
setattr(wrapper, '__ag_compiled', True) setattr(wrapper, '__ag_compiled', True)
return wrapper return wrapper

View File

@ -212,6 +212,16 @@ class ApiTest(test.TestCase):
self.assertEqual((), self.assertEqual((),
tuple(function_utils.fn_args(tc.test_method_whitelisted))) tuple(function_utils.fn_args(tc.test_method_whitelisted)))
def test_do_not_convert_callable_object(self):
class TestClass(object):
def __call__(self):
return 1
tc = TestClass()
self.assertEqual(1, api.do_not_convert(tc)())
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def test_convert_call_site_decorator(self): def test_convert_call_site_decorator(self):
@ -729,6 +739,12 @@ class ApiTest(test.TestCase):
self.assertEqual( self.assertEqual(
ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED) ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)
@api.call_with_unspecified_conversion_status
def unspecified_fn():
self.assertEqual(
ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)
unspecified_fn()
def test_to_graph_basic(self): def test_to_graph_basic(self):
def test_fn(x, s): def test_fn(x, s):
@ -888,6 +904,23 @@ class ApiTest(test.TestCase):
# The code in `f` is only valid with AutoGraph. # The code in `f` is only valid with AutoGraph.
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED)) test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
def test_tf_convert_unspecified_not_converted_by_default(self):
def f():
self.assertEqual(
ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)
if tf.reduce_sum([1, 2]) > 0:
return -1
return 1
@def_function.function
def test_fn(ctx):
return api.tf_convert(f, ctx, convert_by_default=False)()
with self.assertRaisesRegex(TypeError, 'tf.Tensor.*bool'):
# The code in `f` is only valid with AutoGraph.
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
def test_tf_convert_whitelisted_method(self): def test_tf_convert_whitelisted_method(self):
model = sequential.Sequential([ model = sequential.Sequential([