Clean up the test to remove duplication and run with within the real framework.

PiperOrigin-RevId: 318854218
Change-Id: I45125a5a028a6d5fd7ae8cd9bed698d31d04196b
This commit is contained in:
Dan Moldovan 2020-06-29 11:23:20 -07:00 committed by TensorFlower Gardener
parent f072535ba3
commit ce054f48e6
19 changed files with 826 additions and 982 deletions

View File

@ -77,12 +77,6 @@ py_test(
srcs = ["call_trees_test.py"], srcs = ["call_trees_test.py"],
python_version = "PY3", python_version = "PY3",
srcs_version = "PY3", srcs_version = "PY3",
tags = [
"no_oss_py2",
"no_pip",
"no_windows",
"nopip",
],
deps = [ deps = [
":converters", ":converters",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
@ -119,12 +113,6 @@ py_test(
srcs = ["control_flow_test.py"], srcs = ["control_flow_test.py"],
python_version = "PY3", python_version = "PY3",
srcs_version = "PY3", srcs_version = "PY3",
tags = [
"no_oss_py2",
"no_pip",
"no_windows",
"nopip",
],
deps = [ deps = [
":converters", ":converters",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",

View File

@ -24,7 +24,6 @@ from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -32,17 +31,15 @@ class AssertsTest(converter_testing.TestCase):
def test_basic(self): def test_basic(self):
def test_fn(a): def f(a):
assert a, 'testmsg' assert a, 'testmsg'
return a return a
with ops.Graph().as_default(): tr = self.transform(f, (functions, asserts, return_statements))
with self.converted(
test_fn, (functions, asserts, return_statements), {}) as result:
op = result.test_fn(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'): op = tr(constant_op.constant(False))
self.evaluate(op) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
self.evaluate(op)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -21,20 +21,18 @@ from __future__ import print_function
from tensorflow.python.autograph.converters import break_statements from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import anno
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test from tensorflow.python.platform import test
class BreakCanonicalizationTest(converter_testing.TestCase): class BreakCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs): def assertTransformedEquivalent(self, f, *inputs):
with self.converted(test_fn, break_statements, {}, tr = self.transform(f, break_statements)
(constant_op.constant,)) as result: self.assertEqual(f(*inputs), tr(*inputs))
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_while_loop(self): def test_while_loop(self):
def test_fn(x): def f(x):
v = [] v = []
while x > 0: while x > 0:
x -= 1 x -= 1
@ -43,28 +41,29 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 1) self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_while_loop_preserves_directives(self): def test_while_loop_preserves_directives(self):
def test_fn(x): def f(x):
while x > 0: while x > 0:
x -= 1 x -= 1
if x % 2 == 0: if x % 2 == 0:
break break
node, ctx = self.prepare(test_fn, {}) _, node, ctx = self.transform(f, (), include_ast=True)
fake_annotation = object() fake_annotation = object()
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation) anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
node = break_statements.transform(node, ctx) node = break_statements.transform(node, ctx)
self.assertIs( self.assertIs(
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation) anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
def test_for_loop(self): def test_for_loop(self):
def test_fn(a): def f(a):
v = [] v = []
for x in a: for x in a:
x -= 1 x -= 1
@ -73,20 +72,18 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
with self.converted(test_fn, break_statements, {}, tr = self.transform(f, break_statements)
(constant_op.constant,)) as result:
# The break is incompletely canonicalized. The loop will not interrupt, self.assertEqual([3], tr([5, 4]))
# but the section following the break will be skipped.
self.assertEqual([3], result.test_fn([5, 4]))
def test_for_loop_preserves_directives(self): def test_for_loop_preserves_directives(self):
def test_fn(a): def f(a):
for x in a: for x in a:
if x % 2 == 0: if x % 2 == 0:
break break
node, ctx = self.prepare(test_fn, {}) _, node, ctx = self.transform(f, (), include_ast=True)
fake_annotation = object() fake_annotation = object()
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation) anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
node = break_statements.transform(node, ctx) node = break_statements.transform(node, ctx)
@ -95,7 +92,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
def test_nested(self): def test_nested(self):
def test_fn(x): def f(x):
v = [] v = []
u = [] u = []
w = [] w = []
@ -110,13 +107,13 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v, u, w return v, u, w
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 11) self.assertTransformedEquivalent(f, 11)
def test_nested_loops(self): def test_nested_loops(self):
def test_fn(x): def f(x):
v = [] v = []
u = [] u = []
while x > 0: while x > 0:
@ -132,14 +129,14 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v, u return v, u
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 5) self.assertTransformedEquivalent(f, 5)
def test_loop_orelse(self): def test_loop_orelse(self):
def test_fn(x): def f(x):
v = [] v = []
u = [] u = []
while x > 0: while x > 0:
@ -153,12 +150,12 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v, u return v, u
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
def test_multiple_correlated_breaks_with_side_effects(self): def test_multiple_correlated_breaks_with_side_effects(self):
def test_fn(cond1): def f(cond1):
lst = [] lst = []
while True: while True:
if cond1: if cond1:
@ -169,8 +166,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
break break
return lst return lst
self.assertTransformedEquivalent(test_fn, True) self.assertTransformedEquivalent(f, True)
self.assertTransformedEquivalent(test_fn, False) self.assertTransformedEquivalent(f, False)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -27,169 +27,193 @@ from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test from tensorflow.python.platform import test
class MockConvertedCall(object):
def __init__(self):
self.calls = []
def __call__(self, f, args, kwargs, caller_fn_scope=None, options=None):
del caller_fn_scope, options
self.calls.append((args, kwargs))
kwargs = kwargs or {}
return f(*args, **kwargs)
class CallTreesTest(converter_testing.TestCase): class CallTreesTest(converter_testing.TestCase):
def _transform_with_mock(self, f):
mock = MockConvertedCall()
tr = self.transform(
f, (functions, call_trees),
ag_overrides={'converted_call': mock})
return tr, mock
def test_function_no_args(self): def test_function_no_args(self):
def test_fn(f): def f(f):
return f() + 20 return f() + 20
with self.converted(test_fn, (functions, call_trees), {}) as result: tr, mock = self._transform_with_mock(f)
self.assertEqual(result.test_fn(lambda: 1), 21)
self.assertListEqual(self.dynamic_calls, [((), None)]) self.assertEqual(tr(lambda: 1), 21)
self.assertListEqual(mock.calls, [((), None)])
def test_function_with_expression_in_argument(self): def test_function_with_expression_in_argument(self):
def test_fn(f, g): def f(f, g):
return f(g() + 20) + 4000 return f(g() + 20) + 4000
with self.converted(test_fn, (functions, call_trees), {}) as result: tr, mock = self._transform_with_mock(f)
self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
self.assertListEqual(self.dynamic_calls, [ self.assertEqual(tr(lambda x: x + 300, lambda: 1), 4321)
((), None), self.assertListEqual(mock.calls, [
((21,), None), ((), None),
]) ((21,), None),
])
def test_function_with_call_in_argument(self): def test_function_with_call_in_argument(self):
def test_fn(f, g): def f(f, g):
return f(g()) + 300 return f(g()) + 300
with self.converted(test_fn, (functions, call_trees), {}) as result: tr, mock = self._transform_with_mock(f)
self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
self.assertListEqual(self.dynamic_calls, [ self.assertEqual(tr(lambda x: x + 20, lambda: 1), 321)
((), None), self.assertListEqual(mock.calls, [
((1,), None), ((), None),
]) ((1,), None),
])
def test_function_chaining(self): def test_function_chaining(self):
def get_one(): def get_one():
return 1 return 1
def test_fn(): def f():
return get_one().__add__(20) return get_one().__add__(20)
with self.converted(test_fn, (functions, call_trees), tr, mock = self._transform_with_mock(f)
{'get_one': get_one}, ()) as result:
self.assertEqual(result.test_fn(), 21) self.assertEqual(tr(), 21)
self.assertListEqual(mock.calls, [
self.assertListEqual(self.dynamic_calls, [ ((), None),
((), None), ((20,), None),
((20,), None), ])
])
def test_function_with_single_arg(self): def test_function_with_single_arg(self):
def test_fn(f, a): def f(f, a):
return f(a) + 20 return f(a) + 20
with self.converted(test_fn, (functions, call_trees), {}) as result: tr, mock = self._transform_with_mock(f)
self.assertEqual(result.test_fn(lambda a: a, 1), 21)
self.assertListEqual(self.dynamic_calls, [((1,), None)]) self.assertEqual(tr(lambda a: a, 1), 21)
self.assertListEqual(mock.calls, [((1,), None)])
def test_function_with_args_only(self): def test_function_with_args_only(self):
def test_fn(f, a, b): def f(f, a, b):
return f(a, b) + 300 return f(a, b) + 300
with self.converted(test_fn, (functions, call_trees), {}) as result: tr, mock = self._transform_with_mock(f)
self.assertEqual(result.test_fn(lambda a, b: a + b, 1, 20), 321)
self.assertListEqual(self.dynamic_calls, [((1, 20), None)]) self.assertEqual(tr(lambda a, b: a + b, 1, 20), 321)
self.assertListEqual(mock.calls, [((1, 20), None)])
def test_function_with_kwarg(self): def test_function_with_kwarg(self):
def test_fn(f, a, b): def f(f, a, b):
return f(a, c=b) + 300 return f(a, c=b) + 300
with self.converted(test_fn, (functions, call_trees), {}) as result: tr, mock = self._transform_with_mock(f)
self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})]) self.assertEqual(tr(lambda a, c: a + c, 1, 20), 321)
self.assertListEqual(mock.calls, [((1,), {'c': 20})])
def test_function_with_kwargs_starargs(self): def test_function_with_kwargs_starargs(self):
def test_fn(f, a, *args, **kwargs): def f(f, a, *args, **kwargs):
return f(a, *args, **kwargs) + 5 return f(a, *args, **kwargs) + 5
with self.converted(test_fn, (functions, call_trees), {}) as result: tr, mock = self._transform_with_mock(f)
self.assertEqual(
result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{ self.assertEqual(
'b': 4, tr(lambda *args, **kwargs: 7, 1, *[2, 3], **{
'c': 5 'b': 4,
}), 12) 'c': 5
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})]) }), 12)
self.assertListEqual(mock.calls, [((1, 2, 3), {'b': 4, 'c': 5})])
def test_function_with_starargs_only(self): def test_function_with_starargs_only(self):
def f(*args): def g(*args):
return sum(args) return sum(args)
def test_fn(): def f():
args = [1, 20, 300] args = [1, 20, 300]
return f(*args) + 4000 return g(*args) + 4000
with self.converted(test_fn, (functions, call_trees), tr, mock = self._transform_with_mock(f)
{'f': f}) as result:
self.assertEqual(result.test_fn(), 4321)
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
# TODO(b/142586827): Enable this test. self.assertEqual(tr(), 4321)
# def test_function_with_starargs_mixed(self): self.assertListEqual(mock.calls, [((1, 20, 300), None)])
#
# def f(a, b, c, d): def test_function_with_starargs_mixed(self):
# return a * 1000 + b * 100 + c * 10 + d
# def g(a, b, c, d):
# def test_fn(): return a * 1000 + b * 100 + c * 10 + d
# args1 = (1,)
# args2 = [3] def f():
# return f(*args1, 2, *args2, 4) args1 = (1,)
# args2 = [3]
# with self.converted(test_fn, (functions, call_trees), return g(*args1, 2, *args2, 4)
# {'f': f}) as result:
# self.assertEqual(result.test_fn(), 1234) tr, mock = self._transform_with_mock(f)
# self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)])
self.assertEqual(tr(), 1234)
self.assertListEqual(mock.calls, [((1, 2, 3, 4), None)])
def test_function_with_kwargs_keywords(self): def test_function_with_kwargs_keywords(self):
def test_fn(f, a, b, **kwargs): def f(f, a, b, **kwargs):
return f(a, b=b, **kwargs) + 5 return f(a, b=b, **kwargs) + 5
with self.converted(test_fn, (functions, call_trees), {}) as result: tr, mock = self._transform_with_mock(f)
self.assertEqual(
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
# TODO(b/142586827): Enable this test. self.assertEqual(
# def test_function_with_multiple_kwargs(self): tr(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
# self.assertListEqual(mock.calls, [((1,), {'b': 2, 'c': 3})])
# def test_fn(f, a, b, c, kwargs1, kwargs2):
# return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5 def test_function_with_multiple_kwargs(self):
#
# with self.converted(test_fn, (functions, call_trees), {}) as result: def f(f, a, b, c, kwargs1, kwargs2):
# self.assertEqual( return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
# result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4},
# {'e': 5}), 12) tr, mock = self._transform_with_mock(f)
# self.assertListEqual(self.dynamic_calls, [((1,), {
# 'b': 2, self.assertEqual(
# 'c': 3, tr(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4}, {'e': 5}), 12)
# 'd': 4, self.assertListEqual(mock.calls, [((1,), {
# 'e': 5 'b': 2,
# })]) 'c': 3,
'd': 4,
'e': 5
})])
def test_function_with_call_in_lambda_argument(self): def test_function_with_call_in_lambda_argument(self):
def f(l, a): def h(l, a):
return l(a) + 4000 return l(a) + 4000
def g(a, *args): def g(a, *args):
return a + sum(args) return a + sum(args)
def test_fn(f, g, a, *args): def f(h, g, a, *args):
return f(lambda x: g(x, *args), a) return h(lambda x: g(x, *args), a)
with self.converted(test_fn, (functions, call_trees), {}) as result: tr, _ = self._transform_with_mock(f)
self.assertEqual(result.test_fn(f, g, 1, *(20, 300)), 4321)
self.assertEqual(tr(h, g, 1, *(20, 300)), 4321)
def test_debugger_set_trace(self): def test_debugger_set_trace(self):
@ -198,13 +222,13 @@ class CallTreesTest(converter_testing.TestCase):
pdb = imp.new_module('fake_pdb') pdb = imp.new_module('fake_pdb')
pdb.set_trace = lambda: tracking_list.append(1) pdb.set_trace = lambda: tracking_list.append(1)
def test_fn(): def f():
return pdb.set_trace() return pdb.set_trace()
with self.converted(test_fn, (functions, call_trees), tr, _ = self._transform_with_mock(f)
{'pdb': pdb}) as result:
result.test_fn() tr()
self.assertListEqual(tracking_list, [1]) self.assertListEqual(tracking_list, [1])
def test_class_method(self): def test_class_method(self):
@ -217,10 +241,10 @@ class CallTreesTest(converter_testing.TestCase):
return self.other_method(a) + 300 return self.other_method(a) + 300
tc = TestClass() tc = TestClass()
with self.converted(TestClass.test_method, (functions, call_trees), tr, mock = self._transform_with_mock(TestClass.test_method)
{}) as result:
self.assertEqual(321, result.test_method(tc, 1)) self.assertEqual(321, tr(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)]) self.assertListEqual(mock.calls, [((1,), None)])
def test_object_method(self): def test_object_method(self):
@ -233,10 +257,10 @@ class CallTreesTest(converter_testing.TestCase):
return self.other_method(a) + 300 return self.other_method(a) + 300
tc = TestClass() tc = TestClass()
with self.converted(tc.test_method, (functions, call_trees), tr, mock = self._transform_with_mock(tc.test_method)
{}) as result:
self.assertEqual(321, result.test_method(tc, 1)) self.assertEqual(321, tr(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)]) self.assertListEqual(mock.calls, [((1,), None)])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -25,28 +25,27 @@ from tensorflow.python.platform import test
class ConditionalExpressionsTest(converter_testing.TestCase): class ConditionalExpressionsTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs): def assertTransformedEquivalent(self, f, *inputs):
ns = {} tr = self.transform(f, conditional_expressions)
with self.converted(test_fn, conditional_expressions, ns) as result: self.assertEqual(f(*inputs), tr(*inputs))
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_basic(self): def test_basic(self):
def test_fn(x): def f(x):
return 1 if x else 0 return 1 if x else 0
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
def test_nested_orelse(self): def test_nested_orelse(self):
def test_fn(x): def f(x):
y = x * x if x > 0 else x if x else 1 y = x * x if x > 0 else x if x else 1
return y return y
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -20,21 +20,19 @@ from __future__ import print_function
from tensorflow.python.autograph.converters import continue_statements from tensorflow.python.autograph.converters import continue_statements
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
class ContinueCanonicalizationTest(converter_testing.TestCase): class ContinueCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs): def assertTransformedEquivalent(self, f, *inputs):
with self.converted(test_fn, continue_statements, {'ops': ops}, tr = self.transform(f, continue_statements)
(constant_op.constant,)) as result: self.assertEqual(f(*inputs), tr(*inputs))
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_basic(self): def test_basic(self):
def test_fn(x): def f(x):
v = [] v = []
while x > 0: while x > 0:
x -= 1 x -= 1
@ -43,14 +41,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 1) self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_multiple_continues(self): def test_multiple_continues(self):
def test_fn(x): def f(x):
v = [] v = []
while x > 0: while x > 0:
x -= 1 x -= 1
@ -61,14 +59,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 1) self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_multiple_continues_in_nested_scope(self): def test_multiple_continues_in_nested_scope(self):
def test_fn(a): def f(a):
v = [] v = []
for x in a: for x in a:
x -= 1 x -= 1
@ -81,14 +79,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, []) self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(test_fn, [1]) self.assertTransformedEquivalent(f, [1])
self.assertTransformedEquivalent(test_fn, [2]) self.assertTransformedEquivalent(f, [2])
self.assertTransformedEquivalent(test_fn, [1, 2, 3]) self.assertTransformedEquivalent(f, [1, 2, 3])
def test_for_loop(self): def test_for_loop(self):
def test_fn(a): def f(a):
v = [] v = []
for x in a: for x in a:
x -= 1 x -= 1
@ -97,14 +95,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, []) self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(test_fn, [1]) self.assertTransformedEquivalent(f, [1])
self.assertTransformedEquivalent(test_fn, [2]) self.assertTransformedEquivalent(f, [2])
self.assertTransformedEquivalent(test_fn, [1, 2, 3]) self.assertTransformedEquivalent(f, [1, 2, 3])
def test_nested_with(self): def test_nested_with(self):
def test_fn(x): def f(x):
v = [] v = []
while x > 0: while x > 0:
x -= 1 x -= 1
@ -114,14 +112,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 1) self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_nested_multiple_withs(self): def test_nested_multiple_withs(self):
def test_fn(x): def f(x):
v = [] v = []
while x > 0: while x > 0:
x -= 1 x -= 1
@ -133,14 +131,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 1) self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_nested_multiple_withs_and_statements(self): def test_nested_multiple_withs_and_statements(self):
def test_fn(x): def f(x):
v = [] v = []
while x > 0: while x > 0:
x -= 1 x -= 1
@ -154,14 +152,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 1) self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_nested_multiple_withs_and_nested_withs(self): def test_nested_multiple_withs_and_nested_withs(self):
def test_fn(x): def f(x):
v = [] v = []
while x > 0: while x > 0:
x -= 1 x -= 1
@ -176,14 +174,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 1) self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_nested(self): def test_nested(self):
def test_fn(x): def f(x):
v = [] v = []
u = [] u = []
w = [] w = []
@ -198,14 +196,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v, u, w return v, u, w
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 1) self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_multiple_guarded_continues_with_side_effects(self): def test_multiple_guarded_continues_with_side_effects(self):
def test_fn(x): def f(x):
def track(u, x): def track(u, x):
u.append(x) u.append(x)
return x return x
@ -221,8 +219,8 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
v.append(x) v.append(x)
return u, v return u, v
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -23,6 +23,8 @@ import collections
import numpy as np import numpy as np
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.converters import continue_statements
from tensorflow.python.autograph.converters import control_flow from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
@ -34,7 +36,8 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import nest from tensorflow.python.util import nest
# TODO(mdan): These tests are not isolated - they also test the operators.
for_unaffected_global = None
class ControlFlowTestBase(converter_testing.TestCase): class ControlFlowTestBase(converter_testing.TestCase):
@ -45,22 +48,19 @@ class ControlFlowTestBase(converter_testing.TestCase):
actual) actual)
self.assertAllEqual(values, expected) self.assertAllEqual(values, expected)
def assertTransformedResult(self, test_fn, inputs, expected, symbols=None): def assertTransformedResult(self, f, inputs, expected):
if not isinstance(inputs, tuple): if not isinstance(inputs, tuple):
inputs = (inputs,) inputs = (inputs,)
if not symbols: tr = self.transform(f, control_flow)
symbols = {} returns = tr(*inputs)
with self.converted(test_fn, control_flow, symbols, self.assertValuesEqual(returns, expected)
(constant_op.constant,)) as result:
returns = result.test_fn(*inputs)
self.assertValuesEqual(returns, expected)
class NestedControlFlowTest(ControlFlowTestBase): class NestedControlFlowTest(ControlFlowTestBase):
def test_basic(self): def test_basic(self):
def test_fn(n): def f(n):
i = 0 i = 0
j = 0 j = 0
s = 0 s = 0
@ -73,7 +73,7 @@ class NestedControlFlowTest(ControlFlowTestBase):
j = 0 j = 0
return s, i, j, n return s, i, j, n
self.assertTransformedResult(test_fn, constant_op.constant(5), self.assertTransformedResult(f, constant_op.constant(5),
(25, 5, 0, 5)) (25, 5, 0, 5))
def test_composite_state_complex(self): def test_composite_state_complex(self):
@ -88,7 +88,7 @@ class NestedControlFlowTest(ControlFlowTestBase):
def __init__(self, y): def __init__(self, y):
self.y = y self.y = y
def test_fn(n): def f(n):
tc = TestClassX(TestClassY({'z': TestClassX(n)})) tc = TestClassX(TestClassY({'z': TestClassX(n)}))
if n > 0: if n > 0:
while n > 0: while n > 0:
@ -97,19 +97,17 @@ class NestedControlFlowTest(ControlFlowTestBase):
n -= 1 n -= 1
return n, tc return n, tc
with self.converted(test_fn, control_flow, { tr = self.transform(f, control_flow)
'TestClassX': TestClassX,
'TestClassY': TestClassY, n, tc = tr(constant_op.constant(5))
}) as result: self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6))
n, tc = result.test_fn(constant_op.constant(5))
self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6))
class WhileStatementTest(ControlFlowTestBase): class WhileStatementTest(ControlFlowTestBase):
def test_basic(self): def test_basic(self):
def test_fn(n): def f(n):
i = 0 i = 0
s = 0 s = 0
while i < n: while i < n:
@ -117,16 +115,16 @@ class WhileStatementTest(ControlFlowTestBase):
i += 1 i += 1
return s, i, n return s, i, n
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5)) self.assertTransformedResult(f, constant_op.constant(5), (10, 5, 5))
def test_single_output(self): def test_single_output(self):
def test_fn(n): def f(n):
while n > 0: while n > 0:
n -= 1 n -= 1
return n return n
self.assertTransformedResult(test_fn, constant_op.constant(5), 0) self.assertTransformedResult(f, constant_op.constant(5), 0)
def test_composite_state_attr(self): def test_composite_state_attr(self):
@ -135,19 +133,18 @@ class WhileStatementTest(ControlFlowTestBase):
def __init__(self): def __init__(self):
self.x = constant_op.constant(3) self.x = constant_op.constant(3)
def test_fn(n): def f(n):
tc = TestClass() tc = TestClass()
while n > 0: while n > 0:
tc.x += 1 tc.x += 1
n -= 1 n -= 1
return n return n
self.assertTransformedResult( self.assertTransformedResult(f, constant_op.constant(5), 0)
test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass})
def test_composite_state_slice(self): def test_composite_state_slice(self):
def test_fn(n): def f(n):
d = {'a': n} d = {'a': n}
k = 'a' k = 'a'
while n > 0: while n > 0:
@ -155,25 +152,25 @@ class WhileStatementTest(ControlFlowTestBase):
n -= 1 n -= 1
return d[k], n return d[k], n
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 0)) self.assertTransformedResult(f, constant_op.constant(5), (10, 0))
def test_composite_state_literal_slice(self): def test_composite_state_literal_slice(self):
def test_fn(n): def f(n):
d = {'a': n} d = {'a': n}
while n > 0: while n > 0:
d['a'] += 1 d['a'] += 1
n -= 1 n -= 1
return d['a'], n return d['a'], n
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 0)) self.assertTransformedResult(f, constant_op.constant(5), (10, 0))
def test_composite_state_attr_initialized_in_loop(self): def test_composite_state_attr_initialized_in_loop(self):
class TestClass(object): class TestClass(object):
pass pass
def test_fn(n, x): def f(n, x):
tc = TestClass() tc = TestClass()
while n < 5: while n < 5:
if n == 0: if n == 0:
@ -183,19 +180,15 @@ class WhileStatementTest(ControlFlowTestBase):
n += 1 n += 1
return tc.subattr return tc.subattr
self.assertTransformedResult( self.assertTransformedResult(f, (0, constant_op.constant(10)), 14)
test_fn, (0, constant_op.constant(10)), tr = self.transform(f, control_flow)
14, with self.assertRaisesRegex(
symbols={'TestClass': TestClass}) ValueError, "'tc.subattr' must be defined before the loop"):
with self.converted( tr(constant_op.constant(0), 0)
test_fn, control_flow, {'TestClass': TestClass}) as result:
with self.assertRaisesRegex(
ValueError, "'tc.subattr' must be defined before the loop"):
result.test_fn(constant_op.constant(0), 0)
def test_composite_state_slice_initialized_in_loop(self): def test_composite_state_slice_initialized_in_loop(self):
def test_fn(n, x): def f(n, x):
d = {} d = {}
k = 'subkey' k = 'subkey'
while n < 5: while n < 5:
@ -206,16 +199,16 @@ class WhileStatementTest(ControlFlowTestBase):
n += 1 n += 1
return d return d
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), self.assertTransformedResult(f, (0, constant_op.constant(10)),
{'subkey': 14}) {'subkey': 14})
with self.converted(test_fn, control_flow, {}) as result: tr = self.transform(f, control_flow)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r"'d\[k\]' must be defined before the loop"): ValueError, r"'d\[k\]' must be defined before the loop"):
result.test_fn(constant_op.constant(0), 0) tr(constant_op.constant(0), 0)
def test_composite_state_literal_slice_initialized_in_loop(self): def test_composite_state_literal_slice_initialized_in_loop(self):
def test_fn(n, x): def f(n, x):
d = {} d = {}
while n < 5: while n < 5:
if n == 0: if n == 0:
@ -225,16 +218,16 @@ class WhileStatementTest(ControlFlowTestBase):
n += 1 n += 1
return d return d
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), self.assertTransformedResult(f, (0, constant_op.constant(10)),
{'subkey': 14}) {'subkey': 14})
with self.converted(test_fn, control_flow, {}) as result: tr = self.transform(f, control_flow)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r"'d\['subkey'\]' must be defined before the loop"): ValueError, r"'d\['subkey'\]' must be defined before the loop"):
result.test_fn(constant_op.constant(0), 0) tr(constant_op.constant(0), 0)
def test_composite_state_slice_aliased_to_local(self): def test_composite_state_slice_aliased_to_local(self):
def test_fn(n, x): def f(n, x):
d = {} d = {}
while n < 5: while n < 5:
k = 'subkey' k = 'subkey'
@ -242,15 +235,15 @@ class WhileStatementTest(ControlFlowTestBase):
n += 1 n += 1
return d return d
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), self.assertTransformedResult(f, (0, constant_op.constant(10)),
{'subkey': 11}) {'subkey': 11})
with self.converted(test_fn, control_flow, {}) as result: tr = self.transform(f, control_flow)
# TODO(b/136999953): Better error message. # TODO(b/136999953): Better error message.
# Note that this error happens at execution time. # Note that this error happens at execution time.
with self.assertRaises(errors.InaccessibleTensorError): with self.assertRaises(errors.InaccessibleTensorError):
graph_fn = def_function.function(result.test_fn, autograph=False) graph_fn = def_function.function(tr, autograph=False)
self.evaluate( self.evaluate(
graph_fn(constant_op.constant(0), constant_op.constant(5))) graph_fn(constant_op.constant(0), constant_op.constant(5)))
def test_local_composite_attr(self): def test_local_composite_attr(self):
@ -259,19 +252,18 @@ class WhileStatementTest(ControlFlowTestBase):
def __init__(self): def __init__(self):
self.x = constant_op.constant(3) self.x = constant_op.constant(3)
def test_fn(n): def f(n):
while n > 0: while n > 0:
tc = TestClass() tc = TestClass()
tc.x = tc.x tc.x = tc.x
n -= 1 n -= 1
return n return n
self.assertTransformedResult( self.assertTransformedResult(f, constant_op.constant(5), 0)
test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass})
def test_local_composite_slice(self): def test_local_composite_slice(self):
def test_fn(n): def f(n):
while n > 0: while n > 0:
d = {'x': n} d = {'x': n}
k = 'x' k = 'x'
@ -279,26 +271,26 @@ class WhileStatementTest(ControlFlowTestBase):
n -= 1 n -= 1
return n return n
self.assertTransformedResult(test_fn, constant_op.constant(5), 0, {}) self.assertTransformedResult(f, constant_op.constant(5), 0)
def test_local_composite_literal_slice(self): def test_local_composite_literal_slice(self):
def test_fn(n): def f(n):
while n > 0: while n > 0:
d = {'x': n} d = {'x': n}
d['x'] = d['x'] d['x'] = d['x']
n -= 1 n -= 1
return n return n
self.assertTransformedResult(test_fn, constant_op.constant(5), 0, {}) self.assertTransformedResult(f, constant_op.constant(5), 0)
def test_non_tensor_state(self): def test_non_tensor_state(self):
# This class is ok to be in a tf.while_loop's state. # This class is ok to be in a tf.while's state.
class TestClass(collections.namedtuple('TestClass', ('x'))): class TestClass(collections.namedtuple('TestClass', ('x'))):
pass pass
def test_fn(n): def f(n):
tc = TestClass([constant_op.constant(0)]) tc = TestClass([constant_op.constant(0)])
while n > 0: while n > 0:
tc = TestClass([constant_op.constant(3)]) tc = TestClass([constant_op.constant(3)])
@ -306,9 +298,7 @@ class WhileStatementTest(ControlFlowTestBase):
n -= 1 n -= 1
return tc.x[0] return tc.x[0]
ns = {'TestClass': TestClass, 'constant_op': constant_op} self.assertTransformedResult(f, constant_op.constant(5), 4)
self.assertTransformedResult(
test_fn, constant_op.constant(5), 4, symbols=ns)
def test_non_tensor_state_illegal_type(self): def test_non_tensor_state_illegal_type(self):
@ -317,20 +307,20 @@ class WhileStatementTest(ControlFlowTestBase):
def __init__(self): def __init__(self):
self.x = [constant_op.constant(3)] self.x = [constant_op.constant(3)]
def test_fn(n): def f(n):
while n > 0: while n > 0:
tc = TestClass() tc = TestClass()
tc.x[0] = tc.x[0] + 1 tc.x[0] = tc.x[0] + 1
n -= 1 n -= 1
return tc.x[0] return tc.x[0]
with self.converted( tr = self.transform(f, control_flow)
test_fn, control_flow, {'TestClass': TestClass}) as result:
# The tested function would require `tc` to become part of the while loop # The tested function would require `tc` to become part of the while loop
# state, but TensorFlow doesn't support classes at the moment. # state, but TensorFlow doesn't support classes at the moment.
with self.assertRaisesRegexp( with self.assertRaisesRegex(
ValueError, 'tc.*must be defined before the loop'): ValueError, 'tc.*must be defined before the loop'):
result.test_fn(constant_op.constant(5)) tr(constant_op.constant(5))
def test_dispatches_by_cond_only(self): def test_dispatches_by_cond_only(self):
@ -343,27 +333,27 @@ class WhileStatementTest(ControlFlowTestBase):
def __add__(self, other): def __add__(self, other):
return TensorIncompatibleNumeric(self.val + other) return TensorIncompatibleNumeric(self.val + other)
def test_fn(n, s): def f(n, s):
while n > 0: while n > 0:
n -= 1 n -= 1
s += n s += n
return s return s
self.assertTransformedResult(test_fn, (constant_op.constant(5), 0), 10) self.assertTransformedResult(f, (constant_op.constant(5), 0), 10)
with self.converted(test_fn, control_flow, {}) as result: tr = self.transform(f, control_flow)
# n alone controls the staging. When the loop is not staged, Python # n alone controls the staging. When the loop is not staged, Python
# knows how to add the two objects. But when staged, tf.while_loop will # knows how to add the two objects. But when staged, tf.while will
# not know how to deal with the TensorIncompatibleNumeric object. # not know how to deal with the TensorIncompatibleNumeric object.
self.assertEqual(result.test_fn(5, TensorIncompatibleNumeric(0)).val, 10) self.assertEqual(tr(5, TensorIncompatibleNumeric(0)).val, 10)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
result.test_fn(constant_op.constant(5), TensorIncompatibleNumeric(0)) tr(constant_op.constant(5), TensorIncompatibleNumeric(0))
class IfStatementTest(ControlFlowTestBase): class IfStatementTest(ControlFlowTestBase):
def test_basic(self): def test_basic(self):
def test_fn(n): def f(n):
a = 0 a = 0
b = 0 b = 0
if n > 0: if n > 0:
@ -372,20 +362,20 @@ class IfStatementTest(ControlFlowTestBase):
b = 2 * n b = 2 * n
return a, b return a, b
self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0)) self.assertTransformedResult(f, constant_op.constant(1), (-1, 0))
self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2)) self.assertTransformedResult(f, constant_op.constant(-1), (0, -2))
def test_sparse_tensor(self): def test_sparse_tensor(self):
def test_fn(cond, a): def f(cond, a):
if cond: if cond:
a = -a a = -a
return a return a
st = sparse_tensor.SparseTensor( st = sparse_tensor.SparseTensor(
indices=((0,),), values=(0,), dense_shape=(1,)) indices=((0,),), values=(0,), dense_shape=(1,))
self.assertTransformedResult(test_fn, (st, constant_op.constant(1)), -1) self.assertTransformedResult(f, (st, constant_op.constant(1)), -1)
self.assertTransformedResult(test_fn, (None, constant_op.constant(1)), 1) self.assertTransformedResult(f, (None, constant_op.constant(1)), 1)
def test_complex_outputs(self): def test_complex_outputs(self):
@ -395,7 +385,7 @@ class IfStatementTest(ControlFlowTestBase):
self.a = a self.a = a
self.b = b self.b = b
def test_fn(n, obj): def f(n, obj):
obj.a = 0 obj.a = 0
obj.b = 0 obj.b = 0
if n > 0: if n > 0:
@ -404,94 +394,94 @@ class IfStatementTest(ControlFlowTestBase):
obj.b = 2 * n obj.b = 2 * n
return obj return obj
with self.converted(test_fn, control_flow, {}) as result: tr = self.transform(f, control_flow)
res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0))
self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0)) res_obj = tr(constant_op.constant(1), TestClass(0, 0))
res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0)) self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0))
self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2)) res_obj = tr(constant_op.constant(-1), TestClass(0, 0))
self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2))
def test_single_output(self): def test_single_output(self):
def test_fn(n): def f(n):
if n > 0: if n > 0:
n = -n n = -n
return n return n
self.assertTransformedResult(test_fn, constant_op.constant(1), -1) self.assertTransformedResult(f, constant_op.constant(1), -1)
def test_unbalanced(self): def test_unbalanced(self):
def test_fn(n): def f(n):
if n > 0: if n > 0:
n = 3 n = 3
return n return n
self.assertTransformedResult(test_fn, constant_op.constant(2), 3) self.assertTransformedResult(f, constant_op.constant(2), 3)
self.assertTransformedResult(test_fn, constant_op.constant(-3), -3) self.assertTransformedResult(f, constant_op.constant(-3), -3)
def test_unbalanced_raising(self): def test_unbalanced_raising(self):
def test_fn(n): def f(n):
if n > 0: if n > 0:
n = n + 1 n = n + 1
raise ValueError() raise ValueError()
return n return n
self.assertTransformedResult(test_fn, -3, -3) self.assertTransformedResult(f, -3, -3)
with self.converted(test_fn, control_flow, {}) as result: tr = self.transform(f, control_flow)
with self.assertRaises(ValueError):
result.test_fn(1) with self.assertRaises(ValueError):
tr(1)
def test_local_var(self): def test_local_var(self):
def test_fn(n): def f(n):
if n > 0: if n > 0:
b = 4 b = 4
n = b + 1 n = b + 1
return n return n
self.assertTransformedResult(test_fn, constant_op.constant(1), 5) self.assertTransformedResult(f, constant_op.constant(1), 5)
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) self.assertTransformedResult(f, constant_op.constant(-1), -1)
def test_local_remains_local(self): def test_local_remains_local(self):
def test_fn(n): def f(n):
if n > 0: if n > 0:
b = 4 b = 4
n = b + 1 n = b + 1
return n return n
self.assertTransformedResult(test_fn, constant_op.constant(1), 5) self.assertTransformedResult(f, constant_op.constant(1), 5)
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) self.assertTransformedResult(f, constant_op.constant(-1), -1)
def test_no_outputs(self): def test_no_outputs(self):
def test_fn(n): def f(n):
if n > 0: if n > 0:
b = 4 # pylint:disable=unused-variable b = 4 # pylint:disable=unused-variable
return n return n
# Without side effect guards, the if statement will stage a cond, self.assertTransformedResult(f, constant_op.constant(1), 1)
# but that will be pruned at execution. self.assertTransformedResult(f, constant_op.constant(-1), -1)
self.assertTransformedResult(test_fn, constant_op.constant(1), 1)
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
def test_created_outputs(self): def test_created_outputs(self):
def test_fn(i): def f(i):
if i == 0: if i == 0:
result = i - 1 result = i - 1
else: else:
result = i + 1 result = i + 1
return result return result
self.assertTransformedResult(test_fn, 0, -1) self.assertTransformedResult(f, 0, -1)
self.assertTransformedResult(test_fn, 1, 2) self.assertTransformedResult(f, 1, 2)
def test_created_loop_local_outputs(self): def test_created_loop_local_outputs(self):
def test_fn(n, x): def f(n, x):
for i in n: for i in n:
if i == 0: if i == 0:
result = i - 1 result = i - 1
@ -501,11 +491,11 @@ class IfStatementTest(ControlFlowTestBase):
x += 1 x += 1
return x return x
self.assertTransformedResult(test_fn, (range(5), 10), 14) self.assertTransformedResult(f, (range(5), 10), 14)
def test_created_loop_variable(self): def test_created_loop_variable(self):
def test_fn(n, x): def f(n, x):
for i in n: for i in n:
if i == 0: if i == 0:
result = i - 1 result = i - 1
@ -514,22 +504,26 @@ class IfStatementTest(ControlFlowTestBase):
x += 1 x += 1
return x return x
self.assertTransformedResult(test_fn, (range(5), 10), 14) self.assertTransformedResult(f, (range(5), 10), 14)
def test_unaffected_global(self): def test_unaffected_global(self):
def test_fn(i): global for_unaffected_global
global g # pylint:disable=global-variable-undefined for_unaffected_global = 3
if i == 0:
g = i - 1
return g
self.assertTransformedResult(test_fn, 1, 3, symbols={'g': 3}) def f(i):
self.assertTransformedResult(test_fn, 0, -1, symbols={'g': 3}) global for_unaffected_global
if i == 0:
for_unaffected_global = i - 1
return for_unaffected_global
self.assertTransformedResult(f, 1, 3)
self.assertTransformedResult(f, 0, -1)
self.assertEqual(for_unaffected_global, -1)
def test_unaffected_nonlocal(self): def test_unaffected_nonlocal(self):
def test_fn(i): def f(i):
def inner_fn(): def inner_fn():
nonlocal n nonlocal n
if i == 0: if i == 0:
@ -539,12 +533,12 @@ class IfStatementTest(ControlFlowTestBase):
inner_fn() inner_fn()
return n return n
self.assertTransformedResult(test_fn, 1, 3) self.assertTransformedResult(f, 1, 3)
self.assertTransformedResult(test_fn, 0, -1) self.assertTransformedResult(f, 0, -1)
def test_output_defined_in_prior_except(self): def test_output_defined_in_prior_except(self):
def test_fn(i): def f(i):
try: try:
raise ValueError() raise ValueError()
except ValueError: except ValueError:
@ -553,8 +547,8 @@ class IfStatementTest(ControlFlowTestBase):
x = i - 1 x = i - 1
return x return x
self.assertTransformedResult(test_fn, 1, 1) self.assertTransformedResult(f, 1, 1)
self.assertTransformedResult(test_fn, 0, -1) self.assertTransformedResult(f, 0, -1)
def test_unbalanced_multiple_composites(self): def test_unbalanced_multiple_composites(self):
@ -564,7 +558,7 @@ class IfStatementTest(ControlFlowTestBase):
self.b = 2 self.b = 2
self.c = 3 self.c = 3
def test_fn(x, condition): def f(x, condition):
z = 5 z = 5
if condition: if condition:
@ -574,9 +568,9 @@ class IfStatementTest(ControlFlowTestBase):
return x.b, x.c, z return x.b, x.c, z
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)), self.assertTransformedResult(f, (Foo(), constant_op.constant(True)),
(7, 11, 13)) (7, 11, 13))
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)), self.assertTransformedResult(f, (Foo(), constant_op.constant(False)),
(2, 3, 5)) (2, 3, 5))
def test_unbalanced_composite(self): def test_unbalanced_composite(self):
@ -586,7 +580,7 @@ class IfStatementTest(ControlFlowTestBase):
def __init__(self): def __init__(self):
self.b = 2 self.b = 2
def test_fn(x, condition): def f(x, condition):
z = 5 z = 5
if condition: if condition:
@ -595,9 +589,9 @@ class IfStatementTest(ControlFlowTestBase):
return x.b, z return x.b, z
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)), self.assertTransformedResult(f, (Foo(), constant_op.constant(True)),
(7, 13)) (7, 13))
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)), self.assertTransformedResult(f, (Foo(), constant_op.constant(False)),
(2, 5)) (2, 5))
@ -605,7 +599,7 @@ class ForStatementTest(ControlFlowTestBase):
def test_basic(self): def test_basic(self):
def test_fn(l): def f(l):
s1 = 0 s1 = 0
s2 = 0 s2 = 0
for e in l: for e in l:
@ -613,21 +607,21 @@ class ForStatementTest(ControlFlowTestBase):
s2 += e * e s2 += e * e
return s1, s2 return s1, s2
self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10)) self.assertTransformedResult(f, constant_op.constant([1, 3]), (4, 10))
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
self.assertTransformedResult(test_fn, empty_vector, (0, 0)) self.assertTransformedResult(f, empty_vector, (0, 0))
def test_single_output(self): def test_single_output(self):
def test_fn(l): def f(l):
s = 0 s = 0
for e in l: for e in l:
s += e s += e
return s return s
self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4) self.assertTransformedResult(f, constant_op.constant([1, 3]), 4)
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
self.assertTransformedResult(test_fn, empty_vector, 0) self.assertTransformedResult(f, empty_vector, 0)
def test_iterated_expression(self): def test_iterated_expression(self):
@ -637,26 +631,23 @@ class ForStatementTest(ControlFlowTestBase):
eval_count[0] += 1 eval_count[0] += 1
return x return x
def test_fn(n): def f(n):
s = 0 s = 0
for e in count_evals(range(n)): for e in count_evals(range(n)):
s += e s += e
return s return s
ns = {'count_evals': count_evals} tr = self.transform(f, control_flow)
node, ctx = self.prepare(test_fn, ns)
node = control_flow.transform(node, ctx)
with self.compiled(node, ns) as result: self.assertEqual(tr(5), 10)
self.assertEqual(result.test_fn(5), 10) self.assertEqual(eval_count[0], 1)
self.assertEqual(eval_count[0], 1)
def test_composite_state_initialized_in_loop(self): def test_composite_state_initialized_in_loop(self):
class TestClass(object): class TestClass(object):
pass pass
def test_fn(n, x): def f(n, x):
tc = TestClass() tc = TestClass()
for i in n: for i in n:
if i == 0: if i == 0:
@ -665,37 +656,97 @@ class ForStatementTest(ControlFlowTestBase):
tc.x = tc.x + i tc.x = tc.x + i
return tc.x return tc.x
self.assertTransformedResult( self.assertTransformedResult(f, (range(5), constant_op.constant(10)), 20)
test_fn, (range(5), constant_op.constant(10)), tr = self.transform(f, control_flow)
20,
symbols={'TestClass': TestClass}) with self.assertRaisesRegex(
with self.converted( ValueError, "'tc.x' must be defined before the loop"):
test_fn, control_flow, {'TestClass': TestClass}) as result: tr(constant_op.constant(list(range(5))), 0)
with self.assertRaisesRegex(
ValueError, "'tc.x' must be defined before the loop"):
result.test_fn(constant_op.constant(list(range(5))), 0)
def test_tuple_unpacking(self): def test_tuple_unpacking(self):
def test_fn(x_list):
z = tf.constant(0) # pylint:disable=undefined-variable def f(x_list):
z = constant_op.constant(0) # pylint:disable=undefined-variable
for i, x in enumerate(x_list): for i, x in enumerate(x_list):
z = z + x + i z = z + x + i
return z return z
self.assertTransformedResult(test_fn, [3, 3], 7) self.assertTransformedResult(f, [3, 3], 7)
def test_with_comprehension_in_body(self): def test_with_comprehension_in_body(self):
def test_fn(l, n): def f(l, n):
s = constant_op.constant(list(range(n))) s = constant_op.constant(list(range(n)))
for _ in l: for _ in l:
s += constant_op.constant([a for a in range(n)]) s += constant_op.constant([a for a in range(n)])
return s return s
self.assertTransformedResult( self.assertTransformedResult(f, (constant_op.constant([1, 2, 3]), 5),
test_fn, (constant_op.constant([1, 2, 3]), 5), np.array(range(5)) * 4)
np.array(range(5)) * 4,
symbols={'constant_op': constant_op})
class AdvancedControlFlowTest(ControlFlowTestBase):
def assertTransformedEquivalent(self, f, *inputs):
tr = self.transform(
f, (break_statements, continue_statements, control_flow))
self.assertEqual(f(*inputs), tr(*inputs))
def test_while_with_else(self):
def f(x):
while x > 2:
x /= 2
else:
x += 1
return x
self.assertTransformedEquivalent(f, 4)
self.assertTransformedEquivalent(f, 2)
def test_while_with_else_and_break(self):
def f(cond1):
x = 8
while x > 2:
x /= 2
if cond1:
break
else:
x += 1
return x
self.assertTransformedEquivalent(f, True)
self.assertTransformedEquivalent(f, False)
def test_for_with_else(self):
def f(l):
res = 0
for x in l:
res += x
else:
res += 1
return res
self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(f, [1, 2])
def test_for_with_else_and_break(self):
def f(flag):
l = [1, 2, 3]
res = 0
for x in l:
res += x
if flag:
break
else:
res += 1
return res
self.assertTransformedEquivalent(f, True)
self.assertTransformedEquivalent(f, False)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -22,7 +22,6 @@ from tensorflow.python.autograph.converters import directives as directives_conv
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -30,13 +29,12 @@ class DirectivesTest(converter_testing.TestCase):
def test_local_target(self): def test_local_target(self):
def test_fn(): def f():
l = [] l = []
string_var = 0 string_var = 0
directives.set_element_type(l, 'a', string_var) directives.set_element_type(l, 'a', string_var)
node, ctx = self.prepare(test_fn, {'directives': directives}) _, node, _ = self.transform(f, directives_converter, include_ast=True)
node = directives_converter.transform(node, ctx)
def_, = anno.getanno(node.body[0].targets[0], def_, = anno.getanno(node.body[0].targets[0],
anno.Static.DEFINITIONS) anno.Static.DEFINITIONS)
@ -46,11 +44,11 @@ class DirectivesTest(converter_testing.TestCase):
def test_argument_target(self): def test_argument_target(self):
def test_fn(a): def f(a):
directives.set_element_type(a, 1, shape=2) directives.set_element_type(a, 1, shape=2)
pass
node, ctx = self.prepare(test_fn, {'directives': directives}) _, node, _ = self.transform(f, directives_converter, include_ast=True)
node = directives_converter.transform(node, ctx)
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
d = def_.directives[directives.set_element_type] d = def_.directives[directives.set_element_type]
@ -59,13 +57,13 @@ class DirectivesTest(converter_testing.TestCase):
def test_loop_target(self): def test_loop_target(self):
def test_fn(): def f():
a = True a = True
while True: while True:
directives.set_loop_options(parallel_iterations=10, back_prop=a) directives.set_loop_options(parallel_iterations=10, back_prop=a)
pass
node, ctx = self.prepare(test_fn, {'directives': directives}) _, node, _ = self.transform(f, directives_converter, include_ast=True)
node = directives_converter.transform(node, ctx)
d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES) d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES)
d = d[directives.set_loop_options] d = d[directives.set_loop_options]
@ -75,40 +73,23 @@ class DirectivesTest(converter_testing.TestCase):
def test_loop_target_no_loop(self): def test_loop_target_no_loop(self):
def test_fn(): def f():
directives.set_loop_options() directives.set_loop_options()
pass
node, ctx = self.prepare(test_fn, {'directives': directives})
with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'): with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'):
node = directives_converter.transform(node, ctx) self.transform(f, directives_converter, include_ast=True)
def test_loop_target_not_first(self): def test_loop_target_not_first(self):
def test_fn(): def f():
a = 1 a = 1
while True: while True:
a = 2 a = 2
directives.set_loop_options(parallel_iterations=10, back_prop=a) directives.set_loop_options(parallel_iterations=10, back_prop=a)
node, ctx = self.prepare(test_fn, {'directives': directives})
with self.assertRaisesRegexp(ValueError, 'must be the first statement'): with self.assertRaisesRegexp(ValueError, 'must be the first statement'):
node = directives_converter.transform(node, ctx) self.transform(f, directives_converter, include_ast=True)
def test_invalid_default(self):
def invalid_directive(valid_arg, invalid_default=object()):
del valid_arg
del invalid_default
return
def call_invalid_directive():
invalid_directive(1)
node, _ = parser.parse_entity(call_invalid_directive, ())
# Find the call to the invalid directive
node = node.body[0].value
with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
directives_converter._map_args(node, invalid_directive)
def test_value_verification_does_not_trigger_properties(self): def test_value_verification_does_not_trigger_properties(self):
@ -122,11 +103,11 @@ class DirectivesTest(converter_testing.TestCase):
tc = TestClass() tc = TestClass()
def test_fn(): def f():
return tc.b + 1 return tc.b + 1
node, ctx = self.prepare(test_fn, {'tc': tc}) _, node, _ = self.transform(f, directives_converter, include_ast=True)
node = directives_converter.transform(node, ctx)
self.assertIsNotNone(node) self.assertIsNotNone(node)
def test_value_verification_does_not_trigger_getattr(self): def test_value_verification_does_not_trigger_getattr(self):
@ -143,11 +124,11 @@ class DirectivesTest(converter_testing.TestCase):
tc = TestClass() tc = TestClass()
def test_fn(): def f():
return tc.b + 1 return tc.b + 1
node, ctx = self.prepare(test_fn, {'tc': tc}) _, node, _ = self.transform(f, directives_converter, include_ast=True)
node = directives_converter.transform(node, ctx)
self.assertIsNotNone(node) self.assertIsNotNone(node)
self.assertFalse(tc.getattr_called) self.assertFalse(tc.getattr_called)

