Clean up the test to remove duplication and run with within the real framework.

PiperOrigin-RevId: 318854218
Change-Id: I45125a5a028a6d5fd7ae8cd9bed698d31d04196b
This commit is contained in:
Dan Moldovan 2020-06-29 11:23:20 -07:00 committed by TensorFlower Gardener
parent f072535ba3
commit ce054f48e6
19 changed files with 826 additions and 982 deletions

View File

@ -77,12 +77,6 @@ py_test(
srcs = ["call_trees_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = [
"no_oss_py2",
"no_pip",
"no_windows",
"nopip",
],
deps = [
":converters",
"//tensorflow/python:client_testlib",
@ -119,12 +113,6 @@ py_test(
srcs = ["control_flow_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = [
"no_oss_py2",
"no_pip",
"no_windows",
"nopip",
],
deps = [
":converters",
"//tensorflow/python:client_testlib",

View File

@ -24,7 +24,6 @@ from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
@ -32,17 +31,15 @@ class AssertsTest(converter_testing.TestCase):
def test_basic(self):
def test_fn(a):
def f(a):
assert a, 'testmsg'
return a
with ops.Graph().as_default():
with self.converted(
test_fn, (functions, asserts, return_statements), {}) as result:
op = result.test_fn(constant_op.constant(False))
tr = self.transform(f, (functions, asserts, return_statements))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
self.evaluate(op)
op = tr(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
self.evaluate(op)
if __name__ == '__main__':

View File

@ -21,20 +21,18 @@ from __future__ import print_function
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class BreakCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn, break_statements, {},
(constant_op.constant,)) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def assertTransformedEquivalent(self, f, *inputs):
tr = self.transform(f, break_statements)
self.assertEqual(f(*inputs), tr(*inputs))
def test_while_loop(self):
def test_fn(x):
def f(x):
v = []
while x > 0:
x -= 1
@ -43,28 +41,29 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(f, 4)
def test_while_loop_preserves_directives(self):
def test_fn(x):
def f(x):
while x > 0:
x -= 1
if x % 2 == 0:
break
node, ctx = self.prepare(test_fn, {})
_, node, ctx = self.transform(f, (), include_ast=True)
fake_annotation = object()
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
node = break_statements.transform(node, ctx)
self.assertIs(
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
def test_for_loop(self):
def test_fn(a):
def f(a):
v = []
for x in a:
x -= 1
@ -73,20 +72,18 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
with self.converted(test_fn, break_statements, {},
(constant_op.constant,)) as result:
# The break is incompletely canonicalized. The loop will not interrupt,
# but the section following the break will be skipped.
self.assertEqual([3], result.test_fn([5, 4]))
tr = self.transform(f, break_statements)
self.assertEqual([3], tr([5, 4]))
def test_for_loop_preserves_directives(self):
def test_fn(a):
def f(a):
for x in a:
if x % 2 == 0:
break
node, ctx = self.prepare(test_fn, {})
_, node, ctx = self.transform(f, (), include_ast=True)
fake_annotation = object()
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
node = break_statements.transform(node, ctx)
@ -95,7 +92,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
def test_nested(self):
def test_fn(x):
def f(x):
v = []
u = []
w = []
@ -110,13 +107,13 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 11)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 11)
def test_nested_loops(self):
def test_fn(x):
def f(x):
v = []
u = []
while x > 0:
@ -132,14 +129,14 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 5)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 5)
def test_loop_orelse(self):
def test_fn(x):
def f(x):
v = []
u = []
while x > 0:
@ -153,12 +150,12 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, 3)
def test_multiple_correlated_breaks_with_side_effects(self):
def test_fn(cond1):
def f(cond1):
lst = []
while True:
if cond1:
@ -169,8 +166,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
break
return lst
self.assertTransformedEquivalent(test_fn, True)
self.assertTransformedEquivalent(test_fn, False)
self.assertTransformedEquivalent(f, True)
self.assertTransformedEquivalent(f, False)
if __name__ == '__main__':
test.main()

View File

@ -27,169 +27,193 @@ from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
class MockConvertedCall(object):
def __init__(self):
self.calls = []
def __call__(self, f, args, kwargs, caller_fn_scope=None, options=None):
del caller_fn_scope, options
self.calls.append((args, kwargs))
kwargs = kwargs or {}
return f(*args, **kwargs)
class CallTreesTest(converter_testing.TestCase):
def _transform_with_mock(self, f):
mock = MockConvertedCall()
tr = self.transform(
f, (functions, call_trees),
ag_overrides={'converted_call': mock})
return tr, mock
def test_function_no_args(self):
def test_fn(f):
def f(f):
return f() + 20
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda: 1), 21)
self.assertListEqual(self.dynamic_calls, [((), None)])
tr, mock = self._transform_with_mock(f)
self.assertEqual(tr(lambda: 1), 21)
self.assertListEqual(mock.calls, [((), None)])
def test_function_with_expression_in_argument(self):
def test_fn(f, g):
def f(f, g):
return f(g() + 20) + 4000
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
self.assertListEqual(self.dynamic_calls, [
((), None),
((21,), None),
])
tr, mock = self._transform_with_mock(f)
self.assertEqual(tr(lambda x: x + 300, lambda: 1), 4321)
self.assertListEqual(mock.calls, [
((), None),
((21,), None),
])
def test_function_with_call_in_argument(self):
def test_fn(f, g):
def f(f, g):
return f(g()) + 300
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
self.assertListEqual(self.dynamic_calls, [
((), None),
((1,), None),
])
tr, mock = self._transform_with_mock(f)
self.assertEqual(tr(lambda x: x + 20, lambda: 1), 321)
self.assertListEqual(mock.calls, [
((), None),
((1,), None),
])
def test_function_chaining(self):
def get_one():
return 1
def test_fn():
def f():
return get_one().__add__(20)
with self.converted(test_fn, (functions, call_trees),
{'get_one': get_one}, ()) as result:
tr, mock = self._transform_with_mock(f)
self.assertEqual(result.test_fn(), 21)
self.assertListEqual(self.dynamic_calls, [
((), None),
((20,), None),
])
self.assertEqual(tr(), 21)
self.assertListEqual(mock.calls, [
((), None),
((20,), None),
])
def test_function_with_single_arg(self):
def test_fn(f, a):
def f(f, a):
return f(a) + 20
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda a: a, 1), 21)
self.assertListEqual(self.dynamic_calls, [((1,), None)])
tr, mock = self._transform_with_mock(f)
self.assertEqual(tr(lambda a: a, 1), 21)
self.assertListEqual(mock.calls, [((1,), None)])
def test_function_with_args_only(self):
def test_fn(f, a, b):
def f(f, a, b):
return f(a, b) + 300
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda a, b: a + b, 1, 20), 321)
self.assertListEqual(self.dynamic_calls, [((1, 20), None)])
tr, mock = self._transform_with_mock(f)
self.assertEqual(tr(lambda a, b: a + b, 1, 20), 321)
self.assertListEqual(mock.calls, [((1, 20), None)])
def test_function_with_kwarg(self):
def test_fn(f, a, b):
def f(f, a, b):
return f(a, c=b) + 300
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})])
tr, mock = self._transform_with_mock(f)
self.assertEqual(tr(lambda a, c: a + c, 1, 20), 321)
self.assertListEqual(mock.calls, [((1,), {'c': 20})])
def test_function_with_kwargs_starargs(self):
def test_fn(f, a, *args, **kwargs):
def f(f, a, *args, **kwargs):
return f(a, *args, **kwargs) + 5
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(
result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{
'b': 4,
'c': 5
}), 12)
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
tr, mock = self._transform_with_mock(f)
self.assertEqual(
tr(lambda *args, **kwargs: 7, 1, *[2, 3], **{
'b': 4,
'c': 5
}), 12)
self.assertListEqual(mock.calls, [((1, 2, 3), {'b': 4, 'c': 5})])
def test_function_with_starargs_only(self):
def f(*args):
def g(*args):
return sum(args)
def test_fn():
def f():
args = [1, 20, 300]
return f(*args) + 4000
return g(*args) + 4000
with self.converted(test_fn, (functions, call_trees),
{'f': f}) as result:
self.assertEqual(result.test_fn(), 4321)
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
tr, mock = self._transform_with_mock(f)
# TODO(b/142586827): Enable this test.
# def test_function_with_starargs_mixed(self):
#
# def f(a, b, c, d):
# return a * 1000 + b * 100 + c * 10 + d
#
# def test_fn():
# args1 = (1,)
# args2 = [3]
# return f(*args1, 2, *args2, 4)
#
# with self.converted(test_fn, (functions, call_trees),
# {'f': f}) as result:
# self.assertEqual(result.test_fn(), 1234)
# self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)])
self.assertEqual(tr(), 4321)
self.assertListEqual(mock.calls, [((1, 20, 300), None)])
def test_function_with_starargs_mixed(self):
def g(a, b, c, d):
return a * 1000 + b * 100 + c * 10 + d
def f():
args1 = (1,)
args2 = [3]
return g(*args1, 2, *args2, 4)
tr, mock = self._transform_with_mock(f)
self.assertEqual(tr(), 1234)
self.assertListEqual(mock.calls, [((1, 2, 3, 4), None)])
def test_function_with_kwargs_keywords(self):
def test_fn(f, a, b, **kwargs):
def f(f, a, b, **kwargs):
return f(a, b=b, **kwargs) + 5
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
tr, mock = self._transform_with_mock(f)
# TODO(b/142586827): Enable this test.
# def test_function_with_multiple_kwargs(self):
#
# def test_fn(f, a, b, c, kwargs1, kwargs2):
# return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
#
# with self.converted(test_fn, (functions, call_trees), {}) as result:
# self.assertEqual(
# result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4},
# {'e': 5}), 12)
# self.assertListEqual(self.dynamic_calls, [((1,), {
# 'b': 2,
# 'c': 3,
# 'd': 4,
# 'e': 5
# })])
self.assertEqual(
tr(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
self.assertListEqual(mock.calls, [((1,), {'b': 2, 'c': 3})])
def test_function_with_multiple_kwargs(self):
def f(f, a, b, c, kwargs1, kwargs2):
return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
tr, mock = self._transform_with_mock(f)
self.assertEqual(
tr(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4}, {'e': 5}), 12)
self.assertListEqual(mock.calls, [((1,), {
'b': 2,
'c': 3,
'd': 4,
'e': 5
})])
def test_function_with_call_in_lambda_argument(self):
def f(l, a):
def h(l, a):
return l(a) + 4000
def g(a, *args):
return a + sum(args)
def test_fn(f, g, a, *args):
return f(lambda x: g(x, *args), a)
def f(h, g, a, *args):
return h(lambda x: g(x, *args), a)
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(f, g, 1, *(20, 300)), 4321)
tr, _ = self._transform_with_mock(f)
self.assertEqual(tr(h, g, 1, *(20, 300)), 4321)
def test_debugger_set_trace(self):
@ -198,13 +222,13 @@ class CallTreesTest(converter_testing.TestCase):
pdb = imp.new_module('fake_pdb')
pdb.set_trace = lambda: tracking_list.append(1)
def test_fn():
def f():
return pdb.set_trace()
with self.converted(test_fn, (functions, call_trees),
{'pdb': pdb}) as result:
result.test_fn()
self.assertListEqual(tracking_list, [1])
tr, _ = self._transform_with_mock(f)
tr()
self.assertListEqual(tracking_list, [1])
def test_class_method(self):
@ -217,10 +241,10 @@ class CallTreesTest(converter_testing.TestCase):
return self.other_method(a) + 300
tc = TestClass()
with self.converted(TestClass.test_method, (functions, call_trees),
{}) as result:
self.assertEqual(321, result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)])
tr, mock = self._transform_with_mock(TestClass.test_method)
self.assertEqual(321, tr(tc, 1))
self.assertListEqual(mock.calls, [((1,), None)])
def test_object_method(self):
@ -233,10 +257,10 @@ class CallTreesTest(converter_testing.TestCase):
return self.other_method(a) + 300
tc = TestClass()
with self.converted(tc.test_method, (functions, call_trees),
{}) as result:
self.assertEqual(321, result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)])
tr, mock = self._transform_with_mock(tc.test_method)
self.assertEqual(321, tr(tc, 1))
self.assertListEqual(mock.calls, [((1,), None)])
if __name__ == '__main__':

