Clean up the test to remove duplication and run with within the real framework.
PiperOrigin-RevId: 318854218 Change-Id: I45125a5a028a6d5fd7ae8cd9bed698d31d04196b
This commit is contained in:
parent
f072535ba3
commit
ce054f48e6
@ -77,12 +77,6 @@ py_test(
|
|||||||
srcs = ["call_trees_test.py"],
|
srcs = ["call_trees_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
tags = [
|
|
||||||
"no_oss_py2",
|
|
||||||
"no_pip",
|
|
||||||
"no_windows",
|
|
||||||
"nopip",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":converters",
|
":converters",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
@ -119,12 +113,6 @@ py_test(
|
|||||||
srcs = ["control_flow_test.py"],
|
srcs = ["control_flow_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
tags = [
|
|
||||||
"no_oss_py2",
|
|
||||||
"no_pip",
|
|
||||||
"no_windows",
|
|
||||||
"nopip",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
":converters",
|
":converters",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -24,7 +24,6 @@ from tensorflow.python.autograph.converters import return_statements
|
|||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -32,17 +31,15 @@ class AssertsTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def f(a):
|
||||||
assert a, 'testmsg'
|
assert a, 'testmsg'
|
||||||
return a
|
return a
|
||||||
|
|
||||||
with ops.Graph().as_default():
|
tr = self.transform(f, (functions, asserts, return_statements))
|
||||||
with self.converted(
|
|
||||||
test_fn, (functions, asserts, return_statements), {}) as result:
|
|
||||||
op = result.test_fn(constant_op.constant(False))
|
|
||||||
|
|
||||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
|
op = tr(constant_op.constant(False))
|
||||||
self.evaluate(op)
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
|
||||||
|
self.evaluate(op)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -21,20 +21,18 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.autograph.converters import break_statements
|
from tensorflow.python.autograph.converters import break_statements
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class BreakCanonicalizationTest(converter_testing.TestCase):
|
class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||||
|
|
||||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
def assertTransformedEquivalent(self, f, *inputs):
|
||||||
with self.converted(test_fn, break_statements, {},
|
tr = self.transform(f, break_statements)
|
||||||
(constant_op.constant,)) as result:
|
self.assertEqual(f(*inputs), tr(*inputs))
|
||||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
|
||||||
|
|
||||||
def test_while_loop(self):
|
def test_while_loop(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -43,28 +41,29 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(f, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_while_loop_preserves_directives(self):
|
def test_while_loop_preserves_directives(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
while x > 0:
|
while x > 0:
|
||||||
x -= 1
|
x -= 1
|
||||||
if x % 2 == 0:
|
if x % 2 == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
_, node, ctx = self.transform(f, (), include_ast=True)
|
||||||
fake_annotation = object()
|
fake_annotation = object()
|
||||||
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
|
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
|
||||||
node = break_statements.transform(node, ctx)
|
node = break_statements.transform(node, ctx)
|
||||||
|
|
||||||
self.assertIs(
|
self.assertIs(
|
||||||
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
|
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
|
||||||
|
|
||||||
def test_for_loop(self):
|
def test_for_loop(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def f(a):
|
||||||
v = []
|
v = []
|
||||||
for x in a:
|
for x in a:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -73,20 +72,18 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
with self.converted(test_fn, break_statements, {},
|
tr = self.transform(f, break_statements)
|
||||||
(constant_op.constant,)) as result:
|
|
||||||
# The break is incompletely canonicalized. The loop will not interrupt,
|
self.assertEqual([3], tr([5, 4]))
|
||||||
# but the section following the break will be skipped.
|
|
||||||
self.assertEqual([3], result.test_fn([5, 4]))
|
|
||||||
|
|
||||||
def test_for_loop_preserves_directives(self):
|
def test_for_loop_preserves_directives(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def f(a):
|
||||||
for x in a:
|
for x in a:
|
||||||
if x % 2 == 0:
|
if x % 2 == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
_, node, ctx = self.transform(f, (), include_ast=True)
|
||||||
fake_annotation = object()
|
fake_annotation = object()
|
||||||
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
|
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
|
||||||
node = break_statements.transform(node, ctx)
|
node = break_statements.transform(node, ctx)
|
||||||
@ -95,7 +92,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
def test_nested(self):
|
def test_nested(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
u = []
|
u = []
|
||||||
w = []
|
w = []
|
||||||
@ -110,13 +107,13 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v, u, w
|
return v, u, w
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 11)
|
self.assertTransformedEquivalent(f, 11)
|
||||||
|
|
||||||
def test_nested_loops(self):
|
def test_nested_loops(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
u = []
|
u = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
@ -132,14 +129,14 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v, u
|
return v, u
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 5)
|
self.assertTransformedEquivalent(f, 5)
|
||||||
|
|
||||||
def test_loop_orelse(self):
|
def test_loop_orelse(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
u = []
|
u = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
@ -153,12 +150,12 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v, u
|
return v, u
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
|
|
||||||
def test_multiple_correlated_breaks_with_side_effects(self):
|
def test_multiple_correlated_breaks_with_side_effects(self):
|
||||||
def test_fn(cond1):
|
def f(cond1):
|
||||||
lst = []
|
lst = []
|
||||||
while True:
|
while True:
|
||||||
if cond1:
|
if cond1:
|
||||||
@ -169,8 +166,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
|||||||
break
|
break
|
||||||
return lst
|
return lst
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, True)
|
self.assertTransformedEquivalent(f, True)
|
||||||
self.assertTransformedEquivalent(test_fn, False)
|
self.assertTransformedEquivalent(f, False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -27,169 +27,193 @@ from tensorflow.python.autograph.core import converter_testing
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class MockConvertedCall(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def __call__(self, f, args, kwargs, caller_fn_scope=None, options=None):
|
||||||
|
del caller_fn_scope, options
|
||||||
|
self.calls.append((args, kwargs))
|
||||||
|
kwargs = kwargs or {}
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CallTreesTest(converter_testing.TestCase):
|
class CallTreesTest(converter_testing.TestCase):
|
||||||
|
|
||||||
|
def _transform_with_mock(self, f):
|
||||||
|
mock = MockConvertedCall()
|
||||||
|
tr = self.transform(
|
||||||
|
f, (functions, call_trees),
|
||||||
|
ag_overrides={'converted_call': mock})
|
||||||
|
return tr, mock
|
||||||
|
|
||||||
def test_function_no_args(self):
|
def test_function_no_args(self):
|
||||||
|
|
||||||
def test_fn(f):
|
def f(f):
|
||||||
return f() + 20
|
return f() + 20
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
tr, mock = self._transform_with_mock(f)
|
||||||
self.assertEqual(result.test_fn(lambda: 1), 21)
|
|
||||||
self.assertListEqual(self.dynamic_calls, [((), None)])
|
self.assertEqual(tr(lambda: 1), 21)
|
||||||
|
self.assertListEqual(mock.calls, [((), None)])
|
||||||
|
|
||||||
def test_function_with_expression_in_argument(self):
|
def test_function_with_expression_in_argument(self):
|
||||||
|
|
||||||
def test_fn(f, g):
|
def f(f, g):
|
||||||
return f(g() + 20) + 4000
|
return f(g() + 20) + 4000
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
tr, mock = self._transform_with_mock(f)
|
||||||
self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
|
|
||||||
self.assertListEqual(self.dynamic_calls, [
|
self.assertEqual(tr(lambda x: x + 300, lambda: 1), 4321)
|
||||||
((), None),
|
self.assertListEqual(mock.calls, [
|
||||||
((21,), None),
|
((), None),
|
||||||
])
|
((21,), None),
|
||||||
|
])
|
||||||
|
|
||||||
def test_function_with_call_in_argument(self):
|
def test_function_with_call_in_argument(self):
|
||||||
|
|
||||||
def test_fn(f, g):
|
def f(f, g):
|
||||||
return f(g()) + 300
|
return f(g()) + 300
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
tr, mock = self._transform_with_mock(f)
|
||||||
self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
|
|
||||||
self.assertListEqual(self.dynamic_calls, [
|
self.assertEqual(tr(lambda x: x + 20, lambda: 1), 321)
|
||||||
((), None),
|
self.assertListEqual(mock.calls, [
|
||||||
((1,), None),
|
((), None),
|
||||||
])
|
((1,), None),
|
||||||
|
])
|
||||||
|
|
||||||
def test_function_chaining(self):
|
def test_function_chaining(self):
|
||||||
|
|
||||||
def get_one():
|
def get_one():
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
return get_one().__add__(20)
|
return get_one().__add__(20)
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees),
|
tr, mock = self._transform_with_mock(f)
|
||||||
{'get_one': get_one}, ()) as result:
|
|
||||||
|
|
||||||
self.assertEqual(result.test_fn(), 21)
|
self.assertEqual(tr(), 21)
|
||||||
|
self.assertListEqual(mock.calls, [
|
||||||
self.assertListEqual(self.dynamic_calls, [
|
((), None),
|
||||||
((), None),
|
((20,), None),
|
||||||
((20,), None),
|
])
|
||||||
])
|
|
||||||
|
|
||||||
def test_function_with_single_arg(self):
|
def test_function_with_single_arg(self):
|
||||||
|
|
||||||
def test_fn(f, a):
|
def f(f, a):
|
||||||
return f(a) + 20
|
return f(a) + 20
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
tr, mock = self._transform_with_mock(f)
|
||||||
self.assertEqual(result.test_fn(lambda a: a, 1), 21)
|
|
||||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
self.assertEqual(tr(lambda a: a, 1), 21)
|
||||||
|
self.assertListEqual(mock.calls, [((1,), None)])
|
||||||
|
|
||||||
def test_function_with_args_only(self):
|
def test_function_with_args_only(self):
|
||||||
|
|
||||||
def test_fn(f, a, b):
|
def f(f, a, b):
|
||||||
return f(a, b) + 300
|
return f(a, b) + 300
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
tr, mock = self._transform_with_mock(f)
|
||||||
self.assertEqual(result.test_fn(lambda a, b: a + b, 1, 20), 321)
|
|
||||||
self.assertListEqual(self.dynamic_calls, [((1, 20), None)])
|
self.assertEqual(tr(lambda a, b: a + b, 1, 20), 321)
|
||||||
|
self.assertListEqual(mock.calls, [((1, 20), None)])
|
||||||
|
|
||||||
def test_function_with_kwarg(self):
|
def test_function_with_kwarg(self):
|
||||||
|
|
||||||
def test_fn(f, a, b):
|
def f(f, a, b):
|
||||||
return f(a, c=b) + 300
|
return f(a, c=b) + 300
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
tr, mock = self._transform_with_mock(f)
|
||||||
self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
|
|
||||||
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})])
|
self.assertEqual(tr(lambda a, c: a + c, 1, 20), 321)
|
||||||
|
self.assertListEqual(mock.calls, [((1,), {'c': 20})])
|
||||||
|
|
||||||
def test_function_with_kwargs_starargs(self):
|
def test_function_with_kwargs_starargs(self):
|
||||||
|
|
||||||
def test_fn(f, a, *args, **kwargs):
|
def f(f, a, *args, **kwargs):
|
||||||
return f(a, *args, **kwargs) + 5
|
return f(a, *args, **kwargs) + 5
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
tr, mock = self._transform_with_mock(f)
|
||||||
self.assertEqual(
|
|
||||||
result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{
|
self.assertEqual(
|
||||||
'b': 4,
|
tr(lambda *args, **kwargs: 7, 1, *[2, 3], **{
|
||||||
'c': 5
|
'b': 4,
|
||||||
}), 12)
|
'c': 5
|
||||||
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
|
}), 12)
|
||||||
|
self.assertListEqual(mock.calls, [((1, 2, 3), {'b': 4, 'c': 5})])
|
||||||
|
|
||||||
def test_function_with_starargs_only(self):
|
def test_function_with_starargs_only(self):
|
||||||
|
|
||||||
def f(*args):
|
def g(*args):
|
||||||
return sum(args)
|
return sum(args)
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
args = [1, 20, 300]
|
args = [1, 20, 300]
|
||||||
return f(*args) + 4000
|
return g(*args) + 4000
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees),
|
tr, mock = self._transform_with_mock(f)
|
||||||
{'f': f}) as result:
|
|
||||||
self.assertEqual(result.test_fn(), 4321)
|
|
||||||
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
|
|
||||||
|
|
||||||
# TODO(b/142586827): Enable this test.
|
self.assertEqual(tr(), 4321)
|
||||||
# def test_function_with_starargs_mixed(self):
|
self.assertListEqual(mock.calls, [((1, 20, 300), None)])
|
||||||
#
|
|
||||||
# def f(a, b, c, d):
|
def test_function_with_starargs_mixed(self):
|
||||||
# return a * 1000 + b * 100 + c * 10 + d
|
|
||||||
#
|
def g(a, b, c, d):
|
||||||
# def test_fn():
|
return a * 1000 + b * 100 + c * 10 + d
|
||||||
# args1 = (1,)
|
|
||||||
# args2 = [3]
|
def f():
|
||||||
# return f(*args1, 2, *args2, 4)
|
args1 = (1,)
|
||||||
#
|
args2 = [3]
|
||||||
# with self.converted(test_fn, (functions, call_trees),
|
return g(*args1, 2, *args2, 4)
|
||||||
# {'f': f}) as result:
|
|
||||||
# self.assertEqual(result.test_fn(), 1234)
|
tr, mock = self._transform_with_mock(f)
|
||||||
# self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)])
|
|
||||||
|
self.assertEqual(tr(), 1234)
|
||||||
|
self.assertListEqual(mock.calls, [((1, 2, 3, 4), None)])
|
||||||
|
|
||||||
def test_function_with_kwargs_keywords(self):
|
def test_function_with_kwargs_keywords(self):
|
||||||
|
|
||||||
def test_fn(f, a, b, **kwargs):
|
def f(f, a, b, **kwargs):
|
||||||
return f(a, b=b, **kwargs) + 5
|
return f(a, b=b, **kwargs) + 5
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
tr, mock = self._transform_with_mock(f)
|
||||||
self.assertEqual(
|
|
||||||
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
|
|
||||||
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
|
|
||||||
|
|
||||||
# TODO(b/142586827): Enable this test.
|
self.assertEqual(
|
||||||
# def test_function_with_multiple_kwargs(self):
|
tr(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
|
||||||
#
|
self.assertListEqual(mock.calls, [((1,), {'b': 2, 'c': 3})])
|
||||||
# def test_fn(f, a, b, c, kwargs1, kwargs2):
|
|
||||||
# return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
|
def test_function_with_multiple_kwargs(self):
|
||||||
#
|
|
||||||
# with self.converted(test_fn, (functions, call_trees), {}) as result:
|
def f(f, a, b, c, kwargs1, kwargs2):
|
||||||
# self.assertEqual(
|
return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
|
||||||
# result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4},
|
|
||||||
# {'e': 5}), 12)
|
tr, mock = self._transform_with_mock(f)
|
||||||
# self.assertListEqual(self.dynamic_calls, [((1,), {
|
|
||||||
# 'b': 2,
|
self.assertEqual(
|
||||||
# 'c': 3,
|
tr(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4}, {'e': 5}), 12)
|
||||||
# 'd': 4,
|
self.assertListEqual(mock.calls, [((1,), {
|
||||||
# 'e': 5
|
'b': 2,
|
||||||
# })])
|
'c': 3,
|
||||||
|
'd': 4,
|
||||||
|
'e': 5
|
||||||
|
})])
|
||||||
|
|
||||||
def test_function_with_call_in_lambda_argument(self):
|
def test_function_with_call_in_lambda_argument(self):
|
||||||
|
|
||||||
def f(l, a):
|
def h(l, a):
|
||||||
return l(a) + 4000
|
return l(a) + 4000
|
||||||
|
|
||||||
def g(a, *args):
|
def g(a, *args):
|
||||||
return a + sum(args)
|
return a + sum(args)
|
||||||
|
|
||||||
def test_fn(f, g, a, *args):
|
def f(h, g, a, *args):
|
||||||
return f(lambda x: g(x, *args), a)
|
return h(lambda x: g(x, *args), a)
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
tr, _ = self._transform_with_mock(f)
|
||||||
self.assertEqual(result.test_fn(f, g, 1, *(20, 300)), 4321)
|
|
||||||
|
self.assertEqual(tr(h, g, 1, *(20, 300)), 4321)
|
||||||
|
|
||||||
def test_debugger_set_trace(self):
|
def test_debugger_set_trace(self):
|
||||||
|
|
||||||
@ -198,13 +222,13 @@ class CallTreesTest(converter_testing.TestCase):
|
|||||||
pdb = imp.new_module('fake_pdb')
|
pdb = imp.new_module('fake_pdb')
|
||||||
pdb.set_trace = lambda: tracking_list.append(1)
|
pdb.set_trace = lambda: tracking_list.append(1)
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
return pdb.set_trace()
|
return pdb.set_trace()
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, call_trees),
|
tr, _ = self._transform_with_mock(f)
|
||||||
{'pdb': pdb}) as result:
|
|
||||||
result.test_fn()
|
tr()
|
||||||
self.assertListEqual(tracking_list, [1])
|
self.assertListEqual(tracking_list, [1])
|
||||||
|
|
||||||
def test_class_method(self):
|
def test_class_method(self):
|
||||||
|
|
||||||
@ -217,10 +241,10 @@ class CallTreesTest(converter_testing.TestCase):
|
|||||||
return self.other_method(a) + 300
|
return self.other_method(a) + 300
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.converted(TestClass.test_method, (functions, call_trees),
|
tr, mock = self._transform_with_mock(TestClass.test_method)
|
||||||
{}) as result:
|
|
||||||
self.assertEqual(321, result.test_method(tc, 1))
|
self.assertEqual(321, tr(tc, 1))
|
||||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
self.assertListEqual(mock.calls, [((1,), None)])
|
||||||
|
|
||||||
def test_object_method(self):
|
def test_object_method(self):
|
||||||
|
|
||||||
@ -233,10 +257,10 @@ class CallTreesTest(converter_testing.TestCase):
|
|||||||
return self.other_method(a) + 300
|
return self.other_method(a) + 300
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.converted(tc.test_method, (functions, call_trees),
|
tr, mock = self._transform_with_mock(tc.test_method)
|
||||||
{}) as result:
|
|
||||||
self.assertEqual(321, result.test_method(tc, 1))
|
self.assertEqual(321, tr(tc, 1))
|
||||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
self.assertListEqual(mock.calls, [((1,), None)])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -25,28 +25,27 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class ConditionalExpressionsTest(converter_testing.TestCase):
|
class ConditionalExpressionsTest(converter_testing.TestCase):
|
||||||
|
|
||||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
def assertTransformedEquivalent(self, f, *inputs):
|
||||||
ns = {}
|
tr = self.transform(f, conditional_expressions)
|
||||||
with self.converted(test_fn, conditional_expressions, ns) as result:
|
self.assertEqual(f(*inputs), tr(*inputs))
|
||||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
return 1 if x else 0
|
return 1 if x else 0
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
|
|
||||||
def test_nested_orelse(self):
|
def test_nested_orelse(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
y = x * x if x > 0 else x if x else 1
|
y = x * x if x > 0 else x if x else 1
|
||||||
return y
|
return y
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -20,21 +20,19 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.autograph.converters import continue_statements
|
from tensorflow.python.autograph.converters import continue_statements
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class ContinueCanonicalizationTest(converter_testing.TestCase):
|
class ContinueCanonicalizationTest(converter_testing.TestCase):
|
||||||
|
|
||||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
def assertTransformedEquivalent(self, f, *inputs):
|
||||||
with self.converted(test_fn, continue_statements, {'ops': ops},
|
tr = self.transform(f, continue_statements)
|
||||||
(constant_op.constant,)) as result:
|
self.assertEqual(f(*inputs), tr(*inputs))
|
||||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -43,14 +41,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(f, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_multiple_continues(self):
|
def test_multiple_continues(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -61,14 +59,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(f, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_multiple_continues_in_nested_scope(self):
|
def test_multiple_continues_in_nested_scope(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def f(a):
|
||||||
v = []
|
v = []
|
||||||
for x in a:
|
for x in a:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -81,14 +79,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, [])
|
self.assertTransformedEquivalent(f, [])
|
||||||
self.assertTransformedEquivalent(test_fn, [1])
|
self.assertTransformedEquivalent(f, [1])
|
||||||
self.assertTransformedEquivalent(test_fn, [2])
|
self.assertTransformedEquivalent(f, [2])
|
||||||
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
|
self.assertTransformedEquivalent(f, [1, 2, 3])
|
||||||
|
|
||||||
def test_for_loop(self):
|
def test_for_loop(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def f(a):
|
||||||
v = []
|
v = []
|
||||||
for x in a:
|
for x in a:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -97,14 +95,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, [])
|
self.assertTransformedEquivalent(f, [])
|
||||||
self.assertTransformedEquivalent(test_fn, [1])
|
self.assertTransformedEquivalent(f, [1])
|
||||||
self.assertTransformedEquivalent(test_fn, [2])
|
self.assertTransformedEquivalent(f, [2])
|
||||||
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
|
self.assertTransformedEquivalent(f, [1, 2, 3])
|
||||||
|
|
||||||
def test_nested_with(self):
|
def test_nested_with(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -114,14 +112,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(f, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_nested_multiple_withs(self):
|
def test_nested_multiple_withs(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -133,14 +131,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(f, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_nested_multiple_withs_and_statements(self):
|
def test_nested_multiple_withs_and_statements(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -154,14 +152,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(f, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_nested_multiple_withs_and_nested_withs(self):
|
def test_nested_multiple_withs_and_nested_withs(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -176,14 +174,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(f, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_nested(self):
|
def test_nested(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
u = []
|
u = []
|
||||||
w = []
|
w = []
|
||||||
@ -198,14 +196,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v, u, w
|
return v, u, w
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(f, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_multiple_guarded_continues_with_side_effects(self):
|
def test_multiple_guarded_continues_with_side_effects(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
def track(u, x):
|
def track(u, x):
|
||||||
u.append(x)
|
u.append(x)
|
||||||
return x
|
return x
|
||||||
@ -221,8 +219,8 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return u, v
|
return u, v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -23,6 +23,8 @@ import collections
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.converters import break_statements
|
||||||
|
from tensorflow.python.autograph.converters import continue_statements
|
||||||
from tensorflow.python.autograph.converters import control_flow
|
from tensorflow.python.autograph.converters import control_flow
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
@ -34,7 +36,8 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
# TODO(mdan): These tests are not isolated - they also test the operators.
|
|
||||||
|
for_unaffected_global = None
|
||||||
|
|
||||||
|
|
||||||
class ControlFlowTestBase(converter_testing.TestCase):
|
class ControlFlowTestBase(converter_testing.TestCase):
|
||||||
@ -45,22 +48,19 @@ class ControlFlowTestBase(converter_testing.TestCase):
|
|||||||
actual)
|
actual)
|
||||||
self.assertAllEqual(values, expected)
|
self.assertAllEqual(values, expected)
|
||||||
|
|
||||||
def assertTransformedResult(self, test_fn, inputs, expected, symbols=None):
|
def assertTransformedResult(self, f, inputs, expected):
|
||||||
if not isinstance(inputs, tuple):
|
if not isinstance(inputs, tuple):
|
||||||
inputs = (inputs,)
|
inputs = (inputs,)
|
||||||
if not symbols:
|
tr = self.transform(f, control_flow)
|
||||||
symbols = {}
|
returns = tr(*inputs)
|
||||||
with self.converted(test_fn, control_flow, symbols,
|
self.assertValuesEqual(returns, expected)
|
||||||
(constant_op.constant,)) as result:
|
|
||||||
returns = result.test_fn(*inputs)
|
|
||||||
self.assertValuesEqual(returns, expected)
|
|
||||||
|
|
||||||
|
|
||||||
class NestedControlFlowTest(ControlFlowTestBase):
|
class NestedControlFlowTest(ControlFlowTestBase):
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
i = 0
|
i = 0
|
||||||
j = 0
|
j = 0
|
||||||
s = 0
|
s = 0
|
||||||
@ -73,7 +73,7 @@ class NestedControlFlowTest(ControlFlowTestBase):
|
|||||||
j = 0
|
j = 0
|
||||||
return s, i, j, n
|
return s, i, j, n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(5),
|
self.assertTransformedResult(f, constant_op.constant(5),
|
||||||
(25, 5, 0, 5))
|
(25, 5, 0, 5))
|
||||||
|
|
||||||
def test_composite_state_complex(self):
|
def test_composite_state_complex(self):
|
||||||
@ -88,7 +88,7 @@ class NestedControlFlowTest(ControlFlowTestBase):
|
|||||||
def __init__(self, y):
|
def __init__(self, y):
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
tc = TestClassX(TestClassY({'z': TestClassX(n)}))
|
tc = TestClassX(TestClassY({'z': TestClassX(n)}))
|
||||||
if n > 0:
|
if n > 0:
|
||||||
while n > 0:
|
while n > 0:
|
||||||
@ -97,19 +97,17 @@ class NestedControlFlowTest(ControlFlowTestBase):
|
|||||||
n -= 1
|
n -= 1
|
||||||
return n, tc
|
return n, tc
|
||||||
|
|
||||||
with self.converted(test_fn, control_flow, {
|
tr = self.transform(f, control_flow)
|
||||||
'TestClassX': TestClassX,
|
|
||||||
'TestClassY': TestClassY,
|
n, tc = tr(constant_op.constant(5))
|
||||||
}) as result:
|
self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6))
|
||||||
n, tc = result.test_fn(constant_op.constant(5))
|
|
||||||
self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6))
|
|
||||||
|
|
||||||
|
|
||||||
class WhileStatementTest(ControlFlowTestBase):
|
class WhileStatementTest(ControlFlowTestBase):
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
i = 0
|
i = 0
|
||||||
s = 0
|
s = 0
|
||||||
while i < n:
|
while i < n:
|
||||||
@ -117,16 +115,16 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
i += 1
|
i += 1
|
||||||
return s, i, n
|
return s, i, n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
|
self.assertTransformedResult(f, constant_op.constant(5), (10, 5, 5))
|
||||||
|
|
||||||
def test_single_output(self):
|
def test_single_output(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
while n > 0:
|
while n > 0:
|
||||||
n -= 1
|
n -= 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(5), 0)
|
self.assertTransformedResult(f, constant_op.constant(5), 0)
|
||||||
|
|
||||||
def test_composite_state_attr(self):
|
def test_composite_state_attr(self):
|
||||||
|
|
||||||
@ -135,19 +133,18 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.x = constant_op.constant(3)
|
self.x = constant_op.constant(3)
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
while n > 0:
|
while n > 0:
|
||||||
tc.x += 1
|
tc.x += 1
|
||||||
n -= 1
|
n -= 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(
|
self.assertTransformedResult(f, constant_op.constant(5), 0)
|
||||||
test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass})
|
|
||||||
|
|
||||||
def test_composite_state_slice(self):
|
def test_composite_state_slice(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
d = {'a': n}
|
d = {'a': n}
|
||||||
k = 'a'
|
k = 'a'
|
||||||
while n > 0:
|
while n > 0:
|
||||||
@ -155,25 +152,25 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
n -= 1
|
n -= 1
|
||||||
return d[k], n
|
return d[k], n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 0))
|
self.assertTransformedResult(f, constant_op.constant(5), (10, 0))
|
||||||
|
|
||||||
def test_composite_state_literal_slice(self):
|
def test_composite_state_literal_slice(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
d = {'a': n}
|
d = {'a': n}
|
||||||
while n > 0:
|
while n > 0:
|
||||||
d['a'] += 1
|
d['a'] += 1
|
||||||
n -= 1
|
n -= 1
|
||||||
return d['a'], n
|
return d['a'], n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 0))
|
self.assertTransformedResult(f, constant_op.constant(5), (10, 0))
|
||||||
|
|
||||||
def test_composite_state_attr_initialized_in_loop(self):
|
def test_composite_state_attr_initialized_in_loop(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_fn(n, x):
|
def f(n, x):
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
while n < 5:
|
while n < 5:
|
||||||
if n == 0:
|
if n == 0:
|
||||||
@ -183,19 +180,15 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
n += 1
|
n += 1
|
||||||
return tc.subattr
|
return tc.subattr
|
||||||
|
|
||||||
self.assertTransformedResult(
|
self.assertTransformedResult(f, (0, constant_op.constant(10)), 14)
|
||||||
test_fn, (0, constant_op.constant(10)),
|
tr = self.transform(f, control_flow)
|
||||||
14,
|
with self.assertRaisesRegex(
|
||||||
symbols={'TestClass': TestClass})
|
ValueError, "'tc.subattr' must be defined before the loop"):
|
||||||
with self.converted(
|
tr(constant_op.constant(0), 0)
|
||||||
test_fn, control_flow, {'TestClass': TestClass}) as result:
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, "'tc.subattr' must be defined before the loop"):
|
|
||||||
result.test_fn(constant_op.constant(0), 0)
|
|
||||||
|
|
||||||
def test_composite_state_slice_initialized_in_loop(self):
|
def test_composite_state_slice_initialized_in_loop(self):
|
||||||
|
|
||||||
def test_fn(n, x):
|
def f(n, x):
|
||||||
d = {}
|
d = {}
|
||||||
k = 'subkey'
|
k = 'subkey'
|
||||||
while n < 5:
|
while n < 5:
|
||||||
@ -206,16 +199,16 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
n += 1
|
n += 1
|
||||||
return d
|
return d
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
|
self.assertTransformedResult(f, (0, constant_op.constant(10)),
|
||||||
{'subkey': 14})
|
{'subkey': 14})
|
||||||
with self.converted(test_fn, control_flow, {}) as result:
|
tr = self.transform(f, control_flow)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, r"'d\[k\]' must be defined before the loop"):
|
ValueError, r"'d\[k\]' must be defined before the loop"):
|
||||||
result.test_fn(constant_op.constant(0), 0)
|
tr(constant_op.constant(0), 0)
|
||||||
|
|
||||||
def test_composite_state_literal_slice_initialized_in_loop(self):
|
def test_composite_state_literal_slice_initialized_in_loop(self):
|
||||||
|
|
||||||
def test_fn(n, x):
|
def f(n, x):
|
||||||
d = {}
|
d = {}
|
||||||
while n < 5:
|
while n < 5:
|
||||||
if n == 0:
|
if n == 0:
|
||||||
@ -225,16 +218,16 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
n += 1
|
n += 1
|
||||||
return d
|
return d
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
|
self.assertTransformedResult(f, (0, constant_op.constant(10)),
|
||||||
{'subkey': 14})
|
{'subkey': 14})
|
||||||
with self.converted(test_fn, control_flow, {}) as result:
|
tr = self.transform(f, control_flow)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, r"'d\['subkey'\]' must be defined before the loop"):
|
ValueError, r"'d\['subkey'\]' must be defined before the loop"):
|
||||||
result.test_fn(constant_op.constant(0), 0)
|
tr(constant_op.constant(0), 0)
|
||||||
|
|
||||||
def test_composite_state_slice_aliased_to_local(self):
|
def test_composite_state_slice_aliased_to_local(self):
|
||||||
|
|
||||||
def test_fn(n, x):
|
def f(n, x):
|
||||||
d = {}
|
d = {}
|
||||||
while n < 5:
|
while n < 5:
|
||||||
k = 'subkey'
|
k = 'subkey'
|
||||||
@ -242,15 +235,15 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
n += 1
|
n += 1
|
||||||
return d
|
return d
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, (0, constant_op.constant(10)),
|
self.assertTransformedResult(f, (0, constant_op.constant(10)),
|
||||||
{'subkey': 11})
|
{'subkey': 11})
|
||||||
with self.converted(test_fn, control_flow, {}) as result:
|
tr = self.transform(f, control_flow)
|
||||||
# TODO(b/136999953): Better error message.
|
# TODO(b/136999953): Better error message.
|
||||||
# Note that this error happens at execution time.
|
# Note that this error happens at execution time.
|
||||||
with self.assertRaises(errors.InaccessibleTensorError):
|
with self.assertRaises(errors.InaccessibleTensorError):
|
||||||
graph_fn = def_function.function(result.test_fn, autograph=False)
|
graph_fn = def_function.function(tr, autograph=False)
|
||||||
self.evaluate(
|
self.evaluate(
|
||||||
graph_fn(constant_op.constant(0), constant_op.constant(5)))
|
graph_fn(constant_op.constant(0), constant_op.constant(5)))
|
||||||
|
|
||||||
def test_local_composite_attr(self):
|
def test_local_composite_attr(self):
|
||||||
|
|
||||||
@ -259,19 +252,18 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.x = constant_op.constant(3)
|
self.x = constant_op.constant(3)
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
while n > 0:
|
while n > 0:
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
tc.x = tc.x
|
tc.x = tc.x
|
||||||
n -= 1
|
n -= 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(
|
self.assertTransformedResult(f, constant_op.constant(5), 0)
|
||||||
test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass})
|
|
||||||
|
|
||||||
def test_local_composite_slice(self):
|
def test_local_composite_slice(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
while n > 0:
|
while n > 0:
|
||||||
d = {'x': n}
|
d = {'x': n}
|
||||||
k = 'x'
|
k = 'x'
|
||||||
@ -279,26 +271,26 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
n -= 1
|
n -= 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(5), 0, {})
|
self.assertTransformedResult(f, constant_op.constant(5), 0)
|
||||||
|
|
||||||
def test_local_composite_literal_slice(self):
|
def test_local_composite_literal_slice(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
while n > 0:
|
while n > 0:
|
||||||
d = {'x': n}
|
d = {'x': n}
|
||||||
d['x'] = d['x']
|
d['x'] = d['x']
|
||||||
n -= 1
|
n -= 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(5), 0, {})
|
self.assertTransformedResult(f, constant_op.constant(5), 0)
|
||||||
|
|
||||||
def test_non_tensor_state(self):
|
def test_non_tensor_state(self):
|
||||||
|
|
||||||
# This class is ok to be in a tf.while_loop's state.
|
# This class is ok to be in a tf.while's state.
|
||||||
class TestClass(collections.namedtuple('TestClass', ('x'))):
|
class TestClass(collections.namedtuple('TestClass', ('x'))):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
tc = TestClass([constant_op.constant(0)])
|
tc = TestClass([constant_op.constant(0)])
|
||||||
while n > 0:
|
while n > 0:
|
||||||
tc = TestClass([constant_op.constant(3)])
|
tc = TestClass([constant_op.constant(3)])
|
||||||
@ -306,9 +298,7 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
n -= 1
|
n -= 1
|
||||||
return tc.x[0]
|
return tc.x[0]
|
||||||
|
|
||||||
ns = {'TestClass': TestClass, 'constant_op': constant_op}
|
self.assertTransformedResult(f, constant_op.constant(5), 4)
|
||||||
self.assertTransformedResult(
|
|
||||||
test_fn, constant_op.constant(5), 4, symbols=ns)
|
|
||||||
|
|
||||||
def test_non_tensor_state_illegal_type(self):
|
def test_non_tensor_state_illegal_type(self):
|
||||||
|
|
||||||
@ -317,20 +307,20 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.x = [constant_op.constant(3)]
|
self.x = [constant_op.constant(3)]
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
while n > 0:
|
while n > 0:
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
tc.x[0] = tc.x[0] + 1
|
tc.x[0] = tc.x[0] + 1
|
||||||
n -= 1
|
n -= 1
|
||||||
return tc.x[0]
|
return tc.x[0]
|
||||||
|
|
||||||
with self.converted(
|
tr = self.transform(f, control_flow)
|
||||||
test_fn, control_flow, {'TestClass': TestClass}) as result:
|
|
||||||
# The tested function would require `tc` to become part of the while loop
|
# The tested function would require `tc` to become part of the while loop
|
||||||
# state, but TensorFlow doesn't support classes at the moment.
|
# state, but TensorFlow doesn't support classes at the moment.
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegex(
|
||||||
ValueError, 'tc.*must be defined before the loop'):
|
ValueError, 'tc.*must be defined before the loop'):
|
||||||
result.test_fn(constant_op.constant(5))
|
tr(constant_op.constant(5))
|
||||||
|
|
||||||
def test_dispatches_by_cond_only(self):
|
def test_dispatches_by_cond_only(self):
|
||||||
|
|
||||||
@ -343,27 +333,27 @@ class WhileStatementTest(ControlFlowTestBase):
|
|||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
return TensorIncompatibleNumeric(self.val + other)
|
return TensorIncompatibleNumeric(self.val + other)
|
||||||
|
|
||||||
def test_fn(n, s):
|
def f(n, s):
|
||||||
while n > 0:
|
while n > 0:
|
||||||
n -= 1
|
n -= 1
|
||||||
s += n
|
s += n
|
||||||
return s
|
return s
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, (constant_op.constant(5), 0), 10)
|
self.assertTransformedResult(f, (constant_op.constant(5), 0), 10)
|
||||||
with self.converted(test_fn, control_flow, {}) as result:
|
tr = self.transform(f, control_flow)
|
||||||
# n alone controls the staging. When the loop is not staged, Python
|
# n alone controls the staging. When the loop is not staged, Python
|
||||||
# knows how to add the two objects. But when staged, tf.while_loop will
|
# knows how to add the two objects. But when staged, tf.while will
|
||||||
# not know how to deal with the TensorIncompatibleNumeric object.
|
# not know how to deal with the TensorIncompatibleNumeric object.
|
||||||
self.assertEqual(result.test_fn(5, TensorIncompatibleNumeric(0)).val, 10)
|
self.assertEqual(tr(5, TensorIncompatibleNumeric(0)).val, 10)
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
result.test_fn(constant_op.constant(5), TensorIncompatibleNumeric(0))
|
tr(constant_op.constant(5), TensorIncompatibleNumeric(0))
|
||||||
|
|
||||||
|
|
||||||
class IfStatementTest(ControlFlowTestBase):
|
class IfStatementTest(ControlFlowTestBase):
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
a = 0
|
a = 0
|
||||||
b = 0
|
b = 0
|
||||||
if n > 0:
|
if n > 0:
|
||||||
@ -372,20 +362,20 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
b = 2 * n
|
b = 2 * n
|
||||||
return a, b
|
return a, b
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0))
|
self.assertTransformedResult(f, constant_op.constant(1), (-1, 0))
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2))
|
self.assertTransformedResult(f, constant_op.constant(-1), (0, -2))
|
||||||
|
|
||||||
def test_sparse_tensor(self):
|
def test_sparse_tensor(self):
|
||||||
|
|
||||||
def test_fn(cond, a):
|
def f(cond, a):
|
||||||
if cond:
|
if cond:
|
||||||
a = -a
|
a = -a
|
||||||
return a
|
return a
|
||||||
|
|
||||||
st = sparse_tensor.SparseTensor(
|
st = sparse_tensor.SparseTensor(
|
||||||
indices=((0,),), values=(0,), dense_shape=(1,))
|
indices=((0,),), values=(0,), dense_shape=(1,))
|
||||||
self.assertTransformedResult(test_fn, (st, constant_op.constant(1)), -1)
|
self.assertTransformedResult(f, (st, constant_op.constant(1)), -1)
|
||||||
self.assertTransformedResult(test_fn, (None, constant_op.constant(1)), 1)
|
self.assertTransformedResult(f, (None, constant_op.constant(1)), 1)
|
||||||
|
|
||||||
def test_complex_outputs(self):
|
def test_complex_outputs(self):
|
||||||
|
|
||||||
@ -395,7 +385,7 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
self.a = a
|
self.a = a
|
||||||
self.b = b
|
self.b = b
|
||||||
|
|
||||||
def test_fn(n, obj):
|
def f(n, obj):
|
||||||
obj.a = 0
|
obj.a = 0
|
||||||
obj.b = 0
|
obj.b = 0
|
||||||
if n > 0:
|
if n > 0:
|
||||||
@ -404,94 +394,94 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
obj.b = 2 * n
|
obj.b = 2 * n
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
with self.converted(test_fn, control_flow, {}) as result:
|
tr = self.transform(f, control_flow)
|
||||||
res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0))
|
|
||||||
self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0))
|
res_obj = tr(constant_op.constant(1), TestClass(0, 0))
|
||||||
res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0))
|
self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0))
|
||||||
self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2))
|
res_obj = tr(constant_op.constant(-1), TestClass(0, 0))
|
||||||
|
self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2))
|
||||||
|
|
||||||
def test_single_output(self):
|
def test_single_output(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
if n > 0:
|
if n > 0:
|
||||||
n = -n
|
n = -n
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(1), -1)
|
self.assertTransformedResult(f, constant_op.constant(1), -1)
|
||||||
|
|
||||||
def test_unbalanced(self):
|
def test_unbalanced(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
if n > 0:
|
if n > 0:
|
||||||
n = 3
|
n = 3
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(2), 3)
|
self.assertTransformedResult(f, constant_op.constant(2), 3)
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(-3), -3)
|
self.assertTransformedResult(f, constant_op.constant(-3), -3)
|
||||||
|
|
||||||
def test_unbalanced_raising(self):
|
def test_unbalanced_raising(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
if n > 0:
|
if n > 0:
|
||||||
n = n + 1
|
n = n + 1
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, -3, -3)
|
self.assertTransformedResult(f, -3, -3)
|
||||||
|
|
||||||
with self.converted(test_fn, control_flow, {}) as result:
|
tr = self.transform(f, control_flow)
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
result.test_fn(1)
|
with self.assertRaises(ValueError):
|
||||||
|
tr(1)
|
||||||
|
|
||||||
def test_local_var(self):
|
def test_local_var(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
if n > 0:
|
if n > 0:
|
||||||
b = 4
|
b = 4
|
||||||
n = b + 1
|
n = b + 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
|
self.assertTransformedResult(f, constant_op.constant(1), 5)
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
|
self.assertTransformedResult(f, constant_op.constant(-1), -1)
|
||||||
|
|
||||||
def test_local_remains_local(self):
|
def test_local_remains_local(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
if n > 0:
|
if n > 0:
|
||||||
b = 4
|
b = 4
|
||||||
n = b + 1
|
n = b + 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
|
self.assertTransformedResult(f, constant_op.constant(1), 5)
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
|
self.assertTransformedResult(f, constant_op.constant(-1), -1)
|
||||||
|
|
||||||
def test_no_outputs(self):
|
def test_no_outputs(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
if n > 0:
|
if n > 0:
|
||||||
b = 4 # pylint:disable=unused-variable
|
b = 4 # pylint:disable=unused-variable
|
||||||
return n
|
return n
|
||||||
|
|
||||||
# Without side effect guards, the if statement will stage a cond,
|
self.assertTransformedResult(f, constant_op.constant(1), 1)
|
||||||
# but that will be pruned at execution.
|
self.assertTransformedResult(f, constant_op.constant(-1), -1)
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(1), 1)
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
|
|
||||||
|
|
||||||
def test_created_outputs(self):
|
def test_created_outputs(self):
|
||||||
|
|
||||||
def test_fn(i):
|
def f(i):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
result = i - 1
|
result = i - 1
|
||||||
else:
|
else:
|
||||||
result = i + 1
|
result = i + 1
|
||||||
return result
|
return result
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, 0, -1)
|
self.assertTransformedResult(f, 0, -1)
|
||||||
self.assertTransformedResult(test_fn, 1, 2)
|
self.assertTransformedResult(f, 1, 2)
|
||||||
|
|
||||||
def test_created_loop_local_outputs(self):
|
def test_created_loop_local_outputs(self):
|
||||||
|
|
||||||
def test_fn(n, x):
|
def f(n, x):
|
||||||
for i in n:
|
for i in n:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
result = i - 1
|
result = i - 1
|
||||||
@ -501,11 +491,11 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
x += 1
|
x += 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, (range(5), 10), 14)
|
self.assertTransformedResult(f, (range(5), 10), 14)
|
||||||
|
|
||||||
def test_created_loop_variable(self):
|
def test_created_loop_variable(self):
|
||||||
|
|
||||||
def test_fn(n, x):
|
def f(n, x):
|
||||||
for i in n:
|
for i in n:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
result = i - 1
|
result = i - 1
|
||||||
@ -514,22 +504,26 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
x += 1
|
x += 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, (range(5), 10), 14)
|
self.assertTransformedResult(f, (range(5), 10), 14)
|
||||||
|
|
||||||
def test_unaffected_global(self):
|
def test_unaffected_global(self):
|
||||||
|
|
||||||
def test_fn(i):
|
global for_unaffected_global
|
||||||
global g # pylint:disable=global-variable-undefined
|
for_unaffected_global = 3
|
||||||
if i == 0:
|
|
||||||
g = i - 1
|
|
||||||
return g
|
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, 1, 3, symbols={'g': 3})
|
def f(i):
|
||||||
self.assertTransformedResult(test_fn, 0, -1, symbols={'g': 3})
|
global for_unaffected_global
|
||||||
|
if i == 0:
|
||||||
|
for_unaffected_global = i - 1
|
||||||
|
return for_unaffected_global
|
||||||
|
|
||||||
|
self.assertTransformedResult(f, 1, 3)
|
||||||
|
self.assertTransformedResult(f, 0, -1)
|
||||||
|
self.assertEqual(for_unaffected_global, -1)
|
||||||
|
|
||||||
def test_unaffected_nonlocal(self):
|
def test_unaffected_nonlocal(self):
|
||||||
|
|
||||||
def test_fn(i):
|
def f(i):
|
||||||
def inner_fn():
|
def inner_fn():
|
||||||
nonlocal n
|
nonlocal n
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@ -539,12 +533,12 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
inner_fn()
|
inner_fn()
|
||||||
return n
|
return n
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, 1, 3)
|
self.assertTransformedResult(f, 1, 3)
|
||||||
self.assertTransformedResult(test_fn, 0, -1)
|
self.assertTransformedResult(f, 0, -1)
|
||||||
|
|
||||||
def test_output_defined_in_prior_except(self):
|
def test_output_defined_in_prior_except(self):
|
||||||
|
|
||||||
def test_fn(i):
|
def f(i):
|
||||||
try:
|
try:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -553,8 +547,8 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
x = i - 1
|
x = i - 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, 1, 1)
|
self.assertTransformedResult(f, 1, 1)
|
||||||
self.assertTransformedResult(test_fn, 0, -1)
|
self.assertTransformedResult(f, 0, -1)
|
||||||
|
|
||||||
def test_unbalanced_multiple_composites(self):
|
def test_unbalanced_multiple_composites(self):
|
||||||
|
|
||||||
@ -564,7 +558,7 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
self.b = 2
|
self.b = 2
|
||||||
self.c = 3
|
self.c = 3
|
||||||
|
|
||||||
def test_fn(x, condition):
|
def f(x, condition):
|
||||||
|
|
||||||
z = 5
|
z = 5
|
||||||
if condition:
|
if condition:
|
||||||
@ -574,9 +568,9 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
|
|
||||||
return x.b, x.c, z
|
return x.b, x.c, z
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)),
|
self.assertTransformedResult(f, (Foo(), constant_op.constant(True)),
|
||||||
(7, 11, 13))
|
(7, 11, 13))
|
||||||
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)),
|
self.assertTransformedResult(f, (Foo(), constant_op.constant(False)),
|
||||||
(2, 3, 5))
|
(2, 3, 5))
|
||||||
|
|
||||||
def test_unbalanced_composite(self):
|
def test_unbalanced_composite(self):
|
||||||
@ -586,7 +580,7 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.b = 2
|
self.b = 2
|
||||||
|
|
||||||
def test_fn(x, condition):
|
def f(x, condition):
|
||||||
|
|
||||||
z = 5
|
z = 5
|
||||||
if condition:
|
if condition:
|
||||||
@ -595,9 +589,9 @@ class IfStatementTest(ControlFlowTestBase):
|
|||||||
|
|
||||||
return x.b, z
|
return x.b, z
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)),
|
self.assertTransformedResult(f, (Foo(), constant_op.constant(True)),
|
||||||
(7, 13))
|
(7, 13))
|
||||||
self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)),
|
self.assertTransformedResult(f, (Foo(), constant_op.constant(False)),
|
||||||
(2, 5))
|
(2, 5))
|
||||||
|
|
||||||
|
|
||||||
@ -605,7 +599,7 @@ class ForStatementTest(ControlFlowTestBase):
|
|||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
s1 = 0
|
s1 = 0
|
||||||
s2 = 0
|
s2 = 0
|
||||||
for e in l:
|
for e in l:
|
||||||
@ -613,21 +607,21 @@ class ForStatementTest(ControlFlowTestBase):
|
|||||||
s2 += e * e
|
s2 += e * e
|
||||||
return s1, s2
|
return s1, s2
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10))
|
self.assertTransformedResult(f, constant_op.constant([1, 3]), (4, 10))
|
||||||
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
|
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
|
||||||
self.assertTransformedResult(test_fn, empty_vector, (0, 0))
|
self.assertTransformedResult(f, empty_vector, (0, 0))
|
||||||
|
|
||||||
def test_single_output(self):
|
def test_single_output(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
s = 0
|
s = 0
|
||||||
for e in l:
|
for e in l:
|
||||||
s += e
|
s += e
|
||||||
return s
|
return s
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4)
|
self.assertTransformedResult(f, constant_op.constant([1, 3]), 4)
|
||||||
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
|
empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
|
||||||
self.assertTransformedResult(test_fn, empty_vector, 0)
|
self.assertTransformedResult(f, empty_vector, 0)
|
||||||
|
|
||||||
def test_iterated_expression(self):
|
def test_iterated_expression(self):
|
||||||
|
|
||||||
@ -637,26 +631,23 @@ class ForStatementTest(ControlFlowTestBase):
|
|||||||
eval_count[0] += 1
|
eval_count[0] += 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
s = 0
|
s = 0
|
||||||
for e in count_evals(range(n)):
|
for e in count_evals(range(n)):
|
||||||
s += e
|
s += e
|
||||||
return s
|
return s
|
||||||
|
|
||||||
ns = {'count_evals': count_evals}
|
tr = self.transform(f, control_flow)
|
||||||
node, ctx = self.prepare(test_fn, ns)
|
|
||||||
node = control_flow.transform(node, ctx)
|
|
||||||
|
|
||||||
with self.compiled(node, ns) as result:
|
self.assertEqual(tr(5), 10)
|
||||||
self.assertEqual(result.test_fn(5), 10)
|
self.assertEqual(eval_count[0], 1)
|
||||||
self.assertEqual(eval_count[0], 1)
|
|
||||||
|
|
||||||
def test_composite_state_initialized_in_loop(self):
|
def test_composite_state_initialized_in_loop(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_fn(n, x):
|
def f(n, x):
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
for i in n:
|
for i in n:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@ -665,37 +656,97 @@ class ForStatementTest(ControlFlowTestBase):
|
|||||||
tc.x = tc.x + i
|
tc.x = tc.x + i
|
||||||
return tc.x
|
return tc.x
|
||||||
|
|
||||||
self.assertTransformedResult(
|
self.assertTransformedResult(f, (range(5), constant_op.constant(10)), 20)
|
||||||
test_fn, (range(5), constant_op.constant(10)),
|
tr = self.transform(f, control_flow)
|
||||||
20,
|
|
||||||
symbols={'TestClass': TestClass})
|
with self.assertRaisesRegex(
|
||||||
with self.converted(
|
ValueError, "'tc.x' must be defined before the loop"):
|
||||||
test_fn, control_flow, {'TestClass': TestClass}) as result:
|
tr(constant_op.constant(list(range(5))), 0)
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, "'tc.x' must be defined before the loop"):
|
|
||||||
result.test_fn(constant_op.constant(list(range(5))), 0)
|
|
||||||
|
|
||||||
def test_tuple_unpacking(self):
|
def test_tuple_unpacking(self):
|
||||||
def test_fn(x_list):
|
|
||||||
z = tf.constant(0) # pylint:disable=undefined-variable
|
def f(x_list):
|
||||||
|
z = constant_op.constant(0) # pylint:disable=undefined-variable
|
||||||
for i, x in enumerate(x_list):
|
for i, x in enumerate(x_list):
|
||||||
z = z + x + i
|
z = z + x + i
|
||||||
return z
|
return z
|
||||||
|
|
||||||
self.assertTransformedResult(test_fn, [3, 3], 7)
|
self.assertTransformedResult(f, [3, 3], 7)
|
||||||
|
|
||||||
def test_with_comprehension_in_body(self):
|
def test_with_comprehension_in_body(self):
|
||||||
|
|
||||||
def test_fn(l, n):
|
def f(l, n):
|
||||||
s = constant_op.constant(list(range(n)))
|
s = constant_op.constant(list(range(n)))
|
||||||
for _ in l:
|
for _ in l:
|
||||||
s += constant_op.constant([a for a in range(n)])
|
s += constant_op.constant([a for a in range(n)])
|
||||||
return s
|
return s
|
||||||
|
|
||||||
self.assertTransformedResult(
|
self.assertTransformedResult(f, (constant_op.constant([1, 2, 3]), 5),
|
||||||
test_fn, (constant_op.constant([1, 2, 3]), 5),
|
np.array(range(5)) * 4)
|
||||||
np.array(range(5)) * 4,
|
|
||||||
symbols={'constant_op': constant_op})
|
|
||||||
|
class AdvancedControlFlowTest(ControlFlowTestBase):
|
||||||
|
|
||||||
|
def assertTransformedEquivalent(self, f, *inputs):
|
||||||
|
tr = self.transform(
|
||||||
|
f, (break_statements, continue_statements, control_flow))
|
||||||
|
self.assertEqual(f(*inputs), tr(*inputs))
|
||||||
|
|
||||||
|
def test_while_with_else(self):
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
while x > 2:
|
||||||
|
x /= 2
|
||||||
|
else:
|
||||||
|
x += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
self.assertTransformedEquivalent(f, 2)
|
||||||
|
|
||||||
|
def test_while_with_else_and_break(self):
|
||||||
|
|
||||||
|
def f(cond1):
|
||||||
|
x = 8
|
||||||
|
while x > 2:
|
||||||
|
x /= 2
|
||||||
|
if cond1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
x += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
self.assertTransformedEquivalent(f, True)
|
||||||
|
self.assertTransformedEquivalent(f, False)
|
||||||
|
|
||||||
|
def test_for_with_else(self):
|
||||||
|
|
||||||
|
def f(l):
|
||||||
|
res = 0
|
||||||
|
for x in l:
|
||||||
|
res += x
|
||||||
|
else:
|
||||||
|
res += 1
|
||||||
|
return res
|
||||||
|
|
||||||
|
self.assertTransformedEquivalent(f, [])
|
||||||
|
self.assertTransformedEquivalent(f, [1, 2])
|
||||||
|
|
||||||
|
def test_for_with_else_and_break(self):
|
||||||
|
|
||||||
|
def f(flag):
|
||||||
|
l = [1, 2, 3]
|
||||||
|
res = 0
|
||||||
|
for x in l:
|
||||||
|
res += x
|
||||||
|
if flag:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
res += 1
|
||||||
|
return res
|
||||||
|
|
||||||
|
self.assertTransformedEquivalent(f, True)
|
||||||
|
self.assertTransformedEquivalent(f, False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -22,7 +22,6 @@ from tensorflow.python.autograph.converters import directives as directives_conv
|
|||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.autograph.lang import directives
|
from tensorflow.python.autograph.lang import directives
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.autograph.pyct import parser
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -30,13 +29,12 @@ class DirectivesTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
def test_local_target(self):
|
def test_local_target(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
l = []
|
l = []
|
||||||
string_var = 0
|
string_var = 0
|
||||||
directives.set_element_type(l, 'a', string_var)
|
directives.set_element_type(l, 'a', string_var)
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {'directives': directives})
|
_, node, _ = self.transform(f, directives_converter, include_ast=True)
|
||||||
node = directives_converter.transform(node, ctx)
|
|
||||||
|
|
||||||
def_, = anno.getanno(node.body[0].targets[0],
|
def_, = anno.getanno(node.body[0].targets[0],
|
||||||
anno.Static.DEFINITIONS)
|
anno.Static.DEFINITIONS)
|
||||||
@ -46,11 +44,11 @@ class DirectivesTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
def test_argument_target(self):
|
def test_argument_target(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def f(a):
|
||||||
directives.set_element_type(a, 1, shape=2)
|
directives.set_element_type(a, 1, shape=2)
|
||||||
|
pass
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {'directives': directives})
|
_, node, _ = self.transform(f, directives_converter, include_ast=True)
|
||||||
node = directives_converter.transform(node, ctx)
|
|
||||||
|
|
||||||
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
|
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
|
||||||
d = def_.directives[directives.set_element_type]
|
d = def_.directives[directives.set_element_type]
|
||||||
@ -59,13 +57,13 @@ class DirectivesTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
def test_loop_target(self):
|
def test_loop_target(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
a = True
|
a = True
|
||||||
while True:
|
while True:
|
||||||
directives.set_loop_options(parallel_iterations=10, back_prop=a)
|
directives.set_loop_options(parallel_iterations=10, back_prop=a)
|
||||||
|
pass
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {'directives': directives})
|
_, node, _ = self.transform(f, directives_converter, include_ast=True)
|
||||||
node = directives_converter.transform(node, ctx)
|
|
||||||
|
|
||||||
d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES)
|
d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES)
|
||||||
d = d[directives.set_loop_options]
|
d = d[directives.set_loop_options]
|
||||||
@ -75,40 +73,23 @@ class DirectivesTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
def test_loop_target_no_loop(self):
|
def test_loop_target_no_loop(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
directives.set_loop_options()
|
directives.set_loop_options()
|
||||||
|
pass
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {'directives': directives})
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'):
|
with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'):
|
||||||
node = directives_converter.transform(node, ctx)
|
self.transform(f, directives_converter, include_ast=True)
|
||||||
|
|
||||||
def test_loop_target_not_first(self):
|
def test_loop_target_not_first(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
a = 1
|
a = 1
|
||||||
while True:
|
while True:
|
||||||
a = 2
|
a = 2
|
||||||
directives.set_loop_options(parallel_iterations=10, back_prop=a)
|
directives.set_loop_options(parallel_iterations=10, back_prop=a)
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {'directives': directives})
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'must be the first statement'):
|
with self.assertRaisesRegexp(ValueError, 'must be the first statement'):
|
||||||
node = directives_converter.transform(node, ctx)
|
self.transform(f, directives_converter, include_ast=True)
|
||||||
|
|
||||||
def test_invalid_default(self):
|
|
||||||
|
|
||||||
def invalid_directive(valid_arg, invalid_default=object()):
|
|
||||||
del valid_arg
|
|
||||||
del invalid_default
|
|
||||||
return
|
|
||||||
|
|
||||||
def call_invalid_directive():
|
|
||||||
invalid_directive(1)
|
|
||||||
|
|
||||||
node, _ = parser.parse_entity(call_invalid_directive, ())
|
|
||||||
# Find the call to the invalid directive
|
|
||||||
node = node.body[0].value
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
|
|
||||||
directives_converter._map_args(node, invalid_directive)
|
|
||||||
|
|
||||||
def test_value_verification_does_not_trigger_properties(self):
|
def test_value_verification_does_not_trigger_properties(self):
|
||||||
|
|
||||||
@ -122,11 +103,11 @@ class DirectivesTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
return tc.b + 1
|
return tc.b + 1
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {'tc': tc})
|
_, node, _ = self.transform(f, directives_converter, include_ast=True)
|
||||||
node = directives_converter.transform(node, ctx)
|
|
||||||
self.assertIsNotNone(node)
|
self.assertIsNotNone(node)
|
||||||
|
|
||||||
def test_value_verification_does_not_trigger_getattr(self):
|
def test_value_verification_does_not_trigger_getattr(self):
|
||||||
@ -143,11 +124,11 @@ class DirectivesTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
return tc.b + 1
|
return tc.b + 1
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {'tc': tc})
|
_, node, _ = self.transform(f, directives_converter, include_ast=True)
|
||||||
node = directives_converter.transform(node, ctx)
|
|
||||||
self.assertIsNotNone(node)
|
self.assertIsNotNone(node)
|
||||||
self.assertFalse(tc.getattr_called)
|
self.assertFalse(tc.getattr_called)
|
||||||
|
|
||||||
|
@ -23,51 +23,49 @@ from tensorflow.python.autograph.converters import return_statements
|
|||||||
from tensorflow.python.autograph.core import ag_ctx
|
from tensorflow.python.autograph.core import ag_ctx
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
|
from tensorflow.python.autograph.impl import api
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class FunctionTransformer(converter_testing.TestCase):
|
class FunctionTransformer(converter_testing.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
"""Docstring."""
|
"""Docstring."""
|
||||||
a = 1
|
a = 1
|
||||||
l += a
|
l += a
|
||||||
return l
|
return l
|
||||||
|
|
||||||
with self.converted(test_fn, functions, {}) as result:
|
tr = self.transform(f, functions)
|
||||||
result_op = result.test_fn(constant_op.constant(1))
|
|
||||||
self.assertIn('test_fn/', result_op.op.name)
|
result_op = tr(constant_op.constant(1))
|
||||||
self.assertEqual('Docstring.', result.test_fn.__doc__)
|
self.assertIn('f/', result_op.op.name)
|
||||||
|
self.assertEqual('Docstring.', tr.__doc__)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_multiline_docstring(self):
|
def test_multiline_docstring(self):
|
||||||
|
|
||||||
tf = None
|
def f():
|
||||||
|
|
||||||
def test_fn():
|
|
||||||
"""First sentence.
|
"""First sentence.
|
||||||
|
|
||||||
Second sentence.
|
Second sentence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Something.
|
||||||
"""
|
"""
|
||||||
return tf.constant(1)
|
return constant_op.constant(1)
|
||||||
|
|
||||||
with self.converted(test_fn, functions, {},
|
tr = self.transform(f, functions)
|
||||||
(constant_op.constant,)) as result:
|
|
||||||
result_op = result.test_fn()
|
result_op = tr()
|
||||||
self.assertIn('test_fn/', result_op.op.name)
|
self.assertIn('f/', result_op.op.name)
|
||||||
self.assertIn('First sentence.', result.test_fn.__doc__)
|
self.assertIn('First sentence.', tr.__doc__)
|
||||||
self.assertIn('Second sentence.', result.test_fn.__doc__)
|
self.assertIn('Second sentence.', tr.__doc__)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_nested_functions(self):
|
def test_nested_functions(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
|
|
||||||
def inner_fn(i):
|
def inner_fn(i):
|
||||||
return i + 1
|
return i + 1
|
||||||
@ -75,41 +73,35 @@ class FunctionTransformer(converter_testing.TestCase):
|
|||||||
l += 1
|
l += 1
|
||||||
return l, inner_fn(l)
|
return l, inner_fn(l)
|
||||||
|
|
||||||
with self.converted(test_fn, (functions, return_statements), {},
|
tr = self.transform(f, (functions, return_statements))
|
||||||
(ops.name_scope,)) as result:
|
|
||||||
first, second = result.test_fn(constant_op.constant(1))
|
first, second = tr(constant_op.constant(1))
|
||||||
self.assertIn('test_fn/', first.op.name)
|
self.assertIn('f/', first.op.name)
|
||||||
self.assertNotIn('inner_fn', first.op.name)
|
self.assertNotIn('inner_fn', first.op.name)
|
||||||
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
|
self.assertIn('f/inner_fn/', second.op.inputs[0].name)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_conversion_context_preserves_in_inner_functions(self):
|
def test_conversion_context_preserves_in_inner_functions(self):
|
||||||
|
|
||||||
def inner_fn_callee():
|
def inner_fn_callee():
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
ag_ctx.control_status_ctx().status, ag_ctx.Status.DISABLED)
|
ag_ctx.control_status_ctx().status, ag_ctx.Status.DISABLED)
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
def inner_fn():
|
def inner_fn():
|
||||||
inner_fn_callee()
|
inner_fn_callee()
|
||||||
with ag_ctx.ControlStatusCtx(
|
with ag_ctx.ControlStatusCtx(
|
||||||
ag_ctx.Status.DISABLED, converter.ConversionOptions(recursive=True)):
|
ag_ctx.Status.DISABLED, converter.ConversionOptions(recursive=True)):
|
||||||
inner_fn()
|
inner_fn()
|
||||||
|
|
||||||
ns = {
|
tr = self.transform(f, functions)
|
||||||
'inner_fn_callee': inner_fn_callee,
|
|
||||||
'ag_ctx': ag_ctx,
|
tr()
|
||||||
'converter': converter
|
|
||||||
}
|
|
||||||
with self.converted(test_fn, functions, ns) as result:
|
|
||||||
result.test_fn()
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_method(self):
|
def test_method(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
|
||||||
def test_fn(self, l):
|
def f(self, l):
|
||||||
|
|
||||||
def inner_fn(i):
|
def inner_fn(i):
|
||||||
return i + 1
|
return i + 1
|
||||||
@ -117,25 +109,22 @@ class FunctionTransformer(converter_testing.TestCase):
|
|||||||
l += 1
|
l += 1
|
||||||
return l, inner_fn(l)
|
return l, inner_fn(l)
|
||||||
|
|
||||||
ns = {'TestClass': TestClass}
|
tr = self.transform(TestClass.f, (functions, return_statements))
|
||||||
node, ctx = self.prepare(TestClass, ns)
|
|
||||||
node = functions.transform(node, ctx)
|
|
||||||
node = return_statements.transform(node, ctx)
|
|
||||||
|
|
||||||
with self.compiled(node, {}, (ops.name_scope,)) as result:
|
first, second = tr(TestClass(), constant_op.constant(1))
|
||||||
first, second = result.TestClass().test_fn(constant_op.constant(1))
|
self.assertIn('f/', first.op.name)
|
||||||
self.assertIn('test_fn/', first.op.name)
|
self.assertNotIn('inner_fn', first.op.name)
|
||||||
self.assertNotIn('inner_fn', first.op.name)
|
self.assertIn('f/inner_fn/', second.op.inputs[0].name)
|
||||||
self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name)
|
|
||||||
|
|
||||||
def test_lambda_in_return_value(self):
|
def test_lambda_in_return_value(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
return lambda x: x + 1
|
return lambda x: x + 1
|
||||||
|
|
||||||
with self.converted(test_fn, functions, {}) as result:
|
tr = self.transform(f, functions)
|
||||||
result_l = result.test_fn()
|
|
||||||
self.assertTrue(result_l.fake_autograph_artifact)
|
result_l = tr()
|
||||||
|
self.assertTrue(api.is_autograph_artifact(result_l))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -25,36 +25,36 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class ListCompTest(converter_testing.TestCase):
|
class ListCompTest(converter_testing.TestCase):
|
||||||
|
|
||||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
def assertTransformedEquivalent(self, f, *inputs):
|
||||||
with self.converted(test_fn, list_comprehensions, {}) as result:
|
tr = self.transform(f, list_comprehensions)
|
||||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
self.assertEqual(f(*inputs), tr(*inputs))
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
s = [e * e for e in l]
|
s = [e * e for e in l]
|
||||||
return s
|
return s
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, [])
|
self.assertTransformedEquivalent(f, [])
|
||||||
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
|
self.assertTransformedEquivalent(f, [1, 2, 3])
|
||||||
|
|
||||||
def test_multiple_generators(self):
|
def test_multiple_generators(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
s = [e * e for sublist in l for e in sublist]
|
s = [e * e for sublist in l for e in sublist] # pylint:disable=g-complex-comprehension
|
||||||
return s
|
return s
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, [])
|
self.assertTransformedEquivalent(f, [])
|
||||||
self.assertTransformedEquivalent(test_fn, [[1], [2], [3]])
|
self.assertTransformedEquivalent(f, [[1], [2], [3]])
|
||||||
|
|
||||||
def test_cond(self):
|
def test_cond(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
s = [e * e for e in l if e > 1]
|
s = [e * e for e in l if e > 1]
|
||||||
return s
|
return s
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, [])
|
self.assertTransformedEquivalent(f, [])
|
||||||
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
|
self.assertTransformedEquivalent(f, [1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -18,12 +18,11 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.converters import directives as directives_converter
|
||||||
from tensorflow.python.autograph.converters import lists
|
from tensorflow.python.autograph.converters import lists
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.autograph.lang import directives
|
from tensorflow.python.autograph.lang import directives
|
||||||
from tensorflow.python.autograph.lang import special_functions
|
from tensorflow.python.autograph.lang import special_functions
|
||||||
from tensorflow.python.autograph.pyct import anno
|
|
||||||
from tensorflow.python.autograph.pyct import parser
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -31,101 +30,81 @@ from tensorflow.python.ops import list_ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
tf = None # Will be replaced by a mock.
|
|
||||||
|
|
||||||
|
|
||||||
class ListTest(converter_testing.TestCase):
|
class ListTest(converter_testing.TestCase):
|
||||||
|
|
||||||
def test_empty_list(self):
|
def test_empty_list(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
with self.converted(test_fn, lists, {}) as result:
|
tr = self.transform(f, lists)
|
||||||
tl = result.test_fn()
|
|
||||||
# Empty tensor lists cannot be evaluated or stacked.
|
tl = tr()
|
||||||
self.assertTrue(isinstance(tl, ops.Tensor))
|
# Empty tensor lists cannot be evaluated or stacked.
|
||||||
self.assertEqual(tl.dtype, dtypes.variant)
|
self.assertIsInstance(tl, ops.Tensor)
|
||||||
|
self.assertEqual(tl.dtype, dtypes.variant)
|
||||||
|
|
||||||
def test_initialized_list(self):
|
def test_initialized_list(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
return [1, 2, 3]
|
return [1, 2, 3]
|
||||||
|
|
||||||
with self.converted(test_fn, lists, {}) as result:
|
tr = self.transform(f, lists)
|
||||||
self.assertAllEqual(result.test_fn(), [1, 2, 3])
|
|
||||||
|
self.assertAllEqual(tr(), [1, 2, 3])
|
||||||
|
|
||||||
def test_list_append(self):
|
def test_list_append(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
l = special_functions.tensor_list([1])
|
l = special_functions.tensor_list([1])
|
||||||
l.append(2)
|
l.append(2)
|
||||||
l.append(3)
|
l.append(3)
|
||||||
return l
|
return l
|
||||||
|
|
||||||
ns = {'special_functions': special_functions}
|
tr = self.transform(f, lists)
|
||||||
with self.converted(test_fn, lists, ns) as result:
|
|
||||||
with self.cached_session() as sess:
|
tl = tr()
|
||||||
tl = result.test_fn()
|
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
|
||||||
self.assertAllEqual(self.evaluate(r), [1, 2, 3])
|
|
||||||
|
|
||||||
def test_list_pop(self):
|
def test_list_pop(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
l = special_functions.tensor_list([1, 2, 3])
|
l = special_functions.tensor_list([1, 2, 3])
|
||||||
|
directives.set_element_type(l, dtype=dtypes.int32, shape=())
|
||||||
s = l.pop()
|
s = l.pop()
|
||||||
return s, l
|
return s, l
|
||||||
|
|
||||||
ns = {'special_functions': special_functions}
|
tr = self.transform(f, (directives_converter, lists))
|
||||||
node, ctx = self.prepare(test_fn, ns)
|
|
||||||
def_, = anno.getanno(node.body[0].targets[0],
|
|
||||||
anno.Static.ORIG_DEFINITIONS)
|
|
||||||
def_.directives[directives.set_element_type] = {
|
|
||||||
'dtype': parser.parse_expression('tf.int32'),
|
|
||||||
'shape': parser.parse_expression('()'),
|
|
||||||
}
|
|
||||||
node = lists.transform(node, ctx)
|
|
||||||
|
|
||||||
with self.compiled(node, ns, (dtypes.int32,)) as result:
|
ts, tl = tr()
|
||||||
with self.cached_session() as sess:
|
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||||
ts, tl = result.test_fn()
|
self.assertAllEqual(self.evaluate(r), [1, 2])
|
||||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
self.assertAllEqual(self.evaluate(ts), 3)
|
||||||
self.assertAllEqual(self.evaluate(r), [1, 2])
|
|
||||||
self.assertAllEqual(self.evaluate(ts), 3)
|
|
||||||
|
|
||||||
def test_double_list_pop(self):
|
def test_double_list_pop(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
s = l.pop().pop()
|
s = l.pop().pop()
|
||||||
return s
|
return s
|
||||||
|
|
||||||
with self.converted(test_fn, lists, {}) as result:
|
tr = self.transform(f, lists)
|
||||||
test_input = [1, 2, [1, 2, 3]]
|
|
||||||
# TODO(mdan): Pass a list of lists of tensor when we fully support that.
|
test_input = [1, 2, [1, 2, 3]]
|
||||||
# For now, we just pass a regular Python list of lists just to verify that
|
# TODO(mdan): Pass a list of lists of tensor when we fully support that.
|
||||||
# the two pop calls are sequenced properly.
|
# For now, we just pass a regular Python list of lists just to verify that
|
||||||
self.assertAllEqual(result.test_fn(test_input), 3)
|
# the two pop calls are sequenced properly.
|
||||||
|
self.assertAllEqual(tr(test_input), 3)
|
||||||
|
|
||||||
def test_list_stack(self):
|
def test_list_stack(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
l = [1, 2, 3]
|
l = [1, 2, 3]
|
||||||
return tf.stack(l)
|
return array_ops.stack(l)
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
tr = self.transform(f, lists)
|
||||||
def_, = anno.getanno(node.body[0].targets[0],
|
|
||||||
anno.Static.ORIG_DEFINITIONS)
|
|
||||||
def_.directives[directives.set_element_type] = {
|
|
||||||
'dtype': parser.parse_expression('tf.int32')
|
|
||||||
}
|
|
||||||
node = lists.transform(node, ctx)
|
|
||||||
|
|
||||||
with self.compiled(node, {}, (array_ops.stack, dtypes.int32)) as result:
|
self.assertAllEqual(self.evaluate(tr()), [1, 2, 3])
|
||||||
with self.cached_session() as sess:
|
|
||||||
self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])
|
|
||||||
|
|
||||||
# TODO(mdan): Add a test with tf.stack with axis kwarg.
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -27,62 +27,59 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class LogicalExpressionTest(converter_testing.TestCase):
|
class LogicalExpressionTest(converter_testing.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_equals(self):
|
def test_equals(self):
|
||||||
|
|
||||||
def test_fn(a, b):
|
def f(a, b):
|
||||||
return a == b
|
return a == b
|
||||||
|
|
||||||
with self.converted(test_fn, logical_expressions, {}) as result:
|
tr = self.transform(f, logical_expressions)
|
||||||
with self.cached_session() as sess:
|
|
||||||
self.assertTrue(sess.run(result.test_fn(constant_op.constant(1), 1)))
|
self.assertTrue(self.evaluate(tr(constant_op.constant(1), 1)))
|
||||||
self.assertFalse(sess.run(result.test_fn(constant_op.constant(1), 2)))
|
self.assertFalse(self.evaluate(tr(constant_op.constant(1), 2)))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_bool_ops(self):
|
def test_bool_ops(self):
|
||||||
|
|
||||||
def test_fn(a, b, c):
|
def f(a, b, c):
|
||||||
return (a or b) and (a or b or c) and not c
|
return (a or b) and (a or b or c) and not c
|
||||||
|
|
||||||
with self.converted(test_fn, logical_expressions, {}) as result:
|
tr = self.transform(f, logical_expressions)
|
||||||
with self.cached_session() as sess:
|
|
||||||
self.assertTrue(
|
self.assertTrue(self.evaluate(tr(constant_op.constant(True), False, False)))
|
||||||
sess.run(result.test_fn(constant_op.constant(True), False, False)))
|
self.assertFalse(self.evaluate(tr(constant_op.constant(True), False, True)))
|
||||||
self.assertFalse(
|
|
||||||
sess.run(result.test_fn(constant_op.constant(True), False, True)))
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def test_comparison(self):
|
def test_comparison(self):
|
||||||
|
|
||||||
def test_fn(a, b, c, d):
|
def f(a, b, c, d):
|
||||||
return a < b == c > d
|
return a < b == c > d
|
||||||
|
|
||||||
with self.converted(test_fn, logical_expressions, {}) as result:
|
tr = self.transform(f, logical_expressions)
|
||||||
with self.cached_session() as sess:
|
|
||||||
# Note: having just the first constant a tensor tests that the
|
# Note: having just the first constant a tensor tests that the
|
||||||
# operations execute in the correct order. If anything other than
|
# operations execute in the correct order. If anything other than
|
||||||
# a < b executed first, the result would be a Python scalar and not a
|
# a < b executed first, the result would be a Python scalar and not a
|
||||||
# Tensor. This is valid as long as the dispat is automatic based on
|
# Tensor. This is valid as long as the dispat is automatic based on
|
||||||
# type.
|
# type.
|
||||||
self.assertTrue(
|
self.assertTrue(self.evaluate(tr(constant_op.constant(1), 2, 2, 1)))
|
||||||
sess.run(result.test_fn(constant_op.constant(1), 2, 2, 1)))
|
self.assertFalse(self.evaluate(tr(constant_op.constant(1), 2, 2, 3)))
|
||||||
self.assertFalse(
|
|
||||||
sess.run(result.test_fn(constant_op.constant(1), 2, 2, 3)))
|
|
||||||
|
|
||||||
def test_default_ops(self):
|
def test_default_ops(self):
|
||||||
|
|
||||||
def test_fn(a, b):
|
def f(a, b):
|
||||||
return a in b
|
return a in b
|
||||||
|
|
||||||
with self.converted(test_fn, logical_expressions, {}) as result:
|
tr = self.transform(f, logical_expressions)
|
||||||
self.assertTrue(result.test_fn('a', ('a',)))
|
|
||||||
|
self.assertTrue(tr('a', ('a',)))
|
||||||
|
|
||||||
def test_unary_ops(self):
|
def test_unary_ops(self):
|
||||||
def test_fn(a):
|
|
||||||
|
def f(a):
|
||||||
return ~a, -a, +a
|
return ~a, -a, +a
|
||||||
|
|
||||||
with self.converted(test_fn, logical_expressions, {}) as result:
|
tr = self.transform(f, logical_expressions)
|
||||||
self.assertEqual(result.test_fn(1), (-2, -1, 1))
|
|
||||||
|
self.assertEqual(tr(1), (-2, -1, 1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,95 +0,0 @@
|
|||||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Integration Tests for loop."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.python.autograph.converters import break_statements
|
|
||||||
from tensorflow.python.autograph.converters import continue_statements
|
|
||||||
from tensorflow.python.autograph.converters import control_flow
|
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
|
||||||
|
|
||||||
class LoopIntegrationTest(converter_testing.TestCase):
|
|
||||||
|
|
||||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
|
||||||
with self.converted(test_fn,
|
|
||||||
[break_statements, continue_statements, control_flow],
|
|
||||||
{}, (constant_op.constant,)) as result:
|
|
||||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
|
||||||
|
|
||||||
def test_while_loop_with_else(self):
|
|
||||||
|
|
||||||
def test_fn(x):
|
|
||||||
while x > 2:
|
|
||||||
x /= 2
|
|
||||||
else:
|
|
||||||
x += 1
|
|
||||||
return x
|
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
|
||||||
|
|
||||||
def test_while_loop_with_else_and_break(self):
|
|
||||||
|
|
||||||
def test_fn(cond1):
|
|
||||||
x = 8
|
|
||||||
while x > 2:
|
|
||||||
x /= 2
|
|
||||||
if cond1:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
x += 1
|
|
||||||
return x
|
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, True)
|
|
||||||
self.assertTransformedEquivalent(test_fn, False)
|
|
||||||
|
|
||||||
def test_for_loop_with_else(self):
|
|
||||||
|
|
||||||
def test_fn(l):
|
|
||||||
res = 0
|
|
||||||
for x in l:
|
|
||||||
res += x
|
|
||||||
else:
|
|
||||||
res += 1
|
|
||||||
return res
|
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, [])
|
|
||||||
self.assertTransformedEquivalent(test_fn, [1, 2])
|
|
||||||
|
|
||||||
def test_for_loop_with_else_and_break(self):
|
|
||||||
|
|
||||||
def test_fn(flag):
|
|
||||||
l = [1, 2, 3]
|
|
||||||
res = 0
|
|
||||||
for x in l:
|
|
||||||
res += x
|
|
||||||
if flag:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
res += 1
|
|
||||||
return res
|
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, True)
|
|
||||||
self.assertTransformedEquivalent(test_fn, False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test.main()
|
|
@ -27,81 +27,80 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class SingleReturnTest(converter_testing.TestCase):
|
class SingleReturnTest(converter_testing.TestCase):
|
||||||
|
|
||||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
def assertTransformedEquivalent(self, f, *inputs):
|
||||||
ns = {'ops': ops}
|
tr = self.transform(f, (functions, return_statements))
|
||||||
with self.converted(test_fn, (functions, return_statements), ns) as result:
|
self.assertEqual(f(*inputs), tr(*inputs))
|
||||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
|
||||||
|
|
||||||
def test_straightline(self):
|
def test_straightline(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
return x * x
|
return x * x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
|
|
||||||
def test_superfluous_returns(self):
|
def test_superfluous_returns(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
retval = 1
|
retval = 1
|
||||||
return retval
|
return retval
|
||||||
retval = 2 # pylint:disable=unreachable
|
retval = 2 # pylint:disable=unreachable
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn)
|
self.assertTransformedEquivalent(f)
|
||||||
|
|
||||||
def test_superfluous_returns_adjacent(self):
|
def test_superfluous_returns_adjacent(self):
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
return 1
|
return 1
|
||||||
return 2 # pylint:disable=unreachable
|
return 2 # pylint:disable=unreachable
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn)
|
self.assertTransformedEquivalent(f)
|
||||||
|
|
||||||
def test_conditional(self):
|
def test_conditional(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
if x > 0:
|
if x > 0:
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
return x * x
|
return x * x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
|
|
||||||
def test_conditional_missing_else(self):
|
def test_conditional_missing_else(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
if x > 0:
|
if x > 0:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
|
|
||||||
def test_conditional_missing_else_then_default(self):
|
def test_conditional_missing_else_then_default(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
if x > 0:
|
if x > 0:
|
||||||
return x
|
return x
|
||||||
return x * x
|
return x * x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
|
|
||||||
def test_conditional_else_only_then_default(self):
|
def test_conditional_else_only_then_default(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
if x < 0:
|
if x < 0:
|
||||||
x *= x
|
x *= x
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
|
|
||||||
def test_conditional_nested(self):
|
def test_conditional_nested(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
if x > 0:
|
if x > 0:
|
||||||
if x < 5:
|
if x < 5:
|
||||||
return x
|
return x
|
||||||
@ -110,53 +109,53 @@ class SingleReturnTest(converter_testing.TestCase):
|
|||||||
else:
|
else:
|
||||||
return x * x * x
|
return x * x * x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
self.assertTransformedEquivalent(test_fn, 5)
|
self.assertTransformedEquivalent(f, 5)
|
||||||
|
|
||||||
def test_context_manager(self):
|
def test_context_manager(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
with ops.name_scope(''):
|
with ops.name_scope(''):
|
||||||
return x * x
|
return x * x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
|
|
||||||
def test_context_manager_in_conditional(self):
|
def test_context_manager_in_conditional(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
if x > 0:
|
if x > 0:
|
||||||
with ops.name_scope(''):
|
with ops.name_scope(''):
|
||||||
return x * x
|
return x * x
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
|
|
||||||
def text_conditional_in_context_manager(self):
|
def text_conditional_in_context_manager(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
with ops.name_scope(''):
|
with ops.name_scope(''):
|
||||||
if x > 0:
|
if x > 0:
|
||||||
return x * x
|
return x * x
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
|
|
||||||
def test_no_return(self):
|
def test_no_return(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
x *= x
|
x *= x
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
|
|
||||||
def test_nested_function(self):
|
def test_nested_function(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
|
|
||||||
def inner_fn(y):
|
def inner_fn(y):
|
||||||
if y > 0:
|
if y > 0:
|
||||||
@ -166,33 +165,33 @@ class SingleReturnTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
return inner_fn(x)
|
return inner_fn(x)
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
|
|
||||||
def test_nested_function_in_control_flow(self):
|
def test_nested_function_in_control_flow(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
|
|
||||||
if x:
|
if x:
|
||||||
def inner_fn(y):
|
def inner_fn(y):
|
||||||
return y
|
return y
|
||||||
inner_fn(x)
|
inner_fn(x)
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(f, -2)
|
||||||
|
|
||||||
def test_for_loop(self):
|
def test_for_loop(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
|
|
||||||
def test_while_loop(self):
|
def test_while_loop(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
i = 0
|
i = 0
|
||||||
s = 0
|
s = 0
|
||||||
while i < n:
|
while i < n:
|
||||||
@ -202,23 +201,23 @@ class SingleReturnTest(converter_testing.TestCase):
|
|||||||
return s
|
return s
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(f, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_null_return(self):
|
def test_null_return(self):
|
||||||
|
|
||||||
def test_fn(n):
|
def f(n):
|
||||||
if n > 4:
|
if n > 4:
|
||||||
return
|
return
|
||||||
return
|
return
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
self.assertTransformedEquivalent(test_fn, 5)
|
self.assertTransformedEquivalent(f, 5)
|
||||||
|
|
||||||
def test_nested_multiple_withs(self):
|
def test_nested_multiple_withs(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def f(x):
|
||||||
v = []
|
v = []
|
||||||
while x > 0:
|
while x > 0:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -230,14 +229,14 @@ class SingleReturnTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, 0)
|
self.assertTransformedEquivalent(f, 0)
|
||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(f, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 3)
|
self.assertTransformedEquivalent(f, 3)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(f, 4)
|
||||||
|
|
||||||
def test_multiple_returns_in_nested_scope(self):
|
def test_multiple_returns_in_nested_scope(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def f(a):
|
||||||
v = []
|
v = []
|
||||||
for x in a:
|
for x in a:
|
||||||
x -= 1
|
x -= 1
|
||||||
@ -250,10 +249,10 @@ class SingleReturnTest(converter_testing.TestCase):
|
|||||||
v.append(x)
|
v.append(x)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
self.assertTransformedEquivalent(test_fn, [])
|
self.assertTransformedEquivalent(f, [])
|
||||||
self.assertTransformedEquivalent(test_fn, [1])
|
self.assertTransformedEquivalent(f, [1])
|
||||||
self.assertTransformedEquivalent(test_fn, [2])
|
self.assertTransformedEquivalent(f, [2])
|
||||||
self.assertTransformedEquivalent(test_fn, [1, 2, 3])
|
self.assertTransformedEquivalent(f, [1, 2, 3])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -18,11 +18,10 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.converters import directives as directives_converter
|
||||||
from tensorflow.python.autograph.converters import slices
|
from tensorflow.python.autograph.converters import slices
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.autograph.lang import directives
|
from tensorflow.python.autograph.lang import directives
|
||||||
from tensorflow.python.autograph.pyct import anno
|
|
||||||
from tensorflow.python.autograph.pyct import parser
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import list_ops
|
from tensorflow.python.ops import list_ops
|
||||||
@ -33,42 +32,26 @@ class SliceTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
def test_index_access(self):
|
def test_index_access(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
|
directives.set_element_type(l, dtypes.int32)
|
||||||
return l[1]
|
return l[1]
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
tr = self.transform(f, (directives_converter, slices))
|
||||||
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
|
|
||||||
def_.directives[directives.set_element_type] = {
|
|
||||||
'dtype': parser.parse_expression('tf.int32')
|
|
||||||
}
|
|
||||||
node = slices.transform(node, ctx)
|
|
||||||
|
|
||||||
with self.compiled(node, {}, (dtypes.int32,)) as result:
|
tl = list_ops.tensor_list_from_tensor(
|
||||||
with self.cached_session() as sess:
|
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
|
||||||
tl = list_ops.tensor_list_from_tensor(
|
y = tr(tl)
|
||||||
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
|
self.assertEqual(2, self.evaluate(y))
|
||||||
y = result.test_fn(tl)
|
|
||||||
self.assertEqual(2, self.evaluate(y))
|
|
||||||
|
|
||||||
def test_index_access_multiple_definitions(self):
|
def test_index_access_multiple_definitions(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
|
directives.set_element_type(l, dtypes.int32)
|
||||||
if l:
|
if l:
|
||||||
l = []
|
l = []
|
||||||
return l[1]
|
return l[1]
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
self.transform(f, (directives_converter, slices))
|
||||||
def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
|
|
||||||
def_.directives[directives.set_element_type] = {
|
|
||||||
'dtype': parser.parse_expression('tf.int32')
|
|
||||||
}
|
|
||||||
def_, = anno.getanno(node.body[0].body[0].targets[0],
|
|
||||||
anno.Static.DEFINITIONS)
|
|
||||||
def_.directives[directives.set_element_type] = {
|
|
||||||
'dtype': parser.parse_expression('tf.float32')
|
|
||||||
}
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
slices.transform(node, ctx)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import contextlib
|
|
||||||
|
|
||||||
from tensorflow.python.autograph.converters import variables
|
from tensorflow.python.autograph.converters import variables
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -27,60 +25,63 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class VariablesTest(converter_testing.TestCase):
|
class VariablesTest(converter_testing.TestCase):
|
||||||
|
|
||||||
@contextlib.contextmanager
|
def transform_with_test_ld(self, f):
|
||||||
def apply_add_one_conversion(self, fn):
|
|
||||||
"""Generates code which adds 1 to all variable reads."""
|
"""Generates code which adds 1 to all variable reads."""
|
||||||
with self.converted(fn, variables, {}) as result:
|
return self.transform(f, variables, ag_overrides={'ld': lambda x: x + 1})
|
||||||
result.ag__.__dict__['ld'] = lambda x: x + 1
|
|
||||||
yield result
|
|
||||||
|
|
||||||
def test_read(self):
|
def test_read(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
return l
|
return l
|
||||||
|
|
||||||
with self.apply_add_one_conversion(test_fn) as result:
|
tr = self.transform_with_test_ld(f)
|
||||||
self.assertEqual(result.test_fn(1), 2)
|
|
||||||
|
self.assertEqual(tr(1), 2)
|
||||||
|
|
||||||
def test_aug_assign(self):
|
def test_aug_assign(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
l *= 10
|
l *= 10
|
||||||
return l
|
return l
|
||||||
|
|
||||||
with self.apply_add_one_conversion(test_fn) as result:
|
tr = self.transform_with_test_ld(f)
|
||||||
self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads
|
|
||||||
|
self.assertEqual(tr(1), (1 + 1) * 10 + 1) # two reads
|
||||||
|
|
||||||
def test_del(self):
|
def test_del(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
del l
|
del l
|
||||||
return l
|
return l
|
||||||
|
|
||||||
with self.converted(test_fn, variables, {}) as result:
|
tr = self.transform(f, variables)
|
||||||
with self.assertRaisesRegex(
|
|
||||||
NameError, "'l' is used before assignment"):
|
|
||||||
result.test_fn(1)
|
|
||||||
|
|
||||||
def test_del_getitem_ignored(self):
|
with self.assertRaisesRegex(NameError, "'l' is used before assignment"):
|
||||||
|
tr(1)
|
||||||
|
|
||||||
def basic_slice(l):
|
def test_del_getitem_ignored_basic_slice(self):
|
||||||
|
|
||||||
|
def f(l):
|
||||||
del l[0]
|
del l[0]
|
||||||
return l
|
return l
|
||||||
|
|
||||||
with self.converted(basic_slice, variables, {}) as result:
|
tr = self.transform(f, variables)
|
||||||
self.assertListEqual([2], result.basic_slice([1, 2]))
|
|
||||||
|
|
||||||
def range_slice(l):
|
self.assertListEqual([2], tr([1, 2]))
|
||||||
|
|
||||||
|
def test_del_getitem_ignored_range_slice(self):
|
||||||
|
|
||||||
|
def f(l):
|
||||||
del l[0:2]
|
del l[0:2]
|
||||||
return l
|
return l
|
||||||
|
|
||||||
with self.converted(range_slice, variables, {}) as result:
|
tr = self.transform(f, variables)
|
||||||
self.assertListEqual([], result.range_slice([1, 2]))
|
|
||||||
|
self.assertListEqual([], tr([1, 2]))
|
||||||
|
|
||||||
def test_del_getattr_ignored(self):
|
def test_del_getattr_ignored(self):
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
del l.a
|
del l.a
|
||||||
return l
|
return l
|
||||||
|
|
||||||
@ -90,50 +91,60 @@ class VariablesTest(converter_testing.TestCase):
|
|||||||
self.a = 1
|
self.a = 1
|
||||||
self.b = 2
|
self.b = 2
|
||||||
|
|
||||||
with self.converted(test_fn, variables, {}) as result:
|
tr = self.transform(f, variables)
|
||||||
self.assertFalse(hasattr(result.test_fn(TestClass()), 'a'))
|
|
||||||
self.assertEqual(result.test_fn(TestClass()).b, 2)
|
|
||||||
|
|
||||||
def test_del_packing_ignored(self):
|
self.assertFalse(hasattr(tr(TestClass()), 'a'))
|
||||||
# Note: test for UnboundLocalError, not NameError because in this case we
|
self.assertEqual(tr(TestClass()).b, 2)
|
||||||
|
|
||||||
|
def test_del_packing_ignored_list(self):
|
||||||
|
# Note: testing for UnboundLocalError, not NameError because in this case we
|
||||||
# don't rewrite the del.
|
# don't rewrite the del.
|
||||||
|
|
||||||
def list_(a, b):
|
def f(a, b):
|
||||||
del [a, b]
|
del [a, b]
|
||||||
return a
|
return a
|
||||||
|
|
||||||
with self.converted(list_, variables, {}) as result:
|
tr = self.transform(f, variables)
|
||||||
with self.assertRaises(UnboundLocalError):
|
|
||||||
result.list_(1, 2)
|
|
||||||
|
|
||||||
def nested(a, b, c):
|
with self.assertRaises(UnboundLocalError):
|
||||||
|
tr(1, 2)
|
||||||
|
|
||||||
|
def test_del_packing_ignored_nested(self):
|
||||||
|
# Note: testing for UnboundLocalError, not NameError because in this case we
|
||||||
|
# don't rewrite the del.
|
||||||
|
|
||||||
|
def f(a, b, c):
|
||||||
del [a, (b, c)]
|
del [a, (b, c)]
|
||||||
return c
|
return c
|
||||||
|
|
||||||
with self.converted(nested, variables, {}) as result:
|
tr = self.transform(f, variables)
|
||||||
with self.assertRaises(UnboundLocalError):
|
|
||||||
result.nested(1, 2, 3)
|
|
||||||
|
|
||||||
def test_del_item_multiple_mixed(self):
|
with self.assertRaises(UnboundLocalError):
|
||||||
|
tr(1, 2, 3)
|
||||||
|
|
||||||
def test_fn_failing(a, b, c):
|
def test_del_item_multiple_mixed_used_after(self):
|
||||||
|
|
||||||
|
def f(a, b, c):
|
||||||
del a, b, c[0]
|
del a, b, c[0]
|
||||||
a = 1
|
a = 1
|
||||||
return a, b, c
|
return a, b, c
|
||||||
|
|
||||||
with self.converted(test_fn_failing, variables, {}) as result:
|
tr = self.transform(f, variables)
|
||||||
with self.assertRaisesRegex(
|
|
||||||
NameError, "'b' is used before assignment"):
|
|
||||||
result.test_fn_failing(1, 2, [1, 2])
|
|
||||||
|
|
||||||
def test_fn_passing(a, b, c):
|
with self.assertRaisesRegex(NameError, "'b' is used before assignment"):
|
||||||
|
tr(1, 2, [1, 2])
|
||||||
|
|
||||||
|
def test_del_item_multiple_mixed_unused_after(self):
|
||||||
|
|
||||||
|
def f(a, b, c):
|
||||||
del a, b, c[0]
|
del a, b, c[0]
|
||||||
a = 1
|
a = 1
|
||||||
b = 2
|
b = 2
|
||||||
return c
|
return c
|
||||||
|
|
||||||
with self.converted(test_fn_passing, variables, {}) as result:
|
tr = self.transform(f, variables)
|
||||||
self.assertListEqual([2], result.test_fn_passing(1, 2, [1, 2]))
|
|
||||||
|
self.assertListEqual([2], tr(1, 2, [1, 2]))
|
||||||
|
|
||||||
def test_attribute(self):
|
def test_attribute(self):
|
||||||
|
|
||||||
@ -146,12 +157,13 @@ class VariablesTest(converter_testing.TestCase):
|
|||||||
self.v += other
|
self.v += other
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
return l.v
|
return l.v
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.apply_add_one_conversion(test_fn) as result:
|
tr = self.transform_with_test_ld(f)
|
||||||
self.assertEqual(result.test_fn(tc), 2)
|
|
||||||
|
self.assertEqual(tr(tc), 2)
|
||||||
|
|
||||||
def test_subscript(self):
|
def test_subscript(self):
|
||||||
|
|
||||||
@ -167,12 +179,13 @@ class VariablesTest(converter_testing.TestCase):
|
|||||||
def __getitem__(self, _):
|
def __getitem__(self, _):
|
||||||
return self.v
|
return self.v
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
return l[0]
|
return l[0]
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.apply_add_one_conversion(test_fn) as result:
|
tr = self.transform_with_test_ld(f)
|
||||||
self.assertEqual(result.test_fn(tc), 2)
|
|
||||||
|
self.assertEqual(tr(tc), 2)
|
||||||
|
|
||||||
def test_call(self):
|
def test_call(self):
|
||||||
|
|
||||||
@ -188,12 +201,13 @@ class VariablesTest(converter_testing.TestCase):
|
|||||||
def __call__(self):
|
def __call__(self):
|
||||||
return self.v
|
return self.v
|
||||||
|
|
||||||
def test_fn(l):
|
def f(l):
|
||||||
return l()
|
return l()
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.apply_add_one_conversion(test_fn) as result:
|
tr = self.transform_with_test_ld(f)
|
||||||
self.assertEqual(result.test_fn(tc), 2)
|
|
||||||
|
self.assertEqual(tr(tc), 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import imp
|
||||||
|
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
@ -38,16 +40,18 @@ class ConversionOptionsTest(converter_testing.TestCase):
|
|||||||
opts_ast = opts.to_ast()
|
opts_ast = opts.to_ast()
|
||||||
|
|
||||||
template = '''
|
template = '''
|
||||||
def test_fn():
|
def f():
|
||||||
return opts_ast
|
return opts_ast
|
||||||
'''
|
'''
|
||||||
opts_packed = templates.replace(template, opts_ast=opts_ast)
|
opts_packed = templates.replace(template, opts_ast=opts_ast)
|
||||||
|
|
||||||
reparsed, _, _ = loader.load_ast(opts_packed)
|
reparsed, _, _ = loader.load_ast(opts_packed)
|
||||||
reparsed.__dict__['ag__'] = self.make_fake_mod(
|
fake_ag = imp.new_module('fake_ag')
|
||||||
'fake_ag', converter.ConversionOptions, converter.Feature)
|
fake_ag.ConversionOptions = converter.ConversionOptions
|
||||||
|
fake_ag.Feature = converter.Feature
|
||||||
|
reparsed.ag__ = fake_ag
|
||||||
|
|
||||||
reparsed_opts = reparsed.test_fn()
|
reparsed_opts = reparsed.f()
|
||||||
|
|
||||||
self.assertEqual(opts.recursive, reparsed_opts.recursive)
|
self.assertEqual(opts.recursive, reparsed_opts.recursive)
|
||||||
self.assertEqual(opts.user_requested, False)
|
self.assertEqual(opts.user_requested, False)
|
||||||
@ -63,12 +67,12 @@ class ConverterBaseTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
directive_key = object
|
directive_key = object
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
a = 1
|
a = 1
|
||||||
return a
|
return a
|
||||||
|
|
||||||
ns = {}
|
_, node, ctx = self.transform(f, (), include_ast=True)
|
||||||
node, ctx = self.prepare(test_fn, ns)
|
|
||||||
symbol_a = node.body[1].value
|
symbol_a = node.body[1].value
|
||||||
defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
|
defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
|
||||||
defs.directives[directive_key] = {
|
defs.directives[directive_key] = {
|
||||||
@ -84,12 +88,12 @@ class ConverterBaseTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
directive_key = object
|
directive_key = object
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
a = 1
|
a = 1
|
||||||
return a
|
return a
|
||||||
|
|
||||||
ns = {}
|
_, node, ctx = self.transform(f, (), include_ast=True)
|
||||||
node, ctx = self.prepare(test_fn, ns)
|
|
||||||
symbol_a = node.body[1].value
|
symbol_a = node.body[1].value
|
||||||
c = TestConverter(ctx)
|
c = TestConverter(ctx)
|
||||||
value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
|
value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
|
||||||
@ -100,14 +104,14 @@ class ConverterBaseTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
directive_key = object
|
directive_key = object
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
a = 1
|
a = 1
|
||||||
if a:
|
if a:
|
||||||
a = 2
|
a = 2
|
||||||
return a
|
return a
|
||||||
|
|
||||||
ns = {}
|
_, node, ctx = self.transform(f, (), include_ast=True)
|
||||||
node, ctx = self.prepare(test_fn, ns)
|
|
||||||
symbol_a = node.body[2].value
|
symbol_a = node.body[2].value
|
||||||
defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
|
defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
|
||||||
defs[0].directives[directive_key] = {
|
defs[0].directives[directive_key] = {
|
||||||
@ -127,14 +131,14 @@ class ConverterBaseTest(converter_testing.TestCase):
|
|||||||
|
|
||||||
directive_key = object
|
directive_key = object
|
||||||
|
|
||||||
def test_fn():
|
def f():
|
||||||
a = 1
|
a = 1
|
||||||
if a:
|
if a:
|
||||||
a = 2
|
a = 2
|
||||||
return a
|
return a
|
||||||
|
|
||||||
ns = {}
|
_, node, ctx = self.transform(f, (), include_ast=True)
|
||||||
node, ctx = self.prepare(test_fn, ns)
|
|
||||||
symbol_a = node.body[2].value
|
symbol_a = node.body[2].value
|
||||||
defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
|
defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
|
||||||
defs[0].directives[directive_key] = {
|
defs[0].directives[directive_key] = {
|
||||||
|
@ -25,27 +25,15 @@ import sys
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python.autograph import operators
|
|
||||||
from tensorflow.python.autograph import utils
|
|
||||||
from tensorflow.python.autograph.core import config
|
from tensorflow.python.autograph.core import config
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.core import function_wrappers
|
from tensorflow.python.autograph.impl import api
|
||||||
from tensorflow.python.autograph.lang import special_functions
|
from tensorflow.python.autograph.impl import conversion
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.autograph.pyct import cfg
|
|
||||||
from tensorflow.python.autograph.pyct import loader
|
|
||||||
from tensorflow.python.autograph.pyct import naming
|
|
||||||
from tensorflow.python.autograph.pyct import origin_info
|
|
||||||
from tensorflow.python.autograph.pyct import parser
|
|
||||||
from tensorflow.python.autograph.pyct import pretty_printer
|
|
||||||
from tensorflow.python.autograph.pyct import qual_names
|
|
||||||
from tensorflow.python.autograph.pyct import transformer
|
|
||||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
|
||||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
def whitelist(entity):
|
def whitelist(f):
|
||||||
"""Helper that marks a callable as whtelitisted."""
|
"""Helper that marks a callable as whtelitisted."""
|
||||||
if 'whitelisted_module_for_testing' not in sys.modules:
|
if 'whitelisted_module_for_testing' not in sys.modules:
|
||||||
whitelisted_mod = imp.new_module('whitelisted_module_for_testing')
|
whitelisted_mod = imp.new_module('whitelisted_module_for_testing')
|
||||||
@ -54,7 +42,7 @@ def whitelist(entity):
|
|||||||
(config.DoNotConvert('whitelisted_module_for_testing'),) +
|
(config.DoNotConvert('whitelisted_module_for_testing'),) +
|
||||||
config.CONVERSION_RULES)
|
config.CONVERSION_RULES)
|
||||||
|
|
||||||
entity.__module__ = 'whitelisted_module_for_testing'
|
f.__module__ = 'whitelisted_module_for_testing'
|
||||||
|
|
||||||
|
|
||||||
def is_inside_generated_code():
|
def is_inside_generated_code():
|
||||||
@ -76,9 +64,39 @@ def is_inside_generated_code():
|
|||||||
del frame
|
del frame
|
||||||
|
|
||||||
|
|
||||||
|
class TestingTranspiler(conversion.AutoGraphTranspiler):
|
||||||
|
"""Testing version that only applies given transformations."""
|
||||||
|
|
||||||
|
def __init__(self, converters):
|
||||||
|
super(TestingTranspiler, self).__init__()
|
||||||
|
if isinstance(converters, (list, tuple)):
|
||||||
|
self._converters = converters
|
||||||
|
else:
|
||||||
|
self._converters = (converters,)
|
||||||
|
self.transformed_ast = None
|
||||||
|
|
||||||
|
def transform_ast(self, node, ctx):
|
||||||
|
node = self.initial_analysis(node, ctx)
|
||||||
|
|
||||||
|
for c in self._converters:
|
||||||
|
node = c.transform(node, ctx)
|
||||||
|
|
||||||
|
self.transformed_ast = node
|
||||||
|
self.transform_ctx = ctx
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
class TestCase(test.TestCase):
|
class TestCase(test.TestCase):
|
||||||
"""Base class for unit tests in this module. Contains relevant utilities."""
|
"""Base class for unit tests in this module. Contains relevant utilities."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# AutoGraph tests must run in graph mode to properly test control flow.
|
||||||
|
self.graph = ops.Graph().as_default()
|
||||||
|
self.graph.__enter__()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.graph.__exit__(None, None, None)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def assertPrints(self, expected_result):
|
def assertPrints(self, expected_result):
|
||||||
try:
|
try:
|
||||||
@ -89,108 +107,26 @@ class TestCase(test.TestCase):
|
|||||||
finally:
|
finally:
|
||||||
sys.stdout = sys.__stdout__
|
sys.stdout = sys.__stdout__
|
||||||
|
|
||||||
@contextlib.contextmanager
|
def transform(
|
||||||
def compiled(self, node, namespace, symbols=()):
|
self, f, converter_module, include_ast=False, ag_overrides=None):
|
||||||
source = None
|
|
||||||
|
|
||||||
self.dynamic_calls = []
|
|
||||||
# See api.converted_call
|
|
||||||
def converted_call(
|
|
||||||
f, args, kwargs, unused_opts=None, unused_function_ctx=None):
|
|
||||||
"""Mock version of api.converted_call."""
|
|
||||||
self.dynamic_calls.append((args, kwargs))
|
|
||||||
if kwargs is None:
|
|
||||||
kwargs = {}
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
|
|
||||||
def fake_autograph_artifact(f):
|
|
||||||
setattr(f, 'fake_autograph_artifact', True)
|
|
||||||
return f
|
|
||||||
|
|
||||||
try:
|
|
||||||
result, source, source_map = loader.load_ast(
|
|
||||||
node, include_source_map=True)
|
|
||||||
# TODO(mdan): Move the unparsing from converter into pyct and reuse here.
|
|
||||||
|
|
||||||
# TODO(mdan): Move this into self.prepare()
|
|
||||||
result.tf = self.make_fake_mod('fake_tf', *symbols)
|
|
||||||
fake_ag = self.make_fake_mod('fake_ag', converted_call,
|
|
||||||
converter.ConversionOptions)
|
|
||||||
fake_ag.__dict__.update(operators.__dict__)
|
|
||||||
fake_ag.__dict__.update(special_functions.__dict__)
|
|
||||||
fake_ag.ConversionOptions = converter.ConversionOptions
|
|
||||||
fake_ag.Feature = converter.Feature
|
|
||||||
fake_ag.utils = utils
|
|
||||||
fake_ag.FunctionScope = function_wrappers.FunctionScope
|
|
||||||
fake_ag.autograph_artifact = fake_autograph_artifact
|
|
||||||
result.ag__ = fake_ag
|
|
||||||
result.ag_source_map__ = source_map
|
|
||||||
for k, v in namespace.items():
|
|
||||||
result.__dict__[k] = v
|
|
||||||
yield result
|
|
||||||
except Exception: # pylint:disable=broad-except
|
|
||||||
if source is None:
|
|
||||||
print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
|
|
||||||
else:
|
|
||||||
print('Offending source code:\n%s' % source)
|
|
||||||
raise
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def converted(self, entity, converter_module, namespace, tf_symbols=()):
|
|
||||||
|
|
||||||
node, ctx = self.prepare(entity, namespace)
|
|
||||||
|
|
||||||
if not isinstance(converter_module, (list, tuple)):
|
|
||||||
converter_module = (converter_module,)
|
|
||||||
for m in converter_module:
|
|
||||||
node = m.transform(node, ctx)
|
|
||||||
|
|
||||||
with self.compiled(node, namespace, tf_symbols) as result:
|
|
||||||
yield result
|
|
||||||
|
|
||||||
def make_fake_mod(self, name, *symbols):
|
|
||||||
fake_mod = imp.new_module(name)
|
|
||||||
for s in symbols:
|
|
||||||
if hasattr(s, '__name__'):
|
|
||||||
setattr(fake_mod, s.__name__, s)
|
|
||||||
elif hasattr(s, 'name'):
|
|
||||||
# This is a bit of a hack, but works for things like tf.int32
|
|
||||||
setattr(fake_mod, s.name, s)
|
|
||||||
else:
|
|
||||||
raise ValueError('can not attach %s - what should be its name?' % s)
|
|
||||||
return fake_mod
|
|
||||||
|
|
||||||
def attach_namespace(self, module, **ns):
|
|
||||||
for k, v in ns.items():
|
|
||||||
setattr(module, k, v)
|
|
||||||
|
|
||||||
def prepare(self, test_fn, namespace, recursive=True):
|
|
||||||
namespace['ConversionOptions'] = converter.ConversionOptions
|
|
||||||
|
|
||||||
future_features = ('print_function', 'division')
|
|
||||||
node, source = parser.parse_entity(test_fn, future_features=future_features)
|
|
||||||
namer = naming.Namer(namespace)
|
|
||||||
program_ctx = converter.ProgramContext(
|
program_ctx = converter.ProgramContext(
|
||||||
options=converter.ConversionOptions(recursive=recursive),
|
options=converter.ConversionOptions(recursive=True),
|
||||||
autograph_module=None)
|
autograph_module=api)
|
||||||
entity_info = transformer.EntityInfo(
|
|
||||||
name=test_fn.__name__,
|
|
||||||
source_code=source,
|
|
||||||
source_file='<fragment>',
|
|
||||||
future_features=future_features,
|
|
||||||
namespace=namespace)
|
|
||||||
ctx = transformer.Context(entity_info, namer, program_ctx)
|
|
||||||
origin_info.resolve_entity(node, source, test_fn)
|
|
||||||
|
|
||||||
graphs = cfg.build(node)
|
conversion.create_custom_vars(program_ctx)
|
||||||
node = qual_names.resolve(node)
|
custom_vars = dict(conversion.custom_vars)
|
||||||
node = activity.resolve(node, ctx, None)
|
|
||||||
node = reaching_definitions.resolve(node, ctx, graphs)
|
|
||||||
anno.dup(
|
|
||||||
node,
|
|
||||||
{
|
|
||||||
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return node, ctx
|
if ag_overrides:
|
||||||
|
modified_ag = imp.new_module('fake_autograph')
|
||||||
|
modified_ag.__dict__.update(custom_vars['ag__'].__dict__)
|
||||||
|
modified_ag.__dict__.update(ag_overrides)
|
||||||
|
custom_vars['ag__'] = modified_ag
|
||||||
|
|
||||||
|
tr = TestingTranspiler(converter_module)
|
||||||
|
transformed, _, _ = tr.transform_function(
|
||||||
|
f, program_ctx.options, program_ctx, custom_vars)
|
||||||
|
|
||||||
|
if include_ast:
|
||||||
|
return transformed, tr.transformed_ast, tr.transform_ctx
|
||||||
|
|
||||||
|
return transformed
|
||||||
|
@ -60,11 +60,7 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
|
|||||||
def get_transformed_name(self, node):
|
def get_transformed_name(self, node):
|
||||||
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
|
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
|
||||||
|
|
||||||
def transform_ast(self, node, ctx):
|
def initial_analysis(self, node, ctx):
|
||||||
# TODO(mdan): Insert list_comprehensions somewhere.
|
|
||||||
unsupported_features_checker.verify(node)
|
|
||||||
|
|
||||||
# Run initial analysis.
|
|
||||||
graphs = cfg.build(node)
|
graphs = cfg.build(node)
|
||||||
node = qual_names.resolve(node)
|
node = qual_names.resolve(node)
|
||||||
node = activity.resolve(node, ctx, None)
|
node = activity.resolve(node, ctx, None)
|
||||||
@ -75,6 +71,11 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler):
|
|||||||
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
|
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
return node
|
||||||
|
|
||||||
|
def transform_ast(self, node, ctx):
|
||||||
|
unsupported_features_checker.verify(node)
|
||||||
|
node = self.initial_analysis(node, ctx)
|
||||||
|
|
||||||
node = functions.transform(node, ctx)
|
node = functions.transform(node, ctx)
|
||||||
node = directives.transform(node, ctx)
|
node = directives.transform(node, ctx)
|
||||||
@ -114,7 +115,7 @@ def convert(entity, program_ctx):
|
|||||||
'expose a __code__ object. If this is a @tf.function,'
|
'expose a __code__ object. If this is a @tf.function,'
|
||||||
' try passing f.python_function instead.')
|
' try passing f.python_function instead.')
|
||||||
|
|
||||||
_create_custom_vars(program_ctx)
|
create_custom_vars(program_ctx)
|
||||||
transformed, module, source_map = _TRANSPILER.transform_function(
|
transformed, module, source_map = _TRANSPILER.transform_function(
|
||||||
entity, program_ctx.options, program_ctx, custom_vars)
|
entity, program_ctx.options, program_ctx, custom_vars)
|
||||||
|
|
||||||
@ -248,7 +249,8 @@ def cache_whitelisted(entity, options):
|
|||||||
|
|
||||||
|
|
||||||
# TODO(mdan): Move into core or replace with an actual importable module.
|
# TODO(mdan): Move into core or replace with an actual importable module.
|
||||||
def _create_custom_vars(program_ctx):
|
# Visible for testing.
|
||||||
|
def create_custom_vars(program_ctx):
|
||||||
"""Adds namespace references to the module that exposes the api itself."""
|
"""Adds namespace references to the module that exposes the api itself."""
|
||||||
global custom_vars
|
global custom_vars
|
||||||
if custom_vars is None:
|
if custom_vars is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user