Add mechanisms for inferring the arguments of a local function from another higher-order call (this pattern is used in for_stmt).

Add support for tuple unpacking in assignment.

PiperOrigin-RevId: 326700596
Change-Id: I9d7349b1c47d35af7cf3827ee04e0613fa53fd8b
This commit is contained in:
Dan Moldovan 2020-08-14 11:55:50 -07:00 committed by TensorFlower Gardener
parent 6b6df31ddc
commit 300bb99547
2 changed files with 105 additions and 25 deletions

View File

@ -75,8 +75,19 @@ class Resolver(object):
"""Resolves the type a literal or static value."""
raise NotImplementedError('subclasses must implement')
def res_arg(self, ns, types_ns, f_name, name, type_anno):
"""Resolves the type of a (possibly annotated) function argument."""
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
"""Resolves the type of a (possibly annotated) function argument.
Args:
ns: namespace
types_ns: types namespace
f_name: str, the function name
name: str, the argument name
type_anno: the type annotating the argument, if any
f_is_local: bool, whether the function is a local function
Returns:
Set of the argument types.
"""
raise NotImplementedError('subclasses must implement')
def res_call(self, ns, types_ns, node, f_type, args, keywords):
@ -98,8 +109,9 @@ class Resolver(object):
"""
raise NotImplementedError('subclasses must implement')
def res_subscript(self, ns, types_ns, node, value, slice_):
"""Resolves the return type of a unary operation."""
# TODO(mdan): Clean this up.
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
"""Resolves the return type of slice operation."""
raise NotImplementedError('subclasses must implement')
def res_compare(self, ns, types_ns, node, left, right):
@ -217,7 +229,18 @@ class StmtInferrer(gast.NodeVisitor):
return {Tuple}
assert isinstance(node.ctx, gast.Store)
# TODO(mdan): Implement tuple unpacking.
if self.rtype is not None:
original_stype = self.rtype
# TODO(mdan): Find a better way to express unpacking.
i_type = self.resolver.res_value(self.namespace, 0)
for i, elt in enumerate(node.elts):
self.rtype = self.resolver.res_subscript(
self.namespace, self.types_in.types, i, original_stype, i_type)
self.visit(elt)
self.rtype = original_stype
return original_stype
return None
def visit_List(self, node):
@ -249,9 +272,13 @@ class StmtInferrer(gast.NodeVisitor):
anno.setanno(node, anno.Static.VALUE, value)
elif isinstance(node.ctx, gast.Param):
# The direct parent it the whole function scope. See activity.py.
f_is_local = self.scope.parent.parent is not None
type_name = anno.getanno(node.annotation, anno.Basic.QN, None)
types = self.resolver.res_arg(self.namespace, self.types_in.types,
self.scope.function_name, name, type_name)
self.scope.function_name, name, type_name,
f_is_local)
if types is not None:
self.new_symbols[name] = types
@ -317,8 +344,6 @@ class StmtInferrer(gast.NodeVisitor):
if node.decorator_list:
raise NotImplementedError('decorators: {}'.format(node.decorator_list))
# TODO(mdan): Use args.
ret_types = None
if node.returns:
ret_types, _ = self.resolver.res_name(
@ -371,7 +396,7 @@ class StmtInferrer(gast.NodeVisitor):
ret_type, side_effects = None, None
else:
ret_type, side_effects = self._resolve_typed_callable(
self.types_in.types.get(f_name), arg_types, keyword_types)
f_type, arg_types, keyword_types)
else:
# Nonlocal function, resolve externally.

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Callable
from typing import Any, Callable, Tuple
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
@ -43,7 +43,7 @@ class BasicTestResolver(type_inference.Resolver):
def res_value(self, ns, value):
return {type(value)}
def res_arg(self, ns, types_ns, f_name, name, type_anno):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
if type_anno is None:
return None
return {str(type_anno)}
@ -87,7 +87,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return None
def test_fn(a, b):
@ -106,7 +106,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return 1
def test_fn(a):
@ -122,7 +122,8 @@ class TypeInferenceAnalyzerTest(test.TestCase):
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
test_self.assertFalse(f_is_local)
if name == qual_names.QN('a'):
test_self.assertEqual(type_anno, qual_names.QN('int'))
return {str(name) + '_type'}
@ -138,19 +139,41 @@ class TypeInferenceAnalyzerTest(test.TestCase):
def test_argument_of_local_function(self):
test_self = self
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
if f_name == 'test_fn':
test_self.assertFalse(f_is_local)
test_self.assertEqual(name, qual_names.QN('a'))
test_self.assertEqual(type_anno, qual_names.QN('int'))
elif f_name == 'foo':
test_self.assertTrue(f_is_local)
if name == qual_names.QN('x'):
test_self.assertEqual(type_anno, qual_names.QN('float'))
elif name == qual_names.QN('y'):
test_self.assertIsNone(type_anno)
else:
test_self.fail('unexpected argument {} for {}'.format(name, f_name))
else:
test_self.fail('unexpected function name {}'.format(f_name))
return {str(name) + '_type'}
def test_fn(a: int):
def foo(x: float):
return x
def foo(x: float, y):
return x, y
return foo(a)
return foo(a, a)
tr = TestTranspiler(BasicTestResolver)
tr = TestTranspiler(Resolver)
node, _ = tr.transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[0].body[0].value, 'float')
self.assertClosureTypes(fn_body[0], {'a': {'int'}})
self.assertTypes(fn_body[0].body[0].value, Tuple)
self.assertTypes(fn_body[0].body[0].value.elts[0], 'x_type')
self.assertTypes(fn_body[0].body[0].value.elts[1], 'y_type')
def test_assign_straightline(self):
@ -434,7 +457,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
test_self.assertEqual(name, qual_names.QN('g'))
return None, g
def res_arg(self, ns, types_ns, f_name, name, type_anno):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {str(type_anno)}
def res_call(self, ns, types_ns, node, f_type, args, keywords):
@ -591,7 +614,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
test_self.assertEqual(value, 1.0)
return {float}
def res_arg(self, ns, types_ns, f_name, name, type_anno):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {str(type_anno)}
def res_call(self, ns, types_ns, node, f_type, args, keywords):
@ -627,7 +650,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {list}
def res_value(self, ns, value):
@ -648,13 +671,45 @@ class TypeInferenceAnalyzerTest(test.TestCase):
self.assertTypes(fn_body[0].value.value, list)
self.assertTypes(fn_body[0].value.slice.value, int)
def test_tuple_unpacking(self):
test_self = self
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {list}
def res_value(self, ns, value):
return {int}
def res_subscript(self, ns, types_ns, node_or_slice, value, slice_):
test_self.assertIn(node_or_slice, (0, 1))
test_self.assertSetEqual(value, {list})
test_self.assertSetEqual(slice_, {int})
if node_or_slice == 0:
return {float}
else:
return {str}
def test_fn(t):
a, b = t
return a, b
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[1].value, Tuple)
self.assertTypes(fn_body[1].value.elts[0], float)
self.assertTypes(fn_body[1].value.elts[1], str)
def test_compare(self):
test_self = self
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {int}
def res_compare(self, ns, types_ns, node, left, right):
@ -678,7 +733,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {list}
def res_binop(self, ns, types_ns, node, left, right):