View File

@ -25,28 +25,27 @@ from tensorflow.python.platform import test
class ConditionalExpressionsTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
ns = {}
with self.converted(test_fn, conditional_expressions, ns) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def assertTransformedEquivalent(self, f, *inputs):
tr = self.transform(f, conditional_expressions)
self.assertEqual(f(*inputs), tr(*inputs))
def test_basic(self):
def test_fn(x):
def f(x):
return 1 if x else 0
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 3)
def test_nested_orelse(self):
def test_fn(x):
def f(x):
y = x * x if x > 0 else x if x else 1
return y
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(f, -2)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 2)
if __name__ == '__main__':

View File

@ -20,21 +20,19 @@ from __future__ import print_function
from tensorflow.python.autograph.converters import continue_statements
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
class ContinueCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn, continue_statements, {'ops': ops},
(constant_op.constant,)) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def assertTransformedEquivalent(self, f, *inputs):
tr = self.transform(f, continue_statements)
self.assertEqual(f(*inputs), tr(*inputs))
def test_basic(self):
def test_fn(x):
def f(x):
v = []
while x > 0:
x -= 1
@ -43,14 +41,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 4)
def test_multiple_continues(self):
def test_fn(x):
def f(x):
v = []
while x > 0:
x -= 1
@ -61,14 +59,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 4)
def test_multiple_continues_in_nested_scope(self):
def test_fn(a):
def f(a):
v = []
for x in a:
x -= 1
@ -81,14 +79,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [1])
self.assertTransformedEquivalent(test_fn, [2])
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(f, [1])
self.assertTransformedEquivalent(f, [2])
self.assertTransformedEquivalent(f, [1, 2, 3])
def test_for_loop(self):
def test_fn(a):
def f(a):
v = []
for x in a:
x -= 1
@ -97,14 +95,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [1])
self.assertTransformedEquivalent(test_fn, [2])
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(f, [1])
self.assertTransformedEquivalent(f, [2])
self.assertTransformedEquivalent(f, [1, 2, 3])
def test_nested_with(self):
def test_fn(x):
def f(x):
v = []
while x > 0:
x -= 1
@ -114,14 +112,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 4)
def test_nested_multiple_withs(self):
def test_fn(x):
def f(x):
v = []
while x > 0:
x -= 1
@ -133,14 +131,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 4)
def test_nested_multiple_withs_and_statements(self):
def test_fn(x):
def f(x):
v = []
while x > 0:
x -= 1
@ -154,14 +152,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 4)
def test_nested_multiple_withs_and_nested_withs(self):
def test_fn(x):
def f(x):
v = []
while x > 0:
x -= 1
@ -176,14 +174,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 4)
def test_nested(self):
def test_fn(x):
def f(x):
v = []
u = []
w = []
@ -198,14 +196,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return v, u, w
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 4)
def test_multiple_guarded_continues_with_side_effects(self):
def test_fn(x):
def f(x):
def track(u, x):
u.append(x)
return x
@ -221,8 +219,8 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x)
return u, v
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 2)
if __name__ == '__main__':

View File

