From 8dc62ccf8218dd2d6d8e1757ef63e7c360d35b4d Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 23 Jul 2019 14:11:01 -0700
Subject: [PATCH] Autograph: Fix chained function conversion

Chained functions were not correctly converted. For example,
`foo().bar().baz()` only converted baz.  Now fixed.

PiperOrigin-RevId: 259608163
---
 .../autograph/converters/asserts_test.py      |  2 +-
 .../converters/break_statements_test.py       |  4 +-
 .../python/autograph/converters/call_trees.py | 12 +--
 .../autograph/converters/call_trees_test.py   | 86 ++++++++++---------
 .../converters/continue_statements_test.py    |  2 +-
 .../autograph/converters/control_flow_test.py |  2 +-
 .../converters/function_scopes_test.py        |  7 +-
 .../python/autograph/converters/lists_test.py |  4 +-
 .../converters/side_effect_guards_test.py     | 14 +--
 .../autograph/converters/slices_test.py       |  2 +-
 .../autograph/core/converter_testing.py       | 15 ++--
 11 files changed, 79 insertions(+), 71 deletions(-)

diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py
index 9ae448892a0..061b63f9d10 100644
--- a/tensorflow/python/autograph/converters/asserts_test.py
+++ b/tensorflow/python/autograph/converters/asserts_test.py
@@ -38,7 +38,7 @@ class AssertsTest(converter_testing.TestCase):
       return tf.no_op()  # pylint:disable=undefined-variable
 
     with self.converted(test_fn, (asserts, side_effect_guards), {},
-                        gen_control_flow_ops.no_op) as result:
+                        (gen_control_flow_ops.no_op,)) as result:
       with self.cached_session() as sess:
         op = result.test_fn(constant_op.constant(False))
         with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
diff --git a/tensorflow/python/autograph/converters/break_statements_test.py b/tensorflow/python/autograph/converters/break_statements_test.py
index 816d3bb1b65..c789ced095d 100644
--- a/tensorflow/python/autograph/converters/break_statements_test.py
+++ b/tensorflow/python/autograph/converters/break_statements_test.py
@@ -28,7 +28,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
 
   def assertTransformedEquivalent(self, test_fn, *inputs):
     with self.converted(test_fn, break_statements, {},
-                        constant_op.constant) as result:
+                        (constant_op.constant,)) as result:
       self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
 
   def test_while_loop(self):
@@ -58,7 +58,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
       return v
 
     with self.converted(test_fn, break_statements, {},
-                        constant_op.constant) as result:
+                        (constant_op.constant,)) as result:
       # The break is incompletely canonicalized. The loop will not interrupt,
       # but the section following the break will be skipped.
       self.assertEqual([3], result.test_fn([5, 4]))
diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
index 657d880620f..52e6af52b6f 100644
--- a/tensorflow/python/autograph/converters/call_trees.py
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -71,24 +71,26 @@ class CallTreeTransformer(converter.Base):
     return node
 
   def visit_Call(self, node):
+    full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
+    node = self.generic_visit(node)
+
     # TODO(mdan): Refactor converted_call as a 'Call' operator.
 
     # Calls to the internal 'ag__' module are never converted (though their
     # arguments might be).
-    full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
     if full_name.startswith('ag__.'):
-      return self.generic_visit(node)
+      return node
 
     # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
     # the normal mechanisms to bypass these literals because they are sensitive
     # to the frame they are being called from.
     # TODO(mdan): Generalize this to a "static whitelist" config.
     if full_name in ('pdb.set_trace', 'ipdb.set_trace'):
-      return self.generic_visit(node)
+      return node
 
     if (full_name == 'print' and
         not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
-      return self.generic_visit(node)
+      return node
 
     func = node.func
 
@@ -99,7 +101,6 @@ class CallTreeTransformer(converter.Base):
         assert starred_arg is None, 'Multiple *args should be impossible.'
         starred_arg = a
       else:
-        a = self.visit(a)
         normal_args.append(a)
     if starred_arg is None:
       args = templates.replace_as_expression('(args,)', args=normal_args)
@@ -116,7 +117,6 @@ class CallTreeTransformer(converter.Base):
         assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
         kwargs_arg = k
       else:
-        k = self.visit(k)
         normal_keywords.append(k)
     if kwargs_arg is None:
       if not normal_keywords:
diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py
index d61908fc8e8..b77248b8711 100644
--- a/tensorflow/python/autograph/converters/call_trees_test.py
+++ b/tensorflow/python/autograph/converters/call_trees_test.py
@@ -30,52 +30,62 @@ class CallTreesTest(converter_testing.TestCase):
   def test_normal_function(self):
 
     def test_fn(f):
-      return f() + 3
+      return f() + 20
 
     with self.converted(test_fn, call_trees, {}) as result:
-      self.assertEqual(
-          result.test_fn(None),
-          converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
+      self.assertEqual(result.test_fn(lambda: 1), 21)
       self.assertListEqual(self.dynamic_calls, [((), None)])
 
   def test_function_with_expression_in_argument(self):
 
     def test_fn(f, g):
-      return f(g() + 7) + 3
+      return f(g() + 20) + 4000
 
     with self.converted(test_fn, call_trees, {}) as result:
-      self.assertEqual(
-          result.test_fn(None, None),
-          converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
+      self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
       self.assertListEqual(self.dynamic_calls, [
           ((), None),
-          ((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 7,), None),
+          ((21,), None),
       ])
 
   def test_function_with_call_in_argument(self):
 
     def test_fn(f, g):
-      return f(g()) + 3
+      return f(g()) + 300
 
     with self.converted(test_fn, call_trees, {}) as result:
-      self.assertEqual(
-          result.test_fn(None, None),
-          converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
+      self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
       self.assertListEqual(self.dynamic_calls, [
           ((), None),
-          ((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL,), None),
+          ((1,), None),
+      ])
+
+  def test_function_chaining(self):
+
+    def get_one():
+      return 1
+
+    def test_fn():
+      return get_one().__add__(20)
+
+    with self.converted(test_fn, call_trees, {'get_one': get_one},
+                        ()) as result:
+
+      self.assertEqual(result.test_fn(), 21)
+
+      self.assertListEqual(self.dynamic_calls, [
+          ((), None),
+          ((20,), None),
       ])
 
   def test_function_with_kwarg(self):
 
     def test_fn(f, a, b):
-      return f(a, c=b) + 3
+      return f(a, c=b) + 300
 
     with self.converted(test_fn, call_trees, {}) as result:
-      self.assertEqual(
-          result.test_fn(None, 1, 2),
-          converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
-      self.assertListEqual(self.dynamic_calls, [((1,), {'c': 2})])
+      self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
+      self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})])
 
   def test_function_with_kwargs_starargs(self):
 
@@ -84,25 +94,24 @@ class CallTreesTest(converter_testing.TestCase):
 
     with self.converted(test_fn, call_trees, {}) as result:
       self.assertEqual(
-          result.test_fn(None, 1, *[2, 3], **{
+          result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{
               'b': 4,
               'c': 5
-          }), converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5)
+          }), 12)
       self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
 
   def test_function_with_kwargs_starargs_only(self):
 
-    def f(*unused_args):  # Will not be called.
-      pass
+    def f(*args):
+      return sum(args)
 
     def test_fn():
-      args = [1, 2, 3]
-      return f(*args) + 11
+      args = [1, 20, 300]
+      return f(*args) + 4000
 
     with self.converted(test_fn, call_trees, {'f': f}) as result:
-      self.assertEqual(result.test_fn(),
-                       converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 11)
-      self.assertListEqual(self.dynamic_calls, [((1, 2, 3), None)])
+      self.assertEqual(result.test_fn(), 4321)
+      self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
 
   def test_function_with_kwargs_keywords(self):
 
@@ -111,8 +120,7 @@ class CallTreesTest(converter_testing.TestCase):
 
     with self.converted(test_fn, call_trees, {}) as result:
       self.assertEqual(
-          result.test_fn(None, 1, 2, **{'c': 3}),
-          converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5)
+          result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
       self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
 
   def test_debugger_set_trace(self):
@@ -133,32 +141,30 @@ class CallTreesTest(converter_testing.TestCase):
 
     class TestClass(object):
 
-      def other_method(self, _):
-        raise ValueError('this should not be called')
+      def other_method(self, x):
+        return x + 20
 
       def test_method(self, a):
-        return self.other_method(a) + 1
+        return self.other_method(a) + 300
 
     tc = TestClass()
     with self.converted(TestClass.test_method, call_trees, {}) as result:
-      self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
-                       result.test_method(tc, 1))
+      self.assertEqual(321, result.test_method(tc, 1))
       self.assertListEqual(self.dynamic_calls, [((1,), None)])
 
   def test_object_method(self):
 
     class TestClass(object):
 
-      def other_method(self, _):
-        raise ValueError('this should not be called')
+      def other_method(self, x):
+        return x + 20
 
       def test_method(self, a):
-        return self.other_method(a) + 1
+        return self.other_method(a) + 300
 
     tc = TestClass()
     with self.converted(tc.test_method, call_trees, {}) as result:
