diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py index 9ae448892a0..061b63f9d10 100644 --- a/tensorflow/python/autograph/converters/asserts_test.py +++ b/tensorflow/python/autograph/converters/asserts_test.py @@ -38,7 +38,7 @@ class AssertsTest(converter_testing.TestCase): return tf.no_op() # pylint:disable=undefined-variable with self.converted(test_fn, (asserts, side_effect_guards), {}, - gen_control_flow_ops.no_op) as result: + (gen_control_flow_ops.no_op,)) as result: with self.cached_session() as sess: op = result.test_fn(constant_op.constant(False)) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, diff --git a/tensorflow/python/autograph/converters/break_statements_test.py b/tensorflow/python/autograph/converters/break_statements_test.py index 816d3bb1b65..c789ced095d 100644 --- a/tensorflow/python/autograph/converters/break_statements_test.py +++ b/tensorflow/python/autograph/converters/break_statements_test.py @@ -28,7 +28,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase): def assertTransformedEquivalent(self, test_fn, *inputs): with self.converted(test_fn, break_statements, {}, - constant_op.constant) as result: + (constant_op.constant,)) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_while_loop(self): @@ -58,7 +58,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase): return v with self.converted(test_fn, break_statements, {}, - constant_op.constant) as result: + (constant_op.constant,)) as result: # The break is incompletely canonicalized. The loop will not interrupt, # but the section following the break will be skipped. self.assertEqual([3], result.test_fn([5, 4])) diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 657d880620f..52e6af52b6f 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -71,24 +71,26 @@ class CallTreeTransformer(converter.Base): return node def visit_Call(self, node): + full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) + node = self.generic_visit(node) + # TODO(mdan): Refactor converted_call as a 'Call' operator. # Calls to the internal 'ag__' module are never converted (though their # arguments might be). - full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) if full_name.startswith('ag__.'): - return self.generic_visit(node) + return node # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use # the normal mechanisms to bypass these literals because they are sensitive # to the frame they are being called from. # TODO(mdan): Generalize this to a "static whitelist" config. if full_name in ('pdb.set_trace', 'ipdb.set_trace'): - return self.generic_visit(node) + return node if (full_name == 'print' and not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): - return self.generic_visit(node) + return node func = node.func @@ -99,7 +101,6 @@ class CallTreeTransformer(converter.Base): assert starred_arg is None, 'Multiple *args should be impossible.' starred_arg = a else: - a = self.visit(a) normal_args.append(a) if starred_arg is None: args = templates.replace_as_expression('(args,)', args=normal_args) @@ -116,7 +117,6 @@ class CallTreeTransformer(converter.Base): assert kwargs_arg is None, 'Multiple **kwargs should be impossible.' kwargs_arg = k else: - k = self.visit(k) normal_keywords.append(k) if kwargs_arg is None: if not normal_keywords: diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py index d61908fc8e8..b77248b8711 100644 --- a/tensorflow/python/autograph/converters/call_trees_test.py +++ b/tensorflow/python/autograph/converters/call_trees_test.py @@ -30,52 +30,62 @@ class CallTreesTest(converter_testing.TestCase): def test_normal_function(self): def test_fn(f): - return f() + 3 + return f() + 20 with self.converted(test_fn, call_trees, {}) as result: - self.assertEqual( - result.test_fn(None), - converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3) + self.assertEqual(result.test_fn(lambda: 1), 21) self.assertListEqual(self.dynamic_calls, [((), None)]) def test_function_with_expression_in_argument(self): def test_fn(f, g): - return f(g() + 7) + 3 + return f(g() + 20) + 4000 with self.converted(test_fn, call_trees, {}) as result: - self.assertEqual( - result.test_fn(None, None), - converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3) + self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321) self.assertListEqual(self.dynamic_calls, [ ((), None), - ((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 7,), None), + ((21,), None), ]) def test_function_with_call_in_argument(self): def test_fn(f, g): - return f(g()) + 3 + return f(g()) + 300 with self.converted(test_fn, call_trees, {}) as result: - self.assertEqual( - result.test_fn(None, None), - converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3) + self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321) self.assertListEqual(self.dynamic_calls, [ ((), None), - ((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL,), None), + ((1,), None), + ]) + + def test_function_chaining(self): + + def get_one(): + return 1 + + def test_fn(): + return get_one().__add__(20) + + with self.converted(test_fn, call_trees, {'get_one': get_one}, + ()) as result: + + self.assertEqual(result.test_fn(), 21) + + self.assertListEqual(self.dynamic_calls, [ + ((), None), + ((20,), None), ]) def test_function_with_kwarg(self): def test_fn(f, a, b): - return f(a, c=b) + 3 + return f(a, c=b) + 300 with self.converted(test_fn, call_trees, {}) as result: - self.assertEqual( - result.test_fn(None, 1, 2), - converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3) - self.assertListEqual(self.dynamic_calls, [((1,), {'c': 2})]) + self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321) + self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})]) def test_function_with_kwargs_starargs(self): @@ -84,25 +94,24 @@ class CallTreesTest(converter_testing.TestCase): with self.converted(test_fn, call_trees, {}) as result: self.assertEqual( - result.test_fn(None, 1, *[2, 3], **{ + result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{ 'b': 4, 'c': 5 - }), converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5) + }), 12) self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})]) def test_function_with_kwargs_starargs_only(self): - def f(*unused_args): # Will not be called. - pass + def f(*args): + return sum(args) def test_fn(): - args = [1, 2, 3] - return f(*args) + 11 + args = [1, 20, 300] + return f(*args) + 4000 with self.converted(test_fn, call_trees, {'f': f}) as result: - self.assertEqual(result.test_fn(), - converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 11) - self.assertListEqual(self.dynamic_calls, [((1, 2, 3), None)]) + self.assertEqual(result.test_fn(), 4321) + self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)]) def test_function_with_kwargs_keywords(self): @@ -111,8 +120,7 @@ class CallTreesTest(converter_testing.TestCase): with self.converted(test_fn, call_trees, {}) as result: self.assertEqual( - result.test_fn(None, 1, 2, **{'c': 3}), - converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5) + result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12) self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})]) def test_debugger_set_trace(self): @@ -133,32 +141,30 @@ class CallTreesTest(converter_testing.TestCase): class TestClass(object): - def other_method(self, _): - raise ValueError('this should not be called') + def other_method(self, x): + return x + 20 def test_method(self, a): - return self.other_method(a) + 1 + return self.other_method(a) + 300 tc = TestClass() with self.converted(TestClass.test_method, call_trees, {}) as result: - self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1, - result.test_method(tc, 1)) + self.assertEqual(321, result.test_method(tc, 1)) self.assertListEqual(self.dynamic_calls, [((1,), None)]) def test_object_method(self): class TestClass(object): - def other_method(self, _): - raise ValueError('this should not be called') + def other_method(self, x): + return x + 20 def test_method(self, a): - return self.other_method(a) + 1 + return self.other_method(a) + 300 tc = TestClass() with self.converted(tc.test_method, call_trees, {}) as result: - self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1, - result.test_method(tc, 1)) + self.assertEqual(321, result.test_method(tc, 1)) self.assertListEqual(self.dynamic_calls, [((1,), None)]) diff --git a/tensorflow/python/autograph/converters/continue_statements_test.py b/tensorflow/python/autograph/converters/continue_statements_test.py index 97a975b1698..a24ddd5e527 100644 --- a/tensorflow/python/autograph/converters/continue_statements_test.py +++ b/tensorflow/python/autograph/converters/continue_statements_test.py @@ -29,7 +29,7 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): def assertTransformedEquivalent(self, test_fn, *inputs): with self.converted(test_fn, continue_statements, {'ops': ops}, - constant_op.constant) as result: + (constant_op.constant,)) as result: self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_basic(self): diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py index 4690b114a77..e1ba82043bc 100644 --- a/tensorflow/python/autograph/converters/control_flow_test.py +++ b/tensorflow/python/autograph/converters/control_flow_test.py @@ -39,7 +39,7 @@ class ControlFlowTest(converter_testing.TestCase): if not symbols: symbols = {} with self.converted(test_fn, control_flow, symbols, - constant_op.constant) as result: + (constant_op.constant,)) as result: self.assertAllEqual(self.evaluate(result.test_fn(*inputs)), expected) @test_util.run_deprecated_v1 diff --git a/tensorflow/python/autograph/converters/function_scopes_test.py b/tensorflow/python/autograph/converters/function_scopes_test.py index 0eccf39db7d..f973687e8bb 100644 --- a/tensorflow/python/autograph/converters/function_scopes_test.py +++ b/tensorflow/python/autograph/converters/function_scopes_test.py @@ -55,7 +55,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): return tf.constant(1) with self.converted(test_fn, function_scopes, {}, - constant_op.constant) as result: + (constant_op.constant,)) as result: result_op = result.test_fn() self.assertIn('test_fn/', result_op.op.name) self.assertIn('First sentence.', result.test_fn.__doc__) @@ -72,7 +72,8 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): l += 1 return l, inner_fn(l) - with self.converted(test_fn, function_scopes, {}, ops.name_scope) as result: + with self.converted(test_fn, function_scopes, {}, + (ops.name_scope,)) as result: first, second = result.test_fn(constant_op.constant(1)) self.assertIn('test_fn/', first.op.name) self.assertNotIn('inner_fn', first.op.name) @@ -95,7 +96,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase): node, ctx = self.prepare(TestClass, ns) node = function_scopes.transform(node, ctx) - with self.compiled(node, {}, ops.name_scope) as result: + with self.compiled(node, {}, (ops.name_scope,)) as result: first, second = result.TestClass().test_fn(constant_op.constant(1)) self.assertIn('TestClass/test_fn/', first.op.name) self.assertNotIn('inner_fn', first.op.name) diff --git a/tensorflow/python/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py index 39843c7d74f..9436b69d749 100644 --- a/tensorflow/python/autograph/converters/lists_test.py +++ b/tensorflow/python/autograph/converters/lists_test.py @@ -87,7 +87,7 @@ class ListTest(converter_testing.TestCase): } node = lists.transform(node, ctx) - with self.compiled(node, ns, dtypes.int32) as result: + with self.compiled(node, ns, (dtypes.int32,)) as result: with self.cached_session() as sess: ts, tl = result.test_fn() r = list_ops.tensor_list_stack(tl, dtypes.int32) @@ -121,7 +121,7 @@ class ListTest(converter_testing.TestCase): } node = lists.transform(node, ctx) - with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result: + with self.compiled(node, {}, (array_ops.stack, dtypes.int32)) as result: with self.cached_session() as sess: self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3]) diff --git a/tensorflow/python/autograph/converters/side_effect_guards_test.py b/tensorflow/python/autograph/converters/side_effect_guards_test.py index 645267e5600..ead05d041aa 100644 --- a/tensorflow/python/autograph/converters/side_effect_guards_test.py +++ b/tensorflow/python/autograph/converters/side_effect_guards_test.py @@ -47,7 +47,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): self.assertEqual(len(node.body), 1) - with self.compiled(node, {}, state_ops.assign) as result: + with self.compiled(node, {}, (state_ops.assign,)) as result: with self.cached_session() as sess: v = variable_scope.get_variable('test', initializer=2) self.evaluate(v.initializer) @@ -68,7 +68,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): self.assertEqual(len(node.body), 1) - with self.compiled(node, {}, state_ops.assign) as result: + with self.compiled(node, {}, (state_ops.assign,)) as result: with self.cached_session() as sess: v = variable_scope.get_variable('test', initializer=2) self.evaluate(v.initializer) @@ -89,7 +89,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): self.assertEqual(len(node.body), 1) - with self.compiled(node, {}, control_flow_ops.Assert) as result: + with self.compiled(node, {}, (control_flow_ops.Assert,)) as result: with self.cached_session() as sess: with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'expected in throw'): @@ -109,7 +109,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): self.assertEqual(len(node.body), 1) - with self.compiled(node, {}, state_ops.assign_add) as result: + with self.compiled(node, {}, (state_ops.assign_add,)) as result: with self.cached_session() as sess: v = variable_scope.get_variable('test', initializer=2) self.evaluate(v.initializer) @@ -130,7 +130,7 @@ class SideEffectGuardsTest(converter_testing.TestCase): self.assertEqual(len(node.body[0].body), 1) - with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result: + with self.compiled(node, {}, (state_ops.assign, ops.name_scope)) as result: with self.cached_session() as sess: v = variable_scope.get_variable('test', initializer=2) self.evaluate(v.initializer) @@ -152,8 +152,8 @@ class SideEffectGuardsTest(converter_testing.TestCase): self.assertEqual(len(node.body), 1) - with self.compiled(node, {}, state_ops.assign, - state_ops.assign_add) as result: + with self.compiled(node, {}, + (state_ops.assign, state_ops.assign_add)) as result: with self.cached_session() as sess: v = variable_scope.get_variable('test', initializer=2) self.evaluate(v.initializer) diff --git a/tensorflow/python/autograph/converters/slices_test.py b/tensorflow/python/autograph/converters/slices_test.py index 11e3736d4fb..2fea1c7f81f 100644 --- a/tensorflow/python/autograph/converters/slices_test.py +++ b/tensorflow/python/autograph/converters/slices_test.py @@ -43,7 +43,7 @@ class SliceTest(converter_testing.TestCase): } node = slices.transform(node, ctx) - with self.compiled(node, {}, dtypes.int32) as result: + with self.compiled(node, {}, (dtypes.int32,)) as result: with self.cached_session() as sess: tl = list_ops.tensor_list_from_tensor( [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index bb2ed38fbbb..507739fdbc2 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -37,8 +37,6 @@ from tensorflow.python.autograph.pyct import pretty_printer from tensorflow.python.autograph.pyct import transformer from tensorflow.python.platform import test -RESULT_OF_MOCK_CONVERTED_CALL = 7 - class TestCase(test.TestCase): """Base class for unit tests in this module. Contains relevant utilities.""" @@ -54,15 +52,17 @@ class TestCase(test.TestCase): sys.stdout = sys.__stdout__ @contextlib.contextmanager - def compiled(self, node, namespace, *symbols): + def compiled(self, node, namespace, symbols=()): source = None self.dynamic_calls = [] # See api.converted_call - def converted_call(unused_f, unused_opts, args, kwargs): + def converted_call(f, unused_opts, args, kwargs): """Mock version of api.converted_call.""" self.dynamic_calls.append((args, kwargs)) - return RESULT_OF_MOCK_CONVERTED_CALL + if kwargs is None: + kwargs = {} + return f(*args, **kwargs) try: result, source, source_map = compiler.ast_to_object( @@ -92,7 +92,8 @@ class TestCase(test.TestCase): raise @contextlib.contextmanager - def converted(self, entity, converter_module, namespace, *tf_symbols): + def converted(self, entity, converter_module, namespace, tf_symbols=()): + node, ctx = self.prepare(entity, namespace) if not isinstance(converter_module, (list, tuple)): @@ -101,7 +102,7 @@ class TestCase(test.TestCase): node = converter.standard_analysis(node, ctx, is_initial=not i) node = m.transform(node, ctx) - with self.compiled(node, namespace, *tf_symbols) as result: + with self.compiled(node, namespace, tf_symbols) as result: yield result def make_fake_mod(self, name, *symbols):