@ -23,6 +23,8 @@ import collections
import numpy as np
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.converters import continue_statements
from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.eager import def_function
@ -34,7 +36,8 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import test
from tensorflow.python.util import nest
# TODO(mdan): These tests are not isolated - they also test the operators.
for_unaffected_global = None
class ControlFlowTestBase(converter_testing.TestCase):
@ -45,22 +48,19 @@ class ControlFlowTestBase(converter_testing.TestCase):
actual)
self.assertAllEqual(values, expected)
def assertTransformedResult(self, test_fn, inputs, expected, symbols=None):
def assertTransformedResult(self, f, inputs, expected):
if not isinstance(inputs, tuple):
inputs = (inputs,)
if not symbols:
symbols = {}
with self.converted(test_fn, control_flow, symbols,
(constant_op.constant,)) as result:
returns = result.test_fn(*inputs)
self.assertValuesEqual(returns, expected)
tr = self.transform(f, control_flow)
returns = tr(*inputs)
self.assertValuesEqual(returns, expected)
class NestedControlFlowTest(ControlFlowTestBase):
def test_basic(self):
def test_fn(n):
def f(n):
i = 0
j = 0
s = 0
@ -73,7 +73,7 @@ class NestedControlFlowTest(ControlFlowTestBase):
j = 0
return s, i, j, n
self.assertTransformedResult(test_fn, constant_op.constant(5),
self.assertTransformedResult(f, constant_op.constant(5),
(25, 5, 0, 5))
def test_composite_state_complex(self):
@ -88,7 +88,7 @@ class NestedControlFlowTest(ControlFlowTestBase):
def __init__(self, y):
self.y = y
def test_fn(n):
def f(n):
tc = TestClassX(TestClassY({'z': TestClassX(n)}))
if n > 0:
while n > 0:
@ -97,19 +97,17 @@ class NestedControlFlowTest(ControlFlowTestBase):
n -= 1
return n, tc
with self.converted(test_fn, control_flow, {
'TestClassX': TestClassX,
'TestClassY': TestClassY,
}) as result:
n, tc = result.test_fn(constant_op.constant(5))
self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6))
tr = self.transform(f, control_flow)
n, tc = tr(constant_op.constant(5))
self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6))
class WhileStatementTest(ControlFlowTestBase):
def test_basic(self):
def test_fn(n):
def f(n):
i = 0
s = 0
while i < n:
@ -117,16 +115,16 @@ class WhileStatementTest(ControlFlowTestBase):
i += 1
return s, i, n
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
self.assertTransformedResult(f, constant_op.constant(5), (10, 5, 5))
def test_single_output(self):
def test_fn(n):
def f(n):
while n > 0:
n -= 1
return n
self.assertTransformedResult(test_fn, constant_op.constant(5), 0)
self.assertTransformedResult(f, constant_op.constant(5), 0)
def test_composite_state_attr(self):
@ -135,19 +133,18 @@ class WhileStatementTest(ControlFlowTestBase):
def __init__(self):
self.x = constant_op.constant(3)
def test_fn(n):
def f(n):
tc = TestClass()
while n > 0:
tc.x += 1
n -= 1
return n
self.assertTransformedResult(
test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass})
self.assertTransformedResult(f, constant_op.constant(5), 0)
def test_composite_state_slice(self):
def test_fn(n):
def f(n):
d = {'a': n}
k = 'a'
while n > 0:
@ -155,25 +152,25 @@ class WhileStatementTest(ControlFlowTestBase):
n -= 1
return d[k], n
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 0))
self.assertTransformedResult(f, constant_op.constant(5), (10, 0))
def test_composite_state_literal_slice(self):
def test_fn(n):
def f(n):
d = {'a': n}
while n > 0:
d['a'] += 1
n -= 1
return d['a'], n
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 0))
self.assertTransformedResult(f, constant_op.constant(5), (10, 0))
def test_composite_state_attr_initialized_in_loop(self):
class TestClass(object):
pass
def test_fn(n, x):
def f(n, x):
tc = TestClass()
while n < 5:
if n == 0:
@ -183,19 +180,15 @@ class WhileStatementTest(ControlFlowTestBase):
n += 1
return tc.subattr
self.assertTransformedResult(
test_fn, (0, constant_op.constant(10)),
14,
symbols={'TestClass': TestClass})
with self.converted(
test_fn, control_flow, {'TestClass': TestClass}) as result:
with self.assertRaisesRegex(
ValueError, "'tc.subattr' must be defined before the loop"):
result.test_fn(constant_op.constant(0), 0)
self.assertTransformedResult(f, (0, constant_op.constant(10)), 14)
tr = self.transform(f, control_flow)
with self.assertRaisesRegex(
ValueError, "'tc.subattr' must be defined before the loop"):
tr(constant_op.constant(0), 0)
def test_composite_state_slice_initialized_in_loop(self):
def test_fn(n, x):
def f(n, x):
d = {}
k = 'subkey'
while n < 5:
@ -206,16 +199,16 @@ class WhileStatementTest(ControlFlowTestBase):
n += 1
return d
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
self.assertTransformedResult(f, (0, constant_op.constant(10)),
{'subkey': 14})
with self.converted(test_fn, control_flow, {}) as result:
with self.assertRaisesRegex(
ValueError, r"'d\[k\]' must be defined before the loop"):
result.test_fn(constant_op.constant(0), 0)
tr = self.transform(f, control_flow)
with self.assertRaisesRegex(
ValueError, r"'d\[k\]' must be defined before the loop"):
tr(constant_op.constant(0), 0)
def test_composite_state_literal_slice_initialized_in_loop(self):
def test_fn(n, x):
def f(n, x):
d = {}
while n < 5:
if n == 0:
@ -225,16 +218,16 @@ class WhileStatementTest(ControlFlowTestBase):
n += 1
return d
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
self.assertTransformedResult(f, (0, constant_op.constant(10)),
{'subkey': 14})
with self.converted(test_fn, control_flow, {}) as result:
with self.assertRaisesRegex(
ValueError, r"'d\['subkey'\]' must be defined before the loop"):
result.test_fn(constant_op.constant(0), 0)
tr = self.transform(f, control_flow)
with self.assertRaisesRegex(
ValueError, r"'d\['subkey'\]' must be defined before the loop"):
tr(constant_op.constant(0), 0)
def test_composite_state_slice_aliased_to_local(self):
def test_fn(n, x):
def f(n, x):
d = {}
while n < 5:
k = 'subkey'
@ -242,15 +235,15 @@ class WhileStatementTest(ControlFlowTestBase):
n += 1
return d
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
self.assertTransformedResult(f, (0, constant_op.constant(10)),
{'subkey': 11})
with self.converted(test_fn, control_flow, {}) as result:
# TODO(b/136999953): Better error message.
# Note that this error happens at execution time.
with self.assertRaises(errors.InaccessibleTensorError):
graph_fn = def_function.function(result.test_fn, autograph=False)
self.evaluate(
graph_fn(constant_op.constant(0), constant_op.constant(5)))
tr = self.transform(f, control_flow)
# TODO(b/136999953): Better error message.
# Note that this error happens at execution time.
with self.assertRaises(errors.InaccessibleTensorError):
graph_fn = def_function.function(tr, autograph=False)
self.evaluate(
graph_fn(constant_op.constant(0), constant_op.constant(5)))
def test_local_composite_attr(self):
@ -259,19 +252,18 @@ class WhileStatementTest(ControlFlowTestBase):
def __init__(self):
self.x = constant_op.constant(3)
def test_fn(n):
def f(n):
while n > 0:
tc = TestClass()
tc.x = tc.x
n -= 1
return n
self.assertTransformedResult(
test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass})
self.assertTransformedResult(f, constant_op.constant(5), 0)
def test_local_composite_slice(self):
def test_fn(n):
def f(n):
while n > 0:
d = {'x': n}
k = 'x'
@ -279,26 +271,26 @@ class WhileStatementTest(ControlFlowTestBase):
n -= 1
return n
self.assertTransformedResult(test_fn, constant_op.constant(5), 0, {})
self.assertTransformedResult(f, constant_op.constant(5), 0)
def test_local_composite_literal_slice(self):
def test_fn(n):
def f(n):
while n > 0:
d = {'x': n}
d['x'] = d['x']
n -= 1
return n
self.assertTransformedResult(test_fn, constant_op.constant(5), 0, {})
self.assertTransformedResult(f, constant_op.constant(5), 0)
def test_non_tensor_state(self):
# This class is ok to be in a tf.while_loop's state.
# This class is ok to be in a tf.while's state.
class TestClass(collections.namedtuple('TestClass', ('x'))):
pass
def test_fn(n):
def f(n):
tc = TestClass([constant_op.constant(0)])
while n > 0:
tc = TestClass([constant_op.constant(3)])
@ -306,9 +298,7 @@ class WhileStatementTest(ControlFlowTestBase):
n -= 1
return tc.x[0]
ns = {'TestClass': TestClass, 'constant_op': constant_op}
self.assertTransformedResult(
test_fn, constant_op.constant(5), 4, symbols=ns)
self.assertTransformedResult(f, constant_op.constant(5), 4)
def test_non_tensor_state_illegal_type(self):
@ -317,20 +307,20 @@ class WhileStatementTest(ControlFlowTestBase):
def __init__(self):
self.x = [constant_op.constant(3)]
def test_fn(n):
def f(n):
while n > 0:
tc = TestClass()
tc.x[0] = tc.x[0] + 1
n -= 1
return tc.x[0]
with self.converted(
test_fn, control_flow, {'TestClass': TestClass}) as result:
# The tested function would require `tc` to become part of the while loop
# state, but TensorFlow doesn't support classes at the moment.
with self.assertRaisesRegexp(
ValueError, 'tc.*must be defined before the loop'):
result.test_fn(constant_op.constant(5))
tr = self.transform(f, control_flow)
# The tested function would require `tc` to become part of the while loop
# state, but TensorFlow doesn't support classes at the moment.
with self.assertRaisesRegex(
ValueError, 'tc.*must be defined before the loop'):
tr(constant_op.constant(5))
def test_dispatches_by_cond_only(self):
@ -343,27 +333,27 @@ class WhileStatementTest(ControlFlowTestBase):
def __add__(self, other):
return TensorIncompatibleNumeric(self.val + other)
def test_fn(n, s):
def f(n, s):
while n > 0:
n -= 1
s += n
return s
self.assertTransformedResult(test_fn, (constant_op.constant(5), 0), 10)
with self.converted(test_fn, control_flow, {}) as result:
# n alone controls the staging. When the loop is not staged, Python
# knows how to add the two objects. But when staged, tf.while_loop will
# not know how to deal with the TensorIncompatibleNumeric object.
self.assertEqual(result.test_fn(5, TensorIncompatibleNumeric(0)).val, 10)
with self.assertRaises(TypeError):
result.test_fn(constant_op.constant(5), TensorIncompatibleNumeric(0))
self.assertTransformedResult(f, (constant_op.constant(5), 0), 10)
tr = self.transform(f, control_flow)
# n alone controls the staging. When the loop is not staged, Python
# knows how to add the two objects. But when staged, tf.while will
# not know how to deal with the TensorIncompatibleNumeric object.
self.assertEqual(tr(5, TensorIncompatibleNumeric(0)).val, 10)
with self.assertRaises(TypeError):
tr(constant_op.constant(5), TensorIncompatibleNumeric(0))
class IfStatementTest(ControlFlowTestBase):
def test_basic(self):
def test_fn(n):
def f(n):
a = 0
b = 0
if n > 0:
@ -372,20 +362,20 @@ class IfStatementTest(ControlFlowTestBase):
b = 2 * n
return a, b
self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0))
self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2))
self.assertTransformedResult(f, constant_op.constant(1), (-1, 0))
self.assertTransformedResult(f, constant_op.constant(-1), (0, -2))
def test_sparse_tensor(self):
def test_fn(cond, a):
def f(cond, a):
if cond:
a = -a
return a
st = sparse_tensor.SparseTensor(
indices=((0,),), values=(0,), dense_shape=(1,))
self.assertTransformedResult(test_fn, (st, constant_op.constant(1)), -1)
self.assertTransformedResult(test_fn, (None, constant_op.constant(1)), 1)
self.assertTransformedResult(f, (st, constant_op.constant(1)), -1)
self.assertTransformedResult(f, (None, constant_op.constant(1)), 1)
def test_complex_outputs(self):
@ -395,7 +385,7 @@ class IfStatementTest(ControlFlowTestBase):
self.a = a
self.b = b
def test_fn(n, obj):
def f(n, obj):
obj.a = 0
obj.b = 0
if n > 0:
@ -404,94 +394,94 @@ class IfStatementTest(ControlFlowTestBase):
obj.b = 2 * n
return obj
with self.converted(test_fn, control_flow, {}) as result:
res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0))
self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0))
res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0))
self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2))
tr = self.transform(f, control_flow)
res_obj = tr(constant_op.constant(1), TestClass(0, 0))
self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0))
res_obj = tr(constant_op.constant(-1), TestClass(0, 0))
self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2))
def test_single_output(self):
def test_fn(n):
def f(n):
if n > 0:
n = -n
return n
self.assertTransformedResult(test_fn, constant_op.constant(1), -1)
self.assertTransformedResult(f, constant_op.constant(1), -1)
def test_unbalanced(self):
def test_fn(n):
def f(n):
if n > 0:
n = 3
return n
self.assertTransformedResult(test_fn, constant_op.constant(2), 3)
self.assertTransformedResult(test_fn, constant_op.constant(-3), -3)
self.assertTransformedResult(f, constant_op.constant(2), 3)
self.assertTransformedResult(f, constant_op.constant(-3), -3)
def test_unbalanced_raising(self):
def test_fn(n):
def f(n):
if n > 0:
n = n + 1
raise ValueError()
return n
self.assertTransformedResult(test_fn, -3, -3)
self.assertTransformedResult(f, -3, -3)
with self.converted(test_fn, control_flow, {}) as result:
with self.assertRaises(ValueError):
result.test_fn(1)
tr = self.transform(f, control_flow)
with self.assertRaises(ValueError):
tr(1)
def test_local_var(self):
def test_fn(n):
def f(n):
if n > 0:
b = 4
n = b + 1
return n
self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
self.assertTransformedResult(f, constant_op.constant(1), 5)
self.assertTransformedResult(f, constant_op.constant(-1), -1)
def test_local_remains_local(self):
def test_fn(n):
def f(n):
if n > 0:
b = 4
n = b + 1
return n
self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
self.assertTransformedResult(f, constant_op.constant(1), 5)
self.assertTransformedResult(f, constant_op.constant(-1), -1)
def test_no_outputs(self):
def test_fn(n):
def f(n):
if n > 0:
b = 4 # pylint:disable=unused-variable
return n
# Without side effect guards, the if statement will stage a cond,
# but that will be pruned at execution.
self.assertTransformedResult(test_fn, constant_op.constant(1), 1)
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
self.assertTransformedResult(f, constant_op.constant(1), 1)
self.assertTransformedResult(f, constant_op.constant(-1), -1)
def test_created_outputs(self):
def test_fn(i):
def f(i):
if i == 0:
result = i - 1
else:
result = i + 1
return result
self.assertTransformedResult(test_fn, 0, -1)
self.assertTransformedResult(test_fn, 1, 2)
self.assertTransformedResult(f, 0, -1)
self.assertTransformedResult(f, 1, 2)
def test_created_loop_local_outputs(self):
def test_fn(n, x):
def f(n, x):
for i in n:
if i == 0:
result = i - 1
@ -501,11 +491,11 @@ class IfStatementTest(ControlFlowTestBase):
x += 1
return x
self.assertTransformedResult(test_fn, (range(5), 10), 14)
self.assertTransformedResult(f, (range(5), 10), 14)
def test_created_loop_variable(self):
def test_fn(n, x):
def f(n, x):
for i in n:
if i == 0:
result = i - 1
@ -514,22 +504,26 @@ class IfStatementTest(ControlFlowTestBase):
x += 1
return x
self.assertTransformedResult(test_fn, (range(5), 10), 14)
self.assertTransformedResult(f, (range(5), 10), 14)
def test_unaffected_global(self):
def test_fn(i):
global g # pylint:disable=global-variable-undefined
if i == 0:
g = i - 1
return g
global for_unaffected_global
for_unaffected_global = 3
self.assertTransformedResult(test_fn, 1, 3, symbols={'g': 3})
self.assertTransformedResult(test_fn, 0, -1, symbols={'g': 3})
def f(i):
global for_unaffected_global
if i == 0:
for_unaffected_global = i - 1
return for_unaffected_global
self.assertTransformedResult(f, 1, 3)
self.assertTransformedResult(f, 0, -1)
self.assertEqual(for_unaffected_global, -1)
def test_unaffected_nonlocal(self):
def test_fn(i):
def f(i):
def inner_fn():
nonlocal n
if i == 0:
@ -539,12 +533,12 @@ class IfStatementTest(ControlFlowTestBase):
inner_fn()
return n
self.assertTransformedResult(test_fn, 1, 3)
self.assertTransformedResult(test_fn, 0, -1)
self.assertTransformedResult(f, 1, 3)
self.assertTransformedResult(f, 0, -1)
def test_output_defined_in_prior_except(self):
def test_fn(i):
def f(i):
try:
raise ValueError()
except ValueError:
@ -553,8 +547,8 @@ class IfStatementTest(ControlFlowTestBase):
x = i - 1
return x
self.assertTransformedResult(test_fn, 1, 1)
self.assertTransformedResult(test_fn, 0, -1)
self.assertTransformedResult(f, 1, 1)
self.assertTransformedResult(f, 0, -1)
def test_unbalanced_multiple_composites(self):
@ -564,7 +558,7 @@ class IfStatementTest(ControlFlowTestBase):
self.b = 2
self.c = 3
def test_fn(x, condition):
def f(x, condition):
z = 5
if condition:
@ -574,9 +568,9 @@ class IfStatementTest(ControlFlowTestBase):
return x.b, x.c, z
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)),
self.assertTransformedResult(f, (Foo(), constant_op.constant(True)),
(7, 11, 13))
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)),
self.assertTransformedResult(f, (Foo(), constant_op.constant(False)),
(2, 3, 5))
def test_unbalanced_composite(self):
@ -586,7 +580,7 @@ class IfStatementTest(ControlFlowTestBase):
def __init__(self):
self.b = 2
def test_fn(x, condition):
def f(x, condition):
z = 5
if condition:
@ -595,9 +589,9 @@ class IfStatementTest(ControlFlowTestBase):
return x.b, z
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)),
self.assertTransformedResult(f, (Foo(), constant_op.constant(True)),
(7, 13))
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)),
self.assertTransformedResult(f, (Foo(), constant_op.constant(False)),
(2, 5))
@ -605,7 +599,7 @@ class ForStatementTest(ControlFlowTestBase):
def test_basic(self):
def test_fn(l):
def f(l):
s1 = 0
s2 = 0
for e in l:
@ -613,21 +607,21 @@ class ForStatementTest(ControlFlowTestBase):
s2 += e * e
return s1, s2
self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10))
self.assertTransformedResult(f, constant_op.constant([1, 3]), (4, 10))
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
self.assertTransformedResult(test_fn, empty_vector, (0, 0))
self.assertTransformedResult(f, empty_vector, (0, 0))
def test_single_output(self):
def test_fn(l):
def f(l):
s = 0
for e in l:
s += e
return s
self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4)
self.assertTransformedResult(f, constant_op.constant([1, 3]), 4)
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
self.assertTransformedResult(test_fn, empty_vector, 0)
self.assertTransformedResult(f, empty_vector, 0)
def test_iterated_expression(self):
@ -637,26 +631,23 @@ class ForStatementTest(ControlFlowTestBase):
eval_count[0] += 1
return x
def test_fn(n):
def f(n):
s = 0
for e in count_evals(range(n)):
s += e
return s
ns = {'count_evals': count_evals}
node, ctx = self.prepare(test_fn, ns)
node = control_flow.transform(node, ctx)
tr = self.transform(f, control_flow)
with self.compiled(node, ns) as result:
self.assertEqual(result.test_fn(5), 10)
self.assertEqual(eval_count[0], 1)
self.assertEqual(tr(5), 10)
self.assertEqual(eval_count[0], 1)
def test_composite_state_initialized_in_loop(self):
class TestClass(object):
pass
def test_fn(n, x):
def f(n, x):
tc = TestClass()
for i in n:
if i == 0:
@ -665,37 +656,97 @@ class ForStatementTest(ControlFlowTestBase):
tc.x = tc.x + i
return tc.x
self.assertTransformedResult(
test_fn, (range(5), constant_op.constant(10)),
20,
symbols={'TestClass': TestClass})
with self.converted(
test_fn, control_flow, {'TestClass': TestClass}) as result:
with self.assertRaisesRegex(
ValueError, "'tc.x' must be defined before the loop"):
result.test_fn(constant_op.constant(list(range(5))), 0)
self.assertTransformedResult(f, (range(5), constant_op.constant(10)), 20)
tr = self.transform(f, control_flow)
with self.assertRaisesRegex(
ValueError, "'tc.x' must be defined before the loop"):
tr(constant_op.constant(list(range(5))), 0)
def test_tuple_unpacking(self):
def test_fn(x_list):
z = tf.constant(0) # pylint:disable=undefined-variable
def f(x_list):
z = constant_op.constant(0) # pylint:disable=undefined-variable
for i, x in enumerate(x_list):
z = z + x + i
return z
self.assertTransformedResult(test_fn, [3, 3], 7)
self.assertTransformedResult(f, [3, 3], 7)
def test_with_comprehension_in_body(self):
def test_fn(l, n):
def f(l, n):
s = constant_op.constant(list(range(n)))
for _ in l:
s += constant_op.constant([a for a in range(n)])
return s
self.assertTransformedResult(
test_fn, (constant_op.constant([1, 2, 3]), 5),
np.array(range(5)) * 4,
symbols={'constant_op': constant_op})
self.assertTransformedResult(f, (constant_op.constant([1, 2, 3]), 5),
np.array(range(5)) * 4)
class AdvancedControlFlowTest(ControlFlowTestBase):
def assertTransformedEquivalent(self, f, *inputs):
tr = self.transform(
f, (break_statements, continue_statements, control_flow))
self.assertEqual(f(*inputs), tr(*inputs))
def test_while_with_else(self):
def f(x):
while x > 2:
x /= 2
else:
x += 1
return x
self.assertTransformedEquivalent(f, 4)
self.assertTransformedEquivalent(f, 2)
def test_while_with_else_and_break(self):
def f(cond1):
x = 8
while x > 2:
x /= 2
if cond1:
break
else:
x += 1
return x
self.assertTransformedEquivalent(f, True)
self.assertTransformedEquivalent(f, False)
def test_for_with_else(self):
def f(l):
res = 0
for x in l:
res += x
else:
res += 1
return res
self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(f, [1, 2])
def test_for_with_else_and_break(self):
def f(flag):
l = [1, 2, 3]
res = 0
for x in l:
res += x
if flag:
break
else:
res += 1
return res
self.assertTransformedEquivalent(f, True)
self.assertTransformedEquivalent(f, False)
if __name__ == '__main__':