-      self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
-                       result.test_method(tc, 1))
+      self.assertEqual(321, result.test_method(tc, 1))
       self.assertListEqual(self.dynamic_calls, [((1,), None)])
 
 
diff --git a/tensorflow/python/autograph/converters/continue_statements_test.py b/tensorflow/python/autograph/converters/continue_statements_test.py
index 97a975b1698..a24ddd5e527 100644
--- a/tensorflow/python/autograph/converters/continue_statements_test.py
+++ b/tensorflow/python/autograph/converters/continue_statements_test.py
@@ -29,7 +29,7 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
 
   def assertTransformedEquivalent(self, test_fn, *inputs):
     with self.converted(test_fn, continue_statements, {'ops': ops},
-                        constant_op.constant) as result:
+                        (constant_op.constant,)) as result:
       self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
 
   def test_basic(self):
diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
index 4690b114a77..e1ba82043bc 100644
--- a/tensorflow/python/autograph/converters/control_flow_test.py
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -39,7 +39,7 @@ class ControlFlowTest(converter_testing.TestCase):
     if not symbols:
       symbols = {}
     with self.converted(test_fn, control_flow, symbols,
-                        constant_op.constant) as result:
+                        (constant_op.constant,)) as result:
       self.assertAllEqual(self.evaluate(result.test_fn(*inputs)), expected)
 
   @test_util.run_deprecated_v1
diff --git a/tensorflow/python/autograph/converters/function_scopes_test.py b/tensorflow/python/autograph/converters/function_scopes_test.py
index 0eccf39db7d..f973687e8bb 100644
--- a/tensorflow/python/autograph/converters/function_scopes_test.py
+++ b/tensorflow/python/autograph/converters/function_scopes_test.py
@@ -55,7 +55,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
       return tf.constant(1)
 
     with self.converted(test_fn, function_scopes, {},
-                        constant_op.constant) as result:
+                        (constant_op.constant,)) as result:
       result_op = result.test_fn()
       self.assertIn('test_fn/', result_op.op.name)
       self.assertIn('First sentence.', result.test_fn.__doc__)
@@ -72,7 +72,8 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
       l += 1
       return l, inner_fn(l)
 
-    with self.converted(test_fn, function_scopes, {}, ops.name_scope) as result:
+    with self.converted(test_fn, function_scopes, {},
+                        (ops.name_scope,)) as result:
       first, second = result.test_fn(constant_op.constant(1))
       self.assertIn('test_fn/', first.op.name)
       self.assertNotIn('inner_fn', first.op.name)
@@ -95,7 +96,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
     node, ctx = self.prepare(TestClass, ns)
     node = function_scopes.transform(node, ctx)
 
-    with self.compiled(node, {}, ops.name_scope) as result:
+    with self.compiled(node, {}, (ops.name_scope,)) as result:
       first, second = result.TestClass().test_fn(constant_op.constant(1))
       self.assertIn('TestClass/test_fn/', first.op.name)
       self.assertNotIn('inner_fn', first.op.name)
diff --git a/tensorflow/python/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py
index 39843c7d74f..9436b69d749 100644
--- a/tensorflow/python/autograph/converters/lists_test.py
+++ b/tensorflow/python/autograph/converters/lists_test.py
@@ -87,7 +87,7 @@ class ListTest(converter_testing.TestCase):
     }
     node = lists.transform(node, ctx)
 
-    with self.compiled(node, ns, dtypes.int32) as result:
+    with self.compiled(node, ns, (dtypes.int32,)) as result:
       with self.cached_session() as sess:
         ts, tl = result.test_fn()
         r = list_ops.tensor_list_stack(tl, dtypes.int32)
@@ -121,7 +121,7 @@ class ListTest(converter_testing.TestCase):
     }
     node = lists.transform(node, ctx)
 
-    with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
+    with self.compiled(node, {}, (array_ops.stack, dtypes.int32)) as result:
       with self.cached_session() as sess:
         self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])
 
diff --git a/tensorflow/python/autograph/converters/side_effect_guards_test.py b/tensorflow/python/autograph/converters/side_effect_guards_test.py
index 645267e5600..ead05d041aa 100644
--- a/tensorflow/python/autograph/converters/side_effect_guards_test.py
+++ b/tensorflow/python/autograph/converters/side_effect_guards_test.py
@@ -47,7 +47,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
 
     self.assertEqual(len(node.body), 1)
 
-    with self.compiled(node, {}, state_ops.assign) as result:
+    with self.compiled(node, {}, (state_ops.assign,)) as result:
       with self.cached_session() as sess:
         v = variable_scope.get_variable('test', initializer=2)
         self.evaluate(v.initializer)