View File

@ -23,51 +23,49 @@ from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.impl import api
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test from tensorflow.python.platform import test
class FunctionTransformer(converter_testing.TestCase): class FunctionTransformer(converter_testing.TestCase):
@test_util.run_deprecated_v1
def test_basic(self): def test_basic(self):
def test_fn(l): def f(l):
"""Docstring.""" """Docstring."""
a = 1 a = 1
l += a l += a
return l return l
with self.converted(test_fn, functions, {}) as result: tr = self.transform(f, functions)
result_op = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', result_op.op.name) result_op = tr(constant_op.constant(1))
self.assertEqual('Docstring.', result.test_fn.__doc__) self.assertIn('f/', result_op.op.name)
self.assertEqual('Docstring.', tr.__doc__)
@test_util.run_deprecated_v1
def test_multiline_docstring(self): def test_multiline_docstring(self):
tf = None def f():
def test_fn():
"""First sentence. """First sentence.
Second sentence. Second sentence.
Returns:
Something.
""" """
return tf.constant(1) return constant_op.constant(1)
with self.converted(test_fn, functions, {}, tr = self.transform(f, functions)
(constant_op.constant,)) as result:
result_op = result.test_fn() result_op = tr()
self.assertIn('test_fn/', result_op.op.name) self.assertIn('f/', result_op.op.name)
self.assertIn('First sentence.', result.test_fn.__doc__) self.assertIn('First sentence.', tr.__doc__)
self.assertIn('Second sentence.', result.test_fn.__doc__) self.assertIn('Second sentence.', tr.__doc__)
@test_util.run_deprecated_v1
def test_nested_functions(self): def test_nested_functions(self):
def test_fn(l): def f(l):
def inner_fn(i): def inner_fn(i):
return i + 1 return i + 1
@ -75,41 +73,35 @@ class FunctionTransformer(converter_testing.TestCase):
l += 1 l += 1
return l, inner_fn(l) return l, inner_fn(l)
with self.converted(test_fn, (functions, return_statements), {}, tr = self.transform(f, (functions, return_statements))
(ops.name_scope,)) as result:
first, second = result.test_fn(constant_op.constant(1)) first, second = tr(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name) self.assertIn('f/', first.op.name)
self.assertNotIn('inner_fn', first.op.name) self.assertNotIn('inner_fn', first.op.name)
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name) self.assertIn('f/inner_fn/', second.op.inputs[0].name)
@test_util.run_deprecated_v1
def test_conversion_context_preserves_in_inner_functions(self): def test_conversion_context_preserves_in_inner_functions(self):
def inner_fn_callee(): def inner_fn_callee():
self.assertEqual( self.assertEqual(
ag_ctx.control_status_ctx().status, ag_ctx.Status.DISABLED) ag_ctx.control_status_ctx().status, ag_ctx.Status.DISABLED)
def test_fn(): def f():
def inner_fn(): def inner_fn():
inner_fn_callee() inner_fn_callee()
with ag_ctx.ControlStatusCtx( with ag_ctx.ControlStatusCtx(
ag_ctx.Status.DISABLED, converter.ConversionOptions(recursive=True)): ag_ctx.Status.DISABLED, converter.ConversionOptions(recursive=True)):
inner_fn() inner_fn()
ns = { tr = self.transform(f, functions)
'inner_fn_callee': inner_fn_callee,
'ag_ctx': ag_ctx, tr()
'converter': converter
}
with self.converted(test_fn, functions, ns) as result:
result.test_fn()
@test_util.run_deprecated_v1
def test_method(self): def test_method(self):
class TestClass(object): class TestClass(object):
def test_fn(self, l): def f(self, l):
def inner_fn(i): def inner_fn(i):
return i + 1 return i + 1
@ -117,25 +109,22 @@ class FunctionTransformer(converter_testing.TestCase):
l += 1 l += 1
return l, inner_fn(l) return l, inner_fn(l)
ns = {'TestClass': TestClass} tr = self.transform(TestClass.f, (functions, return_statements))
node, ctx = self.prepare(TestClass, ns)
node = functions.transform(node, ctx)
node = return_statements.transform(node, ctx)
with self.compiled(node, {}, (ops.name_scope,)) as result: first, second = tr(TestClass(), constant_op.constant(1))
first, second = result.TestClass().test_fn(constant_op.constant(1)) self.assertIn('f/', first.op.name)
self.assertIn('test_fn/', first.op.name) self.assertNotIn('inner_fn', first.op.name)
self.assertNotIn('inner_fn', first.op.name) self.assertIn('f/inner_fn/', second.op.inputs[0].name)
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
def test_lambda_in_return_value(self): def test_lambda_in_return_value(self):
def test_fn(): def f():
return lambda x: x + 1 return lambda x: x + 1
with self.converted(test_fn, functions, {}) as result: tr = self.transform(f, functions)
result_l = result.test_fn()
self.assertTrue(result_l.fake_autograph_artifact) result_l = tr()
self.assertTrue(api.is_autograph_artifact(result_l))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -25,36 +25,36 @@ from tensorflow.python.platform import test
class ListCompTest(converter_testing.TestCase): class ListCompTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs): def assertTransformedEquivalent(self, f, *inputs):
with self.converted(test_fn, list_comprehensions, {}) as result: tr = self.transform(f, list_comprehensions)
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) self.assertEqual(f(*inputs), tr(*inputs))
def test_basic(self): def test_basic(self):
def test_fn(l): def f(l):
s = [e * e for e in l] s = [e * e for e in l]
return s return s
self.assertTransformedEquivalent(test_fn, []) self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(test_fn, [1, 2, 3]) self.assertTransformedEquivalent(f, [1, 2, 3])
def test_multiple_generators(self): def test_multiple_generators(self):
def test_fn(l): def f(l):
s = [e * e for sublist in l for e in sublist] s = [e * e for sublist in l for e in sublist] # pylint:disable=g-complex-comprehension
return s return s
self.assertTransformedEquivalent(test_fn, []) self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(test_fn, [[1], [2], [3]]) self.assertTransformedEquivalent(f, [[1], [2], [3]])
def test_cond(self): def test_cond(self):
def test_fn(l): def f(l):
s = [e * e for e in l if e > 1] s = [e * e for e in l if e > 1]
return s return s
self.assertTransformedEquivalent(test_fn, []) self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(test_fn, [1, 2, 3]) self.assertTransformedEquivalent(f, [1, 2, 3])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -18,12 +18,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.autograph.converters import directives as directives_converter
from tensorflow.python.autograph.converters import lists from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -31,101 +30,81 @@ from tensorflow.python.ops import list_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
tf = None # Will be replaced by a mock.
class ListTest(converter_testing.TestCase): class ListTest(converter_testing.TestCase):
def test_empty_list(self): def test_empty_list(self):
def test_fn(): def f():
return [] return []
with self.converted(test_fn, lists, {}) as result: tr = self.transform(f, lists)
tl = result.test_fn()
# Empty tensor lists cannot be evaluated or stacked. tl = tr()
self.assertTrue(isinstance(tl, ops.Tensor)) # Empty tensor lists cannot be evaluated or stacked.
self.assertEqual(tl.dtype, dtypes.variant) self.assertIsInstance(tl, ops.Tensor)
self.assertEqual(tl.dtype, dtypes.variant)
def test_initialized_list(self): def test_initialized_list(self):
def test_fn(): def f():
return [1, 2, 3] return [1, 2, 3]
with self.converted(test_fn, lists, {}) as result: tr = self.transform(f, lists)
self.assertAllEqual(result.test_fn(), [1, 2, 3])
self.assertAllEqual(tr(), [1, 2, 3])
def test_list_append(self): def test_list_append(self):
def test_fn(): def f():
l = special_functions.tensor_list([1]) l = special_functions.tensor_list([1])
l.append(2) l.append(2)
l.append(3) l.append(3)
return l return l
ns = {'special_functions': special_functions} tr = self.transform(f, lists)
with self.converted(test_fn, lists, ns) as result:
with self.cached_session() as sess: tl = tr()
tl = result.test_fn() r = list_ops.tensor_list_stack(tl, dtypes.int32)
r = list_ops.tensor_list_stack(tl, dtypes.int32) self.assertAllEqual(self.evaluate(r), [1, 2, 3])
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
def test_list_pop(self): def test_list_pop(self):
def test_fn(): def f():
l = special_functions.tensor_list([1, 2, 3]) l = special_functions.tensor_list([1, 2, 3])
directives.set_element_type(l, dtype=dtypes.int32, shape=())
s = l.pop() s = l.pop()
return s, l return s, l
ns = {'special_functions': special_functions} tr = self.transform(f, (directives_converter, lists))
node, ctx = self.prepare(test_fn, ns)
def_, = anno.getanno(node.body[0].targets[0],
anno.Static.ORIG_DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32'),
'shape': parser.parse_expression('()'),
}
node = lists.transform(node, ctx)
with self.compiled(node, ns, (dtypes.int32,)) as result: ts, tl = tr()
with self.cached_session() as sess: r = list_ops.tensor_list_stack(tl, dtypes.int32)
ts, tl = result.test_fn() self.assertAllEqual(self.evaluate(r), [1, 2])
r = list_ops.tensor_list_stack(tl, dtypes.int32) self.assertAllEqual(self.evaluate(ts), 3)
self.assertAllEqual(self.evaluate(r), [1, 2])
self.assertAllEqual(self.evaluate(ts), 3)
def test_double_list_pop(self): def test_double_list_pop(self):
def test_fn(l): def f(l):
s = l.pop().pop() s = l.pop().pop()
return s return s
with self.converted(test_fn, lists, {}) as result: tr = self.transform(f, lists)
test_input = [1, 2, [1, 2, 3]]
# TODO(mdan): Pass a list of lists of tensor when we fully support that. test_input = [1, 2, [1, 2, 3]]
# For now, we just pass a regular Python list of lists just to verify that # TODO(mdan): Pass a list of lists of tensor when we fully support that.
# the two pop calls are sequenced properly. # For now, we just pass a regular Python list of lists just to verify that
self.assertAllEqual(result.test_fn(test_input), 3) # the two pop calls are sequenced properly.
self.assertAllEqual(tr(test_input), 3)
def test_list_stack(self): def test_list_stack(self):
def test_fn(): def f():
l = [1, 2, 3] l = [1, 2, 3]
return tf.stack(l) return array_ops.stack(l)
node, ctx = self.prepare(test_fn, {}) tr = self.transform(f, lists)
def_, = anno.getanno(node.body[0].targets[0],
anno.Static.ORIG_DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
node = lists.transform(node, ctx)
with self.compiled(node, {}, (array_ops.stack, dtypes.int32)) as result: self.assertAllEqual(self.evaluate(tr()), [1, 2, 3])
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])
# TODO(mdan): Add a test with tf.stack with axis kwarg.
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -27,62 +27,59 @@ from tensorflow.python.platform import test
class LogicalExpressionTest(converter_testing.TestCase): class LogicalExpressionTest(converter_testing.TestCase):
@test_util.run_deprecated_v1
def test_equals(self): def test_equals(self):
def test_fn(a, b): def f(a, b):
return a == b return a == b
with self.converted(test_fn, logical_expressions, {}) as result: tr = self.transform(f, logical_expressions)
with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(constant_op.constant(1), 1))) self.assertTrue(self.evaluate(tr(constant_op.constant(1), 1)))
self.assertFalse(sess.run(result.test_fn(constant_op.constant(1), 2))) self.assertFalse(self.evaluate(tr(constant_op.constant(1), 2)))
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def test_bool_ops(self): def test_bool_ops(self):
def test_fn(a, b, c): def f(a, b, c):
return (a or b) and (a or b or c) and not c return (a or b) and (a or b or c) and not c
with self.converted(test_fn, logical_expressions, {}) as result: tr = self.transform(f, logical_expressions)
with self.cached_session() as sess:
self.assertTrue( self.assertTrue(self.evaluate(tr(constant_op.constant(True), False, False)))
sess.run(result.test_fn(constant_op.constant(True), False, False))) self.assertFalse(self.evaluate(tr(constant_op.constant(True), False, True)))
self.assertFalse(
sess.run(result.test_fn(constant_op.constant(True), False, True)))
@test_util.run_deprecated_v1
def test_comparison(self): def test_comparison(self):
def test_fn(a, b, c, d): def f(a, b, c, d):
return a < b == c > d return a < b == c > d
with self.converted(test_fn, logical_expressions, {}) as result: tr = self.transform(f, logical_expressions)
with self.cached_session() as sess:
# Note: having just the first constant a tensor tests that the # Note: having just the first constant a tensor tests that the
# operations execute in the correct order. If anything other than # operations execute in the correct order. If anything other than
# a < b executed first, the result would be a Python scalar and not a # a < b executed first, the result would be a Python scalar and not a
# Tensor. This is valid as long as the dispat is automatic based on # Tensor. This is valid as long as the dispat is automatic based on
# type. # type.
self.assertTrue( self.assertTrue(self.evaluate(tr(constant_op.constant(1), 2, 2, 1)))
sess.run(result.test_fn(constant_op.constant(1), 2, 2, 1))) self.assertFalse(self.evaluate(tr(constant_op.constant(1), 2, 2, 3)))
self.assertFalse(
sess.run(result.test_fn(constant_op.constant(1), 2, 2, 3)))
def test_default_ops(self): def test_default_ops(self):
def test_fn(a, b): def f(a, b):
return a in b return a in b
with self.converted(test_fn, logical_expressions, {}) as result: tr = self.transform(f, logical_expressions)
self.assertTrue(result.test_fn('a', ('a',)))
self.assertTrue(tr('a', ('a',)))
def test_unary_ops(self): def test_unary_ops(self):
def test_fn(a):
def f(a):
return ~a, -a, +a return ~a, -a, +a
with self.converted(test_fn, logical_expressions, {}) as result: tr = self.transform(f, logical_expressions)
self.assertEqual(result.test_fn(1), (-2, -1, 1))
self.assertEqual(tr(1), (-2, -1, 1))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,95 +0,0 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration Tests for loop."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.converters import continue_statements
from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class LoopIntegrationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn,
[break_statements, continue_statements, control_flow],
{}, (constant_op.constant,)) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_while_loop_with_else(self):
def test_fn(x):
while x > 2:
x /= 2
else:
x += 1
return x
self.assertTransformedEquivalent(test_fn, 4)
self.assertTransformedEquivalent(test_fn, 2)
def test_while_loop_with_else_and_break(self):
def test_fn(cond1):
x = 8
while x > 2:
x /= 2
if cond1:
break
else:
x += 1
return x
self.assertTransformedEquivalent(test_fn, True)
self.assertTransformedEquivalent(test_fn, False)
def test_for_loop_with_else(self):
def test_fn(l):
res = 0
for x in l:
res += x
else:
res += 1
return res
self.assertTransformedEquivalent(test_fn, [])
self.assertTransformedEquivalent(test_fn, [1, 2])
def test_for_loop_with_else_and_break(self):
def test_fn(flag):
l = [1, 2, 3]
res = 0
for x in l:
res += x
if flag:
break
else:
res += 1
return res
self.assertTransformedEquivalent(test_fn, True)
self.assertTransformedEquivalent(test_fn, False)
if __name__ == '__main__':
test.main()