View File

@ -22,7 +22,6 @@ from tensorflow.python.autograph.converters import directives as directives_conv
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
@ -30,13 +29,12 @@ class DirectivesTest(converter_testing.TestCase):
def test_local_target(self):
def test_fn():
def f():
l = []
string_var = 0
directives.set_element_type(l, 'a', string_var)
node, ctx = self.prepare(test_fn, {'directives': directives})
node = directives_converter.transform(node, ctx)
_, node, _ = self.transform(f, directives_converter, include_ast=True)
def_, = anno.getanno(node.body[0].targets[0],
anno.Static.DEFINITIONS)
@ -46,11 +44,11 @@ class DirectivesTest(converter_testing.TestCase):
def test_argument_target(self):
def test_fn(a):
def f(a):
directives.set_element_type(a, 1, shape=2)
pass
node, ctx = self.prepare(test_fn, {'directives': directives})
node = directives_converter.transform(node, ctx)
_, node, _ = self.transform(f, directives_converter, include_ast=True)
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
d = def_.directives[directives.set_element_type]
@ -59,13 +57,13 @@ class DirectivesTest(converter_testing.TestCase):
def test_loop_target(self):
def test_fn():
def f():
a = True
while True:
directives.set_loop_options(parallel_iterations=10, back_prop=a)
pass
node, ctx = self.prepare(test_fn, {'directives': directives})
node = directives_converter.transform(node, ctx)
_, node, _ = self.transform(f, directives_converter, include_ast=True)
d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES)
d = d[directives.set_loop_options]
@ -75,40 +73,23 @@ class DirectivesTest(converter_testing.TestCase):
def test_loop_target_no_loop(self):
def test_fn():
def f():
directives.set_loop_options()
pass
node, ctx = self.prepare(test_fn, {'directives': directives})
with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'):
node = directives_converter.transform(node, ctx)
self.transform(f, directives_converter, include_ast=True)
def test_loop_target_not_first(self):
def test_fn():
def f():
a = 1
while True:
a = 2
directives.set_loop_options(parallel_iterations=10, back_prop=a)
node, ctx = self.prepare(test_fn, {'directives': directives})
with self.assertRaisesRegexp(ValueError, 'must be the first statement'):
node = directives_converter.transform(node, ctx)
def test_invalid_default(self):
def invalid_directive(valid_arg, invalid_default=object()):
del valid_arg
del invalid_default
return
def call_invalid_directive():
invalid_directive(1)
node, _ = parser.parse_entity(call_invalid_directive, ())
# Find the call to the invalid directive
node = node.body[0].value
with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
directives_converter._map_args(node, invalid_directive)
self.transform(f, directives_converter, include_ast=True)
def test_value_verification_does_not_trigger_properties(self):
@ -122,11 +103,11 @@ class DirectivesTest(converter_testing.TestCase):
tc = TestClass()
def test_fn():
def f():
return tc.b + 1
node, ctx = self.prepare(test_fn, {'tc': tc})
node = directives_converter.transform(node, ctx)
_, node, _ = self.transform(f, directives_converter, include_ast=True)
self.assertIsNotNone(node)
def test_value_verification_does_not_trigger_getattr(self):
@ -143,11 +124,11 @@ class DirectivesTest(converter_testing.TestCase):
tc = TestClass()
def test_fn():
def f():
return tc.b + 1
node, ctx = self.prepare(test_fn, {'tc': tc})
node = directives_converter.transform(node, ctx)
_, node, _ = self.transform(f, directives_converter, include_ast=True)
self.assertIsNotNone(node)
self.assertFalse(tc.getattr_called)

View File

@ -23,51 +23,49 @@ from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.impl import api
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
class FunctionTransformer(converter_testing.TestCase):
@test_util.run_deprecated_v1
def test_basic(self):
def test_fn(l):
def f(l):
"""Docstring."""
a = 1
l += a
return l
with self.converted(test_fn, functions, {}) as result:
result_op = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', result_op.op.name)
self.assertEqual('Docstring.', result.test_fn.__doc__)
tr = self.transform(f, functions)
result_op = tr(constant_op.constant(1))
self.assertIn('f/', result_op.op.name)
self.assertEqual('Docstring.', tr.__doc__)
@test_util.run_deprecated_v1
def test_multiline_docstring(self):
tf = None
def test_fn():
def f():
"""First sentence.
Second sentence.
Returns:
Something.
"""
return tf.constant(1)
return constant_op.constant(1)
with self.converted(test_fn, functions, {},
(constant_op.constant,)) as result:
result_op = result.test_fn()
self.assertIn('test_fn/', result_op.op.name)
self.assertIn('First sentence.', result.test_fn.__doc__)
self.assertIn('Second sentence.', result.test_fn.__doc__)
tr = self.transform(f, functions)
result_op = tr()
self.assertIn('f/', result_op.op.name)
self.assertIn('First sentence.', tr.__doc__)
self.assertIn('Second sentence.', tr.__doc__)
@test_util.run_deprecated_v1
def test_nested_functions(self):
def test_fn(l):
def f(l):
def inner_fn(i):
return i + 1
@ -75,41 +73,35 @@ class FunctionTransformer(converter_testing.TestCase):
l += 1
return l, inner_fn(l)
with self.converted(test_fn, (functions, return_statements), {},
(ops.name_scope,)) as result:
first, second = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name)
self.assertNotIn('inner_fn', first.op.name)
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
tr = self.transform(f, (functions, return_statements))
first, second = tr(constant_op.constant(1))
self.assertIn('f/', first.op.name)
self.assertNotIn('inner_fn', first.op.name)
self.assertIn('f/inner_fn/', second.op.inputs[0].name)
@test_util.run_deprecated_v1
def test_conversion_context_preserves_in_inner_functions(self):
def inner_fn_callee():
self.assertEqual(
ag_ctx.control_status_ctx().status, ag_ctx.Status.DISABLED)
def test_fn():
def f():
def inner_fn():
inner_fn_callee()
with ag_ctx.ControlStatusCtx(
ag_ctx.Status.DISABLED, converter.ConversionOptions(recursive=True)):
inner_fn()
ns = {
'inner_fn_callee': inner_fn_callee,
'ag_ctx': ag_ctx,
'converter': converter
}
with self.converted(test_fn, functions, ns) as result:
result.test_fn()
tr = self.transform(f, functions)
tr()
@test_util.run_deprecated_v1
def test_method(self):
class TestClass(object):
def test_fn(self, l):
def f(self, l):
def inner_fn(i):
return i + 1
@ -117,25 +109,22 @@ class FunctionTransformer(converter_testing.TestCase):
l += 1
return l, inner_fn(l)
ns = {'TestClass': TestClass}
node, ctx = self.prepare(TestClass, ns)
node = functions.transform(node, ctx)
node = return_statements.transform(node, ctx)
tr = self.transform(TestClass.f, (functions, return_statements))
with self.compiled(node, {}, (ops.name_scope,)) as result:
first, second = result.TestClass().test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name)
self.assertNotIn('inner_fn', first.op.name)
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
first, second = tr(TestClass(), constant_op.constant(1))
self.assertIn('f/', first.op.name)
self.assertNotIn('inner_fn', first.op.name)
self.assertIn('f/inner_fn/', second.op.inputs[0].name)
def test_lambda_in_return_value(self):
def test_fn():
def f():
return lambda x: x + 1
with self.converted(test_fn, functions, {}) as result:
result_l = result.test_fn()
self.assertTrue(result_l.fake_autograph_artifact)
tr = self.transform(f, functions)
result_l = tr()
self.assertTrue(api.is_autograph_artifact(result_l))
if __name__ == '__main__':

View File