@@ -68,7 +68,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
 
     self.assertEqual(len(node.body), 1)
 
-    with self.compiled(node, {}, state_ops.assign) as result:
+    with self.compiled(node, {}, (state_ops.assign,)) as result:
       with self.cached_session() as sess:
         v = variable_scope.get_variable('test', initializer=2)
         self.evaluate(v.initializer)
@@ -89,7 +89,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
 
     self.assertEqual(len(node.body), 1)
 
-    with self.compiled(node, {}, control_flow_ops.Assert) as result:
+    with self.compiled(node, {}, (control_flow_ops.Assert,)) as result:
       with self.cached_session() as sess:
         with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                      'expected in throw'):
@@ -109,7 +109,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
 
     self.assertEqual(len(node.body), 1)
 
-    with self.compiled(node, {}, state_ops.assign_add) as result:
+    with self.compiled(node, {}, (state_ops.assign_add,)) as result:
       with self.cached_session() as sess:
         v = variable_scope.get_variable('test', initializer=2)
         self.evaluate(v.initializer)
@@ -130,7 +130,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
 
     self.assertEqual(len(node.body[0].body), 1)
 
-    with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
+    with self.compiled(node, {}, (state_ops.assign, ops.name_scope)) as result:
       with self.cached_session() as sess:
         v = variable_scope.get_variable('test', initializer=2)
         self.evaluate(v.initializer)
@@ -152,8 +152,8 @@ class SideEffectGuardsTest(converter_testing.TestCase):
 
     self.assertEqual(len(node.body), 1)
 
-    with self.compiled(node, {}, state_ops.assign,
-                       state_ops.assign_add) as result:
+    with self.compiled(node, {},
+                       (state_ops.assign, state_ops.assign_add)) as result:
       with self.cached_session() as sess:
         v = variable_scope.get_variable('test', initializer=2)
         self.evaluate(v.initializer)
diff --git a/tensorflow/python/autograph/converters/slices_test.py b/tensorflow/python/autograph/converters/slices_test.py
index 11e3736d4fb..2fea1c7f81f 100644
--- a/tensorflow/python/autograph/converters/slices_test.py
+++ b/tensorflow/python/autograph/converters/slices_test.py
@@ -43,7 +43,7 @@ class SliceTest(converter_testing.TestCase):
     }
     node = slices.transform(node, ctx)
 
-    with self.compiled(node, {}, dtypes.int32) as result:
+    with self.compiled(node, {}, (dtypes.int32,)) as result:
       with self.cached_session() as sess:
         tl = list_ops.tensor_list_from_tensor(
             [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index bb2ed38fbbb..507739fdbc2 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -37,8 +37,6 @@ from tensorflow.python.autograph.pyct import pretty_printer
 from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.platform import test
 
-RESULT_OF_MOCK_CONVERTED_CALL = 7
-
 
 class TestCase(test.TestCase):
   """Base class for unit tests in this module. Contains relevant utilities."""
@@ -54,15 +52,17 @@ class TestCase(test.TestCase):
       sys.stdout = sys.__stdout__
 
   @contextlib.contextmanager
-  def compiled(self, node, namespace, *symbols):
+  def compiled(self, node, namespace, symbols=()):
     source = None
 
     self.dynamic_calls = []
     # See api.converted_call
-    def converted_call(unused_f, unused_opts, args, kwargs):
+    def converted_call(f, unused_opts, args, kwargs):
       """Mock version of api.converted_call."""
       self.dynamic_calls.append((args, kwargs))
-      return RESULT_OF_MOCK_CONVERTED_CALL
+      if kwargs is None:
+        kwargs = {}
+      return f(*args, **kwargs)
 
     try:
       result, source, source_map = compiler.ast_to_object(
@@ -92,7 +92,8 @@ class TestCase(test.TestCase):
       raise
 
   @contextlib.contextmanager
-  def converted(self, entity, converter_module, namespace, *tf_symbols):
+  def converted(self, entity, converter_module, namespace, tf_symbols=()):
+
     node, ctx = self.prepare(entity, namespace)
 
     if not isinstance(converter_module, (list, tuple)):
@@ -101,7 +102,7 @@ class TestCase(test.TestCase):
       node = converter.standard_analysis(node, ctx, is_initial=not i)
       node = m.transform(node, ctx)
 
-    with self.compiled(node, namespace, *tf_symbols) as result:
+    with self.compiled(node, namespace, tf_symbols) as result:
       yield result
 
   def make_fake_mod(self, name, *symbols):