Autograph: Fix chained function conversion

Chained functions were not correctly converted. For example,
`foo().bar().baz()` only converted baz.  Now fixed.

PiperOrigin-RevId: 259608163
This commit is contained in:
A. Unique TensorFlower 2019-07-23 14:11:01 -07:00 committed by TensorFlower Gardener
parent c33f1d1a61
commit 8dc62ccf82
11 changed files with 79 additions and 71 deletions

View File

@ -38,7 +38,7 @@ class AssertsTest(converter_testing.TestCase):
return tf.no_op() # pylint:disable=undefined-variable return tf.no_op() # pylint:disable=undefined-variable
with self.converted(test_fn, (asserts, side_effect_guards), {}, with self.converted(test_fn, (asserts, side_effect_guards), {},
gen_control_flow_ops.no_op) as result: (gen_control_flow_ops.no_op,)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
op = result.test_fn(constant_op.constant(False)) op = result.test_fn(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,

View File

@ -28,7 +28,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs): def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn, break_statements, {}, with self.converted(test_fn, break_statements, {},
constant_op.constant) as result: (constant_op.constant,)) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_while_loop(self): def test_while_loop(self):
@ -58,7 +58,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
return v return v
with self.converted(test_fn, break_statements, {}, with self.converted(test_fn, break_statements, {},
constant_op.constant) as result: (constant_op.constant,)) as result:
# The break is incompletely canonicalized. The loop will not interrupt, # The break is incompletely canonicalized. The loop will not interrupt,
# but the section following the break will be skipped. # but the section following the break will be skipped.
self.assertEqual([3], result.test_fn([5, 4])) self.assertEqual([3], result.test_fn([5, 4]))

View File