@ -25,36 +25,36 @@ from tensorflow.python.platform import test
class ListCompTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn, list_comprehensions, {}) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def assertTransformedEquivalent(self, f, *inputs):
tr = self.transform(f, list_comprehensions)
self.assertEqual(f(*inputs), tr(*inputs))
def test_basic(self):
def test_fn(l):
def f(l):
s = [e * e for e in l]
return s
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(f, [1, 2, 3])
def test_multiple_generators(self):
def test_fn(l):
s = [e * e for sublist in l for e in sublist]
def f(l):
s = [e * e for sublist in l for e in sublist] # pylint:disable=g-complex-comprehension
return s
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [[1], [2], [3]])
self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(f, [[1], [2], [3]])
def test_cond(self):
def test_fn(l):
def f(l):
s = [e * e for e in l if e > 1]
return s
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(f, [1, 2, 3])
if __name__ == '__main__':

View File

@ -18,12 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.converters import directives as directives_converter
from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -31,101 +30,81 @@ from tensorflow.python.ops import list_ops
from tensorflow.python.platform import test
tf = None # Will be replaced by a mock.
class ListTest(converter_testing.TestCase):
def test_empty_list(self):
def test_fn():
def f():
return []
with self.converted(test_fn, lists, {}) as result:
tl = result.test_fn()
# Empty tensor lists cannot be evaluated or stacked.
self.assertTrue(isinstance(tl, ops.Tensor))
self.assertEqual(tl.dtype, dtypes.variant)
tr = self.transform(f, lists)
tl = tr()
# Empty tensor lists cannot be evaluated or stacked.
self.assertIsInstance(tl, ops.Tensor)
self.assertEqual(tl.dtype, dtypes.variant)
def test_initialized_list(self):
def test_fn():
def f():
return [1, 2, 3]
with self.converted(test_fn, lists, {}) as result:
self.assertAllEqual(result.test_fn(), [1, 2, 3])
tr = self.transform(f, lists)
self.assertAllEqual(tr(), [1, 2, 3])
def test_list_append(self):
def test_fn():
def f():
l = special_functions.tensor_list([1])
l.append(2)
l.append(3)
return l
ns = {'special_functions': special_functions}
with self.converted(test_fn, lists, ns) as result:
with self.cached_session() as sess:
tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
tr = self.transform(f, lists)
tl = tr()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
def test_list_pop(self):
def test_fn():
def f():
l = special_functions.tensor_list([1, 2, 3])
directives.set_element_type(l, dtype=dtypes.int32, shape=())
s = l.pop()
return s, l
ns = {'special_functions': special_functions}
node, ctx = self.prepare(test_fn, ns)
def_, = anno.getanno(node.body[0].targets[0],
anno.Static.ORIG_DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32'),
'shape': parser.parse_expression('()'),
}
node = lists.transform(node, ctx)
tr = self.transform(f, (directives_converter, lists))
with self.compiled(node, ns, (dtypes.int32,)) as result:
with self.cached_session() as sess:
ts, tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(self.evaluate(r), [1, 2])
self.assertAllEqual(self.evaluate(ts), 3)
ts, tl = tr()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
self.assertAllEqual(self.evaluate(r), [1, 2])
self.assertAllEqual(self.evaluate(ts), 3)
def test_double_list_pop(self):
def test_fn(l):
def f(l):
s = l.pop().pop()
return s
with self.converted(test_fn, lists, {}) as result:
test_input = [1, 2, [1, 2, 3]]
# TODO(mdan): Pass a list of lists of tensor when we fully support that.
# For now, we just pass a regular Python list of lists just to verify that
# the two pop calls are sequenced properly.
self.assertAllEqual(result.test_fn(test_input), 3)
tr = self.transform(f, lists)
test_input = [1, 2, [1, 2, 3]]
# TODO(mdan): Pass a list of lists of tensor when we fully support that.
# For now, we just pass a regular Python list of lists just to verify that
# the two pop calls are sequenced properly.
self.assertAllEqual(tr(test_input), 3)
def test_list_stack(self):
def test_fn():
def f():
l = [1, 2, 3]
return tf.stack(l)
return array_ops.stack(l)
node, ctx = self.prepare(test_fn, {})
def_, = anno.getanno(node.body[0].targets[0],
anno.Static.ORIG_DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
node = lists.transform(node, ctx)
tr = self.transform(f, lists)
with self.compiled(node, {}, (array_ops.stack, dtypes.int32)) as result:
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])
# TODO(mdan): Add a test with tf.stack with axis kwarg.
self.assertAllEqual(self.evaluate(tr()), [1, 2, 3])
if __name__ == '__main__':

View File

@ -27,62 +27,59 @@ from tensorflow.python.platform import test
class LogicalExpressionTest(converter_testing.TestCase):
@test_util.run_deprecated_v1
def test_equals(self):
def test_fn(a, b):
def f(a, b):
return a == b
with self.converted(test_fn, logical_expressions, {}) as result:
with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(constant_op.constant(1), 1)))
self.assertFalse(sess.run(result.test_fn(constant_op.constant(1), 2)))
tr = self.transform(f, logical_expressions)
self.assertTrue(self.evaluate(tr(constant_op.constant(1), 1)))
self.assertFalse(self.evaluate(tr(constant_op.constant(1), 2)))
@test_util.run_deprecated_v1
def test_bool_ops(self):
def test_fn(a, b, c):
def f(a, b, c):
return (a or b) and (a or b or c) and not c
with self.converted(test_fn, logical_expressions, {}) as result:
with self.cached_session() as sess:
self.assertTrue(
sess.run(result.test_fn(constant_op.constant(True), False, False)))
self.assertFalse(
sess.run(result.test_fn(constant_op.constant(True), False, True)))
tr = self.transform(f, logical_expressions)
self.assertTrue(self.evaluate(tr(constant_op.constant(True), False, False)))
self.assertFalse(self.evaluate(tr(constant_op.constant(True), False, True)))
@test_util.run_deprecated_v1
def test_comparison(self):
def test_fn(a, b, c, d):
def f(a, b, c, d):
return a < b == c > d
with self.converted(test_fn, logical_expressions, {}) as result:
with self.cached_session() as sess:
# Note: having just the first constant a tensor tests that the
# operations execute in the correct order. If anything other than
# a < b executed first, the result would be a Python scalar and not a
# Tensor. This is valid as long as the dispat is automatic based on
# type.
self.assertTrue(
sess.run(result.test_fn(constant_op.constant(1), 2, 2, 1)))
self.assertFalse(
sess.run(result.test_fn(constant_op.constant(1), 2, 2, 3)))
tr = self.transform(f, logical_expressions)
# Note: having just the first constant a tensor tests that the
# operations execute in the correct order. If anything other than
# a < b executed first, the result would be a Python scalar and not a
# Tensor. This is valid as long as the dispat is automatic based on
# type.
self.assertTrue(self.evaluate(tr(constant_op.constant(1), 2, 2, 1)))
self.assertFalse(self.evaluate(tr(constant_op.constant(1), 2, 2, 3)))
def test_default_ops(self):
def test_fn(a, b):
def f(a, b):
return a in b
with self.converted(test_fn, logical_expressions, {}) as result:
self.assertTrue(result.test_fn('a', ('a',)))
tr = self.transform(f, logical_expressions)
self.assertTrue(tr('a', ('a',)))
def test_unary_ops(self):
def test_fn(a):
def f(a):
return ~a, -a, +a
with self.converted(test_fn, logical_expressions, {}) as result:
self.assertEqual(result.test_fn(1), (-2, -1, 1))
tr = self.transform(f, logical_expressions)
self.assertEqual(tr(1), (-2, -1, 1))
if __name__ == '__main__':

View File

@ -1,95 +0,0 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration Tests for loop."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.converters import continue_statements
from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class LoopIntegrationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn,
[break_statements, continue_statements, control_flow],
{}, (constant_op.constant,)) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_while_loop_with_else(self):
def test_fn(x):
while x > 2:
x /= 2
else:
x += 1
return x
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(test_fn, 2)
def test_while_loop_with_else_and_break(self):
def test_fn(cond1):
x = 8
while x > 2:
x /= 2
if cond1:
break
else:
x += 1
return x
self.assertTransformedEquivalent(test_fn, True)
self.assertTransformedEquivalent(test_fn, False)
def test_for_loop_with_else(self):
def test_fn(l):
res = 0
for x in l:
res += x
else:
res += 1
return res
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [1, 2])
def test_for_loop_with_else_and_break(self):
def test_fn(flag):
l = [1, 2, 3]
res = 0
for x in l:
res += x
if flag:
break
else:
res += 1
return res
self.assertTransformedEquivalent(test_fn, True)
self.assertTransformedEquivalent(test_fn, False)
if __name__ == '__main__':
test.main()

View File

