STT-tensorflow/tensorflow/python/autograph/impl/api_test.py
Ran Chen 8c7d50fcfd Fix tf_convert on decorated methods, again
The original fix was incorrect. tf_convert() should continue to return the bound
method, instead of the underlying function.

PiperOrigin-RevId: 334471644
Change-Id: Ic48ce06e54dd9240144bf946aa5a987a87078de5
2020-09-29 15:16:29 -07:00

1267 lines
35 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for api module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import contextlib
import functools
import gc
import imp
import os
import re
import sys
import textwrap
import types
import numpy as np
import six
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.impl import conversion
from tensorflow.python.autograph.pyct import errors
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
global_n = 2
DEFAULT_RECURSIVE = converter.ConversionOptions(recursive=True)
class TestResource(object):
def __init__(self):
self.x = 3
class ApiTest(test.TestCase):
@contextlib.contextmanager
def assertPrints(self, expected, not_expected):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
yield
self.assertIn(expected, out_capturer.getvalue())
self.assertNotIn(not_expected, out_capturer.getvalue())
finally:
sys.stdout = sys.__stdout__
def assertNoMemoryLeaks(self, f):
object_ids_before = {id(o) for o in gc.get_objects()}
f()
gc.collect()
objects_after = tuple(
o for o in gc.get_objects() if id(o) not in object_ids_before)
self.assertEmpty(
tuple(o for o in objects_after if isinstance(o, TestResource)))
def test_converted_call_kwonly_args(self):
def test_fn(*, a):
return a
x = api.converted_call(
test_fn, (), {'a': constant_op.constant(-1)}, options=DEFAULT_RECURSIVE)
self.assertEqual(-1, self.evaluate(x))
def test_super_with_no_arg(self):
test_case_self = self
class TestBase:
def plus_three(self, x):
return x + 3
class TestSubclass(TestBase):
def plus_three(self, x):
test_case_self.fail('This should never be called.')
def no_arg(self, x):
return super().plus_three(x)
tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE)
self.assertEqual(5, tc.no_arg(2))
def test_converted_call_avoids_triggering_operators(self):
test_self = self
class Pair(collections.namedtuple('Pair', ['a', 'b'])):
def __call__(self):
return self.a + self.b
def __eq__(self, other):
test_self.fail('Triggered operator')
p = Pair(constant_op.constant(1), constant_op.constant(2))
x = api.converted_call(p, (), {}, options=DEFAULT_RECURSIVE)
self.assertIsNotNone(self.evaluate(x), 3)
@test_util.run_deprecated_v1
def test_decorator_recursive(self):
class TestClass(object):
def called_member(self, a):
if a < 0:
a = -a
return a
@api.convert(recursive=True)
def test_method(self, x, s, a):
while math_ops.reduce_sum(x) > s:
x //= self.called_member(a)
return x
tc = TestClass()
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1
def test_decorator_not_recursive(self):
class TestClass(object):
def called_member(self, a):
return math_ops.negative(a)
@api.convert(recursive=False)
def test_method(self, x, s, a):
while math_ops.reduce_sum(x) > s:
x //= self.called_member(a)
return x
tc = TestClass()
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1
def test_convert_then_do_not_convert(self):
class TestClass(object):
@api.do_not_convert
def called_member(self, a):
return math_ops.negative(a)
@api.convert(recursive=True)
def test_method(self, x, s, a):
while math_ops.reduce_sum(x) > s:
x //= self.called_member(a)
return x
tc = TestClass()
x = tc.test_method(
constant_op.constant((2, 4)), constant_op.constant(1),
constant_op.constant(-2))
self.assertAllEqual((0, 1), self.evaluate(x))
@test_util.run_deprecated_v1
def test_decorator_calls_decorated(self):
class TestClass(object):
@api.convert()
def called_member(self, a):
if a < 0:
a = -a
return a
@api.convert(recursive=True)
def test_method(self, x, s, a):
while math_ops.reduce_sum(x) > s:
x //= self.called_member(a)
return x
tc = TestClass()
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_decorator_preserves_argspec(self):
class TestClass(object):
def test_method(self, a):
if a < 0:
a = -a
return a
test_method_converted = api.convert()(test_method)
tc = TestClass()
self.assertListEqual(
list(tf_inspect.getfullargspec(tc.test_method)),
list(tf_inspect.getfullargspec(tc.test_method_converted)))
def test_do_not_convert_argspec(self):
class TestClass(object):
def test_method(self, x, y):
z = x + y
return z
test_method_allowlisted = api.do_not_convert(test_method)
tc = TestClass()
self.assertTrue(tf_inspect.ismethod(tc.test_method_allowlisted))
# Because the wrapped function is not generated, we can't preserve its
# arg spec.
self.assertEqual((),
tuple(function_utils.fn_args(tc.test_method_allowlisted)))
def test_do_not_convert_callable_object(self):
class TestClass(object):
def __call__(self):
return 1
tc = TestClass()
self.assertEqual(1, api.do_not_convert(tc)())
@test_util.run_deprecated_v1
def test_convert_call_site_decorator(self):
class TestClass(object):
def called_member(self, a):
if a < 0:
a = -a
return a
@api.convert(recursive=True)
def test_method(self, x, s, a):
while math_ops.reduce_sum(x) > s:
x //= api.converted_call(
self.called_member, (a,), None, options=DEFAULT_RECURSIVE)
return x
tc = TestClass()
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
def test_converted_call_builtin(self):
x = api.converted_call(range, (3,), None, options=DEFAULT_RECURSIVE)
self.assertEqual((0, 1, 2), tuple(x))
x = api.converted_call(
re.compile, ('mnas_v4_a.*\\/.*(weights|kernel):0$',),
None,
options=DEFAULT_RECURSIVE)
self.assertIsNotNone(x.match('mnas_v4_a/weights:0'))
def test_converted_call_function(self):
def test_fn(x):
if x < 0:
return -x
return x
x = api.converted_call(
test_fn, (constant_op.constant(-1),), None, options=DEFAULT_RECURSIVE)
self.assertEqual(1, self.evaluate(x))
@test_util.run_v1_only('b/120545219')
def test_converted_call_functools_partial(self):
def test_fn(x, y, z):
if x < 0:
return -x, -y, -z
return x, y, z
x = api.converted_call(
functools.partial(test_fn, constant_op.constant(-1), z=-3),
(constant_op.constant(-2),),
None,
options=DEFAULT_RECURSIVE)
self.assertEqual((1, 2, 3), self.evaluate(x))
x = api.converted_call(
functools.partial(
functools.partial(test_fn, constant_op.constant(-1)), z=-3),
(constant_op.constant(-2),),
None,
options=DEFAULT_RECURSIVE)
self.assertEqual((1, 2, 3), self.evaluate(x))
@test_util.run_v1_only('b/120545219')
def test_converted_call_functools_partial_kwarg_mutation(self):
def test_fn(x, y, z):
if x < 0:
return -x, -y, -z
return x, y, z
partial_fn = functools.partial(test_fn, constant_op.constant(-1), z=-3)
# Call using kwargs to assign y first to ensure that partial_fn.keywords is
# not mutated for subsequent calls (where y is assign through args).
x = api.converted_call(
partial_fn,
args=(),
kwargs={
'y': constant_op.constant(-2),
},
options=DEFAULT_RECURSIVE)
self.assertEqual((1, 2, 3), self.evaluate(x))
x = api.converted_call(
partial_fn,
args=(constant_op.constant(-4),),
kwargs=None,
options=DEFAULT_RECURSIVE)
self.assertEqual((1, 4, 3), self.evaluate(x))
def test_converted_call_method(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def test_method(self):
if self.x < 0:
return -self.x
return self.x
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
self.assertEqual(1, self.evaluate(x))
def test_converted_call_synthetic_method(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def test_function(self):
if self.x < 0:
return -self.x
return self.x
tc = TestClass(constant_op.constant(-1))
test_method = types.MethodType(test_function, tc)
x = api.converted_call(test_method, (), None, options=DEFAULT_RECURSIVE)
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_wrapper(self):
class TestClass(object):
def foo(self):
pass
tc = TestClass()
# `method.__get__()` returns a so-called method-wrapper.
wrapper = api.converted_call(
tc.foo.__get__, (tc,), None, options=DEFAULT_RECURSIVE)
self.assertEqual(wrapper, tc.foo)
def test_converted_call_method_as_object_attribute(self):
class AnotherClass(object):
def __init__(self):
self.another_class_attr = constant_op.constant(1)
def method(self):
if self.another_class_attr > 0:
return self.another_class_attr + 1
return self.another_class_attr + 10
class TestClass(object):
def __init__(self, another_obj_method):
self.another_obj_method = another_obj_method
obj = AnotherClass()
tc = TestClass(obj.method)
x = api.converted_call(
tc.another_obj_method, (), None, options=DEFAULT_RECURSIVE)
self.assertEqual(self.evaluate(x), 2)
def test_converted_call_method_converts_recursively(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def other_method(self):
if self.x < 0:
return -self.x
return self.x
def test_method(self):
return self.other_method()
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_by_class(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def test_method(self):
if self.x < 0:
return -self.x
return self.x
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(
TestClass.test_method, (tc,), None, options=DEFAULT_RECURSIVE)
self.assertEqual(1, self.evaluate(x))
def test_converted_call_callable_object(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def __call__(self):
if self.x < 0:
return -self.x
return self.x
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc, (), None, options=DEFAULT_RECURSIVE)
self.assertEqual(1, self.evaluate(x))
def test_converted_call_callable_metaclass(self):
test_self = self
class TestMetaclass(type):
def __call__(cls):
self.assertTrue(converter_testing.is_inside_generated_code())
inst = object.__new__(cls)
inst.__init__()
def instance_call(unused_self):
test_self.fail(
'The class-bound __call__ should be called, not the instance'
' bound one.')
inst.__call__ = instance_call
return inst
tmc = TestMetaclass('TestClass', (), {})
tc = api.converted_call(tmc, (), None, options=DEFAULT_RECURSIVE)
self.assertIsInstance(tc, tmc)
def test_converted_call_callable_abc(self):
test_self = self
@six.add_metaclass(abc.ABCMeta)
class TestBase(object):
@abc.abstractmethod
def __call__(self):
test_self.fail('This should not be called')
class TestSubclass(TestBase):
def __init__(self):
test_self.assertFalse(converter_testing.is_inside_generated_code())
def __call__(self, expected):
test_self.assertTrue(expected)
test_self.assertTrue(converter_testing.is_inside_generated_code())
tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE)
api.converted_call(tc, (True,), None, options=DEFAULT_RECURSIVE)
@test_util.run_deprecated_v1
def test_converted_call_constructor(self):
test_self = self
class TestClass(object):
def __init__(self):
test_self.assertFalse(converter_testing.is_inside_generated_code())
tc = api.converted_call(TestClass, (), None, options=DEFAULT_RECURSIVE)
self.assertIsInstance(tc, TestClass)
def test_converted_call_mangled_properties(self):
class TestClass(object):
def __init__(self):
self.__private = constant_op.constant(-1)
def test_method(self):
return self.__private
tc = TestClass()
with self.assertRaisesRegex(
errors.UnsupportedLanguageElementError, 'mangled names'):
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
# TODO(mdan): Refactor to avoid this use of global state.
ag_logging.set_verbosity(0, True)
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '0'
with self.assertPrints('could not transform', 'bug'):
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
ag_logging.set_verbosity(0, False)
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
def test_converted_call_partial_of_allowlisted_function(self):
def test_fn(_):
self.assertFalse(converter_testing.is_inside_generated_code())
converter_testing.allowlist(test_fn)
api.converted_call(
functools.partial(test_fn, None), (), None, options=DEFAULT_RECURSIVE)
def test_converted_call_already_converted(self):
def f(x):
return x == 0
x = api.converted_call(
f, (constant_op.constant(0),), None, options=DEFAULT_RECURSIVE)
self.assertTrue(self.evaluate(x))
converted_f = api.to_graph(
f, experimental_optional_features=converter.Feature.ALL)
x = api.converted_call(
converted_f, (constant_op.constant(0),),
None,
options=DEFAULT_RECURSIVE)
self.assertTrue(self.evaluate(x))
def test_converted_call_then_already_converted_dynamic(self):
@api.convert()
def g(x):
if x > 0:
return x
else:
return -x
def f(g, x):
return g(x)
x = api.converted_call(
f, (g, constant_op.constant(1)), None, options=DEFAULT_RECURSIVE)
self.assertEqual(self.evaluate(x), 1)
def test_converted_call_forced_when_explicitly_allowlisted(self):
@api.do_not_convert()
def f(x):
return x + 1
opts = converter.ConversionOptions(recursive=True, user_requested=True)
x = api.converted_call(f, (constant_op.constant(0),), None, options=opts)
self.assertTrue(self.evaluate(x))
converted_f = api.to_graph(
f, experimental_optional_features=converter.Feature.ALL)
x = api.converted_call(converted_f, (0,), None, options=DEFAULT_RECURSIVE)
self.assertEqual(x, 1)
@test_util.run_deprecated_v1
def test_converted_call_no_user_code(self):
def f(x):
return len(x)
opts = converter.ConversionOptions(internal_convert_user_code=False)
# f should not be converted, causing len to error out.
with self.assertRaisesRegex(Exception, 'len is not well defined'):
api.converted_call(f, (constant_op.constant([0]),), None, options=opts)
# len on the other hand should work fine.
x = api.converted_call(
len, (constant_op.constant([0]),), None, options=opts)
# The constant has static shape so the result is a primitive not a Tensor.
self.assertEqual(x, 1)
def test_converted_call_no_kwargs_allowed(self):
def f(*args):
# Note: np.broadcast rejects any **kwargs, even *{}
return np.broadcast(args[:1])
opts = converter.ConversionOptions(internal_convert_user_code=False)
self.assertIsNotNone(
api.converted_call(f, (1, 2, 3, 4), None, options=opts))
def test_converted_call_allowlisted_method(self):
class TestClass(object):
def method(self):
return converter_testing.is_inside_generated_code()
obj = TestClass()
converter_testing.allowlist(obj.method.__func__)
self.assertFalse(
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
def test_converted_call_allowlisted_method_via_owner(self):
class TestClass(object):
def method(self):
return converter_testing.is_inside_generated_code()
converter_testing.allowlist(TestClass)
obj = TestClass()
self.assertFalse(
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
def test_converted_call_numpy(self):
x = api.converted_call(np.arange, (5,), None, options=DEFAULT_RECURSIVE)
self.assertAllEqual(x, list(range(5)))
def test_converted_call_tf_op_forced(self):
# TODO(mdan): Add the missing level of support to LOGICAL_EXPRESSIONS.
opts = converter.ConversionOptions(
user_requested=True, optional_features=None)
x = api.converted_call(math_ops.add, (1, 1), None, options=opts)
self.assertAllEqual(self.evaluate(x), 2)
def test_converted_call_exec_generated_code(self):
temp_mod = imp.new_module('test_module')
dynamic_code = """
def foo(x):
return x + 1
"""
exec(textwrap.dedent(dynamic_code), temp_mod.__dict__) # pylint:disable=exec-used
opts = converter.ConversionOptions(optional_features=None)
x = api.converted_call(temp_mod.foo, (1,), None, options=opts)
self.assertAllEqual(x, 2)
def test_converted_call_namedtuple(self):
x = api.converted_call(
collections.namedtuple, ('TestNamedtuple', ('a', 'b')),
None,
options=DEFAULT_RECURSIVE)
self.assertTrue(inspect_utils.isnamedtuple(x))
def test_converted_call_namedtuple_via_collections(self):
x = api.converted_call(
collections.namedtuple, ('TestNamedtuple', ('a', 'b')),
None,
options=DEFAULT_RECURSIVE)
self.assertTrue(inspect_utils.isnamedtuple(x))
def test_converted_call_namedtuple_subclass_bound_method(self):
class TestClass(collections.namedtuple('TestNamedtuple', ('a', 'b'))):
def test_method(self, x):
while math_ops.reduce_sum(x) > self.a:
x //= self.b
return x
obj = TestClass(5, 2)
x = api.converted_call(
obj.test_method, (constant_op.constant([2, 4]),),
None,
options=DEFAULT_RECURSIVE)
self.assertAllEqual(self.evaluate(x), [1, 2])
def test_converted_call_namedtuple_method(self):
class TestClass(collections.namedtuple('TestNamedtuple', ('a', 'b'))):
pass
obj = TestClass(5, 2)
# _asdict is a documented method of namedtuple.
x = api.converted_call(obj._asdict, (), None, options=DEFAULT_RECURSIVE)
self.assertDictEqual(x, {'a': 5, 'b': 2})
def test_converted_call_namedtuple_subclass_unbound_method(self):
class TestClass(collections.namedtuple('TestNamedtuple', ('a', 'b'))):
def test_method(self, x):
while math_ops.reduce_sum(x) > self.a:
x //= self.b
return x
obj = TestClass(5, 2)
x = api.converted_call(
TestClass.test_method, (obj, constant_op.constant([2, 4])),
None,
options=DEFAULT_RECURSIVE)
self.assertAllEqual(self.evaluate(x), [1, 2])
def test_converted_call_lambda(self):
l = lambda x: x == 0
x = api.converted_call(
l, (constant_op.constant(0),), None, options=DEFAULT_RECURSIVE)
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(True, self.evaluate(x))
def test_converted_call_defun_object_method(self):
# pylint:disable=method-hidden
class TestClass(object):
def method(self):
return 1
def prepare(self):
self.method = function.defun(self.method)
# pylint:enable=method-hidden
tc = TestClass()
tc.prepare()
x = api.converted_call(tc.method, (), None, options=DEFAULT_RECURSIVE)
self.assertAllEqual(1, self.evaluate(x))
def test_converted_call_native_binding(self):
x = api.converted_call(np.power, (2, 2), None, options=DEFAULT_RECURSIVE)
self.assertAllEqual(x, 4)
def test_converted_call_native_binding_errorneous(self):
class FaultyBinding(object):
def __array__(self):
raise ValueError('fault')
bad_obj = FaultyBinding()
def fail_if_warning(*_):
self.fail('No warning should be issued')
with test.mock.patch.object(ag_logging, 'warn', fail_if_warning):
with self.assertRaisesRegex(ValueError, 'fault'):
api.converted_call(
np.power, (bad_obj, 2), None, options=DEFAULT_RECURSIVE)
def test_converted_call_through_tf_dataset(self):
def other_fn(x):
if x > 0:
return x
return -x
def f():
return dataset_ops.Dataset.range(-3, 3).map(other_fn)
# Dataset iteration only works inside math_ops.
@def_function.function
def graph_fn():
ds = api.converted_call(f, (), None, options=DEFAULT_RECURSIVE)
itr = iter(ds)
return next(itr), next(itr), next(itr)
self.assertAllEqual(self.evaluate(graph_fn()), (3, 2, 1))
def test_converted_call_no_leaks_via_closure(self):
def test_fn():
res = TestResource()
def f(y):
return res.x + y
api.converted_call(f, (1,), None, options=DEFAULT_RECURSIVE)
self.assertNoMemoryLeaks(test_fn)
def test_converted_call_no_leaks_via_inner_function_closure(self):
def test_fn():
res = TestResource()
def f(y):
def inner_f():
return res.x + y
return inner_f
api.converted_call(f, (1,), None, options=DEFAULT_RECURSIVE)()
self.assertNoMemoryLeaks(test_fn)
def test_converted_call_no_caching_on_abort(self):
def test_fn(needs_autograph):
if needs_autograph:
if constant_op.constant(True):
x = constant_op.constant(1)
else:
x = constant_op.constant(2)
else:
x = 3
return x
def call_in_disabled_context():
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
return api.converted_call(
test_fn, (False,), None, options=DEFAULT_RECURSIVE)
def call_in_default_context():
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED):
return api.converted_call(
test_fn, (True,), None, options=DEFAULT_RECURSIVE)
# Note: this is an invariant, not a test (see above).
assert call_in_disabled_context() == 3
# If api.convert placed test_fn in the unconverted cache, this second
# invocation would fail.
self.assertEqual(self.evaluate(call_in_default_context()), 1)
def test_converted_call_caching_of_allowlisted_bound_methods(self):
class TestClass(object):
def __init__(self):
self.__private = constant_op.constant(-1)
def test_method(self):
return self.__private
# TODO(mdan): Refactor to avoid this use of global state.
cache_size_before = len(conversion._ALLOWLIST_CACHE)
# First invocation with fallback on, to allow recording it into cache.
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '0'
tc = TestClass()
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
# Entry should be added to the allowlist cache.
self.assertEqual(len(conversion._ALLOWLIST_CACHE), cache_size_before + 1)
# A second invocation should go through even with fallback off.
tc = TestClass()
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
# No new entries should appear in the allowlist cache.
self.assertEqual(len(conversion._ALLOWLIST_CACHE), cache_size_before + 1)
def test_context_tracking_direct_calls(self):
@api.do_not_convert()
def unconverted_fn():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.DISABLED)
@api.convert()
def converted_fn():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.ENABLED)
unconverted_fn()
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.ENABLED)
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.UNSPECIFIED)
converted_fn()
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.UNSPECIFIED)
@api.call_with_unspecified_conversion_status
def unspecified_fn():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.UNSPECIFIED)
unspecified_fn()
def test_to_graph_basic(self):
def test_fn(x, s):
while math_ops.reduce_sum(x) > s:
x //= 2
return x
compiled_fn = api.to_graph(test_fn)
with ops.Graph().as_default():
x = compiled_fn(constant_op.constant((4, 8)), 4)
self.assertAllEqual(self.evaluate(x), (1, 2))
@test_util.run_deprecated_v1
def test_to_graph_with_defaults(self):
foo = 4
def test_fn(x, s=foo):
while math_ops.reduce_sum(x) > s:
x //= 2
return x
compiled_fn = api.to_graph(test_fn)
x = compiled_fn(constant_op.constant([4, 8]))
self.assertListEqual([1, 2], self.evaluate(x).tolist())
def test_to_graph_with_globals(self):
def test_fn(x):
global global_n
global_n = x + global_n
return global_n
converted_fn = api.to_graph(test_fn)
prev_val = global_n
converted_fn(10)
self.assertGreater(global_n, prev_val)
def test_to_graph_with_kwargs_clashing_converted_call(self):
def called_fn(**kwargs):
return kwargs['f'] + kwargs['owner']
def test_fn():
# These arg names intentionally match converted_call's
return called_fn(f=1, owner=2)
compiled_fn = api.to_graph(test_fn)
self.assertEqual(compiled_fn(), 3)
def test_to_graph_with_kwargs_clashing_unconverted_call(self):
@api.do_not_convert
def called_fn(**kwargs):
return kwargs['f'] + kwargs['owner']
def test_fn():
# These arg names intentionally match _call_unconverted's
return called_fn(f=1, owner=2)
compiled_fn = api.to_graph(test_fn)
self.assertEqual(compiled_fn(), 3)
def test_to_graph_caching(self):
def test_fn(x):
if x > 0:
return x
else:
return -x
converted_functions = tuple(api.to_graph(test_fn) for _ in (-1, 0, 1))
# All outputs are from the same module. We can't use __module__ because
# that's reset when we instantiate the function (see conversion.py).
# TODO(mdan): Can and should we overwrite __module__ instead?
module_names = frozenset(f.ag_module for f in converted_functions)
self.assertEqual(len(module_names), 1)
self.assertNotIn('__main__', module_names)
self.assertEqual(len(frozenset(id(f) for f in converted_functions)), 3)
def test_to_graph_caching_different_options(self):
def called_fn():
pass
def test_fn():
return called_fn()
converted_recursive = api.to_graph(test_fn, recursive=True)
converted_non_recursive = api.to_graph(test_fn, recursive=False)
self.assertNotEqual(converted_recursive.ag_module,
converted_non_recursive.ag_module)
self.assertRegex(
tf_inspect.getsource(converted_recursive),
'FunctionScope(.*recursive=True.*)')
self.assertRegex(
tf_inspect.getsource(converted_non_recursive),
'FunctionScope(.*recursive=False.*)')
def test_to_graph_preserves_bindings(self):
y = 3
def test_fn():
return y
converted = api.to_graph(test_fn)
self.assertEqual(converted(), 3)
y = 7
self.assertEqual(converted(), 7)
def test_to_graph_source_map(self):
def test_fn(y):
return y**2
self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
def test_to_graph_sets_conversion_context(self):
def g():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.ENABLED)
return 0
# Note: the autograph=False sets the connect to Status.DISABLED. The test
# verifies that to_graph overrides that.
@def_function.function(autograph=False)
def f():
converted_g = api.to_graph(g)
converted_g()
f()
def test_to_code_basic(self):
def test_fn(x, s):
while math_ops.reduce_sum(x) > s:
x /= 2
return x
# Just check that the output is parseable Python code.
self.assertIsNotNone(parser.parse(api.to_code(test_fn)))
def test_to_code_with_wrapped_function(self):
@def_function.function
def test_fn(x, s):
while math_ops.reduce_sum(x) > s:
x /= 2
return x
with self.assertRaisesRegex(Exception, 'try passing.*python_function'):
api.to_code(test_fn)
def test_tf_convert_overrides_current_context(self):
def f(expect_converted):
self.assertEqual(
converter_testing.is_inside_generated_code(), expect_converted)
@api.do_not_convert
def test_fn(ctx, expect_converted):
return api.tf_convert(f, ctx)(expect_converted)
test_fn(
ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED), True)
test_fn(
ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED), False)
def test_tf_convert_unspecified_not_converted_by_default(self):
def f():
self.assertEqual(ag_ctx.control_status_ctx().status,
ag_ctx.Status.UNSPECIFIED)
self.assertFalse(converter_testing.is_inside_generated_code())
@def_function.function
def test_fn(ctx):
return api.tf_convert(f, ctx, convert_by_default=False)()
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
def test_tf_convert_allowlisted_method(self):
if six.PY2:
self.skipTest('Test bank not comptible with Python 2.')
class TestClass(object):
def method(self):
return converter_testing.is_inside_generated_code()
converter_testing.allowlist(TestClass.method)
obj = TestClass()
converted_call = api.tf_convert(
obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
_, converted_target = tf_decorator.unwrap(converted_call)
self.assertIs(converted_target.__func__, obj.method.__func__)
def test_tf_convert_tf_decorator_unwrapping_context_enabled(self):
def f():
self.assertTrue(converter_testing.is_inside_generated_code())
@functools.wraps(f)
def wrapper(*args, **kwargs):
return wrapper.__wrapped__(*args, **kwargs)
decorated_f = tf_decorator.make_decorator(f, wrapper)
def test_fn(ctx):
return api.tf_convert(decorated_f, ctx)()
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
def test_tf_convert_tf_decorator_unwrapping_context_disabled(self):
def f():
self.assertFalse(converter_testing.is_inside_generated_code())
@functools.wraps(f)
def wrapper(*args, **kwargs):
return wrapper.__wrapped__(*args, **kwargs)
decorated_f = tf_decorator.make_decorator(f, wrapper)
def test_fn(ctx):
return api.tf_convert(decorated_f, ctx)()
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED))
def test_tf_convert_tf_decorator_allowlist_method(self):
def wrap(f):
def wrapper(*args, **kwargs):
return wrapper.__wrapped__(*args, **kwargs)
return tf_decorator.make_decorator(f, wrapper)
class TestClass(object):
@wrap
def method(self):
return converter_testing.is_inside_generated_code()
converter_testing.allowlist(TestClass.method)
obj = TestClass()
# It's intended that tf_convert modifies the original method in this case.
# This is not desirable, but options are limited.
converted = api.tf_convert(
obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
self.assertTrue(converted())
self.assertTrue(obj.method())
def test_super_with_one_arg(self):
test_case_self = self
class TestBase(object):
def plus_three(self, x):
return x + 3
class TestSubclass(TestBase):
def plus_three(self, x):
test_case_self.fail('This should never be called.')
def one_arg(self, x):
test_base_unbound = super(TestSubclass)
test_base = test_base_unbound.__get__(self, TestSubclass)
return test_base.plus_three(x)
tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE)
self.assertEqual(5, tc.one_arg(2))
def test_super_with_two_args(self):
test_case_self = self
class TestBase(object):
def plus_three(self, x):
return x + 3
class TestSubclass(TestBase):
def plus_three(self, x):
test_case_self.fail('This should never be called.')
def two_args(self, x):
return super(TestSubclass, self).plus_three(x)
tc = api.converted_call(TestSubclass, (), None, options=DEFAULT_RECURSIVE)
self.assertEqual(5, tc.two_args(2))
if __name__ == '__main__':
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
test.main()