@ -71,24 +71,26 @@ class CallTreeTransformer(converter.Base):
return node return node
def visit_Call(self, node): def visit_Call(self, node):
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
node = self.generic_visit(node)
# TODO(mdan): Refactor converted_call as a 'Call' operator. # TODO(mdan): Refactor converted_call as a 'Call' operator.
# Calls to the internal 'ag__' module are never converted (though their # Calls to the internal 'ag__' module are never converted (though their
# arguments might be). # arguments might be).
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
if full_name.startswith('ag__.'): if full_name.startswith('ag__.'):
return self.generic_visit(node) return node
# Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
# the normal mechanisms to bypass these literals because they are sensitive # the normal mechanisms to bypass these literals because they are sensitive
# to the frame they are being called from. # to the frame they are being called from.
# TODO(mdan): Generalize this to a "static whitelist" config. # TODO(mdan): Generalize this to a "static whitelist" config.
if full_name in ('pdb.set_trace', 'ipdb.set_trace'): if full_name in ('pdb.set_trace', 'ipdb.set_trace'):
return self.generic_visit(node) return node
if (full_name == 'print' and if (full_name == 'print' and
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
return self.generic_visit(node) return node
func = node.func func = node.func
@ -99,7 +101,6 @@ class CallTreeTransformer(converter.Base):
assert starred_arg is None, 'Multiple *args should be impossible.' assert starred_arg is None, 'Multiple *args should be impossible.'
starred_arg = a starred_arg = a
else: else:
a = self.visit(a)
normal_args.append(a) normal_args.append(a)
if starred_arg is None: if starred_arg is None:
args = templates.replace_as_expression('(args,)', args=normal_args) args = templates.replace_as_expression('(args,)', args=normal_args)
@ -116,7 +117,6 @@ class CallTreeTransformer(converter.Base):
assert kwargs_arg is None, 'Multiple **kwargs should be impossible.' assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
kwargs_arg = k kwargs_arg = k
else: else:
k = self.visit(k)
normal_keywords.append(k) normal_keywords.append(k)
if kwargs_arg is None: if kwargs_arg is None:
if not normal_keywords: if not normal_keywords:

View File

@ -30,52 +30,62 @@ class CallTreesTest(converter_testing.TestCase):
def test_normal_function(self): def test_normal_function(self):
def test_fn(f): def test_fn(f):
return f() + 3 return f() + 20
with self.converted(test_fn, call_trees, {}) as result: with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual( self.assertEqual(result.test_fn(lambda: 1), 21)
result.test_fn(None),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertListEqual(self.dynamic_calls, [((), None)]) self.assertListEqual(self.dynamic_calls, [((), None)])
def test_function_with_expression_in_argument(self): def test_function_with_expression_in_argument(self):
def test_fn(f, g): def test_fn(f, g):
return f(g() + 7) + 3 return f(g() + 20) + 4000
with self.converted(test_fn, call_trees, {}) as result: with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual( self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
result.test_fn(None, None),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertListEqual(self.dynamic_calls, [ self.assertListEqual(self.dynamic_calls, [
((), None), ((), None),
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 7,), None), ((21,), None),
]) ])
def test_function_with_call_in_argument(self): def test_function_with_call_in_argument(self):
def test_fn(f, g): def test_fn(f, g):
return f(g()) + 3 return f(g()) + 300
with self.converted(test_fn, call_trees, {}) as result: with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual( self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
result.test_fn(None, None),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertListEqual(self.dynamic_calls, [ self.assertListEqual(self.dynamic_calls, [
((), None), ((), None),
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL,), None), ((1,), None),
])
def test_function_chaining(self):
def get_one():
return 1
def test_fn():
return get_one().__add__(20)
with self.converted(test_fn, call_trees, {'get_one': get_one},
()) as result:
self.assertEqual(result.test_fn(), 21)
self.assertListEqual(self.dynamic_calls, [
((), None),
((20,), None),
]) ])
def test_function_with_kwarg(self): def test_function_with_kwarg(self):
def test_fn(f, a, b): def test_fn(f, a, b):
return f(a, c=b) + 3 return f(a, c=b) + 300
with self.converted(test_fn, call_trees, {}) as result: with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual( self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
result.test_fn(None, 1, 2), self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})])
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 2})])
def test_function_with_kwargs_starargs(self): def test_function_with_kwargs_starargs(self):
@ -84,25 +94,24 @@ class CallTreesTest(converter_testing.TestCase):
with self.converted(test_fn, call_trees, {}) as result: with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual( self.assertEqual(
result.test_fn(None, 1, *[2, 3], **{ result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{
'b': 4, 'b': 4,
'c': 5 'c': 5
}), converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5) }), 12)
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})]) self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
def test_function_with_kwargs_starargs_only(self): def test_function_with_kwargs_starargs_only(self):
def f(*unused_args): # Will not be called. def f(*args):
pass return sum(args)
def test_fn(): def test_fn():
args = [1, 2, 3] args = [1, 20, 300]
return f(*args) + 11 return f(*args) + 4000
with self.converted(test_fn, call_trees, {'f': f}) as result: with self.converted(test_fn, call_trees, {'f': f}) as result:
self.assertEqual(result.test_fn(), self.assertEqual(result.test_fn(), 4321)
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 11) self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), None)])
def test_function_with_kwargs_keywords(self): def test_function_with_kwargs_keywords(self):
@ -111,8 +120,7 @@ class CallTreesTest(converter_testing.TestCase):
with self.converted(test_fn, call_trees, {}) as result: with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual( self.assertEqual(
result.test_fn(None, 1, 2, **{'c': 3}), result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5)
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})]) self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
def test_debugger_set_trace(self): def test_debugger_set_trace(self):
@ -133,32 +141,30 @@ class CallTreesTest(converter_testing.TestCase):
class TestClass(object): class TestClass(object):
def other_method(self, _): def other_method(self, x):
raise ValueError('this should not be called') return x + 20
def test_method(self, a): def test_method(self, a):
return self.other_method(a) + 1 return self.other_method(a) + 300
tc = TestClass() tc = TestClass()
with self.converted(TestClass.test_method, call_trees, {}) as result: with self.converted(TestClass.test_method, call_trees, {}) as result:
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1, self.assertEqual(321, result.test_method(tc, 1))
result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)]) self.assertListEqual(self.dynamic_calls, [((1,), None)])
def test_object_method(self): def test_object_method(self):
class TestClass(object): class TestClass(object):
def other_method(self, _): def other_method(self, x):
raise ValueError('this should not be called') return x + 20
def test_method(self, a): def test_method(self, a):
return self.other_method(a) + 1 return self.other_method(a) + 300
tc = TestClass() tc = TestClass()
with self.converted(tc.test_method, call_trees, {}) as result: with self.converted(tc.test_method, call_trees, {}) as result:
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1, self.assertEqual(321, result.test_method(tc, 1))
result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)]) self.assertListEqual(self.dynamic_calls, [((1,), None)])