@ -27,81 +27,80 @@ from tensorflow.python.platform import test
class SingleReturnTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
ns = {'ops': ops}
with self.converted(test_fn, (functions, return_statements), ns) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def assertTransformedEquivalent(self, f, *inputs):
tr = self.transform(f, (functions, return_statements))
self.assertEqual(f(*inputs), tr(*inputs))
def test_straightline(self):
def test_fn(x):
def f(x):
return x * x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(f, 2)
def test_superfluous_returns(self):
def test_fn():
def f():
retval = 1
return retval
retval = 2 # pylint:disable=unreachable
return retval
self.assertTransformedEquivalent(test_fn)
self.assertTransformedEquivalent(f)
def test_superfluous_returns_adjacent(self):
def test_fn():
def f():
return 1
return 2 # pylint:disable=unreachable
self.assertTransformedEquivalent(test_fn)
self.assertTransformedEquivalent(f)
def test_conditional(self):
def test_fn(x):
def f(x):
if x > 0:
return x
else:
return x * x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
def test_conditional_missing_else(self):
def test_fn(x):
def f(x):
if x > 0:
return x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
def test_conditional_missing_else_then_default(self):
def test_fn(x):
def f(x):
if x > 0:
return x
return x * x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
def test_conditional_else_only_then_default(self):
def test_fn(x):
def f(x):
if x < 0:
x *= x
else:
return x
return x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
def test_conditional_nested(self):
def test_fn(x):
def f(x):
if x > 0:
if x < 5:
return x
@ -110,53 +109,53 @@ class SingleReturnTest(converter_testing.TestCase):
else:
return x * x * x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(test_fn, 5)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
self.assertTransformedEquivalent(f, 5)
def test_context_manager(self):
def test_fn(x):
def f(x):
with ops.name_scope(''):
return x * x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
def test_context_manager_in_conditional(self):
def test_fn(x):
def f(x):
if x > 0:
with ops.name_scope(''):
return x * x
else:
return x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
def text_conditional_in_context_manager(self):
def test_fn(x):
def f(x):
with ops.name_scope(''):
if x > 0:
return x * x
else:
return x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
def test_no_return(self):
def test_fn(x):
def f(x):
x *= x
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(f, 2)
def test_nested_function(self):
def test_fn(x):
def f(x):
def inner_fn(y):
if y > 0:
@ -166,33 +165,33 @@ class SingleReturnTest(converter_testing.TestCase):
return inner_fn(x)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
def test_nested_function_in_control_flow(self):
def test_fn(x):
def f(x):
if x:
def inner_fn(y):
return y
inner_fn(x)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, -2)
def test_for_loop(self):
def test_fn(n):
def f(n):
for _ in range(n):
return 1
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, 0)
def test_while_loop(self):
def test_fn(n):
def f(n):
i = 0
s = 0
while i < n:
@ -202,23 +201,23 @@ class SingleReturnTest(converter_testing.TestCase):
return s
return -1
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(f, 4)
def test_null_return(self):
def test_fn(n):
def f(n):
if n > 4:
return
return
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(test_fn, 5)
self.assertTransformedEquivalent(f, 4)
self.assertTransformedEquivalent(f, 5)
def test_nested_multiple_withs(self):
def test_fn(x):
def f(x):
v = []
while x > 0:
x -= 1
@ -230,14 +229,14 @@ class SingleReturnTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, 0)
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(f, 4)
def test_multiple_returns_in_nested_scope(self):
def test_fn(a):
def f(a):
v = []
for x in a:
x -= 1
@ -250,10 +249,10 @@ class SingleReturnTest(converter_testing.TestCase):
v.append(x)
return v
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [1])
self.assertTransformedEquivalent(test_fn, [2])
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(f, [1])
self.assertTransformedEquivalent(f, [2])
self.assertTransformedEquivalent(f, [1, 2, 3])
if __name__ == '__main__':
test.main()

View File