View File

@ -27,81 +27,80 @@ from tensorflow.python.platform import test
class SingleReturnTest(converter_testing.TestCase): class SingleReturnTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs): def assertTransformedEquivalent(self, f, *inputs):
ns = {'ops': ops} tr = self.transform(f, (functions, return_statements))
with self.converted(test_fn, (functions, return_statements), ns) as result: self.assertEqual(f(*inputs), tr(*inputs))
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_straightline(self): def test_straightline(self):
def test_fn(x): def f(x):
return x * x return x * x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
def test_superfluous_returns(self): def test_superfluous_returns(self):
def test_fn(): def f():
retval = 1 retval = 1
return retval return retval
retval = 2 # pylint:disable=unreachable retval = 2 # pylint:disable=unreachable
return retval return retval
self.assertTransformedEquivalent(test_fn) self.assertTransformedEquivalent(f)
def test_superfluous_returns_adjacent(self): def test_superfluous_returns_adjacent(self):
def test_fn(): def f():
return 1 return 1
return 2 # pylint:disable=unreachable return 2 # pylint:disable=unreachable
self.assertTransformedEquivalent(test_fn) self.assertTransformedEquivalent(f)
def test_conditional(self): def test_conditional(self):
def test_fn(x): def f(x):
if x > 0: if x > 0:
return x return x
else: else:
return x * x return x * x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
def test_conditional_missing_else(self): def test_conditional_missing_else(self):
def test_fn(x): def f(x):
if x > 0: if x > 0:
return x return x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
def test_conditional_missing_else_then_default(self): def test_conditional_missing_else_then_default(self):
def test_fn(x): def f(x):
if x > 0: if x > 0:
return x return x
return x * x return x * x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
def test_conditional_else_only_then_default(self): def test_conditional_else_only_then_default(self):
def test_fn(x): def f(x):
if x < 0: if x < 0:
x *= x x *= x
else: else:
return x return x
return x return x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
def test_conditional_nested(self): def test_conditional_nested(self):
def test_fn(x): def f(x):
if x > 0: if x > 0:
if x < 5: if x < 5:
return x return x
@ -110,53 +109,53 @@ class SingleReturnTest(converter_testing.TestCase):
else: else:
return x * x * x return x * x * x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
self.assertTransformedEquivalent(test_fn, 5) self.assertTransformedEquivalent(f, 5)
def test_context_manager(self): def test_context_manager(self):
def test_fn(x): def f(x):
with ops.name_scope(''): with ops.name_scope(''):
return x * x return x * x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
def test_context_manager_in_conditional(self): def test_context_manager_in_conditional(self):
def test_fn(x): def f(x):
if x > 0: if x > 0:
with ops.name_scope(''): with ops.name_scope(''):
return x * x return x * x
else: else:
return x return x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
def text_conditional_in_context_manager(self): def text_conditional_in_context_manager(self):
def test_fn(x): def f(x):
with ops.name_scope(''): with ops.name_scope(''):
if x > 0: if x > 0:
return x * x return x * x
else: else:
return x return x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
def test_no_return(self): def test_no_return(self):
def test_fn(x): def f(x):
x *= x x *= x
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
def test_nested_function(self): def test_nested_function(self):
def test_fn(x): def f(x):
def inner_fn(y): def inner_fn(y):
if y > 0: if y > 0:
@ -166,33 +165,33 @@ class SingleReturnTest(converter_testing.TestCase):
return inner_fn(x) return inner_fn(x)
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
def test_nested_function_in_control_flow(self): def test_nested_function_in_control_flow(self):
def test_fn(x): def f(x):
if x: if x:
def inner_fn(y): def inner_fn(y):
return y return y
inner_fn(x) inner_fn(x)
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, -2) self.assertTransformedEquivalent(f, -2)
def test_for_loop(self): def test_for_loop(self):
def test_fn(n): def f(n):
for _ in range(n): for _ in range(n):
return 1 return 1
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
def test_while_loop(self): def test_while_loop(self):
def test_fn(n): def f(n):
i = 0 i = 0
s = 0 s = 0
while i < n: while i < n:
@ -202,23 +201,23 @@ class SingleReturnTest(converter_testing.TestCase):
return s return s
return -1 return -1
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 2) self.assertTransformedEquivalent(f, 2)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_null_return(self): def test_null_return(self):
def test_fn(n): def f(n):
if n > 4: if n > 4:
return return
return return
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
self.assertTransformedEquivalent(test_fn, 5) self.assertTransformedEquivalent(f, 5)
def test_nested_multiple_withs(self): def test_nested_multiple_withs(self):
def test_fn(x): def f(x):
v = [] v = []
while x > 0: while x > 0:
x -= 1 x -= 1
@ -230,14 +229,14 @@ class SingleReturnTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, 0) self.assertTransformedEquivalent(f, 0)
self.assertTransformedEquivalent(test_fn, 1) self.assertTransformedEquivalent(f, 1)
self.assertTransformedEquivalent(test_fn, 3) self.assertTransformedEquivalent(f, 3)
self.assertTransformedEquivalent(test_fn, 4) self.assertTransformedEquivalent(f, 4)
def test_multiple_returns_in_nested_scope(self): def test_multiple_returns_in_nested_scope(self):
def test_fn(a): def f(a):
v = [] v = []
for x in a: for x in a:
x -= 1 x -= 1
@ -250,10 +249,10 @@ class SingleReturnTest(converter_testing.TestCase):
v.append(x) v.append(x)
return v return v
self.assertTransformedEquivalent(test_fn, []) self.assertTransformedEquivalent(f, [])
self.assertTransformedEquivalent(test_fn, [1]) self.assertTransformedEquivalent(f, [1])
self.assertTransformedEquivalent(test_fn, [2]) self.assertTransformedEquivalent(f, [2])
self.assertTransformedEquivalent(test_fn, [1, 2, 3]) self.assertTransformedEquivalent(f, [1, 2, 3])
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -18,11 +18,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.autograph.converters import directives as directives_converter
from tensorflow.python.autograph.converters import slices from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import list_ops from tensorflow.python.ops import list_ops
@ -33,42 +32,26 @@ class SliceTest(converter_testing.TestCase):
def test_index_access(self): def test_index_access(self):
def test_fn(l): def f(l):
directives.set_element_type(l, dtypes.int32)
return l[1] return l[1]
node, ctx = self.prepare(test_fn, {}) tr = self.transform(f, (directives_converter, slices))
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
node = slices.transform(node, ctx)
with self.compiled(node, {}, (dtypes.int32,)) as result: tl = list_ops.tensor_list_from_tensor(
with self.cached_session() as sess: [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
tl = list_ops.tensor_list_from_tensor( y = tr(tl)
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) self.assertEqual(2, self.evaluate(y))
y = result.test_fn(tl)
self.assertEqual(2, self.evaluate(y))
def test_index_access_multiple_definitions(self): def test_index_access_multiple_definitions(self):
def test_fn(l): def f(l):
directives.set_element_type(l, dtypes.int32)
if l: if l:
l = [] l = []
return l[1] return l[1]
node, ctx = self.prepare(test_fn, {}) self.transform(f, (directives_converter, slices))
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.int32')
}
def_, = anno.getanno(node.body[0].body[0].targets[0],
anno.Static.DEFINITIONS)
def_.directives[directives.set_element_type] = {
'dtype': parser.parse_expression('tf.float32')
}
with self.assertRaises(ValueError):
slices.transform(node, ctx)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import contextlib
from tensorflow.python.autograph.converters import variables from tensorflow.python.autograph.converters import variables
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -27,60 +25,63 @@ from tensorflow.python.platform import test
class VariablesTest(converter_testing.TestCase): class VariablesTest(converter_testing.TestCase):
@contextlib.contextmanager def transform_with_test_ld(self, f):
def apply_add_one_conversion(self, fn):
"""Generates code which adds 1 to all variable reads.""" """Generates code which adds 1 to all variable reads."""
with self.converted(fn, variables, {}) as result: return self.transform(f, variables, ag_overrides={'ld': lambda x: x + 1})
result.ag__.__dict__['ld'] = lambda x: x + 1
yield result
def test_read(self): def test_read(self):
def test_fn(l): def f(l):
return l return l
with self.apply_add_one_conversion(test_fn) as result: tr = self.transform_with_test_ld(f)
self.assertEqual(result.test_fn(1), 2)
self.assertEqual(tr(1), 2)
def test_aug_assign(self): def test_aug_assign(self):
def test_fn(l): def f(l):
l *= 10 l *= 10
return l return l
with self.apply_add_one_conversion(test_fn) as result: tr = self.transform_with_test_ld(f)
self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads
self.assertEqual(tr(1), (1 + 1) * 10 + 1) # two reads
def test_del(self): def test_del(self):
def test_fn(l): def f(l):
del l del l
return l return l
with self.converted(test_fn, variables, {}) as result: tr = self.transform(f, variables)
with self.assertRaisesRegex(
NameError, "'l' is used before assignment"):
result.test_fn(1)
def test_del_getitem_ignored(self): with self.assertRaisesRegex(NameError, "'l' is used before assignment"):
tr(1)
def basic_slice(l): def test_del_getitem_ignored_basic_slice(self):
def f(l):
del l[0] del l[0]
return l return l
with self.converted(basic_slice, variables, {}) as result: tr = self.transform(f, variables)
self.assertListEqual([2], result.basic_slice([1, 2]))
def range_slice(l): self.assertListEqual([2], tr([1, 2]))
def test_del_getitem_ignored_range_slice(self):
def f(l):
del l[0:2] del l[0:2]
return l return l
with self.converted(range_slice, variables, {}) as result: tr = self.transform(f, variables)
self.assertListEqual([], result.range_slice([1, 2]))
self.assertListEqual([], tr([1, 2]))
def test_del_getattr_ignored(self): def test_del_getattr_ignored(self):
def test_fn(l): def f(l):
del l.a del l.a
return l return l
@ -90,50 +91,60 @@ class VariablesTest(converter_testing.TestCase):
self.a = 1 self.a = 1
self.b = 2 self.b = 2
with self.converted(test_fn, variables, {}) as result: tr = self.transform(f, variables)
self.assertFalse(hasattr(result.test_fn(TestClass()), 'a'))
self.assertEqual(result.test_fn(TestClass()).b, 2)
def test_del_packing_ignored(self): self.assertFalse(hasattr(tr(TestClass()), 'a'))
# Note: test for UnboundLocalError, not NameError because in this case we self.assertEqual(tr(TestClass()).b, 2)
def test_del_packing_ignored_list(self):
# Note: testing for UnboundLocalError, not NameError because in this case we
# don't rewrite the del. # don't rewrite the del.
def list_(a, b): def f(a, b):
del [a, b] del [a, b]
return a return a
with self.converted(list_, variables, {}) as result: tr = self.transform(f, variables)
with self.assertRaises(UnboundLocalError):
result.list_(1, 2)
def nested(a, b, c): with self.assertRaises(UnboundLocalError):
tr(1, 2)
def test_del_packing_ignored_nested(self):
# Note: testing for UnboundLocalError, not NameError because in this case we
# don't rewrite the del.
def f(a, b, c):
del [a, (b, c)] del [a, (b, c)]
return c return c
with self.converted(nested, variables, {}) as result: tr = self.transform(f, variables)
with self.assertRaises(UnboundLocalError):
result.nested(1, 2, 3)
def test_del_item_multiple_mixed(self): with self.assertRaises(UnboundLocalError):
tr(1, 2, 3)
def test_fn_failing(a, b, c): def test_del_item_multiple_mixed_used_after(self):
def f(a, b, c):
del a, b, c[0] del a, b, c[0]
a = 1 a = 1
return a, b, c return a, b, c
with self.converted(test_fn_failing, variables, {}) as result: tr = self.transform(f, variables)
with self.assertRaisesRegex(
NameError, "'b' is used before assignment"):
result.test_fn_failing(1, 2, [1, 2])
def test_fn_passing(a, b, c): with self.assertRaisesRegex(NameError, "'b' is used before assignment"):
tr(1, 2, [1, 2])
def test_del_item_multiple_mixed_unused_after(self):
def f(a, b, c):
del a, b, c[0] del a, b, c[0]
a = 1 a = 1
b = 2 b = 2
return c return c
with self.converted(test_fn_passing, variables, {}) as result: tr = self.transform(f, variables)
self.assertListEqual([2], result.test_fn_passing(1, 2, [1, 2]))
self.assertListEqual([2], tr(1, 2, [1, 2]))
def test_attribute(self): def test_attribute(self):
@ -146,12 +157,13 @@ class VariablesTest(converter_testing.TestCase):
self.v += other self.v += other
return self return self
def test_fn(l): def f(l):
return l.v return l.v
tc = TestClass() tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result: tr = self.transform_with_test_ld(f)
self.assertEqual(result.test_fn(tc), 2)
self.assertEqual(tr(tc), 2)
def test_subscript(self): def test_subscript(self):
@ -167,12 +179,13 @@ class VariablesTest(converter_testing.TestCase):
def __getitem__(self, _): def __getitem__(self, _):
return self.v return self.v
def test_fn(l): def f(l):
return l[0] return l[0]
tc = TestClass() tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result: tr = self.transform_with_test_ld(f)
self.assertEqual(result.test_fn(tc), 2)
self.assertEqual(tr(tc), 2)
def test_call(self): def test_call(self):
@ -188,12 +201,13 @@ class VariablesTest(converter_testing.TestCase):
def __call__(self): def __call__(self):
return self.v return self.v
def test_fn(l): def f(l):
return l() return l()
tc = TestClass() tc = TestClass()
with self.apply_add_one_conversion(test_fn) as result: tr = self.transform_with_test_ld(f)
self.assertEqual(result.test_fn(tc), 2)
self.assertEqual(tr(tc), 2)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import imp
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import anno
@ -38,16 +40,18 @@ class ConversionOptionsTest(converter_testing.TestCase):
opts_ast = opts.to_ast() opts_ast = opts.to_ast()
template = ''' template = '''
def test_fn(): def f():
return opts_ast return opts_ast
''' '''
opts_packed = templates.replace(template, opts_ast=opts_ast) opts_packed = templates.replace(template, opts_ast=opts_ast)
reparsed, _, _ = loader.load_ast(opts_packed) reparsed, _, _ = loader.load_ast(opts_packed)
reparsed.__dict__['ag__'] = self.make_fake_mod( fake_ag = imp.new_module('fake_ag')
'fake_ag', converter.ConversionOptions, converter.Feature) fake_ag.ConversionOptions = converter.ConversionOptions
fake_ag.Feature = converter.Feature
reparsed.ag__ = fake_ag
reparsed_opts = reparsed.test_fn() reparsed_opts = reparsed.f()
self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.recursive, reparsed_opts.recursive)
self.assertEqual(opts.user_requested, False) self.assertEqual(opts.user_requested, False)
@ -63,12 +67,12 @@ class ConverterBaseTest(converter_testing.TestCase):
directive_key = object directive_key = object
def test_fn(): def f():
a = 1 a = 1
return a return a
ns = {} _, node, ctx = self.transform(f, (), include_ast=True)
node, ctx = self.prepare(test_fn, ns)
symbol_a = node.body[1].value symbol_a = node.body[1].value
defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
defs.directives[directive_key] = { defs.directives[directive_key] = {
@ -84,12 +88,12 @@ class ConverterBaseTest(converter_testing.TestCase):
directive_key = object directive_key = object
def test_fn(): def f():
a = 1 a = 1
return a return a
ns = {} _, node, ctx = self.transform(f, (), include_ast=True)
node, ctx = self.prepare(test_fn, ns)
symbol_a = node.body[1].value symbol_a = node.body[1].value
c = TestConverter(ctx) c = TestConverter(ctx)
value = c.get_definition_directive(symbol_a, directive_key, 'test_arg', value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
@ -100,14 +104,14 @@ class ConverterBaseTest(converter_testing.TestCase):
directive_key = object directive_key = object
def test_fn(): def f():
a = 1 a = 1
if a: if a:
a = 2 a = 2
return a return a
ns = {} _, node, ctx = self.transform(f, (), include_ast=True)
node, ctx = self.prepare(test_fn, ns)
symbol_a = node.body[2].value symbol_a = node.body[2].value
defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
defs[0].directives[directive_key] = { defs[0].directives[directive_key] = {
@ -127,14 +131,14 @@ class ConverterBaseTest(converter_testing.TestCase):
directive_key = object directive_key = object
def test_fn(): def f():
a = 1 a = 1
if a: if a:
a = 2 a = 2
return a return a
ns = {} _, node, ctx = self.transform(f, (), include_ast=True)
node, ctx = self.prepare(test_fn, ns)
symbol_a = node.body[2].value symbol_a = node.body[2].value
defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
defs[0].directives[directive_key] = { defs[0].directives[directive_key] = {

View File

@ -25,27 +25,15 @@ import sys
import six import six
from tensorflow.python.autograph import operators
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.impl import conversion
from tensorflow.python.autograph.pyct import anno from tensorflow.python.framework import ops
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.platform import test from tensorflow.python.platform import test
def whitelist(entity): def whitelist(f):
"""Helper that marks a callable as whtelitisted.""" """Helper that marks a callable as whtelitisted."""
if 'whitelisted_module_for_testing' not in sys.modules: if 'whitelisted_module_for_testing' not in sys.modules:
whitelisted_mod = imp.new_module('whitelisted_module_for_testing') whitelisted_mod = imp.new_module('whitelisted_module_for_testing')
@ -54,7 +42,7 @@ def whitelist(entity):
(config.DoNotConvert('whitelisted_module_for_testing'),) + (config.DoNotConvert('whitelisted_module_for_testing'),) +
config.CONVERSION_RULES) config.CONVERSION_RULES)
entity.__module__ = 'whitelisted_module_for_testing' f.__module__ = 'whitelisted_module_for_testing'
def is_inside_generated_code(): def is_inside_generated_code():
@ -76,9 +64,39 @@ def is_inside_generated_code():
del frame del frame
class TestingTranspiler(conversion.AutoGraphTranspiler):
"""Testing version that only applies given transformations."""
def __init__(self, converters):
super(TestingTranspiler, self).__init__()
if isinstance(converters, (list, tuple)):
self._converters = converters
else:
self._converters = (converters,)
self.transformed_ast = None
def transform_ast(self, node, ctx):
node = self.initial_analysis(node, ctx)
for c in self._converters:
node = c.transform(node, ctx)
self.transformed_ast = node
self.transform_ctx = ctx
return node
class TestCase(test.TestCase): class TestCase(test.TestCase):
"""Base class for unit tests in this module. Contains relevant utilities.""" """Base class for unit tests in this module. Contains relevant utilities."""
def setUp(self):
# AutoGraph tests must run in graph mode to properly test control flow.
self.graph = ops.Graph().as_default()
self.graph.__enter__()
def tearDown(self):
self.graph.__exit__(None, None, None)
@contextlib.contextmanager @contextlib.contextmanager
def assertPrints(self, expected_result): def assertPrints(self, expected_result):
try: try:
@ -89,108 +107,26 @@ class TestCase(test.TestCase):
finally: finally:
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
@contextlib.contextmanager def transform(
def compiled(self, node, namespace, symbols=()): self, f, converter_module, include_ast=False, ag_overrides=None):
source = None
self.dynamic_calls = []
# See api.converted_call
def converted_call(
f, args, kwargs, unused_opts=None, unused_function_ctx=None):
"""Mock version of api.converted_call."""
self.dynamic_calls.append((args, kwargs))
if kwargs is None:
kwargs = {}
return f(*args, **kwargs)
def fake_autograph_artifact(f):
setattr(f, 'fake_autograph_artifact', True)
return f
try:
result, source, source_map = loader.load_ast(
node, include_source_map=True)
# TODO(mdan): Move the unparsing from converter into pyct and reuse here.
# TODO(mdan): Move this into self.prepare()
result.tf = self.make_fake_mod('fake_tf', *symbols)
fake_ag = self.make_fake_mod('fake_ag', converted_call,
converter.ConversionOptions)
fake_ag.__dict__.update(operators.__dict__)
fake_ag.__dict__.update(special_functions.__dict__)
fake_ag.ConversionOptions = converter.ConversionOptions
fake_ag.Feature = converter.Feature
fake_ag.utils = utils
fake_ag.FunctionScope = function_wrappers.FunctionScope
fake_ag.autograph_artifact = fake_autograph_artifact
result.ag__ = fake_ag
result.ag_source_map__ = source_map
for k, v in namespace.items():
result.__dict__[k] = v
yield result
except Exception: # pylint:disable=broad-except
if source is None:
print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
else:
print('Offending source code:\n%s' % source)
raise
@contextlib.contextmanager
def converted(self, entity, converter_module, namespace, tf_symbols=()):
node, ctx = self.prepare(entity, namespace)
if not isinstance(converter_module, (list, tuple)):
converter_module = (converter_module,)
for m in converter_module:
node = m.transform(node, ctx)
with self.compiled(node, namespace, tf_symbols) as result:
yield result
def make_fake_mod(self, name, *symbols):
fake_mod = imp.new_module(name)
for s in symbols:
if hasattr(s, '__name__'):
setattr(fake_mod, s.__name__, s)
elif hasattr(s, 'name'):
# This is a bit of a hack, but works for things like tf.int32
setattr(fake_mod, s.name, s)
else:
raise ValueError('can not attach %s - what should be its name?' % s)
return fake_mod
def attach_namespace(self, module, **ns):
for k, v in ns.items():
setattr(module, k, v)
def prepare(self, test_fn, namespace, recursive=True):
namespace['ConversionOptions'] = converter.ConversionOptions
future_features = ('print_function', 'division')
node, source = parser.parse_entity(test_fn, future_features=future_features)
namer = naming.Namer(namespace)
program_ctx = converter.ProgramContext( program_ctx = converter.ProgramContext(
options=converter.ConversionOptions(recursive=recursive), options=converter.ConversionOptions(recursive=True),
autograph_module=None) autograph_module=api)
entity_info = transformer.EntityInfo(
name=test_fn.__name__,
source_code=source,
source_file='<fragment>',
future_features=future_features,
namespace=namespace)
ctx = transformer.Context(entity_info, namer, program_ctx)
origin_info.resolve_entity(node, source, test_fn)
graphs = cfg.build(node) conversion.create_custom_vars(program_ctx)
node = qual_names.resolve(node) custom_vars = dict(conversion.custom_vars)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs)
anno.dup(
node,
{
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
},
)
return node, ctx if ag_overrides:
modified_ag = imp.new_module('fake_autograph')
modified_ag.__dict__.update(custom_vars['ag__'].__dict__)
modified_ag.__dict__.update(ag_overrides)
custom_vars['ag__'] = modified_ag
tr = TestingTranspiler(converter_module)
transformed, _, _ = tr.transform_function(
f, program_ctx.options, program_ctx, custom_vars)
if include_ast:
return transformed, tr.transformed_ast, tr.transform_ctx
return transformed

View File

@ -60,11 +60,7 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
def get_transformed_name(self, node): def get_transformed_name(self, node):
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node) return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
def transform_ast(self, node, ctx): def initial_analysis(self, node, ctx):
# TODO(mdan): Insert list_comprehensions somewhere.
unsupported_features_checker.verify(node)
# Run initial analysis.
graphs = cfg.build(node) graphs = cfg.build(node)
node = qual_names.resolve(node) node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None) node = activity.resolve(node, ctx, None)
@ -75,6 +71,11 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
}, },
) )
return node
def transform_ast(self, node, ctx):
unsupported_features_checker.verify(node)
node = self.initial_analysis(node, ctx)
node = functions.transform(node, ctx) node = functions.transform(node, ctx)
node = directives.transform(node, ctx) node = directives.transform(node, ctx)
@ -114,7 +115,7 @@ def convert(entity, program_ctx):
'expose a __code__ object. If this is a @tf.function,' 'expose a __code__ object. If this is a @tf.function,'
' try passing f.python_function instead.') ' try passing f.python_function instead.')
_create_custom_vars(program_ctx) create_custom_vars(program_ctx)
transformed, module, source_map = _TRANSPILER.transform_function( transformed, module, source_map = _TRANSPILER.transform_function(
entity, program_ctx.options, program_ctx, custom_vars) entity, program_ctx.options, program_ctx, custom_vars)
@ -248,7 +249,8 @@ def cache_whitelisted(entity, options):
# TODO(mdan): Move into core or replace with an actual importable module. # TODO(mdan): Move into core or replace with an actual importable module.
def _create_custom_vars(program_ctx): # Visible for testing.
def create_custom_vars(program_ctx):
"""Adds namespace references to the module that exposes the api itself.""" """Adds namespace references to the module that exposes the api itself."""
global custom_vars global custom_vars
if custom_vars is None: if custom_vars is None: