Properly inherit closure types in local functions. Add partial support for resolving local functions based on their type annotations. Propagate types into Expr nodes (although these are roots in expression trees).

PiperOrigin-RevId: 325488762
Change-Id: Id1754d65bf15b47ca0ef991959d6491c7ebdc118
This commit is contained in:
Dan Moldovan 2020-08-07 12:45:04 -07:00 committed by TensorFlower Gardener
parent f3fad99f6f
commit 51ecfb3061
3 changed files with 134 additions and 23 deletions

View File

@ -35,10 +35,14 @@ import gast
class NoValue(enum.Enum):
"""Base class for different types of AST annotations."""
def of(self, node, default=None):
return getanno(node, self, default=default)
def add_to(self, node, value):
setanno(node, self, value)
def exists(self, node):
return hasanno(node, self)

View File

@ -31,7 +31,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Tuple
from typing import Any, Callable, Tuple
import gast
@ -187,16 +187,13 @@ class StmtInferrer(gast.NodeVisitor):
def visit(self, node):
types = super().visit(node)
if __debug__:
self._check_set(types)
if types is not None:
# TODO(mdan): Normalize by removing subtypes.
anno.setanno(node, anno.Static.TYPES, tuple(types))
return types
def visit_FunctionDef(self, node):
# Skip local function definitions. They are analyzed separately.
# TODO(mdan): Don't skip. Analyze side effects instead.
return None
def _check_set(self, value):
if value is not None and not isinstance(value, set):
raise ValueError('{} method expected to return set, got {}'.format(
@ -300,21 +297,73 @@ class StmtInferrer(gast.NodeVisitor):
return types
def visit_FunctionDef(self, node):
f_name = qual_names.QN(node.name)
if node.decorator_list:
raise NotImplementedError('decorators: {}'.format(node.decorator_list))
# TODO(mdan): Use args.
ret_types = None
if node.returns:
ret_types, _ = self.resolver.res_name(
self.namespace, self.types_in.types, anno.Basic.QN.of(node.returns))
if __debug__:
self._check_set(ret_types)
if ret_types is None:
ret_types = {Any}
fn_types = set()
for rt in ret_types:
fn_types.add(Callable[[Any], rt])
self.new_symbols[f_name] = fn_types
# The definition of a function is an expression, hence has no return value.
return None
def _resolve_typed_callable(self, fn_types, arg_types, keyword_types):
ret_types = set()
for t in fn_types:
if isinstance(t, Callable):
# Note: these are undocummented - may be version-specific!
# Callable[[x], y]: __args__ are (x, y)
args = t.__args__
if args:
ret_types.add(args[-1])
else:
ret_types.add(Any)
else:
raise NotImplementedError('callable type {}'.format(type(t)))
# Side effects can not be inferred based on type alone.
side_effects = None
return ret_types, side_effects
def visit_Call(self, node):
self.visit(node.func)
f_name = anno.getanno(node.func, anno.Basic.QN)
if f_name in self.scope.bound:
# Don't attempt external resolution of local functions.
# TODO(mdan): Use type annotations of the local definition.
return None
f_name = anno.Basic.QN.of(node.func)
arg_types = [self.visit(a) for a in node.args]
keyword_types = [self.visit(kw.value) for kw in node.keywords]
ret_type, side_effects = self.resolver.res_call(self.namespace,
self.types_in.types, node,
arg_types, keyword_types)
if f_name in self.scope.bound:
# Local function, use local type definitions, if available.
fn_type = self.types_in.types.get(f_name, None)
if fn_type is None:
# No static type info available, nothing more to do.
ret_type, side_effects = None, None
else:
ret_type, side_effects = self._resolve_typed_callable(
self.types_in.types.get(f_name), arg_types, keyword_types)
else:
# Nonlocal function, resolve externally.
ret_type, side_effects = self.resolver.res_call(self.namespace,
self.types_in.types, node,
arg_types, keyword_types)
if __debug__:
self._check_set(ret_type)
if side_effects:
@ -330,6 +379,9 @@ class StmtInferrer(gast.NodeVisitor):
self.new_symbols.update(side_effects)
return ret_type
def visit_Expr(self, node):
return self.visit(node.value)
def visit_Index(self, node):
return self.visit(node.value)
@ -406,15 +458,24 @@ class Analyzer(cfg.GraphVisitor):
self.scope = scope
self.closure_types = closure_types
context_types = {
n: t for n, t in closure_types.items() if n not in scope.bound
}
if context_types:
self.context_types = _SymbolTable()
self.context_types.types = context_types
else:
self.context_types = None
def init_state(self, _):
return _SymbolTable()
def _update_closure_types(self, ast_node, types):
existing_types = anno.getanno(ast_node, anno.Static.CLOSURE_TYPES, None)
existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None)
if existing_types is None:
existing_types = {}
anno.setanno(ast_node, anno.Static.CLOSURE_TYPES, existing_types)
anno.Static.CLOSURE_TYPES.add_to(ast_node, existing_types)
for k, v in types.types.items():
if k in existing_types:
@ -428,6 +489,8 @@ class Analyzer(cfg.GraphVisitor):
types_in = _SymbolTable()
for n in node.prev:
types_in |= self.out[n]
if (self.context_types is not None) and (node is self.graph.entry):
types_in |= self.context_types
types_out = _SymbolTable(types_in)
ast_node = node.ast_node
@ -437,8 +500,8 @@ class Analyzer(cfg.GraphVisitor):
inferrer.visit(ast_node)
types_out.types.update(inferrer.new_symbols)
reaching_fndefs = anno.getanno(ast_node, anno.Static.DEFINED_FNS_IN)
node_scope = anno.getanno(ast_node, anno.Static.SCOPE, None)
reaching_fndefs = anno.Static.DEFINED_FNS_IN.of(ast_node)
node_scope = anno.Static.SCOPE.of(ast_node, None)
if node_scope is not None:
# TODO(mdan): Check that it's actually safe to skip nodes without scope.
reads = {str(qn) for qn in node_scope.read}

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Callable
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import qual_names
@ -33,7 +35,10 @@ class BasicTestResolver(type_inference.Resolver):
"""A very basic resolver for testing."""
def res_name(self, ns, types_ns, name):
return {type(ns[str(name)])}, ns[str(name)]
str_name = str(name)
if str_name == 'int':
return {int}, int
return {type(ns[str_name])}, ns[str_name]
def res_value(self, ns, value):
return {type(value)}
@ -72,7 +77,9 @@ class TypeInferenceAnalyzerTest(test.TestCase):
def assertClosureTypes(self, node, expected):
actual = anno.getanno(node, anno.Static.CLOSURE_TYPES)
actual = {str(k): v for k, v in actual.items()}
self.assertDictEqual(actual, expected)
for k, v in expected.items():
self.assertIn(k, actual)
self.assertEqual(actual[k], v)
def test_no_inference_on_unknown_operand_types(self):
@ -188,10 +195,11 @@ class TypeInferenceAnalyzerTest(test.TestCase):
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[0].value, int)
self.assertTypes(fn_body[0].value.func, str)
self.assertEqual(
anno.getanno(fn_body[0].value.func, anno.Static.VALUE), tc.a)
self.assertTypes(fn_body[0].value.func, str)
self.assertTypes(fn_body[0].value, int)
self.assertTypes(fn_body[0], int)
def test_assign_overwriting(self):
@ -463,6 +471,26 @@ class TypeInferenceAnalyzerTest(test.TestCase):
self.assertTypes(fn_body[0].body[0].value, 'int')
self.assertClosureTypes(fn_body[0], {'x': {'int'}})
def test_local_function_closure_nested(self):
def test_fn(x: int):
def foo():
def bar():
return x
bar()
foo()
node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[0].body[0].body[0].value, 'int')
self.assertClosureTypes(fn_body[0], {'x': {'int'}})
self.assertClosureTypes(fn_body[0].body[0], {'x': {'int'}})
def test_local_function_closure_mutable_var(self):
def test_fn(x: int):
@ -512,6 +540,22 @@ class TypeInferenceAnalyzerTest(test.TestCase):
self.assertTypes(fn_body[1].targets[0], float)
self.assertClosureTypes(fn_body[0], {'x': {float}})
def test_local_function_type(self):
def test_fn(x: int):
def foo() -> int:
return x
foo()
node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[1].value.func, Callable[[Any], int])
self.assertTypes(fn_body[1].value, int)
self.assertTypes(fn_body[1], int)
def test_side_effects_on_arg_function_closure(self):
test_self = self