View File

@ -29,7 +29,7 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs): def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn, continue_statements, {'ops': ops}, with self.converted(test_fn, continue_statements, {'ops': ops},
constant_op.constant) as result: (constant_op.constant,)) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_basic(self): def test_basic(self):

View File

@ -39,7 +39,7 @@ class ControlFlowTest(converter_testing.TestCase):
if not symbols: if not symbols:
symbols = {} symbols = {}
with self.converted(test_fn, control_flow, symbols, with self.converted(test_fn, control_flow, symbols,
constant_op.constant) as result: (constant_op.constant,)) as result:
self.assertAllEqual(self.evaluate(result.test_fn(*inputs)), expected) self.assertAllEqual(self.evaluate(result.test_fn(*inputs)), expected)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1

View File

@ -55,7 +55,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
return tf.constant(1) return tf.constant(1)
with self.converted(test_fn, function_scopes, {}, with self.converted(test_fn, function_scopes, {},
constant_op.constant) as result: (constant_op.constant,)) as result:
result_op = result.test_fn() result_op = result.test_fn()
self.assertIn('test_fn/', result_op.op.name) self.assertIn('test_fn/', result_op.op.name)
self.assertIn('First sentence.', result.test_fn.__doc__) self.assertIn('First sentence.', result.test_fn.__doc__)
@ -72,7 +72,8 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
l += 1 l += 1
return l, inner_fn(l) return l, inner_fn(l)
with self.converted(test_fn, function_scopes, {}, ops.name_scope) as result: with self.converted(test_fn, function_scopes, {},
(ops.name_scope,)) as result:
first, second = result.test_fn(constant_op.constant(1)) first, second = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name) self.assertIn('test_fn/', first.op.name)
self.assertNotIn('inner_fn', first.op.name) self.assertNotIn('inner_fn', first.op.name)
@ -95,7 +96,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
node, ctx = self.prepare(TestClass, ns) node, ctx = self.prepare(TestClass, ns)
node = function_scopes.transform(node, ctx) node = function_scopes.transform(node, ctx)
with self.compiled(node, {}, ops.name_scope) as result: with self.compiled(node, {}, (ops.name_scope,)) as result:
first, second = result.TestClass().test_fn(constant_op.constant(1)) first, second = result.TestClass().test_fn(constant_op.constant(1))
self.assertIn('TestClass/test_fn/', first.op.name) self.assertIn('TestClass/test_fn/', first.op.name)
self.assertNotIn('inner_fn', first.op.name) self.assertNotIn('inner_fn', first.op.name)

View File

@ -87,7 +87,7 @@ class ListTest(converter_testing.TestCase):
} }
node = lists.transform(node, ctx) node = lists.transform(node, ctx)
with self.compiled(node, ns, dtypes.int32) as result: with self.compiled(node, ns, (dtypes.int32,)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
ts, tl = result.test_fn() ts, tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32) r = list_ops.tensor_list_stack(tl, dtypes.int32)
@ -121,7 +121,7 @@ class ListTest(converter_testing.TestCase):
} }
node = lists.transform(node, ctx) node = lists.transform(node, ctx)
with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result: with self.compiled(node, {}, (array_ops.stack, dtypes.int32)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3]) self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])

