diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 9cf3bba8dd5..f584038978f 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -77,12 +77,6 @@ py_test( srcs = ["call_trees_test.py"], python_version = "PY3", srcs_version = "PY3", - tags = [ - "no_oss_py2", - "no_pip", - "no_windows", - "nopip", - ], deps = [ ":converters", "//tensorflow/python:client_testlib", @@ -119,12 +113,6 @@ py_test( srcs = ["control_flow_test.py"], python_version = "PY3", srcs_version = "PY3", - tags = [ - "no_oss_py2", - "no_pip", - "no_windows", - "nopip", - ], deps = [ ":converters", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py index dc435cbc90e..bf063829e42 100644 --- a/tensorflow/python/autograph/converters/asserts_test.py +++ b/tensorflow/python/autograph/converters/asserts_test.py @@ -24,7 +24,6 @@ from tensorflow.python.autograph.converters import return_statements from tensorflow.python.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl -from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -32,17 +31,15 @@ class AssertsTest(converter_testing.TestCase): def test_basic(self): - def test_fn(a): + def f(a): assert a, 'testmsg' return a - with ops.Graph().as_default(): - with self.converted( - test_fn, (functions, asserts, return_statements), {}) as result: - op = result.test_fn(constant_op.constant(False)) + tr = self.transform(f, (functions, asserts, return_statements)) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'): - self.evaluate(op) + op = tr(constant_op.constant(False)) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'): + self.evaluate(op) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/break_statements_test.py b/tensorflow/python/autograph/converters/break_statements_test.py index 37accdcc1be..6313cda37fc 100644 --- a/tensorflow/python/autograph/converters/break_statements_test.py +++ b/tensorflow/python/autograph/converters/break_statements_test.py @@ -21,20 +21,18 @@ from __future__ import print_function from tensorflow.python.autograph.converters import break_statements from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.pyct import anno -from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class BreakCanonicalizationTest(converter_testing.TestCase): - def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, break_statements, {}, - (constant_op.constant,)) as result: - self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + def assertTransformedEquivalent(self, f, *inputs): + tr = self.transform(f, break_statements) + self.assertEqual(f(*inputs), tr(*inputs)) def test_while_loop(self): - def test_fn(x): + def f(x): v = [] while x > 0: x -= 1 @@ -43,28 +41,29 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 1) + self.assertTransformedEquivalent(f, 4) def test_while_loop_preserves_directives(self): - def test_fn(x): + def f(x): while x > 0: x -= 1 if x % 2 == 0: break - node, ctx = self.prepare(test_fn, {}) + _, node, ctx = self.transform(f, (), include_ast=True) fake_annotation = object() anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation) node = break_statements.transform(node, ctx) + self.assertIs( anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation) def test_for_loop(self): - def test_fn(a): + def f(a): v = [] for x in a: x -= 1 @@ -73,20 +72,18 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - with self.converted(test_fn, break_statements, {}, - (constant_op.constant,)) as result: - # The break is incompletely canonicalized. The loop will not interrupt, - # but the section following the break will be skipped. - self.assertEqual([3], result.test_fn([5, 4])) + tr = self.transform(f, break_statements) + + self.assertEqual([3], tr([5, 4])) def test_for_loop_preserves_directives(self): - def test_fn(a): + def f(a): for x in a: if x % 2 == 0: break - node, ctx = self.prepare(test_fn, {}) + _, node, ctx = self.transform(f, (), include_ast=True) fake_annotation = object() anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation) node = break_statements.transform(node, ctx) @@ -95,7 +92,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase): def test_nested(self): - def test_fn(x): + def f(x): v = [] u = [] w = [] @@ -110,13 +107,13 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 11) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 11) def test_nested_loops(self): - def test_fn(x): + def f(x): v = [] u = [] while x > 0: @@ -132,14 +129,14 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 5) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 5) def test_loop_orelse(self): - def test_fn(x): + def f(x): v = [] u = [] while x > 0: @@ -153,12 +150,12 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, 3) def test_multiple_correlated_breaks_with_side_effects(self): - def test_fn(cond1): + def f(cond1): lst = [] while True: if cond1: @@ -169,8 +166,9 @@ class BreakCanonicalizationTest(converter_testing.TestCase): break return lst - self.assertTransformedEquivalent(test_fn, True) - self.assertTransformedEquivalent(test_fn, False) + self.assertTransformedEquivalent(f, True) + self.assertTransformedEquivalent(f, False) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py index 86ca2dc9c24..a0bae91af02 100644 --- a/tensorflow/python/autograph/converters/call_trees_test.py +++ b/tensorflow/python/autograph/converters/call_trees_test.py @@ -27,169 +27,193 @@ from tensorflow.python.autograph.core import converter_testing from tensorflow.python.platform import test +class MockConvertedCall(object): + + def __init__(self): + self.calls = [] + + def __call__(self, f, args, kwargs, caller_fn_scope=None, options=None): + del caller_fn_scope, options + self.calls.append((args, kwargs)) + kwargs = kwargs or {} + return f(*args, **kwargs) + + class CallTreesTest(converter_testing.TestCase): + def _transform_with_mock(self, f): + mock = MockConvertedCall() + tr = self.transform( + f, (functions, call_trees), + ag_overrides={'converted_call': mock}) + return tr, mock + def test_function_no_args(self): - def test_fn(f): + def f(f): return f() + 20 - with self.converted(test_fn, (functions, call_trees), {}) as result: - self.assertEqual(result.test_fn(lambda: 1), 21) - self.assertListEqual(self.dynamic_calls, [((), None)]) + tr, mock = self._transform_with_mock(f) + + self.assertEqual(tr(lambda: 1), 21) + self.assertListEqual(mock.calls, [((), None)]) def test_function_with_expression_in_argument(self): - def test_fn(f, g): + def f(f, g): return f(g() + 20) + 4000 - with self.converted(test_fn, (functions, call_trees), {}) as result: - self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321) - self.assertListEqual(self.dynamic_calls, [ - ((), None), - ((21,), None), - ]) + tr, mock = self._transform_with_mock(f) + + self.assertEqual(tr(lambda x: x + 300, lambda: 1), 4321) + self.assertListEqual(mock.calls, [ + ((), None), + ((21,), None), + ]) def test_function_with_call_in_argument(self): - def test_fn(f, g): + def f(f, g): return f(g()) + 300 - with self.converted(test_fn, (functions, call_trees), {}) as result: - self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321) - self.assertListEqual(self.dynamic_calls, [ - ((), None), - ((1,), None), - ]) + tr, mock = self._transform_with_mock(f) + + self.assertEqual(tr(lambda x: x + 20, lambda: 1), 321) + self.assertListEqual(mock.calls, [ + ((), None), + ((1,), None), + ]) def test_function_chaining(self): def get_one(): return 1 - def test_fn(): + def f(): return get_one().__add__(20) - with self.converted(test_fn, (functions, call_trees), - {'get_one': get_one}, ()) as result: + tr, mock = self._transform_with_mock(f) - self.assertEqual(result.test_fn(), 21) - - self.assertListEqual(self.dynamic_calls, [ - ((), None), - ((20,), None), - ]) + self.assertEqual(tr(), 21) + self.assertListEqual(mock.calls, [ + ((), None), + ((20,), None), + ]) def test_function_with_single_arg(self): - def test_fn(f, a): + def f(f, a): return f(a) + 20 - with self.converted(test_fn, (functions, call_trees), {}) as result: - self.assertEqual(result.test_fn(lambda a: a, 1), 21) - self.assertListEqual(self.dynamic_calls, [((1,), None)]) + tr, mock = self._transform_with_mock(f) + + self.assertEqual(tr(lambda a: a, 1), 21) + self.assertListEqual(mock.calls, [((1,), None)]) def test_function_with_args_only(self): - def test_fn(f, a, b): + def f(f, a, b): return f(a, b) + 300 - with self.converted(test_fn, (functions, call_trees), {}) as result: - self.assertEqual(result.test_fn(lambda a, b: a + b, 1, 20), 321) - self.assertListEqual(self.dynamic_calls, [((1, 20), None)]) + tr, mock = self._transform_with_mock(f) + + self.assertEqual(tr(lambda a, b: a + b, 1, 20), 321) + self.assertListEqual(mock.calls, [((1, 20), None)]) def test_function_with_kwarg(self): - def test_fn(f, a, b): + def f(f, a, b): return f(a, c=b) + 300 - with self.converted(test_fn, (functions, call_trees), {}) as result: - self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321) - self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})]) + tr, mock = self._transform_with_mock(f) + + self.assertEqual(tr(lambda a, c: a + c, 1, 20), 321) + self.assertListEqual(mock.calls, [((1,), {'c': 20})]) def test_function_with_kwargs_starargs(self): - def test_fn(f, a, *args, **kwargs): + def f(f, a, *args, **kwargs): return f(a, *args, **kwargs) + 5 - with self.converted(test_fn, (functions, call_trees), {}) as result: - self.assertEqual( - result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{ - 'b': 4, - 'c': 5 - }), 12) - self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})]) + tr, mock = self._transform_with_mock(f) + + self.assertEqual( + tr(lambda *args, **kwargs: 7, 1, *[2, 3], **{ + 'b': 4, + 'c': 5 + }), 12) + self.assertListEqual(mock.calls, [((1, 2, 3), {'b': 4, 'c': 5})]) def test_function_with_starargs_only(self): - def f(*args): + def g(*args): return sum(args) - def test_fn(): + def f(): args = [1, 20, 300] - return f(*args) + 4000 + return g(*args) + 4000 - with self.converted(test_fn, (functions, call_trees), - {'f': f}) as result: - self.assertEqual(result.test_fn(), 4321) - self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)]) + tr, mock = self._transform_with_mock(f) - # TODO(b/142586827): Enable this test. - # def test_function_with_starargs_mixed(self): - # - # def f(a, b, c, d): - # return a * 1000 + b * 100 + c * 10 + d - # - # def test_fn(): - # args1 = (1,) - # args2 = [3] - # return f(*args1, 2, *args2, 4) - # - # with self.converted(test_fn, (functions, call_trees), - # {'f': f}) as result: - # self.assertEqual(result.test_fn(), 1234) - # self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)]) + self.assertEqual(tr(), 4321) + self.assertListEqual(mock.calls, [((1, 20, 300), None)]) + + def test_function_with_starargs_mixed(self): + + def g(a, b, c, d): + return a * 1000 + b * 100 + c * 10 + d + + def f(): + args1 = (1,) + args2 = [3] + return g(*args1, 2, *args2, 4) + + tr, mock = self._transform_with_mock(f) + + self.assertEqual(tr(), 1234) + self.assertListEqual(mock.calls, [((1, 2, 3, 4), None)]) def test_function_with_kwargs_keywords(self): - def test_fn(f, a, b, **kwargs): + def f(f, a, b, **kwargs): return f(a, b=b, **kwargs) + 5 - with self.converted(test_fn, (functions, call_trees), {}) as result: - self.assertEqual( - result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12) - self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})]) + tr, mock = self._transform_with_mock(f) - # TODO(b/142586827): Enable this test. - # def test_function_with_multiple_kwargs(self): - # - # def test_fn(f, a, b, c, kwargs1, kwargs2): - # return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5 - # - # with self.converted(test_fn, (functions, call_trees), {}) as result: - # self.assertEqual( - # result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4}, - # {'e': 5}), 12) - # self.assertListEqual(self.dynamic_calls, [((1,), { - # 'b': 2, - # 'c': 3, - # 'd': 4, - # 'e': 5 - # })]) + self.assertEqual( + tr(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12) + self.assertListEqual(mock.calls, [((1,), {'b': 2, 'c': 3})]) + + def test_function_with_multiple_kwargs(self): + + def f(f, a, b, c, kwargs1, kwargs2): + return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5 + + tr, mock = self._transform_with_mock(f) + + self.assertEqual( + tr(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4}, {'e': 5}), 12) + self.assertListEqual(mock.calls, [((1,), { + 'b': 2, + 'c': 3, + 'd': 4, + 'e': 5 + })]) def test_function_with_call_in_lambda_argument(self): - def f(l, a): + def h(l, a): return l(a) + 4000 def g(a, *args): return a + sum(args) - def test_fn(f, g, a, *args): - return f(lambda x: g(x, *args), a) + def f(h, g, a, *args): + return h(lambda x: g(x, *args), a) - with self.converted(test_fn, (functions, call_trees), {}) as result: - self.assertEqual(result.test_fn(f, g, 1, *(20, 300)), 4321) + tr, _ = self._transform_with_mock(f) + + self.assertEqual(tr(h, g, 1, *(20, 300)), 4321) def test_debugger_set_trace(self): @@ -198,13 +222,13 @@ class CallTreesTest(converter_testing.TestCase): pdb = imp.new_module('fake_pdb') pdb.set_trace = lambda: tracking_list.append(1) - def test_fn(): + def f(): return pdb.set_trace() - with self.converted(test_fn, (functions, call_trees), - {'pdb': pdb}) as result: - result.test_fn() - self.assertListEqual(tracking_list, [1]) + tr, _ = self._transform_with_mock(f) + + tr() + self.assertListEqual(tracking_list, [1]) def test_class_method(self): @@ -217,10 +241,10 @@ class CallTreesTest(converter_testing.TestCase): return self.other_method(a) + 300 tc = TestClass() - with self.converted(TestClass.test_method, (functions, call_trees), - {}) as result: - self.assertEqual(321, result.test_method(tc, 1)) - self.assertListEqual(self.dynamic_calls, [((1,), None)]) + tr, mock = self._transform_with_mock(TestClass.test_method) + + self.assertEqual(321, tr(tc, 1)) + self.assertListEqual(mock.calls, [((1,), None)]) def test_object_method(self): @@ -233,10 +257,10 @@ class CallTreesTest(converter_testing.TestCase): return self.other_method(a) + 300 tc = TestClass() - with self.converted(tc.test_method, (functions, call_trees), - {}) as result: - self.assertEqual(321, result.test_method(tc, 1)) - self.assertListEqual(self.dynamic_calls, [((1,), None)]) + tr, mock = self._transform_with_mock(tc.test_method) + + self.assertEqual(321, tr(tc, 1)) + self.assertListEqual(mock.calls, [((1,), None)]) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/conditional_expressions_test.py b/tensorflow/python/autograph/converters/conditional_expressions_test.py index dd1f8d485cc..020849d79f5 100644 --- a/tensorflow/python/autograph/converters/conditional_expressions_test.py +++ b/tensorflow/python/autograph/converters/conditional_expressions_test.py @@ -25,28 +25,27 @@ from tensorflow.python.platform import test class ConditionalExpressionsTest(converter_testing.TestCase): - def assertTransformedEquivalent(self, test_fn, *inputs): - ns = {} - with self.converted(test_fn, conditional_expressions, ns) as result: - self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + def assertTransformedEquivalent(self, f, *inputs): + tr = self.transform(f, conditional_expressions) + self.assertEqual(f(*inputs), tr(*inputs)) def test_basic(self): - def test_fn(x): + def f(x): return 1 if x else 0 - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 3) def test_nested_orelse(self): - def test_fn(x): + def f(x): y = x * x if x > 0 else x if x else 1 return y - self.assertTransformedEquivalent(test_fn, -2) - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(f, -2) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 2) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/continue_statements_test.py b/tensorflow/python/autograph/converters/continue_statements_test.py index a24ddd5e527..ed6e27fca6f 100644 --- a/tensorflow/python/autograph/converters/continue_statements_test.py +++ b/tensorflow/python/autograph/converters/continue_statements_test.py @@ -20,21 +20,19 @@ from __future__ import print_function from tensorflow.python.autograph.converters import continue_statements from tensorflow.python.autograph.core import converter_testing -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test class ContinueCanonicalizationTest(converter_testing.TestCase): - def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, continue_statements, {'ops': ops}, - (constant_op.constant,)) as result: - self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + def assertTransformedEquivalent(self, f, *inputs): + tr = self.transform(f, continue_statements) + self.assertEqual(f(*inputs), tr(*inputs)) def test_basic(self): - def test_fn(x): + def f(x): v = [] while x > 0: x -= 1 @@ -43,14 +41,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 1) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 4) def test_multiple_continues(self): - def test_fn(x): + def f(x): v = [] while x > 0: x -= 1 @@ -61,14 +59,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 1) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 4) def test_multiple_continues_in_nested_scope(self): - def test_fn(a): + def f(a): v = [] for x in a: x -= 1 @@ -81,14 +79,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [1]) - self.assertTransformedEquivalent(test_fn, [2]) - self.assertTransformedEquivalent(test_fn, [1, 2, 3]) + self.assertTransformedEquivalent(f, []) + self.assertTransformedEquivalent(f, [1]) + self.assertTransformedEquivalent(f, [2]) + self.assertTransformedEquivalent(f, [1, 2, 3]) def test_for_loop(self): - def test_fn(a): + def f(a): v = [] for x in a: x -= 1 @@ -97,14 +95,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [1]) - self.assertTransformedEquivalent(test_fn, [2]) - self.assertTransformedEquivalent(test_fn, [1, 2, 3]) + self.assertTransformedEquivalent(f, []) + self.assertTransformedEquivalent(f, [1]) + self.assertTransformedEquivalent(f, [2]) + self.assertTransformedEquivalent(f, [1, 2, 3]) def test_nested_with(self): - def test_fn(x): + def f(x): v = [] while x > 0: x -= 1 @@ -114,14 +112,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 1) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 4) def test_nested_multiple_withs(self): - def test_fn(x): + def f(x): v = [] while x > 0: x -= 1 @@ -133,14 +131,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 1) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 4) def test_nested_multiple_withs_and_statements(self): - def test_fn(x): + def f(x): v = [] while x > 0: x -= 1 @@ -154,14 +152,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 1) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 4) def test_nested_multiple_withs_and_nested_withs(self): - def test_fn(x): + def f(x): v = [] while x > 0: x -= 1 @@ -176,14 +174,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 1) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 4) def test_nested(self): - def test_fn(x): + def f(x): v = [] u = [] w = [] @@ -198,14 +196,14 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 1) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 4) def test_multiple_guarded_continues_with_side_effects(self): - def test_fn(x): + def f(x): def track(u, x): u.append(x) return x @@ -221,8 +219,8 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return u, v - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 2) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py index f0681128698..87f59bef675 100644 --- a/tensorflow/python/autograph/converters/control_flow_test.py +++ b/tensorflow/python/autograph/converters/control_flow_test.py @@ -23,6 +23,8 @@ import collections import numpy as np +from tensorflow.python.autograph.converters import break_statements +from tensorflow.python.autograph.converters import continue_statements from tensorflow.python.autograph.converters import control_flow from tensorflow.python.autograph.core import converter_testing from tensorflow.python.eager import def_function @@ -34,7 +36,8 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.platform import test from tensorflow.python.util import nest -# TODO(mdan): These tests are not isolated - they also test the operators. + +for_unaffected_global = None class ControlFlowTestBase(converter_testing.TestCase): @@ -45,22 +48,19 @@ class ControlFlowTestBase(converter_testing.TestCase): actual) self.assertAllEqual(values, expected) - def assertTransformedResult(self, test_fn, inputs, expected, symbols=None): + def assertTransformedResult(self, f, inputs, expected): if not isinstance(inputs, tuple): inputs = (inputs,) - if not symbols: - symbols = {} - with self.converted(test_fn, control_flow, symbols, - (constant_op.constant,)) as result: - returns = result.test_fn(*inputs) - self.assertValuesEqual(returns, expected) + tr = self.transform(f, control_flow) + returns = tr(*inputs) + self.assertValuesEqual(returns, expected) class NestedControlFlowTest(ControlFlowTestBase): def test_basic(self): - def test_fn(n): + def f(n): i = 0 j = 0 s = 0 @@ -73,7 +73,7 @@ class NestedControlFlowTest(ControlFlowTestBase): j = 0 return s, i, j, n - self.assertTransformedResult(test_fn, constant_op.constant(5), + self.assertTransformedResult(f, constant_op.constant(5), (25, 5, 0, 5)) def test_composite_state_complex(self): @@ -88,7 +88,7 @@ class NestedControlFlowTest(ControlFlowTestBase): def __init__(self, y): self.y = y - def test_fn(n): + def f(n): tc = TestClassX(TestClassY({'z': TestClassX(n)})) if n > 0: while n > 0: @@ -97,19 +97,17 @@ class NestedControlFlowTest(ControlFlowTestBase): n -= 1 return n, tc - with self.converted(test_fn, control_flow, { - 'TestClassX': TestClassX, - 'TestClassY': TestClassY, - }) as result: - n, tc = result.test_fn(constant_op.constant(5)) - self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6)) + tr = self.transform(f, control_flow) + + n, tc = tr(constant_op.constant(5)) + self.assertValuesEqual((n, tc.x.y['z'].x), (0, 6)) class WhileStatementTest(ControlFlowTestBase): def test_basic(self): - def test_fn(n): + def f(n): i = 0 s = 0 while i < n: @@ -117,16 +115,16 @@ class WhileStatementTest(ControlFlowTestBase): i += 1 return s, i, n - self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5)) + self.assertTransformedResult(f, constant_op.constant(5), (10, 5, 5)) def test_single_output(self): - def test_fn(n): + def f(n): while n > 0: n -= 1 return n - self.assertTransformedResult(test_fn, constant_op.constant(5), 0) + self.assertTransformedResult(f, constant_op.constant(5), 0) def test_composite_state_attr(self): @@ -135,19 +133,18 @@ class WhileStatementTest(ControlFlowTestBase): def __init__(self): self.x = constant_op.constant(3) - def test_fn(n): + def f(n): tc = TestClass() while n > 0: tc.x += 1 n -= 1 return n - self.assertTransformedResult( - test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass}) + self.assertTransformedResult(f, constant_op.constant(5), 0) def test_composite_state_slice(self): - def test_fn(n): + def f(n): d = {'a': n} k = 'a' while n > 0: @@ -155,25 +152,25 @@ class WhileStatementTest(ControlFlowTestBase): n -= 1 return d[k], n - self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 0)) + self.assertTransformedResult(f, constant_op.constant(5), (10, 0)) def test_composite_state_literal_slice(self): - def test_fn(n): + def f(n): d = {'a': n} while n > 0: d['a'] += 1 n -= 1 return d['a'], n - self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 0)) + self.assertTransformedResult(f, constant_op.constant(5), (10, 0)) def test_composite_state_attr_initialized_in_loop(self): class TestClass(object): pass - def test_fn(n, x): + def f(n, x): tc = TestClass() while n < 5: if n == 0: @@ -183,19 +180,15 @@ class WhileStatementTest(ControlFlowTestBase): n += 1 return tc.subattr - self.assertTransformedResult( - test_fn, (0, constant_op.constant(10)), - 14, - symbols={'TestClass': TestClass}) - with self.converted( - test_fn, control_flow, {'TestClass': TestClass}) as result: - with self.assertRaisesRegex( - ValueError, "'tc.subattr' must be defined before the loop"): - result.test_fn(constant_op.constant(0), 0) + self.assertTransformedResult(f, (0, constant_op.constant(10)), 14) + tr = self.transform(f, control_flow) + with self.assertRaisesRegex( + ValueError, "'tc.subattr' must be defined before the loop"): + tr(constant_op.constant(0), 0) def test_composite_state_slice_initialized_in_loop(self): - def test_fn(n, x): + def f(n, x): d = {} k = 'subkey' while n < 5: @@ -206,16 +199,16 @@ class WhileStatementTest(ControlFlowTestBase): n += 1 return d - self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), + self.assertTransformedResult(f, (0, constant_op.constant(10)), {'subkey': 14}) - with self.converted(test_fn, control_flow, {}) as result: - with self.assertRaisesRegex( - ValueError, r"'d\[k\]' must be defined before the loop"): - result.test_fn(constant_op.constant(0), 0) + tr = self.transform(f, control_flow) + with self.assertRaisesRegex( + ValueError, r"'d\[k\]' must be defined before the loop"): + tr(constant_op.constant(0), 0) def test_composite_state_literal_slice_initialized_in_loop(self): - def test_fn(n, x): + def f(n, x): d = {} while n < 5: if n == 0: @@ -225,16 +218,16 @@ class WhileStatementTest(ControlFlowTestBase): n += 1 return d - self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), + self.assertTransformedResult(f, (0, constant_op.constant(10)), {'subkey': 14}) - with self.converted(test_fn, control_flow, {}) as result: - with self.assertRaisesRegex( - ValueError, r"'d\['subkey'\]' must be defined before the loop"): - result.test_fn(constant_op.constant(0), 0) + tr = self.transform(f, control_flow) + with self.assertRaisesRegex( + ValueError, r"'d\['subkey'\]' must be defined before the loop"): + tr(constant_op.constant(0), 0) def test_composite_state_slice_aliased_to_local(self): - def test_fn(n, x): + def f(n, x): d = {} while n < 5: k = 'subkey' @@ -242,15 +235,15 @@ class WhileStatementTest(ControlFlowTestBase): n += 1 return d - self.assertTransformedResult(test_fn, (0, constant_op.constant(10)), + self.assertTransformedResult(f, (0, constant_op.constant(10)), {'subkey': 11}) - with self.converted(test_fn, control_flow, {}) as result: - # TODO(b/136999953): Better error message. - # Note that this error happens at execution time. - with self.assertRaises(errors.InaccessibleTensorError): - graph_fn = def_function.function(result.test_fn, autograph=False) - self.evaluate( - graph_fn(constant_op.constant(0), constant_op.constant(5))) + tr = self.transform(f, control_flow) + # TODO(b/136999953): Better error message. + # Note that this error happens at execution time. + with self.assertRaises(errors.InaccessibleTensorError): + graph_fn = def_function.function(tr, autograph=False) + self.evaluate( + graph_fn(constant_op.constant(0), constant_op.constant(5))) def test_local_composite_attr(self): @@ -259,19 +252,18 @@ class WhileStatementTest(ControlFlowTestBase): def __init__(self): self.x = constant_op.constant(3) - def test_fn(n): + def f(n): while n > 0: tc = TestClass() tc.x = tc.x n -= 1 return n - self.assertTransformedResult( - test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass}) + self.assertTransformedResult(f, constant_op.constant(5), 0) def test_local_composite_slice(self): - def test_fn(n): + def f(n): while n > 0: d = {'x': n} k = 'x' @@ -279,26 +271,26 @@ class WhileStatementTest(ControlFlowTestBase): n -= 1 return n - self.assertTransformedResult(test_fn, constant_op.constant(5), 0, {}) + self.assertTransformedResult(f, constant_op.constant(5), 0) def test_local_composite_literal_slice(self): - def test_fn(n): + def f(n): while n > 0: d = {'x': n} d['x'] = d['x'] n -= 1 return n - self.assertTransformedResult(test_fn, constant_op.constant(5), 0, {}) + self.assertTransformedResult(f, constant_op.constant(5), 0) def test_non_tensor_state(self): - # This class is ok to be in a tf.while_loop's state. + # This class is ok to be in a tf.while's state. class TestClass(collections.namedtuple('TestClass', ('x'))): pass - def test_fn(n): + def f(n): tc = TestClass([constant_op.constant(0)]) while n > 0: tc = TestClass([constant_op.constant(3)]) @@ -306,9 +298,7 @@ class WhileStatementTest(ControlFlowTestBase): n -= 1 return tc.x[0] - ns = {'TestClass': TestClass, 'constant_op': constant_op} - self.assertTransformedResult( - test_fn, constant_op.constant(5), 4, symbols=ns) + self.assertTransformedResult(f, constant_op.constant(5), 4) def test_non_tensor_state_illegal_type(self): @@ -317,20 +307,20 @@ class WhileStatementTest(ControlFlowTestBase): def __init__(self): self.x = [constant_op.constant(3)] - def test_fn(n): + def f(n): while n > 0: tc = TestClass() tc.x[0] = tc.x[0] + 1 n -= 1 return tc.x[0] - with self.converted( - test_fn, control_flow, {'TestClass': TestClass}) as result: - # The tested function would require `tc` to become part of the while loop - # state, but TensorFlow doesn't support classes at the moment. - with self.assertRaisesRegexp( - ValueError, 'tc.*must be defined before the loop'): - result.test_fn(constant_op.constant(5)) + tr = self.transform(f, control_flow) + + # The tested function would require `tc` to become part of the while loop + # state, but TensorFlow doesn't support classes at the moment. + with self.assertRaisesRegex( + ValueError, 'tc.*must be defined before the loop'): + tr(constant_op.constant(5)) def test_dispatches_by_cond_only(self): @@ -343,27 +333,27 @@ class WhileStatementTest(ControlFlowTestBase): def __add__(self, other): return TensorIncompatibleNumeric(self.val + other) - def test_fn(n, s): + def f(n, s): while n > 0: n -= 1 s += n return s - self.assertTransformedResult(test_fn, (constant_op.constant(5), 0), 10) - with self.converted(test_fn, control_flow, {}) as result: - # n alone controls the staging. When the loop is not staged, Python - # knows how to add the two objects. But when staged, tf.while_loop will - # not know how to deal with the TensorIncompatibleNumeric object. - self.assertEqual(result.test_fn(5, TensorIncompatibleNumeric(0)).val, 10) - with self.assertRaises(TypeError): - result.test_fn(constant_op.constant(5), TensorIncompatibleNumeric(0)) + self.assertTransformedResult(f, (constant_op.constant(5), 0), 10) + tr = self.transform(f, control_flow) + # n alone controls the staging. When the loop is not staged, Python + # knows how to add the two objects. But when staged, tf.while will + # not know how to deal with the TensorIncompatibleNumeric object. + self.assertEqual(tr(5, TensorIncompatibleNumeric(0)).val, 10) + with self.assertRaises(TypeError): + tr(constant_op.constant(5), TensorIncompatibleNumeric(0)) class IfStatementTest(ControlFlowTestBase): def test_basic(self): - def test_fn(n): + def f(n): a = 0 b = 0 if n > 0: @@ -372,20 +362,20 @@ class IfStatementTest(ControlFlowTestBase): b = 2 * n return a, b - self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0)) - self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2)) + self.assertTransformedResult(f, constant_op.constant(1), (-1, 0)) + self.assertTransformedResult(f, constant_op.constant(-1), (0, -2)) def test_sparse_tensor(self): - def test_fn(cond, a): + def f(cond, a): if cond: a = -a return a st = sparse_tensor.SparseTensor( indices=((0,),), values=(0,), dense_shape=(1,)) - self.assertTransformedResult(test_fn, (st, constant_op.constant(1)), -1) - self.assertTransformedResult(test_fn, (None, constant_op.constant(1)), 1) + self.assertTransformedResult(f, (st, constant_op.constant(1)), -1) + self.assertTransformedResult(f, (None, constant_op.constant(1)), 1) def test_complex_outputs(self): @@ -395,7 +385,7 @@ class IfStatementTest(ControlFlowTestBase): self.a = a self.b = b - def test_fn(n, obj): + def f(n, obj): obj.a = 0 obj.b = 0 if n > 0: @@ -404,94 +394,94 @@ class IfStatementTest(ControlFlowTestBase): obj.b = 2 * n return obj - with self.converted(test_fn, control_flow, {}) as result: - res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0)) - self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0)) - res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0)) - self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2)) + tr = self.transform(f, control_flow) + + res_obj = tr(constant_op.constant(1), TestClass(0, 0)) + self.assertValuesEqual((res_obj.a, res_obj.b), (-1, 0)) + res_obj = tr(constant_op.constant(-1), TestClass(0, 0)) + self.assertValuesEqual((res_obj.a, res_obj.b), (0, -2)) def test_single_output(self): - def test_fn(n): + def f(n): if n > 0: n = -n return n - self.assertTransformedResult(test_fn, constant_op.constant(1), -1) + self.assertTransformedResult(f, constant_op.constant(1), -1) def test_unbalanced(self): - def test_fn(n): + def f(n): if n > 0: n = 3 return n - self.assertTransformedResult(test_fn, constant_op.constant(2), 3) - self.assertTransformedResult(test_fn, constant_op.constant(-3), -3) + self.assertTransformedResult(f, constant_op.constant(2), 3) + self.assertTransformedResult(f, constant_op.constant(-3), -3) def test_unbalanced_raising(self): - def test_fn(n): + def f(n): if n > 0: n = n + 1 raise ValueError() return n - self.assertTransformedResult(test_fn, -3, -3) + self.assertTransformedResult(f, -3, -3) - with self.converted(test_fn, control_flow, {}) as result: - with self.assertRaises(ValueError): - result.test_fn(1) + tr = self.transform(f, control_flow) + + with self.assertRaises(ValueError): + tr(1) def test_local_var(self): - def test_fn(n): + def f(n): if n > 0: b = 4 n = b + 1 return n - self.assertTransformedResult(test_fn, constant_op.constant(1), 5) - self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) + self.assertTransformedResult(f, constant_op.constant(1), 5) + self.assertTransformedResult(f, constant_op.constant(-1), -1) def test_local_remains_local(self): - def test_fn(n): + def f(n): if n > 0: b = 4 n = b + 1 return n - self.assertTransformedResult(test_fn, constant_op.constant(1), 5) - self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) + self.assertTransformedResult(f, constant_op.constant(1), 5) + self.assertTransformedResult(f, constant_op.constant(-1), -1) def test_no_outputs(self): - def test_fn(n): + def f(n): if n > 0: b = 4 # pylint:disable=unused-variable return n - # Without side effect guards, the if statement will stage a cond, - # but that will be pruned at execution. - self.assertTransformedResult(test_fn, constant_op.constant(1), 1) - self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) + self.assertTransformedResult(f, constant_op.constant(1), 1) + self.assertTransformedResult(f, constant_op.constant(-1), -1) def test_created_outputs(self): - def test_fn(i): + def f(i): if i == 0: result = i - 1 else: result = i + 1 return result - self.assertTransformedResult(test_fn, 0, -1) - self.assertTransformedResult(test_fn, 1, 2) + self.assertTransformedResult(f, 0, -1) + self.assertTransformedResult(f, 1, 2) def test_created_loop_local_outputs(self): - def test_fn(n, x): + def f(n, x): for i in n: if i == 0: result = i - 1 @@ -501,11 +491,11 @@ class IfStatementTest(ControlFlowTestBase): x += 1 return x - self.assertTransformedResult(test_fn, (range(5), 10), 14) + self.assertTransformedResult(f, (range(5), 10), 14) def test_created_loop_variable(self): - def test_fn(n, x): + def f(n, x): for i in n: if i == 0: result = i - 1 @@ -514,22 +504,26 @@ class IfStatementTest(ControlFlowTestBase): x += 1 return x - self.assertTransformedResult(test_fn, (range(5), 10), 14) + self.assertTransformedResult(f, (range(5), 10), 14) def test_unaffected_global(self): - def test_fn(i): - global g # pylint:disable=global-variable-undefined - if i == 0: - g = i - 1 - return g + global for_unaffected_global + for_unaffected_global = 3 - self.assertTransformedResult(test_fn, 1, 3, symbols={'g': 3}) - self.assertTransformedResult(test_fn, 0, -1, symbols={'g': 3}) + def f(i): + global for_unaffected_global + if i == 0: + for_unaffected_global = i - 1 + return for_unaffected_global + + self.assertTransformedResult(f, 1, 3) + self.assertTransformedResult(f, 0, -1) + self.assertEqual(for_unaffected_global, -1) def test_unaffected_nonlocal(self): - def test_fn(i): + def f(i): def inner_fn(): nonlocal n if i == 0: @@ -539,12 +533,12 @@ class IfStatementTest(ControlFlowTestBase): inner_fn() return n - self.assertTransformedResult(test_fn, 1, 3) - self.assertTransformedResult(test_fn, 0, -1) + self.assertTransformedResult(f, 1, 3) + self.assertTransformedResult(f, 0, -1) def test_output_defined_in_prior_except(self): - def test_fn(i): + def f(i): try: raise ValueError() except ValueError: @@ -553,8 +547,8 @@ class IfStatementTest(ControlFlowTestBase): x = i - 1 return x - self.assertTransformedResult(test_fn, 1, 1) - self.assertTransformedResult(test_fn, 0, -1) + self.assertTransformedResult(f, 1, 1) + self.assertTransformedResult(f, 0, -1) def test_unbalanced_multiple_composites(self): @@ -564,7 +558,7 @@ class IfStatementTest(ControlFlowTestBase): self.b = 2 self.c = 3 - def test_fn(x, condition): + def f(x, condition): z = 5 if condition: @@ -574,9 +568,9 @@ class IfStatementTest(ControlFlowTestBase): return x.b, x.c, z - self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)), + self.assertTransformedResult(f, (Foo(), constant_op.constant(True)), (7, 11, 13)) - self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)), + self.assertTransformedResult(f, (Foo(), constant_op.constant(False)), (2, 3, 5)) def test_unbalanced_composite(self): @@ -586,7 +580,7 @@ class IfStatementTest(ControlFlowTestBase): def __init__(self): self.b = 2 - def test_fn(x, condition): + def f(x, condition): z = 5 if condition: @@ -595,9 +589,9 @@ class IfStatementTest(ControlFlowTestBase): return x.b, z - self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)), + self.assertTransformedResult(f, (Foo(), constant_op.constant(True)), (7, 13)) - self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)), + self.assertTransformedResult(f, (Foo(), constant_op.constant(False)), (2, 5)) @@ -605,7 +599,7 @@ class ForStatementTest(ControlFlowTestBase): def test_basic(self): - def test_fn(l): + def f(l): s1 = 0 s2 = 0 for e in l: @@ -613,21 +607,21 @@ class ForStatementTest(ControlFlowTestBase): s2 += e * e return s1, s2 - self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10)) + self.assertTransformedResult(f, constant_op.constant([1, 3]), (4, 10)) empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) - self.assertTransformedResult(test_fn, empty_vector, (0, 0)) + self.assertTransformedResult(f, empty_vector, (0, 0)) def test_single_output(self): - def test_fn(l): + def f(l): s = 0 for e in l: s += e return s - self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4) + self.assertTransformedResult(f, constant_op.constant([1, 3]), 4) empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) - self.assertTransformedResult(test_fn, empty_vector, 0) + self.assertTransformedResult(f, empty_vector, 0) def test_iterated_expression(self): @@ -637,26 +631,23 @@ class ForStatementTest(ControlFlowTestBase): eval_count[0] += 1 return x - def test_fn(n): + def f(n): s = 0 for e in count_evals(range(n)): s += e return s - ns = {'count_evals': count_evals} - node, ctx = self.prepare(test_fn, ns) - node = control_flow.transform(node, ctx) + tr = self.transform(f, control_flow) - with self.compiled(node, ns) as result: - self.assertEqual(result.test_fn(5), 10) - self.assertEqual(eval_count[0], 1) + self.assertEqual(tr(5), 10) + self.assertEqual(eval_count[0], 1) def test_composite_state_initialized_in_loop(self): class TestClass(object): pass - def test_fn(n, x): + def f(n, x): tc = TestClass() for i in n: if i == 0: @@ -665,37 +656,97 @@ class ForStatementTest(ControlFlowTestBase): tc.x = tc.x + i return tc.x - self.assertTransformedResult( - test_fn, (range(5), constant_op.constant(10)), - 20, - symbols={'TestClass': TestClass}) - with self.converted( - test_fn, control_flow, {'TestClass': TestClass}) as result: - with self.assertRaisesRegex( - ValueError, "'tc.x' must be defined before the loop"): - result.test_fn(constant_op.constant(list(range(5))), 0) + self.assertTransformedResult(f, (range(5), constant_op.constant(10)), 20) + tr = self.transform(f, control_flow) + + with self.assertRaisesRegex( + ValueError, "'tc.x' must be defined before the loop"): + tr(constant_op.constant(list(range(5))), 0) def test_tuple_unpacking(self): - def test_fn(x_list): - z = tf.constant(0) # pylint:disable=undefined-variable + + def f(x_list): + z = constant_op.constant(0) # pylint:disable=undefined-variable for i, x in enumerate(x_list): z = z + x + i return z - self.assertTransformedResult(test_fn, [3, 3], 7) + self.assertTransformedResult(f, [3, 3], 7) def test_with_comprehension_in_body(self): - def test_fn(l, n): + def f(l, n): s = constant_op.constant(list(range(n))) for _ in l: s += constant_op.constant([a for a in range(n)]) return s - self.assertTransformedResult( - test_fn, (constant_op.constant([1, 2, 3]), 5), - np.array(range(5)) * 4, - symbols={'constant_op': constant_op}) + self.assertTransformedResult(f, (constant_op.constant([1, 2, 3]), 5), + np.array(range(5)) * 4) + + +class AdvancedControlFlowTest(ControlFlowTestBase): + + def assertTransformedEquivalent(self, f, *inputs): + tr = self.transform( + f, (break_statements, continue_statements, control_flow)) + self.assertEqual(f(*inputs), tr(*inputs)) + + def test_while_with_else(self): + + def f(x): + while x > 2: + x /= 2 + else: + x += 1 + return x + + self.assertTransformedEquivalent(f, 4) + self.assertTransformedEquivalent(f, 2) + + def test_while_with_else_and_break(self): + + def f(cond1): + x = 8 + while x > 2: + x /= 2 + if cond1: + break + else: + x += 1 + return x + + self.assertTransformedEquivalent(f, True) + self.assertTransformedEquivalent(f, False) + + def test_for_with_else(self): + + def f(l): + res = 0 + for x in l: + res += x + else: + res += 1 + return res + + self.assertTransformedEquivalent(f, []) + self.assertTransformedEquivalent(f, [1, 2]) + + def test_for_with_else_and_break(self): + + def f(flag): + l = [1, 2, 3] + res = 0 + for x in l: + res += x + if flag: + break + else: + res += 1 + return res + + self.assertTransformedEquivalent(f, True) + self.assertTransformedEquivalent(f, False) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/directives_test.py b/tensorflow/python/autograph/converters/directives_test.py index f86e7a9a0bd..ac8730fe185 100644 --- a/tensorflow/python/autograph/converters/directives_test.py +++ b/tensorflow/python/autograph/converters/directives_test.py @@ -22,7 +22,6 @@ from tensorflow.python.autograph.converters import directives as directives_conv from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import parser from tensorflow.python.platform import test @@ -30,13 +29,12 @@ class DirectivesTest(converter_testing.TestCase): def test_local_target(self): - def test_fn(): + def f(): l = [] string_var = 0 directives.set_element_type(l, 'a', string_var) - node, ctx = self.prepare(test_fn, {'directives': directives}) - node = directives_converter.transform(node, ctx) + _, node, _ = self.transform(f, directives_converter, include_ast=True) def_, = anno.getanno(node.body[0].targets[0], anno.Static.DEFINITIONS) @@ -46,11 +44,11 @@ class DirectivesTest(converter_testing.TestCase): def test_argument_target(self): - def test_fn(a): + def f(a): directives.set_element_type(a, 1, shape=2) + pass - node, ctx = self.prepare(test_fn, {'directives': directives}) - node = directives_converter.transform(node, ctx) + _, node, _ = self.transform(f, directives_converter, include_ast=True) def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) d = def_.directives[directives.set_element_type] @@ -59,13 +57,13 @@ class DirectivesTest(converter_testing.TestCase): def test_loop_target(self): - def test_fn(): + def f(): a = True while True: directives.set_loop_options(parallel_iterations=10, back_prop=a) + pass - node, ctx = self.prepare(test_fn, {'directives': directives}) - node = directives_converter.transform(node, ctx) + _, node, _ = self.transform(f, directives_converter, include_ast=True) d = anno.getanno(node.body[1], anno.Basic.DIRECTIVES) d = d[directives.set_loop_options] @@ -75,40 +73,23 @@ class DirectivesTest(converter_testing.TestCase): def test_loop_target_no_loop(self): - def test_fn(): + def f(): directives.set_loop_options() + pass - node, ctx = self.prepare(test_fn, {'directives': directives}) with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'): - node = directives_converter.transform(node, ctx) + self.transform(f, directives_converter, include_ast=True) def test_loop_target_not_first(self): - def test_fn(): + def f(): a = 1 while True: a = 2 directives.set_loop_options(parallel_iterations=10, back_prop=a) - node, ctx = self.prepare(test_fn, {'directives': directives}) with self.assertRaisesRegexp(ValueError, 'must be the first statement'): - node = directives_converter.transform(node, ctx) - - def test_invalid_default(self): - - def invalid_directive(valid_arg, invalid_default=object()): - del valid_arg - del invalid_default - return - - def call_invalid_directive(): - invalid_directive(1) - - node, _ = parser.parse_entity(call_invalid_directive, ()) - # Find the call to the invalid directive - node = node.body[0].value - with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): - directives_converter._map_args(node, invalid_directive) + self.transform(f, directives_converter, include_ast=True) def test_value_verification_does_not_trigger_properties(self): @@ -122,11 +103,11 @@ class DirectivesTest(converter_testing.TestCase): tc = TestClass() - def test_fn(): + def f(): return tc.b + 1 - node, ctx = self.prepare(test_fn, {'tc': tc}) - node = directives_converter.transform(node, ctx) + _, node, _ = self.transform(f, directives_converter, include_ast=True) + self.assertIsNotNone(node) def test_value_verification_does_not_trigger_getattr(self): @@ -143,11 +124,11 @@ class DirectivesTest(converter_testing.TestCase): tc = TestClass() - def test_fn(): + def f(): return tc.b + 1 - node, ctx = self.prepare(test_fn, {'tc': tc}) - node = directives_converter.transform(node, ctx) + _, node, _ = self.transform(f, directives_converter, include_ast=True) + self.assertIsNotNone(node) self.assertFalse(tc.getattr_called) diff --git a/tensorflow/python/autograph/converters/functions_test.py b/tensorflow/python/autograph/converters/functions_test.py index 2a51ef71ebf..f659c3fdf83 100644 --- a/tensorflow/python/autograph/converters/functions_test.py +++ b/tensorflow/python/autograph/converters/functions_test.py @@ -23,51 +23,49 @@ from tensorflow.python.autograph.converters import return_statements from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter_testing +from tensorflow.python.autograph.impl import api from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.platform import test class FunctionTransformer(converter_testing.TestCase): - @test_util.run_deprecated_v1 def test_basic(self): - def test_fn(l): + def f(l): """Docstring.""" a = 1 l += a return l - with self.converted(test_fn, functions, {}) as result: - result_op = result.test_fn(constant_op.constant(1)) - self.assertIn('test_fn/', result_op.op.name) - self.assertEqual('Docstring.', result.test_fn.__doc__) + tr = self.transform(f, functions) + + result_op = tr(constant_op.constant(1)) + self.assertIn('f/', result_op.op.name) + self.assertEqual('Docstring.', tr.__doc__) - @test_util.run_deprecated_v1 def test_multiline_docstring(self): - tf = None - - def test_fn(): + def f(): """First sentence. Second sentence. + + Returns: + Something. """ - return tf.constant(1) + return constant_op.constant(1) - with self.converted(test_fn, functions, {}, - (constant_op.constant,)) as result: - result_op = result.test_fn() - self.assertIn('test_fn/', result_op.op.name) - self.assertIn('First sentence.', result.test_fn.__doc__) - self.assertIn('Second sentence.', result.test_fn.__doc__) + tr = self.transform(f, functions) + + result_op = tr() + self.assertIn('f/', result_op.op.name) + self.assertIn('First sentence.', tr.__doc__) + self.assertIn('Second sentence.', tr.__doc__) - @test_util.run_deprecated_v1 def test_nested_functions(self): - def test_fn(l): + def f(l): def inner_fn(i): return i + 1 @@ -75,41 +73,35 @@ class FunctionTransformer(converter_testing.TestCase): l += 1 return l, inner_fn(l) - with self.converted(test_fn, (functions, return_statements), {}, - (ops.name_scope,)) as result: - first, second = result.test_fn(constant_op.constant(1)) - self.assertIn('test_fn/', first.op.name) - self.assertNotIn('inner_fn', first.op.name) - self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name) + tr = self.transform(f, (functions, return_statements)) + + first, second = tr(constant_op.constant(1)) + self.assertIn('f/', first.op.name) + self.assertNotIn('inner_fn', first.op.name) + self.assertIn('f/inner_fn/', second.op.inputs[0].name) - @test_util.run_deprecated_v1 def test_conversion_context_preserves_in_inner_functions(self): def inner_fn_callee(): self.assertEqual( ag_ctx.control_status_ctx().status, ag_ctx.Status.DISABLED) - def test_fn(): + def f(): def inner_fn(): inner_fn_callee() with ag_ctx.ControlStatusCtx( ag_ctx.Status.DISABLED, converter.ConversionOptions(recursive=True)): inner_fn() - ns = { - 'inner_fn_callee': inner_fn_callee, - 'ag_ctx': ag_ctx, - 'converter': converter - } - with self.converted(test_fn, functions, ns) as result: - result.test_fn() + tr = self.transform(f, functions) + + tr() - @test_util.run_deprecated_v1 def test_method(self): class TestClass(object): - def test_fn(self, l): + def f(self, l): def inner_fn(i): return i + 1 @@ -117,25 +109,22 @@ class FunctionTransformer(converter_testing.TestCase): l += 1 return l, inner_fn(l) - ns = {'TestClass': TestClass} - node, ctx = self.prepare(TestClass, ns) - node = functions.transform(node, ctx) - node = return_statements.transform(node, ctx) + tr = self.transform(TestClass.f, (functions, return_statements)) - with self.compiled(node, {}, (ops.name_scope,)) as result: - first, second = result.TestClass().test_fn(constant_op.constant(1)) - self.assertIn('test_fn/', first.op.name) - self.assertNotIn('inner_fn', first.op.name) - self.assertIn('test_fn/inner_fn/', second.op.inputs[0].name) + first, second = tr(TestClass(), constant_op.constant(1)) + self.assertIn('f/', first.op.name) + self.assertNotIn('inner_fn', first.op.name) + self.assertIn('f/inner_fn/', second.op.inputs[0].name) def test_lambda_in_return_value(self): - def test_fn(): + def f(): return lambda x: x + 1 - with self.converted(test_fn, functions, {}) as result: - result_l = result.test_fn() - self.assertTrue(result_l.fake_autograph_artifact) + tr = self.transform(f, functions) + + result_l = tr() + self.assertTrue(api.is_autograph_artifact(result_l)) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/list_comprehensions_test.py b/tensorflow/python/autograph/converters/list_comprehensions_test.py index 1e66139af63..7a075903673 100644 --- a/tensorflow/python/autograph/converters/list_comprehensions_test.py +++ b/tensorflow/python/autograph/converters/list_comprehensions_test.py @@ -25,36 +25,36 @@ from tensorflow.python.platform import test class ListCompTest(converter_testing.TestCase): - def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, list_comprehensions, {}) as result: - self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + def assertTransformedEquivalent(self, f, *inputs): + tr = self.transform(f, list_comprehensions) + self.assertEqual(f(*inputs), tr(*inputs)) def test_basic(self): - def test_fn(l): + def f(l): s = [e * e for e in l] return s - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [1, 2, 3]) + self.assertTransformedEquivalent(f, []) + self.assertTransformedEquivalent(f, [1, 2, 3]) def test_multiple_generators(self): - def test_fn(l): - s = [e * e for sublist in l for e in sublist] + def f(l): + s = [e * e for sublist in l for e in sublist] # pylint:disable=g-complex-comprehension return s - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [[1], [2], [3]]) + self.assertTransformedEquivalent(f, []) + self.assertTransformedEquivalent(f, [[1], [2], [3]]) def test_cond(self): - def test_fn(l): + def f(l): s = [e * e for e in l if e > 1] return s - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [1, 2, 3]) + self.assertTransformedEquivalent(f, []) + self.assertTransformedEquivalent(f, [1, 2, 3]) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py index 9436b69d749..75280730598 100644 --- a/tensorflow/python/autograph/converters/lists_test.py +++ b/tensorflow/python/autograph/converters/lists_test.py @@ -18,12 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.autograph.converters import directives as directives_converter from tensorflow.python.autograph.converters import lists from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.lang import special_functions -from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import parser from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -31,101 +30,81 @@ from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -tf = None # Will be replaced by a mock. - - class ListTest(converter_testing.TestCase): def test_empty_list(self): - def test_fn(): + def f(): return [] - with self.converted(test_fn, lists, {}) as result: - tl = result.test_fn() - # Empty tensor lists cannot be evaluated or stacked. - self.assertTrue(isinstance(tl, ops.Tensor)) - self.assertEqual(tl.dtype, dtypes.variant) + tr = self.transform(f, lists) + + tl = tr() + # Empty tensor lists cannot be evaluated or stacked. + self.assertIsInstance(tl, ops.Tensor) + self.assertEqual(tl.dtype, dtypes.variant) def test_initialized_list(self): - def test_fn(): + def f(): return [1, 2, 3] - with self.converted(test_fn, lists, {}) as result: - self.assertAllEqual(result.test_fn(), [1, 2, 3]) + tr = self.transform(f, lists) + + self.assertAllEqual(tr(), [1, 2, 3]) def test_list_append(self): - def test_fn(): + def f(): l = special_functions.tensor_list([1]) l.append(2) l.append(3) return l - ns = {'special_functions': special_functions} - with self.converted(test_fn, lists, ns) as result: - with self.cached_session() as sess: - tl = result.test_fn() - r = list_ops.tensor_list_stack(tl, dtypes.int32) - self.assertAllEqual(self.evaluate(r), [1, 2, 3]) + tr = self.transform(f, lists) + + tl = tr() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(self.evaluate(r), [1, 2, 3]) def test_list_pop(self): - def test_fn(): + def f(): l = special_functions.tensor_list([1, 2, 3]) + directives.set_element_type(l, dtype=dtypes.int32, shape=()) s = l.pop() return s, l - ns = {'special_functions': special_functions} - node, ctx = self.prepare(test_fn, ns) - def_, = anno.getanno(node.body[0].targets[0], - anno.Static.ORIG_DEFINITIONS) - def_.directives[directives.set_element_type] = { - 'dtype': parser.parse_expression('tf.int32'), - 'shape': parser.parse_expression('()'), - } - node = lists.transform(node, ctx) + tr = self.transform(f, (directives_converter, lists)) - with self.compiled(node, ns, (dtypes.int32,)) as result: - with self.cached_session() as sess: - ts, tl = result.test_fn() - r = list_ops.tensor_list_stack(tl, dtypes.int32) - self.assertAllEqual(self.evaluate(r), [1, 2]) - self.assertAllEqual(self.evaluate(ts), 3) + ts, tl = tr() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(self.evaluate(r), [1, 2]) + self.assertAllEqual(self.evaluate(ts), 3) def test_double_list_pop(self): - def test_fn(l): + def f(l): s = l.pop().pop() return s - with self.converted(test_fn, lists, {}) as result: - test_input = [1, 2, [1, 2, 3]] - # TODO(mdan): Pass a list of lists of tensor when we fully support that. - # For now, we just pass a regular Python list of lists just to verify that - # the two pop calls are sequenced properly. - self.assertAllEqual(result.test_fn(test_input), 3) + tr = self.transform(f, lists) + + test_input = [1, 2, [1, 2, 3]] + # TODO(mdan): Pass a list of lists of tensor when we fully support that. + # For now, we just pass a regular Python list of lists just to verify that + # the two pop calls are sequenced properly. + self.assertAllEqual(tr(test_input), 3) def test_list_stack(self): - def test_fn(): + def f(): l = [1, 2, 3] - return tf.stack(l) + return array_ops.stack(l) - node, ctx = self.prepare(test_fn, {}) - def_, = anno.getanno(node.body[0].targets[0], - anno.Static.ORIG_DEFINITIONS) - def_.directives[directives.set_element_type] = { - 'dtype': parser.parse_expression('tf.int32') - } - node = lists.transform(node, ctx) + tr = self.transform(f, lists) - with self.compiled(node, {}, (array_ops.stack, dtypes.int32)) as result: - with self.cached_session() as sess: - self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3]) - - # TODO(mdan): Add a test with tf.stack with axis kwarg. + self.assertAllEqual(self.evaluate(tr()), [1, 2, 3]) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py index 67ccd1fb479..d201f746fc6 100644 --- a/tensorflow/python/autograph/converters/logical_expressions_test.py +++ b/tensorflow/python/autograph/converters/logical_expressions_test.py @@ -27,62 +27,59 @@ from tensorflow.python.platform import test class LogicalExpressionTest(converter_testing.TestCase): - @test_util.run_deprecated_v1 def test_equals(self): - def test_fn(a, b): + def f(a, b): return a == b - with self.converted(test_fn, logical_expressions, {}) as result: - with self.cached_session() as sess: - self.assertTrue(sess.run(result.test_fn(constant_op.constant(1), 1))) - self.assertFalse(sess.run(result.test_fn(constant_op.constant(1), 2))) + tr = self.transform(f, logical_expressions) + + self.assertTrue(self.evaluate(tr(constant_op.constant(1), 1))) + self.assertFalse(self.evaluate(tr(constant_op.constant(1), 2))) @test_util.run_deprecated_v1 def test_bool_ops(self): - def test_fn(a, b, c): + def f(a, b, c): return (a or b) and (a or b or c) and not c - with self.converted(test_fn, logical_expressions, {}) as result: - with self.cached_session() as sess: - self.assertTrue( - sess.run(result.test_fn(constant_op.constant(True), False, False))) - self.assertFalse( - sess.run(result.test_fn(constant_op.constant(True), False, True))) + tr = self.transform(f, logical_expressions) + + self.assertTrue(self.evaluate(tr(constant_op.constant(True), False, False))) + self.assertFalse(self.evaluate(tr(constant_op.constant(True), False, True))) - @test_util.run_deprecated_v1 def test_comparison(self): - def test_fn(a, b, c, d): + def f(a, b, c, d): return a < b == c > d - with self.converted(test_fn, logical_expressions, {}) as result: - with self.cached_session() as sess: - # Note: having just the first constant a tensor tests that the - # operations execute in the correct order. If anything other than - # a < b executed first, the result would be a Python scalar and not a - # Tensor. This is valid as long as the dispat is automatic based on - # type. - self.assertTrue( - sess.run(result.test_fn(constant_op.constant(1), 2, 2, 1))) - self.assertFalse( - sess.run(result.test_fn(constant_op.constant(1), 2, 2, 3))) + tr = self.transform(f, logical_expressions) + + # Note: having just the first constant a tensor tests that the + # operations execute in the correct order. If anything other than + # a < b executed first, the result would be a Python scalar and not a + # Tensor. This is valid as long as the dispat is automatic based on + # type. + self.assertTrue(self.evaluate(tr(constant_op.constant(1), 2, 2, 1))) + self.assertFalse(self.evaluate(tr(constant_op.constant(1), 2, 2, 3))) def test_default_ops(self): - def test_fn(a, b): + def f(a, b): return a in b - with self.converted(test_fn, logical_expressions, {}) as result: - self.assertTrue(result.test_fn('a', ('a',))) + tr = self.transform(f, logical_expressions) + + self.assertTrue(tr('a', ('a',))) def test_unary_ops(self): - def test_fn(a): + + def f(a): return ~a, -a, +a - with self.converted(test_fn, logical_expressions, {}) as result: - self.assertEqual(result.test_fn(1), (-2, -1, 1)) + tr = self.transform(f, logical_expressions) + + self.assertEqual(tr(1), (-2, -1, 1)) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/loop_integration_test.py b/tensorflow/python/autograph/converters/loop_integration_test.py deleted file mode 100644 index 351eb7b92cf..00000000000 --- a/tensorflow/python/autograph/converters/loop_integration_test.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Integration Tests for loop.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.autograph.converters import break_statements -from tensorflow.python.autograph.converters import continue_statements -from tensorflow.python.autograph.converters import control_flow -from tensorflow.python.autograph.core import converter_testing -from tensorflow.python.framework import constant_op -from tensorflow.python.platform import test - - -class LoopIntegrationTest(converter_testing.TestCase): - - def assertTransformedEquivalent(self, test_fn, *inputs): - with self.converted(test_fn, - [break_statements, continue_statements, control_flow], - {}, (constant_op.constant,)) as result: - self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) - - def test_while_loop_with_else(self): - - def test_fn(x): - while x > 2: - x /= 2 - else: - x += 1 - return x - - self.assertTransformedEquivalent(test_fn, 4) - self.assertTransformedEquivalent(test_fn, 2) - - def test_while_loop_with_else_and_break(self): - - def test_fn(cond1): - x = 8 - while x > 2: - x /= 2 - if cond1: - break - else: - x += 1 - return x - - self.assertTransformedEquivalent(test_fn, True) - self.assertTransformedEquivalent(test_fn, False) - - def test_for_loop_with_else(self): - - def test_fn(l): - res = 0 - for x in l: - res += x - else: - res += 1 - return res - - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [1, 2]) - - def test_for_loop_with_else_and_break(self): - - def test_fn(flag): - l = [1, 2, 3] - res = 0 - for x in l: - res += x - if flag: - break - else: - res += 1 - return res - - self.assertTransformedEquivalent(test_fn, True) - self.assertTransformedEquivalent(test_fn, False) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py index 3f1e6a0bd97..de98d3b1b80 100644 --- a/tensorflow/python/autograph/converters/return_statements_test.py +++ b/tensorflow/python/autograph/converters/return_statements_test.py @@ -27,81 +27,80 @@ from tensorflow.python.platform import test class SingleReturnTest(converter_testing.TestCase): - def assertTransformedEquivalent(self, test_fn, *inputs): - ns = {'ops': ops} - with self.converted(test_fn, (functions, return_statements), ns) as result: - self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + def assertTransformedEquivalent(self, f, *inputs): + tr = self.transform(f, (functions, return_statements)) + self.assertEqual(f(*inputs), tr(*inputs)) def test_straightline(self): - def test_fn(x): + def f(x): return x * x - self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(f, 2) def test_superfluous_returns(self): - def test_fn(): + def f(): retval = 1 return retval retval = 2 # pylint:disable=unreachable return retval - self.assertTransformedEquivalent(test_fn) + self.assertTransformedEquivalent(f) def test_superfluous_returns_adjacent(self): - def test_fn(): + def f(): return 1 return 2 # pylint:disable=unreachable - self.assertTransformedEquivalent(test_fn) + self.assertTransformedEquivalent(f) def test_conditional(self): - def test_fn(x): + def f(x): if x > 0: return x else: return x * x - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) def test_conditional_missing_else(self): - def test_fn(x): + def f(x): if x > 0: return x - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) def test_conditional_missing_else_then_default(self): - def test_fn(x): + def f(x): if x > 0: return x return x * x - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) def test_conditional_else_only_then_default(self): - def test_fn(x): + def f(x): if x < 0: x *= x else: return x return x - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) def test_conditional_nested(self): - def test_fn(x): + def f(x): if x > 0: if x < 5: return x @@ -110,53 +109,53 @@ class SingleReturnTest(converter_testing.TestCase): else: return x * x * x - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) - self.assertTransformedEquivalent(test_fn, 5) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) + self.assertTransformedEquivalent(f, 5) def test_context_manager(self): - def test_fn(x): + def f(x): with ops.name_scope(''): return x * x - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) def test_context_manager_in_conditional(self): - def test_fn(x): + def f(x): if x > 0: with ops.name_scope(''): return x * x else: return x - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) def text_conditional_in_context_manager(self): - def test_fn(x): + def f(x): with ops.name_scope(''): if x > 0: return x * x else: return x - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) def test_no_return(self): - def test_fn(x): + def f(x): x *= x - self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(f, 2) def test_nested_function(self): - def test_fn(x): + def f(x): def inner_fn(y): if y > 0: @@ -166,33 +165,33 @@ class SingleReturnTest(converter_testing.TestCase): return inner_fn(x) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) def test_nested_function_in_control_flow(self): - def test_fn(x): + def f(x): if x: def inner_fn(y): return y inner_fn(x) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, -2) def test_for_loop(self): - def test_fn(n): + def f(n): for _ in range(n): return 1 - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, 0) def test_while_loop(self): - def test_fn(n): + def f(n): i = 0 s = 0 while i < n: @@ -202,23 +201,23 @@ class SingleReturnTest(converter_testing.TestCase): return s return -1 - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 2) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 2) + self.assertTransformedEquivalent(f, 4) def test_null_return(self): - def test_fn(n): + def f(n): if n > 4: return return - self.assertTransformedEquivalent(test_fn, 4) - self.assertTransformedEquivalent(test_fn, 5) + self.assertTransformedEquivalent(f, 4) + self.assertTransformedEquivalent(f, 5) def test_nested_multiple_withs(self): - def test_fn(x): + def f(x): v = [] while x > 0: x -= 1 @@ -230,14 +229,14 @@ class SingleReturnTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, 0) - self.assertTransformedEquivalent(test_fn, 1) - self.assertTransformedEquivalent(test_fn, 3) - self.assertTransformedEquivalent(test_fn, 4) + self.assertTransformedEquivalent(f, 0) + self.assertTransformedEquivalent(f, 1) + self.assertTransformedEquivalent(f, 3) + self.assertTransformedEquivalent(f, 4) def test_multiple_returns_in_nested_scope(self): - def test_fn(a): + def f(a): v = [] for x in a: x -= 1 @@ -250,10 +249,10 @@ class SingleReturnTest(converter_testing.TestCase): v.append(x) return v - self.assertTransformedEquivalent(test_fn, []) - self.assertTransformedEquivalent(test_fn, [1]) - self.assertTransformedEquivalent(test_fn, [2]) - self.assertTransformedEquivalent(test_fn, [1, 2, 3]) + self.assertTransformedEquivalent(f, []) + self.assertTransformedEquivalent(f, [1]) + self.assertTransformedEquivalent(f, [2]) + self.assertTransformedEquivalent(f, [1, 2, 3]) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/converters/slices_test.py b/tensorflow/python/autograph/converters/slices_test.py index 2fea1c7f81f..5a4bd6f65bd 100644 --- a/tensorflow/python/autograph/converters/slices_test.py +++ b/tensorflow/python/autograph/converters/slices_test.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.autograph.converters import directives as directives_converter from tensorflow.python.autograph.converters import slices from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.lang import directives -from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import parser from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import list_ops @@ -33,42 +32,26 @@ class SliceTest(converter_testing.TestCase): def test_index_access(self): - def test_fn(l): + def f(l): + directives.set_element_type(l, dtypes.int32) return l[1] - node, ctx = self.prepare(test_fn, {}) - def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) - def_.directives[directives.set_element_type] = { - 'dtype': parser.parse_expression('tf.int32') - } - node = slices.transform(node, ctx) + tr = self.transform(f, (directives_converter, slices)) - with self.compiled(node, {}, (dtypes.int32,)) as result: - with self.cached_session() as sess: - tl = list_ops.tensor_list_from_tensor( - [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) - y = result.test_fn(tl) - self.assertEqual(2, self.evaluate(y)) + tl = list_ops.tensor_list_from_tensor( + [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) + y = tr(tl) + self.assertEqual(2, self.evaluate(y)) def test_index_access_multiple_definitions(self): - def test_fn(l): + def f(l): + directives.set_element_type(l, dtypes.int32) if l: l = [] return l[1] - node, ctx = self.prepare(test_fn, {}) - def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) - def_.directives[directives.set_element_type] = { - 'dtype': parser.parse_expression('tf.int32') - } - def_, = anno.getanno(node.body[0].body[0].targets[0], - anno.Static.DEFINITIONS) - def_.directives[directives.set_element_type] = { - 'dtype': parser.parse_expression('tf.float32') - } - with self.assertRaises(ValueError): - slices.transform(node, ctx) + self.transform(f, (directives_converter, slices)) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/converters/variables_test.py b/tensorflow/python/autograph/converters/variables_test.py index 93a31e63de3..2e22cdcb77f 100644 --- a/tensorflow/python/autograph/converters/variables_test.py +++ b/tensorflow/python/autograph/converters/variables_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib - from tensorflow.python.autograph.converters import variables from tensorflow.python.autograph.core import converter_testing from tensorflow.python.platform import test @@ -27,60 +25,63 @@ from tensorflow.python.platform import test class VariablesTest(converter_testing.TestCase): - @contextlib.contextmanager - def apply_add_one_conversion(self, fn): + def transform_with_test_ld(self, f): """Generates code which adds 1 to all variable reads.""" - with self.converted(fn, variables, {}) as result: - result.ag__.__dict__['ld'] = lambda x: x + 1 - yield result + return self.transform(f, variables, ag_overrides={'ld': lambda x: x + 1}) def test_read(self): - def test_fn(l): + def f(l): return l - with self.apply_add_one_conversion(test_fn) as result: - self.assertEqual(result.test_fn(1), 2) + tr = self.transform_with_test_ld(f) + + self.assertEqual(tr(1), 2) def test_aug_assign(self): - def test_fn(l): + def f(l): l *= 10 return l - with self.apply_add_one_conversion(test_fn) as result: - self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads + tr = self.transform_with_test_ld(f) + + self.assertEqual(tr(1), (1 + 1) * 10 + 1) # two reads def test_del(self): - def test_fn(l): + def f(l): del l return l - with self.converted(test_fn, variables, {}) as result: - with self.assertRaisesRegex( - NameError, "'l' is used before assignment"): - result.test_fn(1) + tr = self.transform(f, variables) - def test_del_getitem_ignored(self): + with self.assertRaisesRegex(NameError, "'l' is used before assignment"): + tr(1) - def basic_slice(l): + def test_del_getitem_ignored_basic_slice(self): + + def f(l): del l[0] return l - with self.converted(basic_slice, variables, {}) as result: - self.assertListEqual([2], result.basic_slice([1, 2])) + tr = self.transform(f, variables) - def range_slice(l): + self.assertListEqual([2], tr([1, 2])) + + def test_del_getitem_ignored_range_slice(self): + + def f(l): del l[0:2] return l - with self.converted(range_slice, variables, {}) as result: - self.assertListEqual([], result.range_slice([1, 2])) + tr = self.transform(f, variables) + + self.assertListEqual([], tr([1, 2])) def test_del_getattr_ignored(self): - def test_fn(l): + def f(l): del l.a return l @@ -90,50 +91,60 @@ class VariablesTest(converter_testing.TestCase): self.a = 1 self.b = 2 - with self.converted(test_fn, variables, {}) as result: - self.assertFalse(hasattr(result.test_fn(TestClass()), 'a')) - self.assertEqual(result.test_fn(TestClass()).b, 2) + tr = self.transform(f, variables) - def test_del_packing_ignored(self): - # Note: test for UnboundLocalError, not NameError because in this case we + self.assertFalse(hasattr(tr(TestClass()), 'a')) + self.assertEqual(tr(TestClass()).b, 2) + + def test_del_packing_ignored_list(self): + # Note: testing for UnboundLocalError, not NameError because in this case we # don't rewrite the del. - def list_(a, b): + def f(a, b): del [a, b] return a - with self.converted(list_, variables, {}) as result: - with self.assertRaises(UnboundLocalError): - result.list_(1, 2) + tr = self.transform(f, variables) - def nested(a, b, c): + with self.assertRaises(UnboundLocalError): + tr(1, 2) + + def test_del_packing_ignored_nested(self): + # Note: testing for UnboundLocalError, not NameError because in this case we + # don't rewrite the del. + + def f(a, b, c): del [a, (b, c)] return c - with self.converted(nested, variables, {}) as result: - with self.assertRaises(UnboundLocalError): - result.nested(1, 2, 3) + tr = self.transform(f, variables) - def test_del_item_multiple_mixed(self): + with self.assertRaises(UnboundLocalError): + tr(1, 2, 3) - def test_fn_failing(a, b, c): + def test_del_item_multiple_mixed_used_after(self): + + def f(a, b, c): del a, b, c[0] a = 1 return a, b, c - with self.converted(test_fn_failing, variables, {}) as result: - with self.assertRaisesRegex( - NameError, "'b' is used before assignment"): - result.test_fn_failing(1, 2, [1, 2]) + tr = self.transform(f, variables) - def test_fn_passing(a, b, c): + with self.assertRaisesRegex(NameError, "'b' is used before assignment"): + tr(1, 2, [1, 2]) + + def test_del_item_multiple_mixed_unused_after(self): + + def f(a, b, c): del a, b, c[0] a = 1 b = 2 return c - with self.converted(test_fn_passing, variables, {}) as result: - self.assertListEqual([2], result.test_fn_passing(1, 2, [1, 2])) + tr = self.transform(f, variables) + + self.assertListEqual([2], tr(1, 2, [1, 2])) def test_attribute(self): @@ -146,12 +157,13 @@ class VariablesTest(converter_testing.TestCase): self.v += other return self - def test_fn(l): + def f(l): return l.v tc = TestClass() - with self.apply_add_one_conversion(test_fn) as result: - self.assertEqual(result.test_fn(tc), 2) + tr = self.transform_with_test_ld(f) + + self.assertEqual(tr(tc), 2) def test_subscript(self): @@ -167,12 +179,13 @@ class VariablesTest(converter_testing.TestCase): def __getitem__(self, _): return self.v - def test_fn(l): + def f(l): return l[0] tc = TestClass() - with self.apply_add_one_conversion(test_fn) as result: - self.assertEqual(result.test_fn(tc), 2) + tr = self.transform_with_test_ld(f) + + self.assertEqual(tr(tc), 2) def test_call(self): @@ -188,12 +201,13 @@ class VariablesTest(converter_testing.TestCase): def __call__(self): return self.v - def test_fn(l): + def f(l): return l() tc = TestClass() - with self.apply_add_one_conversion(test_fn) as result: - self.assertEqual(result.test_fn(tc), 2) + tr = self.transform_with_test_ld(f) + + self.assertEqual(tr(tc), 2) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/core/converter_test.py b/tensorflow/python/autograph/core/converter_test.py index 030ec761d95..f2533762c8c 100644 --- a/tensorflow/python/autograph/core/converter_test.py +++ b/tensorflow/python/autograph/core/converter_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import imp + from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter_testing from tensorflow.python.autograph.pyct import anno @@ -38,16 +40,18 @@ class ConversionOptionsTest(converter_testing.TestCase): opts_ast = opts.to_ast() template = ''' - def test_fn(): + def f(): return opts_ast ''' opts_packed = templates.replace(template, opts_ast=opts_ast) reparsed, _, _ = loader.load_ast(opts_packed) - reparsed.__dict__['ag__'] = self.make_fake_mod( - 'fake_ag', converter.ConversionOptions, converter.Feature) + fake_ag = imp.new_module('fake_ag') + fake_ag.ConversionOptions = converter.ConversionOptions + fake_ag.Feature = converter.Feature + reparsed.ag__ = fake_ag - reparsed_opts = reparsed.test_fn() + reparsed_opts = reparsed.f() self.assertEqual(opts.recursive, reparsed_opts.recursive) self.assertEqual(opts.user_requested, False) @@ -63,12 +67,12 @@ class ConverterBaseTest(converter_testing.TestCase): directive_key = object - def test_fn(): + def f(): a = 1 return a - ns = {} - node, ctx = self.prepare(test_fn, ns) + _, node, ctx = self.transform(f, (), include_ast=True) + symbol_a = node.body[1].value defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) defs.directives[directive_key] = { @@ -84,12 +88,12 @@ class ConverterBaseTest(converter_testing.TestCase): directive_key = object - def test_fn(): + def f(): a = 1 return a - ns = {} - node, ctx = self.prepare(test_fn, ns) + _, node, ctx = self.transform(f, (), include_ast=True) + symbol_a = node.body[1].value c = TestConverter(ctx) value = c.get_definition_directive(symbol_a, directive_key, 'test_arg', @@ -100,14 +104,14 @@ class ConverterBaseTest(converter_testing.TestCase): directive_key = object - def test_fn(): + def f(): a = 1 if a: a = 2 return a - ns = {} - node, ctx = self.prepare(test_fn, ns) + _, node, ctx = self.transform(f, (), include_ast=True) + symbol_a = node.body[2].value defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) defs[0].directives[directive_key] = { @@ -127,14 +131,14 @@ class ConverterBaseTest(converter_testing.TestCase): directive_key = object - def test_fn(): + def f(): a = 1 if a: a = 2 return a - ns = {} - node, ctx = self.prepare(test_fn, ns) + _, node, ctx = self.transform(f, (), include_ast=True) + symbol_a = node.body[2].value defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) defs[0].directives[directive_key] = { diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index fbb031876ad..22e06000906 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -25,27 +25,15 @@ import sys import six -from tensorflow.python.autograph import operators -from tensorflow.python.autograph import utils from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.core import function_wrappers -from tensorflow.python.autograph.lang import special_functions -from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import cfg -from tensorflow.python.autograph.pyct import loader -from tensorflow.python.autograph.pyct import naming -from tensorflow.python.autograph.pyct import origin_info -from tensorflow.python.autograph.pyct import parser -from tensorflow.python.autograph.pyct import pretty_printer -from tensorflow.python.autograph.pyct import qual_names -from tensorflow.python.autograph.pyct import transformer -from tensorflow.python.autograph.pyct.static_analysis import activity -from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.python.autograph.impl import api +from tensorflow.python.autograph.impl import conversion +from tensorflow.python.framework import ops from tensorflow.python.platform import test -def whitelist(entity): +def whitelist(f): """Helper that marks a callable as whtelitisted.""" if 'whitelisted_module_for_testing' not in sys.modules: whitelisted_mod = imp.new_module('whitelisted_module_for_testing') @@ -54,7 +42,7 @@ def whitelist(entity): (config.DoNotConvert('whitelisted_module_for_testing'),) + config.CONVERSION_RULES) - entity.__module__ = 'whitelisted_module_for_testing' + f.__module__ = 'whitelisted_module_for_testing' def is_inside_generated_code(): @@ -76,9 +64,39 @@ def is_inside_generated_code(): del frame +class TestingTranspiler(conversion.AutoGraphTranspiler): + """Testing version that only applies given transformations.""" + + def __init__(self, converters): + super(TestingTranspiler, self).__init__() + if isinstance(converters, (list, tuple)): + self._converters = converters + else: + self._converters = (converters,) + self.transformed_ast = None + + def transform_ast(self, node, ctx): + node = self.initial_analysis(node, ctx) + + for c in self._converters: + node = c.transform(node, ctx) + + self.transformed_ast = node + self.transform_ctx = ctx + return node + + class TestCase(test.TestCase): """Base class for unit tests in this module. Contains relevant utilities.""" + def setUp(self): + # AutoGraph tests must run in graph mode to properly test control flow. + self.graph = ops.Graph().as_default() + self.graph.__enter__() + + def tearDown(self): + self.graph.__exit__(None, None, None) + @contextlib.contextmanager def assertPrints(self, expected_result): try: @@ -89,108 +107,26 @@ class TestCase(test.TestCase): finally: sys.stdout = sys.__stdout__ - @contextlib.contextmanager - def compiled(self, node, namespace, symbols=()): - source = None - - self.dynamic_calls = [] - # See api.converted_call - def converted_call( - f, args, kwargs, unused_opts=None, unused_function_ctx=None): - """Mock version of api.converted_call.""" - self.dynamic_calls.append((args, kwargs)) - if kwargs is None: - kwargs = {} - return f(*args, **kwargs) - - def fake_autograph_artifact(f): - setattr(f, 'fake_autograph_artifact', True) - return f - - try: - result, source, source_map = loader.load_ast( - node, include_source_map=True) - # TODO(mdan): Move the unparsing from converter into pyct and reuse here. - - # TODO(mdan): Move this into self.prepare() - result.tf = self.make_fake_mod('fake_tf', *symbols) - fake_ag = self.make_fake_mod('fake_ag', converted_call, - converter.ConversionOptions) - fake_ag.__dict__.update(operators.__dict__) - fake_ag.__dict__.update(special_functions.__dict__) - fake_ag.ConversionOptions = converter.ConversionOptions - fake_ag.Feature = converter.Feature - fake_ag.utils = utils - fake_ag.FunctionScope = function_wrappers.FunctionScope - fake_ag.autograph_artifact = fake_autograph_artifact - result.ag__ = fake_ag - result.ag_source_map__ = source_map - for k, v in namespace.items(): - result.__dict__[k] = v - yield result - except Exception: # pylint:disable=broad-except - if source is None: - print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) - else: - print('Offending source code:\n%s' % source) - raise - - @contextlib.contextmanager - def converted(self, entity, converter_module, namespace, tf_symbols=()): - - node, ctx = self.prepare(entity, namespace) - - if not isinstance(converter_module, (list, tuple)): - converter_module = (converter_module,) - for m in converter_module: - node = m.transform(node, ctx) - - with self.compiled(node, namespace, tf_symbols) as result: - yield result - - def make_fake_mod(self, name, *symbols): - fake_mod = imp.new_module(name) - for s in symbols: - if hasattr(s, '__name__'): - setattr(fake_mod, s.__name__, s) - elif hasattr(s, 'name'): - # This is a bit of a hack, but works for things like tf.int32 - setattr(fake_mod, s.name, s) - else: - raise ValueError('can not attach %s - what should be its name?' % s) - return fake_mod - - def attach_namespace(self, module, **ns): - for k, v in ns.items(): - setattr(module, k, v) - - def prepare(self, test_fn, namespace, recursive=True): - namespace['ConversionOptions'] = converter.ConversionOptions - - future_features = ('print_function', 'division') - node, source = parser.parse_entity(test_fn, future_features=future_features) - namer = naming.Namer(namespace) + def transform( + self, f, converter_module, include_ast=False, ag_overrides=None): program_ctx = converter.ProgramContext( - options=converter.ConversionOptions(recursive=recursive), - autograph_module=None) - entity_info = transformer.EntityInfo( - name=test_fn.__name__, - source_code=source, - source_file='', - future_features=future_features, - namespace=namespace) - ctx = transformer.Context(entity_info, namer, program_ctx) - origin_info.resolve_entity(node, source, test_fn) + options=converter.ConversionOptions(recursive=True), + autograph_module=api) - graphs = cfg.build(node) - node = qual_names.resolve(node) - node = activity.resolve(node, ctx, None) - node = reaching_definitions.resolve(node, ctx, graphs) - anno.dup( - node, - { - anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, - }, - ) + conversion.create_custom_vars(program_ctx) + custom_vars = dict(conversion.custom_vars) - return node, ctx + if ag_overrides: + modified_ag = imp.new_module('fake_autograph') + modified_ag.__dict__.update(custom_vars['ag__'].__dict__) + modified_ag.__dict__.update(ag_overrides) + custom_vars['ag__'] = modified_ag + + tr = TestingTranspiler(converter_module) + transformed, _, _ = tr.transform_function( + f, program_ctx.options, program_ctx, custom_vars) + + if include_ast: + return transformed, tr.transformed_ast, tr.transform_ctx + + return transformed diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index eeea0aef896..4d5ddeebcc1 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -60,11 +60,7 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler): def get_transformed_name(self, node): return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node) - def transform_ast(self, node, ctx): - # TODO(mdan): Insert list_comprehensions somewhere. - unsupported_features_checker.verify(node) - - # Run initial analysis. + def initial_analysis(self, node, ctx): graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) @@ -75,6 +71,11 @@ class AutoGraphTranspiler(transpiler.FunctionTranspiler): anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, }, ) + return node + + def transform_ast(self, node, ctx): + unsupported_features_checker.verify(node) + node = self.initial_analysis(node, ctx) node = functions.transform(node, ctx) node = directives.transform(node, ctx) @@ -114,7 +115,7 @@ def convert(entity, program_ctx): 'expose a __code__ object. If this is a @tf.function,' ' try passing f.python_function instead.') - _create_custom_vars(program_ctx) + create_custom_vars(program_ctx) transformed, module, source_map = _TRANSPILER.transform_function( entity, program_ctx.options, program_ctx, custom_vars) @@ -248,7 +249,8 @@ def cache_whitelisted(entity, options): # TODO(mdan): Move into core or replace with an actual importable module. -def _create_custom_vars(program_ctx): +# Visible for testing. +def create_custom_vars(program_ctx): """Adds namespace references to the module that exposes the api itself.""" global custom_vars if custom_vars is None: