Add support for list and tuple expansions, inferring types of list literals.

PiperOrigin-RevId: 346778490
Change-Id: Idbbf9079090e22df07c30082610344e885a69545
This commit is contained in:
Dan Moldovan 2020-12-10 07:19:29 -08:00 committed by TensorFlower Gardener
parent 82738f8e3c
commit 2a9070791f
3 changed files with 319 additions and 68 deletions

View File

@ -23,10 +23,10 @@ from __future__ import division
from __future__ import print_function
import enum
import inspect
import os
import re
import types
from typing import List, Tuple
import gast as ast
from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr
@ -46,7 +46,8 @@ from tensorflow.python.autograph.pyct.static_analysis import type_inference
from tensorflow.python.framework import load_library
from tensorflow.python.framework import op_def_registry
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
# TODO(mdan): Use class definitions so that we can mix these with Python types.
class TFRTypes(enum.Enum):
@ -410,7 +411,7 @@ class TFRTypeResolver(type_inference.Resolver):
iterated_type = args[0]
assert iterated_type & {
TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, List[int]
TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR
}, (
iterated_type)
self._for_loop_target_types[body_fn_name] = iterated_type
@ -443,7 +444,7 @@ class TFRTypeResolver(type_inference.Resolver):
elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
assert name.is_simple()
if name == QN('range'):
return {List[int]}, None
return {TFRTypes.ATTR}, None
if name == QN('len'):
return {TFRTypes.INDEX}, None
@ -459,7 +460,7 @@ class TFRTypeResolver(type_inference.Resolver):
if f_name_str in self._for_loop_target_types:
# See autograph/converters/control_flow.py - the function has a single
# argument, the iterate before any expansion.
assert self._for_loop_target_types[f_name_str] & {List[int]}
assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR}
# Assume all loops are TF loops. Then the iterates are autoboxed into
# Tensors.
return {TFRTypes.INDEX}
@ -471,7 +472,7 @@ class TFRTypeResolver(type_inference.Resolver):
op_def, derived_attrs = self._op_defs.lookup(f_name, func)
if op_def is None:
return None
pos = tf_inspect.getfullargspec(func).args.index(str(name))
pos = inspect.getfullargspec(func).args.index(str(name))
if pos < len(op_def.input_arg):
arg_def = op_def.input_arg[pos]
@ -488,7 +489,7 @@ class TFRTypeResolver(type_inference.Resolver):
raise ValueError('Argument is not defined in OpDef: ' + str(name))
def res_subscript(self, ns, types_ns, node_or_slice, value, slice_):
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
assert len(value) == 1
value, = tuple(value)
if value == TFRTypes.TF_TENSOR_SHAPE_LIST:
@ -503,10 +504,40 @@ class TFRTypeResolver(type_inference.Resolver):
# TODO(fengliuai): make sure left and right are compatible
return {TFRTypes.I1}
def res_unop(self, ns, types_ns, node, opnd):
return opnd
def res_binop(self, ns, types_ns, node, left, right):
# TODO(fengliuai): make sure left and right are compatible
return left
def _coerce_to_more_specific_type(self, elt_types):
# TODO(mdan): This needs some type theory study.
if TFRTypes.INDEX in elt_types:
# Constants collapse to indices.
elt_types.discard(TFRTypes.I64)
if TFRTypes.TENSOR in elt_types:
# Constants collapse to tensors.
elt_types.discard(TFRTypes.I64)
# Indices collapse to tensors.
elt_types.discard(TFRTypes.INDEX)
return elt_types
def res_list_literal(self, ns, elt_types):
all_elt_types = set()
for t in elt_types:
all_elt_types |= t
if len(all_elt_types) != 1:
all_elt_types = self._coerce_to_more_specific_type(all_elt_types)
if len(all_elt_types) != 1:
raise ValueError('ambiguous list element types: {}'.format(elt_types))
if TFRTypes.TENSOR in all_elt_types:
return {TFRTypes.TENSOR_LIST}
return {TFRTypes.ATTR}
class SymbolTable(object):
"""Symbol Table for python code."""
@ -599,22 +630,6 @@ class TFRGen(transformer.CodeGenerator):
node, types_))
type_, = types_
# TODO(fengliuai): Tuple is added here to make return tuple work.
if type_ is list or type_ is Tuple:
# TODO(fengliuai): Seems like we need to move the followed list handling
# to the type inference and we shouldn't just put 'list' there. Otherwise
# we couldn't find out the right type for the Name node.
if not isinstance(node, ast.List):
return default
all_types = [
anno.getanno(elt, anno.Static.TYPES, None) for elt in node.elts
]
if (TFRTypes.TENSOR,) in all_types:
# For the elt which is not tfr.tensor, tfr.constant_tensor needs to be
# use to cast it to a tfr.tensor.
return TFRTypes.TENSOR_LIST
else:
return TFRTypes.ATTR
if default is not None and type_ != default:
print('WARN: type annotation {}({}) does not match {}({})'.format(
@ -704,7 +719,6 @@ class TFRGen(transformer.CodeGenerator):
if isinstance(node.value, ast.Attribute):
if isinstance(node.value.value, ast.Name):
if node.value.value.id == 'tf' and node.value.attr == 'raw_ops':
# This branch is used when it is outside tensorflow
return (node.attr, TFRTypes.TF_RAW_OP)
value, ty = self.visit(node.value)
@ -726,13 +740,24 @@ class TFRGen(transformer.CodeGenerator):
raise NotImplementedError('Assignment target type not recognized.')
if isinstance(values, list):
if isinstance(node.value, ast.Call):
expected = tuple(t for n, t in values)
if len(values) == 1:
expected = expected[0]
elif isinstance(node.value, ast.Tuple):
expected = tuple(t for n, t in values)
else:
raise ValueError('unknown assignment target node', node.value)
ty = self._get_inferred_type(node.value, expected)
if len(targets) == len(values):
for key, value in zip(targets, values):
ssa_value, ty_ = value
ty = self._get_inferred_type(node.value, ty_)
self.symbol_table.insert_symbol(key, ssa_value, ty)
# TODO(mdan): This should already be a tuple.
ty_ = (ty,) if len(values) == 1 else ty
for key, value, t in zip(targets, values, ty_):
ssa_value, _ = value
self.symbol_table.insert_symbol(key, ssa_value, t)
elif len(values) == 1:
n, ty = values[0]
n, _ = values[0]
assert ty == TFRTypes.TENSOR_LIST
# assign a tensor_list to multiple variables
for idx, key in enumerate(targets):
@ -747,10 +772,11 @@ class TFRGen(transformer.CodeGenerator):
self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR)
elif len(targets) == 1:
ssa_names = [n for n, _ in values]
tys = [t for _, t in values]
self.symbol_table.insert_symbol(targets[0], ssa_names, tys)
else:
self.symbol_table.insert_symbol(targets[0], values[0], values[1])
self.symbol_table.insert_symbol(targets[0], ssa_names, ty)
return
ty = self._get_inferred_type(node.value, values[1])
self.symbol_table.insert_symbol(targets[0], values[0], ty)
def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty):
assert lhs_ty, rhs_ty
@ -795,7 +821,7 @@ class TFRGen(transformer.CodeGenerator):
def visit_Call(self, node):
func_name, func_type = self.visit(node.func)
_ = self._get_inferred_type(node.func, func_type)
func_type = self._get_inferred_type(node.func, func_type)
if func_type == TFRTypes.AG_BUILTIN_FUNC:
if func_name == 'if_stmt':
cond, _ = self.visit(node.args[0])
@ -1285,6 +1311,7 @@ class TFRGen(transformer.CodeGenerator):
tys = []
for elt in node.elts:
val, ty = self.visit(elt)
ty = self._get_inferred_type(elt, ty)
if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST:
# This list is a tensor list, then cast all the input values to tensors.
val, ty = self._value_to_tensor(val, ty, node)
@ -1405,7 +1432,7 @@ def tfr_gen_from_module(source, method_prefix=None, op_libraries=None):
py_funcs = [
func
for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction)
for name, func in inspect.getmembers(source, inspect.isfunction)
if not method_prefix or name.startswith(method_prefix)
]
# Sort the methods by the line number, to make sure the definitions are

View File

@ -31,7 +31,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Callable, Tuple
import itertools
from typing import Any, Callable, Dict, Set
import gast
@ -39,6 +41,7 @@ from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import annos
@ -118,12 +121,20 @@ class Resolver(object):
"""Resolves the return type of a unary operation."""
raise NotImplementedError('subclasses must implement')
def res_binop(self, ns, types_ns, node, left, right):
def res_unop(self, ns, types_ns, node, opnd):
"""Resolves the return type of a unary operation."""
raise NotImplementedError('subclasses must implement')
def res_binop(self, ns, types_ns, node, left, right):
"""Resolves the return type of a binary operation."""
raise NotImplementedError('subclasses must implement')
class _SymbolTable(object):
def res_list_literal(self, ns, elt_types):
"""Resolves the type of a list literal from its elements."""
raise NotImplementedError('subclasses must implement')
class _TypeMap(object):
"""Abstraction for the state of the CFG walk for type inference.
This is a value type. Only implements the strictly necessary operators.
@ -135,7 +146,7 @@ class _SymbolTable(object):
def __init__(self, init_from=None):
if init_from:
assert isinstance(init_from, _SymbolTable)
assert isinstance(init_from, _TypeMap)
self.types = {
s: set(other_types) for s, other_types in init_from.types.items()
}
@ -152,8 +163,8 @@ class _SymbolTable(object):
return not self.__eq__(other)
def __or__(self, other):
assert isinstance(other, _SymbolTable)
result = _SymbolTable(self)
assert isinstance(other, _TypeMap)
result = _TypeMap(self)
for s, other_types in other.types.items():
if s not in result.types:
self_types = set()
@ -192,13 +203,22 @@ class StmtInferrer(gast.NodeVisitor):
print(a) # a = int; side effect of f() accounted for
"""
def __init__(self, resolver, scope, namespace, closure_types, types_in):
def __init__(self,
resolver: Resolver,
scope: activity.Scope,
namespace: Dict[qual_names.QN, Any],
closure_types: Dict[qual_names.QN, Set[Any]],
types_in: _TypeMap):
self.resolver = resolver
self.scope = scope
self.namespace = namespace
self.closure_types = closure_types
self.types_in = types_in
self.new_symbols = {}
# rvalue type. This property is set when encountering an assign operation,
# so that visiting nodes with Store ctx (typically found on left side of
# assignments) can infer the type they should receive.
self.rtype = None
def visit(self, node):
@ -221,36 +241,36 @@ class StmtInferrer(gast.NodeVisitor):
self._check_set(types)
return types
def visit_Tuple(self, node):
if isinstance(node.ctx, gast.Load):
for elt in node.elts:
self.visit(elt)
# TODO(mdan): Parameterize it.
return {Tuple}
def _apply_unpacking(self, node):
assert isinstance(node.ctx, gast.Store)
if self.rtype is not None:
original_stype = self.rtype
# TODO(mdan): Find a better way to express unpacking.
i_type = self.resolver.res_value(self.namespace, 0)
for i, elt in enumerate(node.elts):
self.rtype = self.resolver.res_subscript(
self.rtype = self.resolver.res_slice(
self.namespace, self.types_in.types, i, original_stype, i_type)
self.visit(elt)
self.rtype = original_stype
return original_stype
return None
def visit_Tuple(self, node):
if isinstance(node.ctx, gast.Load):
elt_types = ()
for elt in node.elts:
types_ = self.visit(elt)
if types_ is None:
return None
elt_types += (types_,)
return set(itertools.product(*elt_types))
return self._apply_unpacking(node)
def visit_List(self, node):
if isinstance(node.ctx, gast.Load):
el_types = []
for elt in node.elts:
el_types.append(self.visit(elt))
return {list}
raise NotImplementedError('list unpacking')
elt_types = tuple(self.visit(elt) for elt in node.elts)
return self.resolver.res_list_literal(self.namespace, elt_types)
return self._apply_unpacking(node)
def visit_Set(self, node):
raise NotImplementedError()
@ -442,7 +462,7 @@ class StmtInferrer(gast.NodeVisitor):
if val_types is None or slice_types is None:
return None
types = self.resolver.res_subscript(
types = self.resolver.res_slice(
self.namespace, self.types_in.types, node, val_types, slice_types)
if __debug__:
@ -480,6 +500,20 @@ class StmtInferrer(gast.NodeVisitor):
return types
def visit_UnaryOp(self, node):
opnd_types = self.visit(node.operand)
if opnd_types is None:
return None
types = self.resolver.res_unop(
self.namespace, self.types_in.types, node, opnd_types)
if __debug__:
self._check_set(types)
return types
class Analyzer(cfg.GraphVisitor):
"""CFG visitor that propagates type information across statements."""
@ -504,13 +538,13 @@ class Analyzer(cfg.GraphVisitor):
n: t for n, t in closure_types.items() if n not in scope.bound
}
if context_types:
self.context_types = _SymbolTable()
self.context_types = _TypeMap()
self.context_types.types = context_types
else:
self.context_types = None
def init_state(self, _):
return _SymbolTable()
return _TypeMap()
def _update_closure_types(self, ast_node, types):
existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None)
@ -528,13 +562,13 @@ class Analyzer(cfg.GraphVisitor):
def visit_node(self, node):
prev_types_out = self.out[node]
types_in = _SymbolTable()
types_in = _TypeMap()
for n in node.prev:
types_in |= self.out[n]
if (self.context_types is not None) and (node is self.graph.entry):
types_in |= self.context_types
types_out = _SymbolTable(types_in)
types_out = _TypeMap(types_in)
ast_node = node.ast_node
inferrer = StmtInferrer(self.resolver, self.scope, self.namespace,

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Callable, Tuple
from typing import Any, Callable, List
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
@ -171,7 +171,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
node, _ = tr.transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[0].body[0].value, Tuple)
self.assertTypes(fn_body[0].body[0].value, (('x_type', 'y_type'),))
self.assertTypes(fn_body[0].body[0].value.elts[0], 'x_type')
self.assertTypes(fn_body[0].body[0].value.elts[1], 'y_type')
@ -656,7 +656,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
def res_value(self, ns, value):
return {int}
def res_subscript(self, ns, types_ns, node, value, slice_):
def res_slice(self, ns, types_ns, node, value, slice_):
test_self.assertSetEqual(value, {list})
test_self.assertSetEqual(slice_, {int})
return {str}
@ -683,7 +683,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
def res_value(self, ns, value):
return {int}
def res_subscript(self, ns, types_ns, node_or_slice, value, slice_):
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
test_self.assertIn(node_or_slice, (0, 1))
test_self.assertSetEqual(value, {list})
test_self.assertSetEqual(slice_, {int})
@ -699,7 +699,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[1].value, Tuple)
self.assertTypes(fn_body[1].value, ((float, str),))
self.assertTypes(fn_body[1].value.elts[0], float)
self.assertTypes(fn_body[1].value.elts[1], str)
@ -751,6 +751,196 @@ class TypeInferenceAnalyzerTest(test.TestCase):
self.assertTypes(fn_body[0].value.left, list)
self.assertTypes(fn_body[0].value.right, list)
def test_unop(self):
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {list}
def res_unop(self, ns, types_ns, node, opnd):
return {float}
def test_fn(a):
return -a
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[0].value, float)
self.assertTypes(fn_body[0].value.operand, list)
def test_tuple_literal(self):
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {int}
def test_fn(a, b):
return a, b
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[0].value, ((int, int),))
self.assertTypes(fn_body[0].value.elts[0], int)
self.assertTypes(fn_body[0].value.elts[1], int)
def test_list_literal(self):
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {int}
def res_list_literal(self, ns, elt_types):
all_types = set()
for s in elt_types:
all_types |= s
return {List[t] for t in all_types}
def test_fn(a, b):
return [a, b]
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[0].value, List[int])
self.assertTypes(fn_body[0].value.elts[0], int)
self.assertTypes(fn_body[0].value.elts[1], int)
def test_tuple_unpacking_syntactic(self):
test_self = self
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
if name == qual_names.QN('a'):
return {int}
else:
return {float}
def res_value(self, ns, value):
test_self.assertIn(value, (0, 1))
return int
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
test_self.assertIn(node_or_slice, (0, 1))
test_self.assertSetEqual(value, {(int, float)})
test_self.assertEqual(slice_, int)
return {t[node_or_slice] for t in value}
def test_fn(a, b):
c, d = a, b
return c, d
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[1].value, ((int, float),))
self.assertTypes(fn_body[1].value.elts[0], int)
self.assertTypes(fn_body[1].value.elts[1], float)
def test_tuple_unpacking_operational(self):
test_self = self
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
return {(int, float)}
def res_value(self, ns, value):
test_self.assertIn(value, (0, 1))
return int
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
test_self.assertIn(node_or_slice, (0, 1))
test_self.assertSetEqual(value, {(int, float)})
test_self.assertEqual(slice_, int)
return {t[node_or_slice] for t in value}
def test_fn(a):
c, d = a
return c, d
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[1].value, ((int, float),))
self.assertTypes(fn_body[1].value.elts[0], int)
self.assertTypes(fn_body[1].value.elts[1], float)
def test_list_expansion_syntactic(self):
test_self = self
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
if name == qual_names.QN('a'):
return {int}
else:
return {float}
def res_value(self, ns, value):
test_self.assertIn(value, (0, 1))
return int
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
test_self.assertIn(node_or_slice, (0, 1))
test_self.assertSetEqual(value, {(int, float)})
test_self.assertEqual(slice_, int)
return {t[node_or_slice] for t in value}
def test_fn(a, b):
[c, d] = a, b
return c, d
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
# TODO(mdan): Whether it's List or Tuple might be open for interpretation.
self.assertTypes(fn_body[1].value, ((int, float),))
self.assertTypes(fn_body[1].value.elts[0], int)
self.assertTypes(fn_body[1].value.elts[1], float)
def test_list_expansion_operational(self):
test_self = self
class Resolver(type_inference.Resolver):
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
if name == qual_names.QN('a'):
return {int}
else:
return {float}
def res_value(self, ns, value):
test_self.assertIn(value, (0, 1))
return int
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
test_self.assertIn(node_or_slice, (0, 1))
test_self.assertSetEqual(value, {(int, float)})
test_self.assertEqual(slice_, int)
return {t[node_or_slice] for t in value}
def test_fn(a, b):
[c, d] = a, b
return c, d
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
# TODO(mdan): Whether it's List or Tuple might be open for interpretation.
self.assertTypes(fn_body[1].value, ((int, float),))
self.assertTypes(fn_body[1].value.elts[0], int)
self.assertTypes(fn_body[1].value.elts[1], float)
if __name__ == '__main__':
test.main()