Properly inherit closure types in local functions. Add partial support for resolving local functions based on their type annotations. Propagate types into Expr nodes (although these are roots in expression trees).
PiperOrigin-RevId: 325488762 Change-Id: Id1754d65bf15b47ca0ef991959d6491c7ebdc118
This commit is contained in:
parent
f3fad99f6f
commit
51ecfb3061
@ -35,10 +35,14 @@ import gast
|
||||
|
||||
|
||||
class NoValue(enum.Enum):
|
||||
"""Base class for different types of AST annotations."""
|
||||
|
||||
def of(self, node, default=None):
|
||||
return getanno(node, self, default=default)
|
||||
|
||||
def add_to(self, node, value):
|
||||
setanno(node, self, value)
|
||||
|
||||
def exists(self, node):
|
||||
return hasanno(node, self)
|
||||
|
||||
|
@ -31,7 +31,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
import gast
|
||||
|
||||
@ -187,16 +187,13 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
|
||||
def visit(self, node):
|
||||
types = super().visit(node)
|
||||
if __debug__:
|
||||
self._check_set(types)
|
||||
if types is not None:
|
||||
# TODO(mdan): Normalize by removing subtypes.
|
||||
anno.setanno(node, anno.Static.TYPES, tuple(types))
|
||||
return types
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
# Skip local function definitions. They are analyzed separately.
|
||||
# TODO(mdan): Don't skip. Analyze side effects instead.
|
||||
return None
|
||||
|
||||
def _check_set(self, value):
|
||||
if value is not None and not isinstance(value, set):
|
||||
raise ValueError('{} method expected to return set, got {}'.format(
|
||||
@ -300,21 +297,73 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
|
||||
return types
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
f_name = qual_names.QN(node.name)
|
||||
|
||||
if node.decorator_list:
|
||||
raise NotImplementedError('decorators: {}'.format(node.decorator_list))
|
||||
|
||||
# TODO(mdan): Use args.
|
||||
|
||||
ret_types = None
|
||||
if node.returns:
|
||||
ret_types, _ = self.resolver.res_name(
|
||||
self.namespace, self.types_in.types, anno.Basic.QN.of(node.returns))
|
||||
if __debug__:
|
||||
self._check_set(ret_types)
|
||||
|
||||
if ret_types is None:
|
||||
ret_types = {Any}
|
||||
|
||||
fn_types = set()
|
||||
for rt in ret_types:
|
||||
fn_types.add(Callable[[Any], rt])
|
||||
|
||||
self.new_symbols[f_name] = fn_types
|
||||
# The definition of a function is an expression, hence has no return value.
|
||||
return None
|
||||
|
||||
def _resolve_typed_callable(self, fn_types, arg_types, keyword_types):
|
||||
ret_types = set()
|
||||
for t in fn_types:
|
||||
|
||||
if isinstance(t, Callable):
|
||||
# Note: these are undocummented - may be version-specific!
|
||||
# Callable[[x], y]: __args__ are (x, y)
|
||||
args = t.__args__
|
||||
if args:
|
||||
ret_types.add(args[-1])
|
||||
else:
|
||||
ret_types.add(Any)
|
||||
else:
|
||||
raise NotImplementedError('callable type {}'.format(type(t)))
|
||||
|
||||
# Side effects can not be inferred based on type alone.
|
||||
side_effects = None
|
||||
return ret_types, side_effects
|
||||
|
||||
def visit_Call(self, node):
|
||||
self.visit(node.func)
|
||||
|
||||
f_name = anno.getanno(node.func, anno.Basic.QN)
|
||||
if f_name in self.scope.bound:
|
||||
# Don't attempt external resolution of local functions.
|
||||
# TODO(mdan): Use type annotations of the local definition.
|
||||
return None
|
||||
|
||||
f_name = anno.Basic.QN.of(node.func)
|
||||
arg_types = [self.visit(a) for a in node.args]
|
||||
keyword_types = [self.visit(kw.value) for kw in node.keywords]
|
||||
|
||||
ret_type, side_effects = self.resolver.res_call(self.namespace,
|
||||
self.types_in.types, node,
|
||||
arg_types, keyword_types)
|
||||
if f_name in self.scope.bound:
|
||||
# Local function, use local type definitions, if available.
|
||||
fn_type = self.types_in.types.get(f_name, None)
|
||||
if fn_type is None:
|
||||
# No static type info available, nothing more to do.
|
||||
ret_type, side_effects = None, None
|
||||
else:
|
||||
ret_type, side_effects = self._resolve_typed_callable(
|
||||
self.types_in.types.get(f_name), arg_types, keyword_types)
|
||||
|
||||
else:
|
||||
# Nonlocal function, resolve externally.
|
||||
ret_type, side_effects = self.resolver.res_call(self.namespace,
|
||||
self.types_in.types, node,
|
||||
arg_types, keyword_types)
|
||||
if __debug__:
|
||||
self._check_set(ret_type)
|
||||
if side_effects:
|
||||
@ -330,6 +379,9 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
self.new_symbols.update(side_effects)
|
||||
return ret_type
|
||||
|
||||
def visit_Expr(self, node):
|
||||
return self.visit(node.value)
|
||||
|
||||
def visit_Index(self, node):
|
||||
return self.visit(node.value)
|
||||
|
||||
@ -406,15 +458,24 @@ class Analyzer(cfg.GraphVisitor):
|
||||
self.scope = scope
|
||||
self.closure_types = closure_types
|
||||
|
||||
context_types = {
|
||||
n: t for n, t in closure_types.items() if n not in scope.bound
|
||||
}
|
||||
if context_types:
|
||||
self.context_types = _SymbolTable()
|
||||
self.context_types.types = context_types
|
||||
else:
|
||||
self.context_types = None
|
||||
|
||||
def init_state(self, _):
|
||||
return _SymbolTable()
|
||||
|
||||
def _update_closure_types(self, ast_node, types):
|
||||
existing_types = anno.getanno(ast_node, anno.Static.CLOSURE_TYPES, None)
|
||||
existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None)
|
||||
|
||||
if existing_types is None:
|
||||
existing_types = {}
|
||||
anno.setanno(ast_node, anno.Static.CLOSURE_TYPES, existing_types)
|
||||
anno.Static.CLOSURE_TYPES.add_to(ast_node, existing_types)
|
||||
|
||||
for k, v in types.types.items():
|
||||
if k in existing_types:
|
||||
@ -428,6 +489,8 @@ class Analyzer(cfg.GraphVisitor):
|
||||
types_in = _SymbolTable()
|
||||
for n in node.prev:
|
||||
types_in |= self.out[n]
|
||||
if (self.context_types is not None) and (node is self.graph.entry):
|
||||
types_in |= self.context_types
|
||||
|
||||
types_out = _SymbolTable(types_in)
|
||||
ast_node = node.ast_node
|
||||
@ -437,8 +500,8 @@ class Analyzer(cfg.GraphVisitor):
|
||||
inferrer.visit(ast_node)
|
||||
types_out.types.update(inferrer.new_symbols)
|
||||
|
||||
reaching_fndefs = anno.getanno(ast_node, anno.Static.DEFINED_FNS_IN)
|
||||
node_scope = anno.getanno(ast_node, anno.Static.SCOPE, None)
|
||||
reaching_fndefs = anno.Static.DEFINED_FNS_IN.of(ast_node)
|
||||
node_scope = anno.Static.SCOPE.of(ast_node, None)
|
||||
if node_scope is not None:
|
||||
# TODO(mdan): Check that it's actually safe to skip nodes without scope.
|
||||
reads = {str(qn) for qn in node_scope.read}
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
@ -33,7 +35,10 @@ class BasicTestResolver(type_inference.Resolver):
|
||||
"""A very basic resolver for testing."""
|
||||
|
||||
def res_name(self, ns, types_ns, name):
|
||||
return {type(ns[str(name)])}, ns[str(name)]
|
||||
str_name = str(name)
|
||||
if str_name == 'int':
|
||||
return {int}, int
|
||||
return {type(ns[str_name])}, ns[str_name]
|
||||
|
||||
def res_value(self, ns, value):
|
||||
return {type(value)}
|
||||
@ -72,7 +77,9 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
def assertClosureTypes(self, node, expected):
|
||||
actual = anno.getanno(node, anno.Static.CLOSURE_TYPES)
|
||||
actual = {str(k): v for k, v in actual.items()}
|
||||
self.assertDictEqual(actual, expected)
|
||||
for k, v in expected.items():
|
||||
self.assertIn(k, actual)
|
||||
self.assertEqual(actual[k], v)
|
||||
|
||||
def test_no_inference_on_unknown_operand_types(self):
|
||||
|
||||
@ -188,10 +195,11 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[0].value, int)
|
||||
self.assertTypes(fn_body[0].value.func, str)
|
||||
self.assertEqual(
|
||||
anno.getanno(fn_body[0].value.func, anno.Static.VALUE), tc.a)
|
||||
self.assertTypes(fn_body[0].value.func, str)
|
||||
self.assertTypes(fn_body[0].value, int)
|
||||
self.assertTypes(fn_body[0], int)
|
||||
|
||||
def test_assign_overwriting(self):
|
||||
|
||||
@ -463,6 +471,26 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
self.assertTypes(fn_body[0].body[0].value, 'int')
|
||||
self.assertClosureTypes(fn_body[0], {'x': {'int'}})
|
||||
|
||||
def test_local_function_closure_nested(self):
|
||||
|
||||
def test_fn(x: int):
|
||||
|
||||
def foo():
|
||||
|
||||
def bar():
|
||||
return x
|
||||
|
||||
bar()
|
||||
|
||||
foo()
|
||||
|
||||
node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[0].body[0].body[0].value, 'int')
|
||||
self.assertClosureTypes(fn_body[0], {'x': {'int'}})
|
||||
self.assertClosureTypes(fn_body[0].body[0], {'x': {'int'}})
|
||||
|
||||
def test_local_function_closure_mutable_var(self):
|
||||
|
||||
def test_fn(x: int):
|
||||
@ -512,6 +540,22 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
self.assertTypes(fn_body[1].targets[0], float)
|
||||
self.assertClosureTypes(fn_body[0], {'x': {float}})
|
||||
|
||||
def test_local_function_type(self):
|
||||
|
||||
def test_fn(x: int):
|
||||
|
||||
def foo() -> int:
|
||||
return x
|
||||
|
||||
foo()
|
||||
|
||||
node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[1].value.func, Callable[[Any], int])
|
||||
self.assertTypes(fn_body[1].value, int)
|
||||
self.assertTypes(fn_body[1], int)
|
||||
|
||||
def test_side_effects_on_arg_function_closure(self):
|
||||
|
||||
test_self = self
|
||||
|
Loading…
Reference in New Issue
Block a user