Add mechanisms for inferring the arguments of a local function from another higher-order call (this pattern is used in for_stmt).
Add support for tuple unpacking in assignment. PiperOrigin-RevId: 326700596 Change-Id: I9d7349b1c47d35af7cf3827ee04e0613fa53fd8b
This commit is contained in:
parent
6b6df31ddc
commit
300bb99547
@ -75,8 +75,19 @@ class Resolver(object):
|
||||
"""Resolves the type a literal or static value."""
|
||||
raise NotImplementedError('subclasses must implement')
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
"""Resolves the type of a (possibly annotated) function argument."""
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
"""Resolves the type of a (possibly annotated) function argument.
|
||||
|
||||
Args:
|
||||
ns: namespace
|
||||
types_ns: types namespace
|
||||
f_name: str, the function name
|
||||
name: str, the argument name
|
||||
type_anno: the type annotating the argument, if any
|
||||
f_is_local: bool, whether the function is a local function
|
||||
Returns:
|
||||
Set of the argument types.
|
||||
"""
|
||||
raise NotImplementedError('subclasses must implement')
|
||||
|
||||
def res_call(self, ns, types_ns, node, f_type, args, keywords):
|
||||
@ -98,8 +109,9 @@ class Resolver(object):
|
||||
"""
|
||||
raise NotImplementedError('subclasses must implement')
|
||||
|
||||
def res_subscript(self, ns, types_ns, node, value, slice_):
|
||||
"""Resolves the return type of a unary operation."""
|
||||
# TODO(mdan): Clean this up.
|
||||
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
"""Resolves the return type of slice operation."""
|
||||
raise NotImplementedError('subclasses must implement')
|
||||
|
||||
def res_compare(self, ns, types_ns, node, left, right):
|
||||
@ -217,7 +229,18 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
return {Tuple}
|
||||
|
||||
assert isinstance(node.ctx, gast.Store)
|
||||
# TODO(mdan): Implement tuple unpacking.
|
||||
|
||||
if self.rtype is not None:
|
||||
original_stype = self.rtype
|
||||
# TODO(mdan): Find a better way to express unpacking.
|
||||
i_type = self.resolver.res_value(self.namespace, 0)
|
||||
for i, elt in enumerate(node.elts):
|
||||
self.rtype = self.resolver.res_subscript(
|
||||
self.namespace, self.types_in.types, i, original_stype, i_type)
|
||||
self.visit(elt)
|
||||
self.rtype = original_stype
|
||||
return original_stype
|
||||
|
||||
return None
|
||||
|
||||
def visit_List(self, node):
|
||||
@ -249,9 +272,13 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
anno.setanno(node, anno.Static.VALUE, value)
|
||||
|
||||
elif isinstance(node.ctx, gast.Param):
|
||||
# The direct parent it the whole function scope. See activity.py.
|
||||
f_is_local = self.scope.parent.parent is not None
|
||||
|
||||
type_name = anno.getanno(node.annotation, anno.Basic.QN, None)
|
||||
types = self.resolver.res_arg(self.namespace, self.types_in.types,
|
||||
self.scope.function_name, name, type_name)
|
||||
self.scope.function_name, name, type_name,
|
||||
f_is_local)
|
||||
if types is not None:
|
||||
self.new_symbols[name] = types
|
||||
|
||||
@ -317,8 +344,6 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
if node.decorator_list:
|
||||
raise NotImplementedError('decorators: {}'.format(node.decorator_list))
|
||||
|
||||
# TODO(mdan): Use args.
|
||||
|
||||
ret_types = None
|
||||
if node.returns:
|
||||
ret_types, _ = self.resolver.res_name(
|
||||
@ -371,7 +396,7 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
ret_type, side_effects = None, None
|
||||
else:
|
||||
ret_type, side_effects = self._resolve_typed_callable(
|
||||
self.types_in.types.get(f_name), arg_types, keyword_types)
|
||||
f_type, arg_types, keyword_types)
|
||||
|
||||
else:
|
||||
# Nonlocal function, resolve externally.
|
||||
|
||||
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
@ -43,7 +43,7 @@ class BasicTestResolver(type_inference.Resolver):
|
||||
def res_value(self, ns, value):
|
||||
return {type(value)}
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
if type_anno is None:
|
||||
return None
|
||||
return {str(type_anno)}
|
||||
@ -87,7 +87,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return None
|
||||
|
||||
def test_fn(a, b):
|
||||
@ -106,7 +106,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return 1
|
||||
|
||||
def test_fn(a):
|
||||
@ -122,7 +122,8 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
test_self.assertFalse(f_is_local)
|
||||
if name == qual_names.QN('a'):
|
||||
test_self.assertEqual(type_anno, qual_names.QN('int'))
|
||||
return {str(name) + '_type'}
|
||||
@ -138,19 +139,41 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
|
||||
def test_argument_of_local_function(self):
|
||||
|
||||
test_self = self
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
if f_name == 'test_fn':
|
||||
test_self.assertFalse(f_is_local)
|
||||
test_self.assertEqual(name, qual_names.QN('a'))
|
||||
test_self.assertEqual(type_anno, qual_names.QN('int'))
|
||||
elif f_name == 'foo':
|
||||
test_self.assertTrue(f_is_local)
|
||||
if name == qual_names.QN('x'):
|
||||
test_self.assertEqual(type_anno, qual_names.QN('float'))
|
||||
elif name == qual_names.QN('y'):
|
||||
test_self.assertIsNone(type_anno)
|
||||
else:
|
||||
test_self.fail('unexpected argument {} for {}'.format(name, f_name))
|
||||
else:
|
||||
test_self.fail('unexpected function name {}'.format(f_name))
|
||||
return {str(name) + '_type'}
|
||||
|
||||
def test_fn(a: int):
|
||||
|
||||
def foo(x: float):
|
||||
return x
|
||||
def foo(x: float, y):
|
||||
return x, y
|
||||
|
||||
return foo(a)
|
||||
return foo(a, a)
|
||||
|
||||
tr = TestTranspiler(BasicTestResolver)
|
||||
tr = TestTranspiler(Resolver)
|
||||
node, _ = tr.transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[0].body[0].value, 'float')
|
||||
self.assertClosureTypes(fn_body[0], {'a': {'int'}})
|
||||
self.assertTypes(fn_body[0].body[0].value, Tuple)
|
||||
self.assertTypes(fn_body[0].body[0].value.elts[0], 'x_type')
|
||||
self.assertTypes(fn_body[0].body[0].value.elts[1], 'y_type')
|
||||
|
||||
def test_assign_straightline(self):
|
||||
|
||||
@ -434,7 +457,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
test_self.assertEqual(name, qual_names.QN('g'))
|
||||
return None, g
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {str(type_anno)}
|
||||
|
||||
def res_call(self, ns, types_ns, node, f_type, args, keywords):
|
||||
@ -591,7 +614,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
test_self.assertEqual(value, 1.0)
|
||||
return {float}
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {str(type_anno)}
|
||||
|
||||
def res_call(self, ns, types_ns, node, f_type, args, keywords):
|
||||
@ -627,7 +650,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {list}
|
||||
|
||||
def res_value(self, ns, value):
|
||||
@ -648,13 +671,45 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
self.assertTypes(fn_body[0].value.value, list)
|
||||
self.assertTypes(fn_body[0].value.slice.value, int)
|
||||
|
||||
def test_tuple_unpacking(self):
|
||||
|
||||
test_self = self
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {list}
|
||||
|
||||
def res_value(self, ns, value):
|
||||
return {int}
|
||||
|
||||
def res_subscript(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
test_self.assertIn(node_or_slice, (0, 1))
|
||||
test_self.assertSetEqual(value, {list})
|
||||
test_self.assertSetEqual(slice_, {int})
|
||||
if node_or_slice == 0:
|
||||
return {float}
|
||||
else:
|
||||
return {str}
|
||||
|
||||
def test_fn(t):
|
||||
a, b = t
|
||||
return a, b
|
||||
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[1].value, Tuple)
|
||||
self.assertTypes(fn_body[1].value.elts[0], float)
|
||||
self.assertTypes(fn_body[1].value.elts[1], str)
|
||||
|
||||
def test_compare(self):
|
||||
|
||||
test_self = self
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {int}
|
||||
|
||||
def res_compare(self, ns, types_ns, node, left, right):
|
||||
@ -678,7 +733,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno):
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {list}
|
||||
|
||||
def res_binop(self, ns, types_ns, node, left, right):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user