Autograph: Fix chained function conversion
Chained functions were not correctly converted. For example, `foo().bar().baz()` only converted baz. Now fixed. PiperOrigin-RevId: 259608163
This commit is contained in:
parent
c33f1d1a61
commit
8dc62ccf82
@ -38,7 +38,7 @@ class AssertsTest(converter_testing.TestCase):
|
||||
return tf.no_op() # pylint:disable=undefined-variable
|
||||
|
||||
with self.converted(test_fn, (asserts, side_effect_guards), {},
|
||||
gen_control_flow_ops.no_op) as result:
|
||||
(gen_control_flow_ops.no_op,)) as result:
|
||||
with self.cached_session() as sess:
|
||||
op = result.test_fn(constant_op.constant(False))
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
|
@ -28,7 +28,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
|
||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
||||
with self.converted(test_fn, break_statements, {},
|
||||
constant_op.constant) as result:
|
||||
(constant_op.constant,)) as result:
|
||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
||||
|
||||
def test_while_loop(self):
|
||||
@ -58,7 +58,7 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
return v
|
||||
|
||||
with self.converted(test_fn, break_statements, {},
|
||||
constant_op.constant) as result:
|
||||
(constant_op.constant,)) as result:
|
||||
# The break is incompletely canonicalized. The loop will not interrupt,
|
||||
# but the section following the break will be skipped.
|
||||
self.assertEqual([3], result.test_fn([5, 4]))
|
||||
|
@ -71,24 +71,26 @@ class CallTreeTransformer(converter.Base):
|
||||
return node
|
||||
|
||||
def visit_Call(self, node):
|
||||
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
|
||||
node = self.generic_visit(node)
|
||||
|
||||
# TODO(mdan): Refactor converted_call as a 'Call' operator.
|
||||
|
||||
# Calls to the internal 'ag__' module are never converted (though their
|
||||
# arguments might be).
|
||||
full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
|
||||
if full_name.startswith('ag__.'):
|
||||
return self.generic_visit(node)
|
||||
return node
|
||||
|
||||
# Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
|
||||
# the normal mechanisms to bypass these literals because they are sensitive
|
||||
# to the frame they are being called from.
|
||||
# TODO(mdan): Generalize this to a "static whitelist" config.
|
||||
if full_name in ('pdb.set_trace', 'ipdb.set_trace'):
|
||||
return self.generic_visit(node)
|
||||
return node
|
||||
|
||||
if (full_name == 'print' and
|
||||
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
|
||||
return self.generic_visit(node)
|
||||
return node
|
||||
|
||||
func = node.func
|
||||
|
||||
@ -99,7 +101,6 @@ class CallTreeTransformer(converter.Base):
|
||||
assert starred_arg is None, 'Multiple *args should be impossible.'
|
||||
starred_arg = a
|
||||
else:
|
||||
a = self.visit(a)
|
||||
normal_args.append(a)
|
||||
if starred_arg is None:
|
||||
args = templates.replace_as_expression('(args,)', args=normal_args)
|
||||
@ -116,7 +117,6 @@ class CallTreeTransformer(converter.Base):
|
||||
assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
|
||||
kwargs_arg = k
|
||||
else:
|
||||
k = self.visit(k)
|
||||
normal_keywords.append(k)
|
||||
if kwargs_arg is None:
|
||||
if not normal_keywords:
|
||||
|
@ -30,52 +30,62 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_normal_function(self):
|
||||
|
||||
def test_fn(f):
|
||||
return f() + 3
|
||||
return f() + 20
|
||||
|
||||
with self.converted(test_fn, call_trees, {}) as result:
|
||||
self.assertEqual(
|
||||
result.test_fn(None),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
|
||||
self.assertEqual(result.test_fn(lambda: 1), 21)
|
||||
self.assertListEqual(self.dynamic_calls, [((), None)])
|
||||
|
||||
def test_function_with_expression_in_argument(self):
|
||||
|
||||
def test_fn(f, g):
|
||||
return f(g() + 7) + 3
|
||||
return f(g() + 20) + 4000
|
||||
|
||||
with self.converted(test_fn, call_trees, {}) as result:
|
||||
self.assertEqual(
|
||||
result.test_fn(None, None),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
|
||||
self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
|
||||
self.assertListEqual(self.dynamic_calls, [
|
||||
((), None),
|
||||
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 7,), None),
|
||||
((21,), None),
|
||||
])
|
||||
|
||||
def test_function_with_call_in_argument(self):
|
||||
|
||||
def test_fn(f, g):
|
||||
return f(g()) + 3
|
||||
return f(g()) + 300
|
||||
|
||||
with self.converted(test_fn, call_trees, {}) as result:
|
||||
self.assertEqual(
|
||||
result.test_fn(None, None),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
|
||||
self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
|
||||
self.assertListEqual(self.dynamic_calls, [
|
||||
((), None),
|
||||
((converter_testing.RESULT_OF_MOCK_CONVERTED_CALL,), None),
|
||||
((1,), None),
|
||||
])
|
||||
|
||||
def test_function_chaining(self):
|
||||
|
||||
def get_one():
|
||||
return 1
|
||||
|
||||
def test_fn():
|
||||
return get_one().__add__(20)
|
||||
|
||||
with self.converted(test_fn, call_trees, {'get_one': get_one},
|
||||
()) as result:
|
||||
|
||||
self.assertEqual(result.test_fn(), 21)
|
||||
|
||||
self.assertListEqual(self.dynamic_calls, [
|
||||
((), None),
|
||||
((20,), None),
|
||||
])
|
||||
|
||||
def test_function_with_kwarg(self):
|
||||
|
||||
def test_fn(f, a, b):
|
||||
return f(a, c=b) + 3
|
||||
return f(a, c=b) + 300
|
||||
|
||||
with self.converted(test_fn, call_trees, {}) as result:
|
||||
self.assertEqual(
|
||||
result.test_fn(None, 1, 2),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 2})])
|
||||
self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})])
|
||||
|
||||
def test_function_with_kwargs_starargs(self):
|
||||
|
||||
@ -84,25 +94,24 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
|
||||
with self.converted(test_fn, call_trees, {}) as result:
|
||||
self.assertEqual(
|
||||
result.test_fn(None, 1, *[2, 3], **{
|
||||
result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{
|
||||
'b': 4,
|
||||
'c': 5
|
||||
}), converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5)
|
||||
}), 12)
|
||||
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), {'b': 4, 'c': 5})])
|
||||
|
||||
def test_function_with_kwargs_starargs_only(self):
|
||||
|
||||
def f(*unused_args): # Will not be called.
|
||||
pass
|
||||
def f(*args):
|
||||
return sum(args)
|
||||
|
||||
def test_fn():
|
||||
args = [1, 2, 3]
|
||||
return f(*args) + 11
|
||||
args = [1, 20, 300]
|
||||
return f(*args) + 4000
|
||||
|
||||
with self.converted(test_fn, call_trees, {'f': f}) as result:
|
||||
self.assertEqual(result.test_fn(),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 11)
|
||||
self.assertListEqual(self.dynamic_calls, [((1, 2, 3), None)])
|
||||
self.assertEqual(result.test_fn(), 4321)
|
||||
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
|
||||
|
||||
def test_function_with_kwargs_keywords(self):
|
||||
|
||||
@ -111,8 +120,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
|
||||
with self.converted(test_fn, call_trees, {}) as result:
|
||||
self.assertEqual(
|
||||
result.test_fn(None, 1, 2, **{'c': 3}),
|
||||
converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 5)
|
||||
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
|
||||
|
||||
def test_debugger_set_trace(self):
|
||||
@ -133,32 +141,30 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def other_method(self, _):
|
||||
raise ValueError('this should not be called')
|
||||
def other_method(self, x):
|
||||
return x + 20
|
||||
|
||||
def test_method(self, a):
|
||||
return self.other_method(a) + 1
|
||||
return self.other_method(a) + 300
|
||||
|
||||
tc = TestClass()
|
||||
with self.converted(TestClass.test_method, call_trees, {}) as result:
|
||||
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
|
||||
result.test_method(tc, 1))
|
||||
self.assertEqual(321, result.test_method(tc, 1))
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
||||
|
||||
def test_object_method(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def other_method(self, _):
|
||||
raise ValueError('this should not be called')
|
||||
def other_method(self, x):
|
||||
return x + 20
|
||||
|
||||
def test_method(self, a):
|
||||
return self.other_method(a) + 1
|
||||
return self.other_method(a) + 300
|
||||
|
||||
tc = TestClass()
|
||||
with self.converted(tc.test_method, call_trees, {}) as result:
|
||||
self.assertEqual(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
|
||||
result.test_method(tc, 1))
|
||||
self.assertEqual(321, result.test_method(tc, 1))
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
||||
|
||||
|
||||
|
@ -29,7 +29,7 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
||||
|
||||
def assertTransformedEquivalent(self, test_fn, *inputs):
|
||||
with self.converted(test_fn, continue_statements, {'ops': ops},
|
||||
constant_op.constant) as result:
|
||||
(constant_op.constant,)) as result:
|
||||
self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
|
||||
|
||||
def test_basic(self):
|
||||
|
@ -39,7 +39,7 @@ class ControlFlowTest(converter_testing.TestCase):
|
||||
if not symbols:
|
||||
symbols = {}
|
||||
with self.converted(test_fn, control_flow, symbols,
|
||||
constant_op.constant) as result:
|
||||
(constant_op.constant,)) as result:
|
||||
self.assertAllEqual(self.evaluate(result.test_fn(*inputs)), expected)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
|
@ -55,7 +55,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
return tf.constant(1)
|
||||
|
||||
with self.converted(test_fn, function_scopes, {},
|
||||
constant_op.constant) as result:
|
||||
(constant_op.constant,)) as result:
|
||||
result_op = result.test_fn()
|
||||
self.assertIn('test_fn/', result_op.op.name)
|
||||
self.assertIn('First sentence.', result.test_fn.__doc__)
|
||||
@ -72,7 +72,8 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
l += 1
|
||||
return l, inner_fn(l)
|
||||
|
||||
with self.converted(test_fn, function_scopes, {}, ops.name_scope) as result:
|
||||
with self.converted(test_fn, function_scopes, {},
|
||||
(ops.name_scope,)) as result:
|
||||
first, second = result.test_fn(constant_op.constant(1))
|
||||
self.assertIn('test_fn/', first.op.name)
|
||||
self.assertNotIn('inner_fn', first.op.name)
|
||||
@ -95,7 +96,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
node, ctx = self.prepare(TestClass, ns)
|
||||
node = function_scopes.transform(node, ctx)
|
||||
|
||||
with self.compiled(node, {}, ops.name_scope) as result:
|
||||
with self.compiled(node, {}, (ops.name_scope,)) as result:
|
||||
first, second = result.TestClass().test_fn(constant_op.constant(1))
|
||||
self.assertIn('TestClass/test_fn/', first.op.name)
|
||||
self.assertNotIn('inner_fn', first.op.name)
|
||||
|
@ -87,7 +87,7 @@ class ListTest(converter_testing.TestCase):
|
||||
}
|
||||
node = lists.transform(node, ctx)
|
||||
|
||||
with self.compiled(node, ns, dtypes.int32) as result:
|
||||
with self.compiled(node, ns, (dtypes.int32,)) as result:
|
||||
with self.cached_session() as sess:
|
||||
ts, tl = result.test_fn()
|
||||
r = list_ops.tensor_list_stack(tl, dtypes.int32)
|
||||
@ -121,7 +121,7 @@ class ListTest(converter_testing.TestCase):
|
||||
}
|
||||
node = lists.transform(node, ctx)
|
||||
|
||||
with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
|
||||
with self.compiled(node, {}, (array_ops.stack, dtypes.int32)) as result:
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])
|
||||
|
||||
|
@ -47,7 +47,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.compiled(node, {}, (state_ops.assign,)) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
@ -68,7 +68,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign) as result:
|
||||
with self.compiled(node, {}, (state_ops.assign,)) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
@ -89,7 +89,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, control_flow_ops.Assert) as result:
|
||||
with self.compiled(node, {}, (control_flow_ops.Assert,)) as result:
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
'expected in throw'):
|
||||
@ -109,7 +109,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign_add) as result:
|
||||
with self.compiled(node, {}, (state_ops.assign_add,)) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
@ -130,7 +130,7 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
|
||||
self.assertEqual(len(node.body[0].body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
|
||||
with self.compiled(node, {}, (state_ops.assign, ops.name_scope)) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
@ -152,8 +152,8 @@ class SideEffectGuardsTest(converter_testing.TestCase):
|
||||
|
||||
self.assertEqual(len(node.body), 1)
|
||||
|
||||
with self.compiled(node, {}, state_ops.assign,
|
||||
state_ops.assign_add) as result:
|
||||
with self.compiled(node, {},
|
||||
(state_ops.assign, state_ops.assign_add)) as result:
|
||||
with self.cached_session() as sess:
|
||||
v = variable_scope.get_variable('test', initializer=2)
|
||||
self.evaluate(v.initializer)
|
||||
|
@ -43,7 +43,7 @@ class SliceTest(converter_testing.TestCase):
|
||||
}
|
||||
node = slices.transform(node, ctx)
|
||||
|
||||
with self.compiled(node, {}, dtypes.int32) as result:
|
||||
with self.compiled(node, {}, (dtypes.int32,)) as result:
|
||||
with self.cached_session() as sess:
|
||||
tl = list_ops.tensor_list_from_tensor(
|
||||
[1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
|
||||
|
@ -37,8 +37,6 @@ from tensorflow.python.autograph.pyct import pretty_printer
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
RESULT_OF_MOCK_CONVERTED_CALL = 7
|
||||
|
||||
|
||||
class TestCase(test.TestCase):
|
||||
"""Base class for unit tests in this module. Contains relevant utilities."""
|
||||
@ -54,15 +52,17 @@ class TestCase(test.TestCase):
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
@contextlib.contextmanager
|
||||
def compiled(self, node, namespace, *symbols):
|
||||
def compiled(self, node, namespace, symbols=()):
|
||||
source = None
|
||||
|
||||
self.dynamic_calls = []
|
||||
# See api.converted_call
|
||||
def converted_call(unused_f, unused_opts, args, kwargs):
|
||||
def converted_call(f, unused_opts, args, kwargs):
|
||||
"""Mock version of api.converted_call."""
|
||||
self.dynamic_calls.append((args, kwargs))
|
||||
return RESULT_OF_MOCK_CONVERTED_CALL
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return f(*args, **kwargs)
|
||||
|
||||
try:
|
||||
result, source, source_map = compiler.ast_to_object(
|
||||
@ -92,7 +92,8 @@ class TestCase(test.TestCase):
|
||||
raise
|
||||
|
||||
@contextlib.contextmanager
|
||||
def converted(self, entity, converter_module, namespace, *tf_symbols):
|
||||
def converted(self, entity, converter_module, namespace, tf_symbols=()):
|
||||
|
||||
node, ctx = self.prepare(entity, namespace)
|
||||
|
||||
if not isinstance(converter_module, (list, tuple)):
|
||||
@ -101,7 +102,7 @@ class TestCase(test.TestCase):
|
||||
node = converter.standard_analysis(node, ctx, is_initial=not i)
|
||||
node = m.transform(node, ctx)
|
||||
|
||||
with self.compiled(node, namespace, *tf_symbols) as result:
|
||||
with self.compiled(node, namespace, tf_symbols) as result:
|
||||
yield result
|
||||
|
||||
def make_fake_mod(self, name, *symbols):
|
||||
|
Loading…
Reference in New Issue
Block a user