View File

@ -47,7 +47,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1) self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result: with self.compiled(node, {}, (state_ops.assign,)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2) v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer) self.evaluate(v.initializer)
@ -68,7 +68,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1) self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result: with self.compiled(node, {}, (state_ops.assign,)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2) v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer) self.evaluate(v.initializer)
@ -89,7 +89,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1) self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, control_flow_ops.Assert) as result: with self.compiled(node, {}, (control_flow_ops.Assert,)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'expected in throw'): 'expected in throw'):
@ -109,7 +109,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1) self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign_add) as result: with self.compiled(node, {}, (state_ops.assign_add,)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2) v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer) self.evaluate(v.initializer)
@ -130,7 +130,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body[0].body), 1) self.assertEqual(len(node.body[0].body), 1)
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result: with self.compiled(node, {}, (state_ops.assign, ops.name_scope)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2) v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer) self.evaluate(v.initializer)
@ -152,8 +152,8 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1) self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign, with self.compiled(node, {},
state_ops.assign_add) as result: (state_ops.assign, state_ops.assign_add)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2) v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer) self.evaluate(v.initializer)

View File

@ -43,7 +43,7 @@ class SliceTest(converter_testing.TestCase):
} }
node = slices.transform(node, ctx) node = slices.transform(node, ctx)
with self.compiled(node, {}, dtypes.int32) as result: with self.compiled(node, {}, (dtypes.int32,)) as result:
with self.cached_session() as sess: with self.cached_session() as sess:
tl = list_ops.tensor_list_from_tensor( tl = list_ops.tensor_list_from_tensor(
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))

View File

@ -37,8 +37,6 @@ from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import transformer from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.platform import test from tensorflow.python.platform import test
RESULT_OF_MOCK_CONVERTED_CALL = 7
class TestCase(test.TestCase): class TestCase(test.TestCase):
"""Base class for unit tests in this module. Contains relevant utilities.""" """Base class for unit tests in this module. Contains relevant utilities."""
@ -54,15 +52,17 @@ class TestCase(test.TestCase):
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
@contextlib.contextmanager @contextlib.contextmanager
def compiled(self, node, namespace, *symbols): def compiled(self, node, namespace, symbols=()):
source = None source = None
self.dynamic_calls = [] self.dynamic_calls = []
# See api.converted_call # See api.converted_call
def converted_call(unused_f, unused_opts, args, kwargs): def converted_call(f, unused_opts, args, kwargs):
"""Mock version of api.converted_call.""" """Mock version of api.converted_call."""
self.dynamic_calls.append((args, kwargs)) self.dynamic_calls.append((args, kwargs))
return RESULT_OF_MOCK_CONVERTED_CALL if kwargs is None:
kwargs = {}
return f(*args, **kwargs)
try: try:
result, source, source_map = compiler.ast_to_object( result, source, source_map = compiler.ast_to_object(
@ -92,7 +92,8 @@ class TestCase(test.TestCase):
raise raise
@contextlib.contextmanager @contextlib.contextmanager
def converted(self, entity, converter_module, namespace, *tf_symbols): def converted(self, entity, converter_module, namespace, tf_symbols=()):
node, ctx = self.prepare(entity, namespace) node, ctx = self.prepare(entity, namespace)
if not isinstance(converter_module, (list, tuple)): if not isinstance(converter_module, (list, tuple)):
@ -101,7 +102,7 @@ class TestCase(test.TestCase):
node = converter.standard_analysis(node, ctx, is_initial=not i) node = converter.standard_analysis(node, ctx, is_initial=not i)
node = m.transform(node, ctx) node = m.transform(node, ctx)
with self.compiled(node, namespace, *tf_symbols) as result: with self.compiled(node, namespace, tf_symbols) as result:
yield result yield result
def make_fake_mod(self, name, *symbols): def make_fake_mod(self, name, *symbols):