Add support for list and tuple expansions, inferring types of list literals.
PiperOrigin-RevId: 346778490 Change-Id: Idbbf9079090e22df07c30082610344e885a69545
This commit is contained in:
parent
82738f8e3c
commit
2a9070791f
@ -23,10 +23,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import enum
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
from typing import List, Tuple
|
||||
import gast as ast
|
||||
|
||||
from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr
|
||||
@ -46,7 +46,8 @@ from tensorflow.python.autograph.pyct.static_analysis import type_inference
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
# TODO(mdan): Use class definitions so that we can mix these with Python types.
|
||||
|
||||
|
||||
class TFRTypes(enum.Enum):
|
||||
@ -410,7 +411,7 @@ class TFRTypeResolver(type_inference.Resolver):
|
||||
|
||||
iterated_type = args[0]
|
||||
assert iterated_type & {
|
||||
TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, List[int]
|
||||
TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR
|
||||
}, (
|
||||
iterated_type)
|
||||
self._for_loop_target_types[body_fn_name] = iterated_type
|
||||
@ -443,7 +444,7 @@ class TFRTypeResolver(type_inference.Resolver):
|
||||
elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
|
||||
assert name.is_simple()
|
||||
if name == QN('range'):
|
||||
return {List[int]}, None
|
||||
return {TFRTypes.ATTR}, None
|
||||
|
||||
if name == QN('len'):
|
||||
return {TFRTypes.INDEX}, None
|
||||
@ -459,7 +460,7 @@ class TFRTypeResolver(type_inference.Resolver):
|
||||
if f_name_str in self._for_loop_target_types:
|
||||
# See autograph/converters/control_flow.py - the function has a single
|
||||
# argument, the iterate before any expansion.
|
||||
assert self._for_loop_target_types[f_name_str] & {List[int]}
|
||||
assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR}
|
||||
# Assume all loops are TF loops. Then the iterates are autoboxed into
|
||||
# Tensors.
|
||||
return {TFRTypes.INDEX}
|
||||
@ -471,7 +472,7 @@ class TFRTypeResolver(type_inference.Resolver):
|
||||
op_def, derived_attrs = self._op_defs.lookup(f_name, func)
|
||||
if op_def is None:
|
||||
return None
|
||||
pos = tf_inspect.getfullargspec(func).args.index(str(name))
|
||||
pos = inspect.getfullargspec(func).args.index(str(name))
|
||||
|
||||
if pos < len(op_def.input_arg):
|
||||
arg_def = op_def.input_arg[pos]
|
||||
@ -488,7 +489,7 @@ class TFRTypeResolver(type_inference.Resolver):
|
||||
|
||||
raise ValueError('Argument is not defined in OpDef: ' + str(name))
|
||||
|
||||
def res_subscript(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
assert len(value) == 1
|
||||
value, = tuple(value)
|
||||
if value == TFRTypes.TF_TENSOR_SHAPE_LIST:
|
||||
@ -503,10 +504,40 @@ class TFRTypeResolver(type_inference.Resolver):
|
||||
# TODO(fengliuai): make sure left and right are compatible
|
||||
return {TFRTypes.I1}
|
||||
|
||||
def res_unop(self, ns, types_ns, node, opnd):
|
||||
return opnd
|
||||
|
||||
def res_binop(self, ns, types_ns, node, left, right):
|
||||
# TODO(fengliuai): make sure left and right are compatible
|
||||
return left
|
||||
|
||||
def _coerce_to_more_specific_type(self, elt_types):
|
||||
# TODO(mdan): This needs some type theory study.
|
||||
if TFRTypes.INDEX in elt_types:
|
||||
# Constants collapse to indices.
|
||||
elt_types.discard(TFRTypes.I64)
|
||||
if TFRTypes.TENSOR in elt_types:
|
||||
# Constants collapse to tensors.
|
||||
elt_types.discard(TFRTypes.I64)
|
||||
# Indices collapse to tensors.
|
||||
elt_types.discard(TFRTypes.INDEX)
|
||||
return elt_types
|
||||
|
||||
def res_list_literal(self, ns, elt_types):
|
||||
all_elt_types = set()
|
||||
for t in elt_types:
|
||||
all_elt_types |= t
|
||||
|
||||
if len(all_elt_types) != 1:
|
||||
all_elt_types = self._coerce_to_more_specific_type(all_elt_types)
|
||||
|
||||
if len(all_elt_types) != 1:
|
||||
raise ValueError('ambiguous list element types: {}'.format(elt_types))
|
||||
|
||||
if TFRTypes.TENSOR in all_elt_types:
|
||||
return {TFRTypes.TENSOR_LIST}
|
||||
return {TFRTypes.ATTR}
|
||||
|
||||
|
||||
class SymbolTable(object):
|
||||
"""Symbol Table for python code."""
|
||||
@ -599,22 +630,6 @@ class TFRGen(transformer.CodeGenerator):
|
||||
node, types_))
|
||||
|
||||
type_, = types_
|
||||
# TODO(fengliuai): Tuple is added here to make return tuple work.
|
||||
if type_ is list or type_ is Tuple:
|
||||
# TODO(fengliuai): Seems like we need to move the followed list handling
|
||||
# to the type inference and we shouldn't just put 'list' there. Otherwise
|
||||
# we couldn't find out the right type for the Name node.
|
||||
if not isinstance(node, ast.List):
|
||||
return default
|
||||
all_types = [
|
||||
anno.getanno(elt, anno.Static.TYPES, None) for elt in node.elts
|
||||
]
|
||||
if (TFRTypes.TENSOR,) in all_types:
|
||||
# For the elt which is not tfr.tensor, tfr.constant_tensor needs to be
|
||||
# use to cast it to a tfr.tensor.
|
||||
return TFRTypes.TENSOR_LIST
|
||||
else:
|
||||
return TFRTypes.ATTR
|
||||
|
||||
if default is not None and type_ != default:
|
||||
print('WARN: type annotation {}({}) does not match {}({})'.format(
|
||||
@ -704,7 +719,6 @@ class TFRGen(transformer.CodeGenerator):
|
||||
if isinstance(node.value, ast.Attribute):
|
||||
if isinstance(node.value.value, ast.Name):
|
||||
if node.value.value.id == 'tf' and node.value.attr == 'raw_ops':
|
||||
# This branch is used when it is outside tensorflow
|
||||
return (node.attr, TFRTypes.TF_RAW_OP)
|
||||
|
||||
value, ty = self.visit(node.value)
|
||||
@ -726,13 +740,24 @@ class TFRGen(transformer.CodeGenerator):
|
||||
raise NotImplementedError('Assignment target type not recognized.')
|
||||
|
||||
if isinstance(values, list):
|
||||
if isinstance(node.value, ast.Call):
|
||||
expected = tuple(t for n, t in values)
|
||||
if len(values) == 1:
|
||||
expected = expected[0]
|
||||
elif isinstance(node.value, ast.Tuple):
|
||||
expected = tuple(t for n, t in values)
|
||||
else:
|
||||
raise ValueError('unknown assignment target node', node.value)
|
||||
ty = self._get_inferred_type(node.value, expected)
|
||||
|
||||
if len(targets) == len(values):
|
||||
for key, value in zip(targets, values):
|
||||
ssa_value, ty_ = value
|
||||
ty = self._get_inferred_type(node.value, ty_)
|
||||
self.symbol_table.insert_symbol(key, ssa_value, ty)
|
||||
# TODO(mdan): This should already be a tuple.
|
||||
ty_ = (ty,) if len(values) == 1 else ty
|
||||
for key, value, t in zip(targets, values, ty_):
|
||||
ssa_value, _ = value
|
||||
self.symbol_table.insert_symbol(key, ssa_value, t)
|
||||
elif len(values) == 1:
|
||||
n, ty = values[0]
|
||||
n, _ = values[0]
|
||||
assert ty == TFRTypes.TENSOR_LIST
|
||||
# assign a tensor_list to multiple variables
|
||||
for idx, key in enumerate(targets):
|
||||
@ -747,10 +772,11 @@ class TFRGen(transformer.CodeGenerator):
|
||||
self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR)
|
||||
elif len(targets) == 1:
|
||||
ssa_names = [n for n, _ in values]
|
||||
tys = [t for _, t in values]
|
||||
self.symbol_table.insert_symbol(targets[0], ssa_names, tys)
|
||||
else:
|
||||
self.symbol_table.insert_symbol(targets[0], values[0], values[1])
|
||||
self.symbol_table.insert_symbol(targets[0], ssa_names, ty)
|
||||
return
|
||||
|
||||
ty = self._get_inferred_type(node.value, values[1])
|
||||
self.symbol_table.insert_symbol(targets[0], values[0], ty)
|
||||
|
||||
def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty):
|
||||
assert lhs_ty, rhs_ty
|
||||
@ -795,7 +821,7 @@ class TFRGen(transformer.CodeGenerator):
|
||||
|
||||
def visit_Call(self, node):
|
||||
func_name, func_type = self.visit(node.func)
|
||||
_ = self._get_inferred_type(node.func, func_type)
|
||||
func_type = self._get_inferred_type(node.func, func_type)
|
||||
if func_type == TFRTypes.AG_BUILTIN_FUNC:
|
||||
if func_name == 'if_stmt':
|
||||
cond, _ = self.visit(node.args[0])
|
||||
@ -1285,6 +1311,7 @@ class TFRGen(transformer.CodeGenerator):
|
||||
tys = []
|
||||
for elt in node.elts:
|
||||
val, ty = self.visit(elt)
|
||||
ty = self._get_inferred_type(elt, ty)
|
||||
if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST:
|
||||
# This list is a tensor list, then cast all the input values to tensors.
|
||||
val, ty = self._value_to_tensor(val, ty, node)
|
||||
@ -1405,7 +1432,7 @@ def tfr_gen_from_module(source, method_prefix=None, op_libraries=None):
|
||||
|
||||
py_funcs = [
|
||||
func
|
||||
for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction)
|
||||
for name, func in inspect.getmembers(source, inspect.isfunction)
|
||||
if not method_prefix or name.startswith(method_prefix)
|
||||
]
|
||||
# Sort the methods by the line number, to make sure the definitions are
|
||||
|
@ -31,7 +31,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Any, Callable, Tuple
|
||||
import itertools
|
||||
|
||||
from typing import Any, Callable, Dict, Set
|
||||
|
||||
import gast
|
||||
|
||||
@ -39,6 +41,7 @@ from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||
|
||||
|
||||
@ -118,12 +121,20 @@ class Resolver(object):
|
||||
"""Resolves the return type of a unary operation."""
|
||||
raise NotImplementedError('subclasses must implement')
|
||||
|
||||
def res_binop(self, ns, types_ns, node, left, right):
|
||||
def res_unop(self, ns, types_ns, node, opnd):
|
||||
"""Resolves the return type of a unary operation."""
|
||||
raise NotImplementedError('subclasses must implement')
|
||||
|
||||
def res_binop(self, ns, types_ns, node, left, right):
|
||||
"""Resolves the return type of a binary operation."""
|
||||
raise NotImplementedError('subclasses must implement')
|
||||
|
||||
class _SymbolTable(object):
|
||||
def res_list_literal(self, ns, elt_types):
|
||||
"""Resolves the type of a list literal from its elements."""
|
||||
raise NotImplementedError('subclasses must implement')
|
||||
|
||||
|
||||
class _TypeMap(object):
|
||||
"""Abstraction for the state of the CFG walk for type inference.
|
||||
|
||||
This is a value type. Only implements the strictly necessary operators.
|
||||
@ -135,7 +146,7 @@ class _SymbolTable(object):
|
||||
|
||||
def __init__(self, init_from=None):
|
||||
if init_from:
|
||||
assert isinstance(init_from, _SymbolTable)
|
||||
assert isinstance(init_from, _TypeMap)
|
||||
self.types = {
|
||||
s: set(other_types) for s, other_types in init_from.types.items()
|
||||
}
|
||||
@ -152,8 +163,8 @@ class _SymbolTable(object):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __or__(self, other):
|
||||
assert isinstance(other, _SymbolTable)
|
||||
result = _SymbolTable(self)
|
||||
assert isinstance(other, _TypeMap)
|
||||
result = _TypeMap(self)
|
||||
for s, other_types in other.types.items():
|
||||
if s not in result.types:
|
||||
self_types = set()
|
||||
@ -192,13 +203,22 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
print(a) # a = int; side effect of f() accounted for
|
||||
"""
|
||||
|
||||
def __init__(self, resolver, scope, namespace, closure_types, types_in):
|
||||
def __init__(self,
|
||||
resolver: Resolver,
|
||||
scope: activity.Scope,
|
||||
namespace: Dict[qual_names.QN, Any],
|
||||
closure_types: Dict[qual_names.QN, Set[Any]],
|
||||
types_in: _TypeMap):
|
||||
self.resolver = resolver
|
||||
self.scope = scope
|
||||
self.namespace = namespace
|
||||
self.closure_types = closure_types
|
||||
self.types_in = types_in
|
||||
self.new_symbols = {}
|
||||
|
||||
# rvalue type. This property is set when encountering an assign operation,
|
||||
# so that visiting nodes with Store ctx (typically found on left side of
|
||||
# assignments) can infer the type they should receive.
|
||||
self.rtype = None
|
||||
|
||||
def visit(self, node):
|
||||
@ -221,36 +241,36 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
self._check_set(types)
|
||||
return types
|
||||
|
||||
def visit_Tuple(self, node):
|
||||
if isinstance(node.ctx, gast.Load):
|
||||
for elt in node.elts:
|
||||
self.visit(elt)
|
||||
# TODO(mdan): Parameterize it.
|
||||
return {Tuple}
|
||||
|
||||
def _apply_unpacking(self, node):
|
||||
assert isinstance(node.ctx, gast.Store)
|
||||
|
||||
if self.rtype is not None:
|
||||
original_stype = self.rtype
|
||||
# TODO(mdan): Find a better way to express unpacking.
|
||||
i_type = self.resolver.res_value(self.namespace, 0)
|
||||
for i, elt in enumerate(node.elts):
|
||||
self.rtype = self.resolver.res_subscript(
|
||||
self.rtype = self.resolver.res_slice(
|
||||
self.namespace, self.types_in.types, i, original_stype, i_type)
|
||||
self.visit(elt)
|
||||
self.rtype = original_stype
|
||||
return original_stype
|
||||
|
||||
return None
|
||||
|
||||
def visit_Tuple(self, node):
|
||||
if isinstance(node.ctx, gast.Load):
|
||||
elt_types = ()
|
||||
for elt in node.elts:
|
||||
types_ = self.visit(elt)
|
||||
if types_ is None:
|
||||
return None
|
||||
elt_types += (types_,)
|
||||
return set(itertools.product(*elt_types))
|
||||
return self._apply_unpacking(node)
|
||||
|
||||
def visit_List(self, node):
|
||||
if isinstance(node.ctx, gast.Load):
|
||||
el_types = []
|
||||
for elt in node.elts:
|
||||
el_types.append(self.visit(elt))
|
||||
return {list}
|
||||
|
||||
raise NotImplementedError('list unpacking')
|
||||
elt_types = tuple(self.visit(elt) for elt in node.elts)
|
||||
return self.resolver.res_list_literal(self.namespace, elt_types)
|
||||
return self._apply_unpacking(node)
|
||||
|
||||
def visit_Set(self, node):
|
||||
raise NotImplementedError()
|
||||
@ -442,7 +462,7 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
if val_types is None or slice_types is None:
|
||||
return None
|
||||
|
||||
types = self.resolver.res_subscript(
|
||||
types = self.resolver.res_slice(
|
||||
self.namespace, self.types_in.types, node, val_types, slice_types)
|
||||
|
||||
if __debug__:
|
||||
@ -480,6 +500,20 @@ class StmtInferrer(gast.NodeVisitor):
|
||||
|
||||
return types
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
opnd_types = self.visit(node.operand)
|
||||
|
||||
if opnd_types is None:
|
||||
return None
|
||||
|
||||
types = self.resolver.res_unop(
|
||||
self.namespace, self.types_in.types, node, opnd_types)
|
||||
|
||||
if __debug__:
|
||||
self._check_set(types)
|
||||
|
||||
return types
|
||||
|
||||
|
||||
class Analyzer(cfg.GraphVisitor):
|
||||
"""CFG visitor that propagates type information across statements."""
|
||||
@ -504,13 +538,13 @@ class Analyzer(cfg.GraphVisitor):
|
||||
n: t for n, t in closure_types.items() if n not in scope.bound
|
||||
}
|
||||
if context_types:
|
||||
self.context_types = _SymbolTable()
|
||||
self.context_types = _TypeMap()
|
||||
self.context_types.types = context_types
|
||||
else:
|
||||
self.context_types = None
|
||||
|
||||
def init_state(self, _):
|
||||
return _SymbolTable()
|
||||
return _TypeMap()
|
||||
|
||||
def _update_closure_types(self, ast_node, types):
|
||||
existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None)
|
||||
@ -528,13 +562,13 @@ class Analyzer(cfg.GraphVisitor):
|
||||
def visit_node(self, node):
|
||||
prev_types_out = self.out[node]
|
||||
|
||||
types_in = _SymbolTable()
|
||||
types_in = _TypeMap()
|
||||
for n in node.prev:
|
||||
types_in |= self.out[n]
|
||||
if (self.context_types is not None) and (node is self.graph.entry):
|
||||
types_in |= self.context_types
|
||||
|
||||
types_out = _SymbolTable(types_in)
|
||||
types_out = _TypeMap(types_in)
|
||||
ast_node = node.ast_node
|
||||
|
||||
inferrer = StmtInferrer(self.resolver, self.scope, self.namespace,
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Any, Callable, Tuple
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
@ -171,7 +171,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
node, _ = tr.transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[0].body[0].value, Tuple)
|
||||
self.assertTypes(fn_body[0].body[0].value, (('x_type', 'y_type'),))
|
||||
self.assertTypes(fn_body[0].body[0].value.elts[0], 'x_type')
|
||||
self.assertTypes(fn_body[0].body[0].value.elts[1], 'y_type')
|
||||
|
||||
@ -656,7 +656,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
def res_value(self, ns, value):
|
||||
return {int}
|
||||
|
||||
def res_subscript(self, ns, types_ns, node, value, slice_):
|
||||
def res_slice(self, ns, types_ns, node, value, slice_):
|
||||
test_self.assertSetEqual(value, {list})
|
||||
test_self.assertSetEqual(slice_, {int})
|
||||
return {str}
|
||||
@ -683,7 +683,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
def res_value(self, ns, value):
|
||||
return {int}
|
||||
|
||||
def res_subscript(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
test_self.assertIn(node_or_slice, (0, 1))
|
||||
test_self.assertSetEqual(value, {list})
|
||||
test_self.assertSetEqual(slice_, {int})
|
||||
@ -699,7 +699,7 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[1].value, Tuple)
|
||||
self.assertTypes(fn_body[1].value, ((float, str),))
|
||||
self.assertTypes(fn_body[1].value.elts[0], float)
|
||||
self.assertTypes(fn_body[1].value.elts[1], str)
|
||||
|
||||
@ -751,6 +751,196 @@ class TypeInferenceAnalyzerTest(test.TestCase):
|
||||
self.assertTypes(fn_body[0].value.left, list)
|
||||
self.assertTypes(fn_body[0].value.right, list)
|
||||
|
||||
def test_unop(self):
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {list}
|
||||
|
||||
def res_unop(self, ns, types_ns, node, opnd):
|
||||
return {float}
|
||||
|
||||
def test_fn(a):
|
||||
return -a
|
||||
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[0].value, float)
|
||||
self.assertTypes(fn_body[0].value.operand, list)
|
||||
|
||||
def test_tuple_literal(self):
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {int}
|
||||
|
||||
def test_fn(a, b):
|
||||
return a, b
|
||||
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[0].value, ((int, int),))
|
||||
self.assertTypes(fn_body[0].value.elts[0], int)
|
||||
self.assertTypes(fn_body[0].value.elts[1], int)
|
||||
|
||||
def test_list_literal(self):
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {int}
|
||||
|
||||
def res_list_literal(self, ns, elt_types):
|
||||
all_types = set()
|
||||
for s in elt_types:
|
||||
all_types |= s
|
||||
return {List[t] for t in all_types}
|
||||
|
||||
def test_fn(a, b):
|
||||
return [a, b]
|
||||
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[0].value, List[int])
|
||||
self.assertTypes(fn_body[0].value.elts[0], int)
|
||||
self.assertTypes(fn_body[0].value.elts[1], int)
|
||||
|
||||
def test_tuple_unpacking_syntactic(self):
|
||||
|
||||
test_self = self
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
if name == qual_names.QN('a'):
|
||||
return {int}
|
||||
else:
|
||||
return {float}
|
||||
|
||||
def res_value(self, ns, value):
|
||||
test_self.assertIn(value, (0, 1))
|
||||
return int
|
||||
|
||||
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
test_self.assertIn(node_or_slice, (0, 1))
|
||||
test_self.assertSetEqual(value, {(int, float)})
|
||||
test_self.assertEqual(slice_, int)
|
||||
return {t[node_or_slice] for t in value}
|
||||
|
||||
def test_fn(a, b):
|
||||
c, d = a, b
|
||||
return c, d
|
||||
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[1].value, ((int, float),))
|
||||
self.assertTypes(fn_body[1].value.elts[0], int)
|
||||
self.assertTypes(fn_body[1].value.elts[1], float)
|
||||
|
||||
def test_tuple_unpacking_operational(self):
|
||||
|
||||
test_self = self
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
return {(int, float)}
|
||||
|
||||
def res_value(self, ns, value):
|
||||
test_self.assertIn(value, (0, 1))
|
||||
return int
|
||||
|
||||
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
test_self.assertIn(node_or_slice, (0, 1))
|
||||
test_self.assertSetEqual(value, {(int, float)})
|
||||
test_self.assertEqual(slice_, int)
|
||||
return {t[node_or_slice] for t in value}
|
||||
|
||||
def test_fn(a):
|
||||
c, d = a
|
||||
return c, d
|
||||
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
self.assertTypes(fn_body[1].value, ((int, float),))
|
||||
self.assertTypes(fn_body[1].value.elts[0], int)
|
||||
self.assertTypes(fn_body[1].value.elts[1], float)
|
||||
|
||||
def test_list_expansion_syntactic(self):
|
||||
|
||||
test_self = self
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
if name == qual_names.QN('a'):
|
||||
return {int}
|
||||
else:
|
||||
return {float}
|
||||
|
||||
def res_value(self, ns, value):
|
||||
test_self.assertIn(value, (0, 1))
|
||||
return int
|
||||
|
||||
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
test_self.assertIn(node_or_slice, (0, 1))
|
||||
test_self.assertSetEqual(value, {(int, float)})
|
||||
test_self.assertEqual(slice_, int)
|
||||
return {t[node_or_slice] for t in value}
|
||||
|
||||
def test_fn(a, b):
|
||||
[c, d] = a, b
|
||||
return c, d
|
||||
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
# TODO(mdan): Whether it's List or Tuple might be open for interpretation.
|
||||
self.assertTypes(fn_body[1].value, ((int, float),))
|
||||
self.assertTypes(fn_body[1].value.elts[0], int)
|
||||
self.assertTypes(fn_body[1].value.elts[1], float)
|
||||
|
||||
def test_list_expansion_operational(self):
|
||||
|
||||
test_self = self
|
||||
|
||||
class Resolver(type_inference.Resolver):
|
||||
|
||||
def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
|
||||
if name == qual_names.QN('a'):
|
||||
return {int}
|
||||
else:
|
||||
return {float}
|
||||
|
||||
def res_value(self, ns, value):
|
||||
test_self.assertIn(value, (0, 1))
|
||||
return int
|
||||
|
||||
def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
|
||||
test_self.assertIn(node_or_slice, (0, 1))
|
||||
test_self.assertSetEqual(value, {(int, float)})
|
||||
test_self.assertEqual(slice_, int)
|
||||
return {t[node_or_slice] for t in value}
|
||||
|
||||
def test_fn(a, b):
|
||||
[c, d] = a, b
|
||||
return c, d
|
||||
|
||||
node, _ = TestTranspiler(Resolver).transform(test_fn, None)
|
||||
fn_body = node.body
|
||||
|
||||
# TODO(mdan): Whether it's List or Tuple might be open for interpretation.
|
||||
self.assertTypes(fn_body[1].value, ((int, float),))
|
||||
self.assertTypes(fn_body[1].value.elts[0], int)
|
||||
self.assertTypes(fn_body[1].value.elts[1], float)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user