@ -18,11 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.converters import directives as directives_converter
from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import list_ops
@ -33,42 +32,26 @@ class SliceTest(converter_testing.TestCase):
def test_index_access(self):
def test_fn(l):
def f(l):
directives.set_element_type(l, dtypes.int32)
return l[1]
node, ctx = self.prepare(test_fn, {})
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
node = slices.transform(node, ctx)
tr = self.transform(f, (directives_converter, slices))
with self.compiled(node, {}, (dtypes.int32,)) as result:
with self.cached_session() as sess:
tl = list_ops.tensor_list_from_tensor(
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
y = result.test_fn(tl)
self.assertEqual(2, self.evaluate(y))
tl = list_ops.tensor_list_from_tensor(
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
y = tr(tl)
self.assertEqual(2, self.evaluate(y))
def test_index_access_multiple_definitions(self):
def test_fn(l):
def f(l):
directives.set_element_type(l, dtypes.int32)
if l:
l = []
return l[1]
node, ctx = self.prepare(test_fn, {})
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
def_, = anno.getanno(node.body[0].body[0].targets[0],
anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.float32')
}
with self.assertRaises(ValueError):
slices.transform(node, ctx)
self.transform(f, (directives_converter, slices))
if __name__ == '__main__':

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
from tensorflow.python.autograph.converters import variables
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
@ -27,60 +25,63 @@ from tensorflow.python.platform import test
class VariablesTest(converter_testing.TestCase):
@contextlib.contextmanager
def apply_add_one_conversion(self, fn):
def transform_with_test_ld(self, f):
"""Generates code which adds 1 to all variable reads."""
with self.converted(fn, variables, {}) as result:
result.ag__.__dict__['ld'] = lambda x: x + 1
yield result
return self.transform(f, variables, ag_overrides={'ld': lambda x: x + 1})
def test_read(self):
def test_fn(l):
def f(l):
return l
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(1), 2)
tr = self.transform_with_test_ld(f)
self.assertEqual(tr(1), 2)
def test_aug_assign(self):
def test_fn(l):
def f(l):
l *= 10
return l
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads
tr = self.transform_with_test_ld(f)
self.assertEqual(tr(1), (1 + 1) * 10 + 1) # two reads
def test_del(self):
def test_fn(l):
def f(l):
del l
return l
with self.converted(test_fn, variables, {}) as result:
with self.assertRaisesRegex(
NameError, "'l' is used before assignment"):
result.test_fn(1)
tr = self.transform(f, variables)
def test_del_getitem_ignored(self):
with self.assertRaisesRegex(NameError, "'l' is used before assignment"):
tr(1)
def basic_slice(l):
def test_del_getitem_ignored_basic_slice(self):
def f(l):
del l[0]
return l
with self.converted(basic_slice, variables, {}) as result:
self.assertListEqual([2], result.basic_slice([1, 2]))
tr = self.transform(f, variables)
def range_slice(l):
self.assertListEqual([2], tr([1, 2]))
def test_del_getitem_ignored_range_slice(self):
def f(l):
del l[0:2]
return l
with self.converted(range_slice, variables, {}) as result:
self.assertListEqual([], result.range_slice([1, 2]))
tr = self.transform(f, variables)
self.assertListEqual([], tr([1, 2]))
def test_del_getattr_ignored(self):
def test_fn(l):
def f(l):
del l.a
return l
@ -90,50 +91,60 @@ class VariablesTest(converter_testing.TestCase):
self.a = 1
self.b = 2
with self.converted(test_fn, variables, {}) as result:
self.assertFalse(hasattr(result.test_fn(TestClass()), 'a'))
self.assertEqual(result.test_fn(TestClass()).b, 2)
tr = self.transform(f, variables)
def test_del_packing_ignored(self):
# Note: test for UnboundLocalError, not NameError because in this case we
self.assertFalse(hasattr(tr(TestClass()), 'a'))
self.assertEqual(tr(TestClass()).b, 2)
def test_del_packing_ignored_list(self):
# Note: testing for UnboundLocalError, not NameError because in this case we
# don't rewrite the del.
def list_(a, b):
def f(a, b):
del [a, b]
return a
with self.converted(list_, variables, {}) as result:
with self.assertRaises(UnboundLocalError):
result.list_(1, 2)
tr = self.transform(f, variables)
def nested(a, b, c):
with self.assertRaises(UnboundLocalError):
tr(1, 2)
def test_del_packing_ignored_nested(self):
# Note: testing for UnboundLocalError, not NameError because in this case we
# don't rewrite the del.
def f(a, b, c):
del [a, (b, c)]
return c
with self.converted(nested, variables, {}) as result:
with self.assertRaises(UnboundLocalError):
result.nested(1, 2, 3)
tr = self.transform(f, variables)
def test_del_item_multiple_mixed(self):
with self.assertRaises(UnboundLocalError):
tr(1, 2, 3)
def test_fn_failing(a, b, c):
def test_del_item_multiple_mixed_used_after(self):
def f(a, b, c):
del a, b, c[0]
a = 1
return a, b, c
with self.converted(test_fn_failing, variables, {}) as result:
with self.assertRaisesRegex(
NameError, "'b' is used before assignment"):
result.test_fn_failing(1, 2, [1, 2])
tr = self.transform(f, variables)
def test_fn_passing(a, b, c):
with self.assertRaisesRegex(NameError, "'b' is used before assignment"):
tr(1, 2, [1, 2])
def test_del_item_multiple_mixed_unused_after(self):
def f(a, b, c):
del a, b, c[0]
a = 1
b = 2
return c
with self.converted(test_fn_passing, variables, {}) as result:
self.assertListEqual([2], result.test_fn_passing(1, 2, [1, 2]))
tr = self.transform(f, variables)
self.assertListEqual([2], tr(1, 2, [1, 2]))
def test_attribute(self):
@ -146,12 +157,13 @@ class VariablesTest(converter_testing.TestCase):
self.v += other
return self
def test_fn(l):
def f(l):
return l.v
tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(tc), 2)
tr = self.transform_with_test_ld(f)
self.assertEqual(tr(tc), 2)
def test_subscript(self):
@ -167,12 +179,13 @@ class VariablesTest(converter_testing.TestCase):
def __getitem__(self, _):
return self.v
def test_fn(l):
def f(l):
return l[0]
tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(tc), 2)
tr = self.transform_with_test_ld(f)
self.assertEqual(tr(tc), 2)
def test_call(self):
@ -188,12 +201,13 @@ class VariablesTest(converter_testing.TestCase):
def __call__(self):
return self.v
def test_fn(l):
def f(l):
return l()
tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result:
self.assertEqual(result.test_fn(tc), 2)
tr = self.transform_with_test_ld(f)
self.assertEqual(tr(tc), 2)
if __name__ == '__main__':

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import imp
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.pyct import anno
@ -38,16 +40,18 @@ class ConversionOptionsTest(converter_testing.TestCase):
opts_ast = opts.to_ast()
template = '''
def test_fn():
def f():
return opts_ast
'''
opts_packed = templates.replace(template, opts_ast=opts_ast)
reparsed, _, _ = loader.load_ast(opts_packed)
reparsed.__dict__['ag__'] = self.make_fake_mod(
'fake_ag', converter.ConversionOptions, converter.Feature)
fake_ag = imp.new_module('fake_ag')
fake_ag.ConversionOptions = converter.ConversionOptions
fake_ag.Feature = converter.Feature
reparsed.ag__ = fake_ag
reparsed_opts = reparsed.test_fn()
reparsed_opts = reparsed.f()
self.assertEqual(opts.recursive, reparsed_opts.recursive)
self.assertEqual(opts.user_requested, False)
@ -63,12 +67,12 @@ class ConverterBaseTest(converter_testing.TestCase):
directive_key = object
def test_fn():
def f():
a = 1
return a
ns = {}
node, ctx = self.prepare(test_fn, ns)
_, node, ctx = self.transform(f, (), include_ast=True)
symbol_a = node.body[1].value
defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
defs.directives[directive_key] = {
@ -84,12 +88,12 @@ class ConverterBaseTest(converter_testing.TestCase):
directive_key = object
def test_fn():
def f():
a = 1
return a
ns = {}
node, ctx = self.prepare(test_fn, ns)
_, node, ctx = self.transform(f, (), include_ast=True)
symbol_a = node.body[1].value
c = TestConverter(ctx)
value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
@ -100,14 +104,14 @@ class ConverterBaseTest(converter_testing.TestCase):
directive_key = object
def test_fn():
def f():
a = 1
if a:
a = 2
return a
ns = {}
node, ctx = self.prepare(test_fn, ns)
_, node, ctx = self.transform(f, (), include_ast=True)
symbol_a = node.body[2].value
defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
defs[0].directives[directive_key] = {
@ -127,14 +131,14 @@ class ConverterBaseTest(converter_testing.TestCase):
directive_key = object
def test_fn():
def f():
a = 1
if a:
a = 2
return a
ns = {}
node, ctx = self.prepare(test_fn, ns)
_, node, ctx = self.transform(f, (), include_ast=True)
symbol_a = node.body[2].value
defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
defs[0].directives[directive_key] = {

View File

@ -25,27 +25,15 @@ import sys
import six
from tensorflow.python.autograph import operators
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.impl import conversion
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
def whitelist(entity):
def whitelist(f):
"""Helper that marks a callable as whtelitisted."""
if 'whitelisted_module_for_testing' not in sys.modules:
whitelisted_mod = imp.new_module('whitelisted_module_for_testing')
@ -54,7 +42,7 @@ def whitelist(entity):
(config.DoNotConvert('whitelisted_module_for_testing'),) +
config.CONVERSION_RULES)
entity.__module__ = 'whitelisted_module_for_testing'
f.__module__ = 'whitelisted_module_for_testing'
def is_inside_generated_code():
@ -76,9 +64,39 @@ def is_inside_generated_code():
del frame
class TestingTranspiler(conversion.AutoGraphTranspiler):
"""Testing version that only applies given transformations."""
def __init__(self, converters):
super(TestingTranspiler, self).__init__()
if isinstance(converters, (list, tuple)):
self._converters = converters
else:
self._converters = (converters,)
self.transformed_ast = None
def transform_ast(self, node, ctx):
node = self.initial_analysis(node, ctx)
for c in self._converters:
node = c.transform(node, ctx)
self.transformed_ast = node
self.transform_ctx = ctx
return node
class TestCase(test.TestCase):
"""Base class for unit tests in this module. Contains relevant utilities."""
def setUp(self):
# AutoGraph tests must run in graph mode to properly test control flow.
self.graph = ops.Graph().as_default()
self.graph.__enter__()
def tearDown(self):
self.graph.__exit__(None, None, None)
@contextlib.contextmanager
def assertPrints(self, expected_result):
try:
@ -89,108 +107,26 @@ class TestCase(test.TestCase):
finally:
sys.stdout = sys.__stdout__
@contextlib.contextmanager
def compiled(self, node, namespace, symbols=()):
source = None
self.dynamic_calls = []
# See api.converted_call
def converted_call(
f, args, kwargs, unused_opts=None, unused_function_ctx=None):
"""Mock version of api.converted_call."""
self.dynamic_calls.append((args, kwargs))
if kwargs is None:
kwargs = {}
return f(*args, **kwargs)
def fake_autograph_artifact(f):
setattr(f, 'fake_autograph_artifact', True)
return f
try:
result, source, source_map = loader.load_ast(
node, include_source_map=True)
# TODO(mdan): Move the unparsing from converter into pyct and reuse here.
# TODO(mdan): Move this into self.prepare()
result.tf = self.make_fake_mod('fake_tf', *symbols)
fake_ag = self.make_fake_mod('fake_ag', converted_call,
converter.ConversionOptions)
fake_ag.__dict__.update(operators.__dict__)
fake_ag.__dict__.update(special_functions.__dict__)
fake_ag.ConversionOptions = converter.ConversionOptions
fake_ag.Feature = converter.Feature
fake_ag.utils = utils
fake_ag.FunctionScope = function_wrappers.FunctionScope
fake_ag.autograph_artifact = fake_autograph_artifact
result.ag__ = fake_ag
result.ag_source_map__ = source_map
for k, v in namespace.items():
result.__dict__[k] = v
yield result
except Exception: # pylint:disable=broad-except
if source is None:
print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
else:
print('Offending source code:\n%s' % source)
raise
@contextlib.contextmanager
def converted(self, entity, converter_module, namespace, tf_symbols=()):
node, ctx = self.prepare(entity, namespace)
if not isinstance(converter_module, (list, tuple)):
converter_module = (converter_module,)
for m in converter_module:
node = m.transform(node, ctx)
with self.compiled(node, namespace, tf_symbols) as result:
yield result
def make_fake_mod(self, name, *symbols):
fake_mod = imp.new_module(name)
for s in symbols:
if hasattr(s, '__name__'):
setattr(fake_mod, s.__name__, s)
elif hasattr(s, 'name'):
# This is a bit of a hack, but works for things like tf.int32
setattr(fake_mod, s.name, s)
else:
raise ValueError('can not attach %s - what should be its name?' % s)
return fake_mod
def attach_namespace(self, module, **ns):
for k, v in ns.items():
setattr(module, k, v)
def prepare(self, test_fn, namespace, recursive=True):
namespace['ConversionOptions'] = converter.ConversionOptions
future_features = ('print_function', 'division')
node, source = parser.parse_entity(test_fn, future_features=future_features)
namer = naming.Namer(namespace)
def transform(
self, f, converter_module, include_ast=False, ag_overrides=None):
program_ctx = converter.ProgramContext(
options=converter.ConversionOptions(recursive=recursive),
autograph_module=None)
entity_info = transformer.EntityInfo(
name=test_fn.__name__,
source_code=source,
source_file='<fragment>',
future_features=future_features,
namespace=namespace)
ctx = transformer.Context(entity_info, namer, program_ctx)
origin_info.resolve_entity(node, source, test_fn)
options=converter.ConversionOptions(recursive=True),
autograph_module=api)
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs)
anno.dup(
node,
{
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
},
)
conversion.create_custom_vars(program_ctx)
custom_vars = dict(conversion.custom_vars)
return node, ctx
if ag_overrides:
modified_ag = imp.new_module('fake_autograph')
modified_ag.__dict__.update(custom_vars['ag__'].__dict__)
modified_ag.__dict__.update(ag_overrides)
custom_vars['ag__'] = modified_ag
tr = TestingTranspiler(converter_module)
transformed, _, _ = tr.transform_function(
f, program_ctx.options, program_ctx, custom_vars)
if include_ast:
return transformed, tr.transformed_ast, tr.transform_ctx
return transformed

View File

@ -60,11 +60,7 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
def get_transformed_name(self, node):
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
def transform_ast(self, node, ctx):
# TODO(mdan): Insert list_comprehensions somewhere.
unsupported_features_checker.verify(node)
# Run initial analysis.
def initial_analysis(self, node, ctx):
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
@ -75,6 +71,11 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
},
)
return node
def transform_ast(self, node, ctx):
unsupported_features_checker.verify(node)
node = self.initial_analysis(node, ctx)
node = functions.transform(node, ctx)
node = directives.transform(node, ctx)
@ -114,7 +115,7 @@ def convert(entity, program_ctx):
'expose a __code__ object. If this is a @tf.function,'
' try passing f.python_function instead.')
_create_custom_vars(program_ctx)
create_custom_vars(program_ctx)
transformed, module, source_map = _TRANSPILER.transform_function(
entity, program_ctx.options, program_ctx, custom_vars)
@ -248,7 +249,8 @@ def cache_whitelisted(entity, options):
# TODO(mdan): Move into core or replace with an actual importable module.
def _create_custom_vars(program_ctx):
# Visible for testing.
def create_custom_vars(program_ctx):
"""Adds namespace references to the module that exposes the api itself."""
global custom_vars
if custom_vars is None: