Autograph: Fix chained function conversion

Chained functions were not correctly converted. For example,
`foo().bar().baz()` only converted baz.  Now fixed.

PiperOrigin-RevId: 259608163
This commit is contained in:
A. Unique TensorFlower 2019-07-23 14:11:01 -07:00 committed by TensorFlower Gardener
parent c33f1d1a61
commit 8dc62ccf82
11 changed files with 79 additions and 71 deletions

View File

@ -38,7 +38,7 @@ class AssertsTest(converter_testing.TestCase):
return tf.no_op() # pylint:disable=undefined-variable
with self.converted(test_fn, (asserts, side_effect_guards), {},
gen_control_flow_ops.no_op) as result:
(gen_control_flow_ops.no_op,)) as result:
with self.cached_session() as sess:
op = result.test_fn(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,

View File

@ -28,7 +28,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn, break_statements, {},
constant_op.constant) as result:
(constant_op.constant,)) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_while_loop(self):
@ -58,7 +58,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
return v
with self.converted(test_fn, break_statements, {},
constant_op.constant) as result:
(constant_op.constant,)) as result:
# The break is incompletely canonicalized. The loop will not interrupt,
# but the section following the break will be skipped.
self.assertEqual([3], result.test_fn([5, 4]))

View File

@ -71,24 +71,26 @@ class CallTreeTransformer(converter.Base):
return node
def visit_Call(self, node):
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
node = self.generic_visit(node)
# TODO(mdan): Refactor converted_call as a 'Call' operator.
# Calls to the internal 'ag__' module are never converted (though their
# arguments might be).
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
if full_name.startswith('ag__.'):
return self.generic_visit(node)
return node
# Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
# the normal mechanisms to bypass these literals because they are sensitive
# to the frame they are being called from.
# TODO(mdan): Generalize this to a "static whitelist" config.
if full_name in ('pdb.set_trace', 'ipdb.set_trace'):
return self.generic_visit(node)
return node
if (full_name == 'print' and
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
return self.generic_visit(node)
return node
func = node.func
@ -99,7 +101,6 @@ class CallTreeTransformer(converter.Base):
assert starred_arg is None, 'Multiple *args should be impossible.'
starred_arg = a
else:
a = self.visit(a)
normal_args.append(a)
if starred_arg is None:
args = templates.replace_as_expression('(args,)', args=normal_args)
@ -116,7 +117,6 @@ class CallTreeTransformer(converter.Base):
assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
kwargs_arg = k
else:
k = self.visit(k)
normal_keywords.append(k)
if kwargs_arg is None:
if not normal_keywords:

View File

@ -30,52 +30,62 @@ class CallTreesTest(converter_testing.TestCase):
def test_normal_function(self):
def test_fn(f):
return f() + 3
return f() + 20
with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual(
result.test_fn(None),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertEqual(result.test_fn(lambda: 1), 21)
self.assertListEqual(self.dynamic_calls, [((), None)])
def test_function_with_expression_in_argument(self):
def test_fn(f, g):
return f(g() + 7) + 3
return f(g() + 20) + 4000
with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual(
result.test_fn(None, None),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
self.assertListEqual(self.dynamic_calls, [
((), None),
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 7,), None),
((21,), None),
])
def test_function_with_call_in_argument(self):
def test_fn(f, g):
return f(g()) + 3
return f(g()) + 300
with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual(
result.test_fn(None, None),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
self.assertListEqual(self.dynamic_calls, [
((), None),
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL,), None),
((1,), None),
])
def test_function_chaining(self):
def get_one():
return 1
def test_fn():
return get_one().__add__(20)
with self.converted(test_fn, call_trees, {'get_one': get_one},
()) as result:
self.assertEqual(result.test_fn(), 21)
self.assertListEqual(self.dynamic_calls, [
((), None),
((20,), None),
])
def test_function_with_kwarg(self):
def test_fn(f, a, b):
return f(a, c=b) + 3
return f(a, c=b) + 300
with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual(
result.test_fn(None, 1, 2),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 2})])
self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})])
def test_function_with_kwargs_starargs(self):
@ -84,25 +94,24 @@ class CallTreesTest(converter_testing.TestCase):
with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual(
result.test_fn(None, 1, *[2, 3], **{
result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{
'b': 4,
'c': 5
}), converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5)
}), 12)
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
def test_function_with_kwargs_starargs_only(self):
def f(*unused_args): # Will not be called.
pass
def f(*args):
return sum(args)
def test_fn():
args = [1, 2, 3]
return f(*args) + 11
args = [1, 20, 300]
return f(*args) + 4000
with self.converted(test_fn, call_trees, {'f': f}) as result:
self.assertEqual(result.test_fn(),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 11)
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), None)])
self.assertEqual(result.test_fn(), 4321)
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
def test_function_with_kwargs_keywords(self):
@ -111,8 +120,7 @@ class CallTreesTest(converter_testing.TestCase):
with self.converted(test_fn, call_trees, {}) as result:
self.assertEqual(
result.test_fn(None, 1, 2, **{'c': 3}),
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5)
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
def test_debugger_set_trace(self):
@ -133,32 +141,30 @@ class CallTreesTest(converter_testing.TestCase):
class TestClass(object):
def other_method(self, _):
raise ValueError('this should not be called')
def other_method(self, x):
return x + 20
def test_method(self, a):
return self.other_method(a) + 1
return self.other_method(a) + 300
tc = TestClass()
with self.converted(TestClass.test_method, call_trees, {}) as result:
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
result.test_method(tc, 1))
self.assertEqual(321, result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)])
def test_object_method(self):
class TestClass(object):
def other_method(self, _):
raise ValueError('this should not be called')
def other_method(self, x):
return x + 20
def test_method(self, a):
return self.other_method(a) + 1
return self.other_method(a) + 300
tc = TestClass()
with self.converted(tc.test_method, call_trees, {}) as result:
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
result.test_method(tc, 1))
self.assertEqual(321, result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)])

View File

@ -29,7 +29,7 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
def assertTransformedEquivalent(self, test_fn, *inputs):
with self.converted(test_fn, continue_statements, {'ops': ops},
constant_op.constant) as result:
(constant_op.constant,)) as result:
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
def test_basic(self):

View File

@ -39,7 +39,7 @@ class ControlFlowTest(converter_testing.TestCase):
if not symbols:
symbols = {}
with self.converted(test_fn, control_flow, symbols,
constant_op.constant) as result:
(constant_op.constant,)) as result:
self.assertAllEqual(self.evaluate(result.test_fn(*inputs)), expected)
@test_util.run_deprecated_v1

View File

@ -55,7 +55,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
return tf.constant(1)
with self.converted(test_fn, function_scopes, {},
constant_op.constant) as result:
(constant_op.constant,)) as result:
result_op = result.test_fn()
self.assertIn('test_fn/', result_op.op.name)
self.assertIn('First sentence.', result.test_fn.__doc__)
@ -72,7 +72,8 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
l += 1
return l, inner_fn(l)
with self.converted(test_fn, function_scopes, {}, ops.name_scope) as result:
with self.converted(test_fn, function_scopes, {},
(ops.name_scope,)) as result:
first, second = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name)
self.assertNotIn('inner_fn', first.op.name)
@ -95,7 +96,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
node, ctx = self.prepare(TestClass, ns)
node = function_scopes.transform(node, ctx)
with self.compiled(node, {}, ops.name_scope) as result:
with self.compiled(node, {}, (ops.name_scope,)) as result:
first, second = result.TestClass().test_fn(constant_op.constant(1))
self.assertIn('TestClass/test_fn/', first.op.name)
self.assertNotIn('inner_fn', first.op.name)

View File

@ -87,7 +87,7 @@ class ListTest(converter_testing.TestCase):
}
node = lists.transform(node, ctx)
with self.compiled(node, ns, dtypes.int32) as result:
with self.compiled(node, ns, (dtypes.int32,)) as result:
with self.cached_session() as sess:
ts, tl = result.test_fn()
r = list_ops.tensor_list_stack(tl, dtypes.int32)
@ -121,7 +121,7 @@ class ListTest(converter_testing.TestCase):
}
node = lists.transform(node, ctx)
with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
with self.compiled(node, {}, (array_ops.stack, dtypes.int32)) as result:
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])

View File

@ -47,7 +47,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result:
with self.compiled(node, {}, (state_ops.assign,)) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
@ -68,7 +68,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign) as result:
with self.compiled(node, {}, (state_ops.assign,)) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
@ -89,7 +89,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, control_flow_ops.Assert) as result:
with self.compiled(node, {}, (control_flow_ops.Assert,)) as result:
with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'expected in throw'):
@ -109,7 +109,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign_add) as result:
with self.compiled(node, {}, (state_ops.assign_add,)) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
@ -130,7 +130,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body[0].body), 1)
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
with self.compiled(node, {}, (state_ops.assign, ops.name_scope)) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
@ -152,8 +152,8 @@ class SideEffectGuardsTest(converter_testing.TestCase):
self.assertEqual(len(node.body), 1)
with self.compiled(node, {}, state_ops.assign,
state_ops.assign_add) as result:
with self.compiled(node, {},
(state_ops.assign, state_ops.assign_add)) as result:
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)

View File

@ -43,7 +43,7 @@ class SliceTest(converter_testing.TestCase):
}
node = slices.transform(node, ctx)
with self.compiled(node, {}, dtypes.int32) as result:
with self.compiled(node, {}, (dtypes.int32,)) as result:
with self.cached_session() as sess:
tl = list_ops.tensor_list_from_tensor(
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))

View File

@ -37,8 +37,6 @@ from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.platform import test
RESULT_OF_MOCK_CONVERTED_CALL = 7
class TestCase(test.TestCase):
"""Base class for unit tests in this module. Contains relevant utilities."""
@ -54,15 +52,17 @@ class TestCase(test.TestCase):
sys.stdout = sys.__stdout__
@contextlib.contextmanager
def compiled(self, node, namespace, *symbols):
def compiled(self, node, namespace, symbols=()):
source = None
self.dynamic_calls = []
# See api.converted_call
def converted_call(unused_f, unused_opts, args, kwargs):
def converted_call(f, unused_opts, args, kwargs):
"""Mock version of api.converted_call."""
self.dynamic_calls.append((args, kwargs))
return RESULT_OF_MOCK_CONVERTED_CALL
if kwargs is None:
kwargs = {}
return f(*args, **kwargs)
try:
result, source, source_map = compiler.ast_to_object(
@ -92,7 +92,8 @@ class TestCase(test.TestCase):
raise
@contextlib.contextmanager
def converted(self, entity, converter_module, namespace, *tf_symbols):
def converted(self, entity, converter_module, namespace, tf_symbols=()):
node, ctx = self.prepare(entity, namespace)
if not isinstance(converter_module, (list, tuple)):
@ -101,7 +102,7 @@ class TestCase(test.TestCase):
node = converter.standard_analysis(node, ctx, is_initial=not i)
node = m.transform(node, ctx)
with self.compiled(node, namespace, *tf_symbols) as result:
with self.compiled(node, namespace, tf_symbols) as result:
yield result
def make_fake_mod(self, name, *symbols):