diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py index 9482785138c..110a4aa680d 100644 --- a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py +++ b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py @@ -23,10 +23,10 @@ from __future__ import division from __future__ import print_function import enum +import inspect import os import re import types -from typing import List, Tuple import gast as ast from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr @@ -46,7 +46,8 @@ from tensorflow.python.autograph.pyct.static_analysis import type_inference from tensorflow.python.framework import load_library from tensorflow.python.framework import op_def_registry from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import tf_inspect + +# TODO(mdan): Use class definitions so that we can mix these with Python types. class TFRTypes(enum.Enum): @@ -410,7 +411,7 @@ class TFRTypeResolver(type_inference.Resolver): iterated_type = args[0] assert iterated_type & { - TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, List[int] + TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, TFRTypes.ATTR }, ( iterated_type) self._for_loop_target_types[body_fn_name] = iterated_type @@ -443,7 +444,7 @@ class TFRTypeResolver(type_inference.Resolver): elif f_type == (TFRTypes.PY_BUILTIN_FUNC,): assert name.is_simple() if name == QN('range'): - return {List[int]}, None + return {TFRTypes.ATTR}, None if name == QN('len'): return {TFRTypes.INDEX}, None @@ -459,7 +460,7 @@ class TFRTypeResolver(type_inference.Resolver): if f_name_str in self._for_loop_target_types: # See autograph/converters/control_flow.py - the function has a single # argument, the iterate before any expansion. - assert self._for_loop_target_types[f_name_str] & {List[int]} + assert self._for_loop_target_types[f_name_str] & {TFRTypes.ATTR} # Assume all loops are TF loops. Then the iterates are autoboxed into # Tensors. return {TFRTypes.INDEX} @@ -471,7 +472,7 @@ class TFRTypeResolver(type_inference.Resolver): op_def, derived_attrs = self._op_defs.lookup(f_name, func) if op_def is None: return None - pos = tf_inspect.getfullargspec(func).args.index(str(name)) + pos = inspect.getfullargspec(func).args.index(str(name)) if pos < len(op_def.input_arg): arg_def = op_def.input_arg[pos] @@ -488,7 +489,7 @@ class TFRTypeResolver(type_inference.Resolver): raise ValueError('Argument is not defined in OpDef: ' + str(name)) - def res_subscript(self, ns, types_ns, node_or_slice, value, slice_): + def res_slice(self, ns, types_ns, node_or_slice, value, slice_): assert len(value) == 1 value, = tuple(value) if value == TFRTypes.TF_TENSOR_SHAPE_LIST: @@ -503,10 +504,40 @@ class TFRTypeResolver(type_inference.Resolver): # TODO(fengliuai): make sure left and right are compatible return {TFRTypes.I1} + def res_unop(self, ns, types_ns, node, opnd): + return opnd + def res_binop(self, ns, types_ns, node, left, right): # TODO(fengliuai): make sure left and right are compatible return left + def _coerce_to_more_specific_type(self, elt_types): + # TODO(mdan): This needs some type theory study. + if TFRTypes.INDEX in elt_types: + # Constants collapse to indices. + elt_types.discard(TFRTypes.I64) + if TFRTypes.TENSOR in elt_types: + # Constants collapse to tensors. + elt_types.discard(TFRTypes.I64) + # Indices collapse to tensors. + elt_types.discard(TFRTypes.INDEX) + return elt_types + + def res_list_literal(self, ns, elt_types): + all_elt_types = set() + for t in elt_types: + all_elt_types |= t + + if len(all_elt_types) != 1: + all_elt_types = self._coerce_to_more_specific_type(all_elt_types) + + if len(all_elt_types) != 1: + raise ValueError('ambiguous list element types: {}'.format(elt_types)) + + if TFRTypes.TENSOR in all_elt_types: + return {TFRTypes.TENSOR_LIST} + return {TFRTypes.ATTR} + class SymbolTable(object): """Symbol Table for python code.""" @@ -599,22 +630,6 @@ class TFRGen(transformer.CodeGenerator): node, types_)) type_, = types_ - # TODO(fengliuai): Tuple is added here to make return tuple work. - if type_ is list or type_ is Tuple: - # TODO(fengliuai): Seems like we need to move the followed list handling - # to the type inference and we shouldn't just put 'list' there. Otherwise - # we couldn't find out the right type for the Name node. - if not isinstance(node, ast.List): - return default - all_types = [ - anno.getanno(elt, anno.Static.TYPES, None) for elt in node.elts - ] - if (TFRTypes.TENSOR,) in all_types: - # For the elt which is not tfr.tensor, tfr.constant_tensor needs to be - # use to cast it to a tfr.tensor. - return TFRTypes.TENSOR_LIST - else: - return TFRTypes.ATTR if default is not None and type_ != default: print('WARN: type annotation {}({}) does not match {}({})'.format( @@ -704,7 +719,6 @@ class TFRGen(transformer.CodeGenerator): if isinstance(node.value, ast.Attribute): if isinstance(node.value.value, ast.Name): if node.value.value.id == 'tf' and node.value.attr == 'raw_ops': - # This branch is used when it is outside tensorflow return (node.attr, TFRTypes.TF_RAW_OP) value, ty = self.visit(node.value) @@ -726,13 +740,24 @@ class TFRGen(transformer.CodeGenerator): raise NotImplementedError('Assignment target type not recognized.') if isinstance(values, list): + if isinstance(node.value, ast.Call): + expected = tuple(t for n, t in values) + if len(values) == 1: + expected = expected[0] + elif isinstance(node.value, ast.Tuple): + expected = tuple(t for n, t in values) + else: + raise ValueError('unknown assignment target node', node.value) + ty = self._get_inferred_type(node.value, expected) + if len(targets) == len(values): - for key, value in zip(targets, values): - ssa_value, ty_ = value - ty = self._get_inferred_type(node.value, ty_) - self.symbol_table.insert_symbol(key, ssa_value, ty) + # TODO(mdan): This should already be a tuple. + ty_ = (ty,) if len(values) == 1 else ty + for key, value, t in zip(targets, values, ty_): + ssa_value, _ = value + self.symbol_table.insert_symbol(key, ssa_value, t) elif len(values) == 1: - n, ty = values[0] + n, _ = values[0] assert ty == TFRTypes.TENSOR_LIST # assign a tensor_list to multiple variables for idx, key in enumerate(targets): @@ -747,10 +772,11 @@ class TFRGen(transformer.CodeGenerator): self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR) elif len(targets) == 1: ssa_names = [n for n, _ in values] - tys = [t for _, t in values] - self.symbol_table.insert_symbol(targets[0], ssa_names, tys) - else: - self.symbol_table.insert_symbol(targets[0], values[0], values[1]) + self.symbol_table.insert_symbol(targets[0], ssa_names, ty) + return + + ty = self._get_inferred_type(node.value, values[1]) + self.symbol_table.insert_symbol(targets[0], values[0], ty) def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty): assert lhs_ty, rhs_ty @@ -795,7 +821,7 @@ class TFRGen(transformer.CodeGenerator): def visit_Call(self, node): func_name, func_type = self.visit(node.func) - _ = self._get_inferred_type(node.func, func_type) + func_type = self._get_inferred_type(node.func, func_type) if func_type == TFRTypes.AG_BUILTIN_FUNC: if func_name == 'if_stmt': cond, _ = self.visit(node.args[0]) @@ -1285,6 +1311,7 @@ class TFRGen(transformer.CodeGenerator): tys = [] for elt in node.elts: val, ty = self.visit(elt) + ty = self._get_inferred_type(elt, ty) if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST: # This list is a tensor list, then cast all the input values to tensors. val, ty = self._value_to_tensor(val, ty, node) @@ -1405,7 +1432,7 @@ def tfr_gen_from_module(source, method_prefix=None, op_libraries=None): py_funcs = [ func - for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction) + for name, func in inspect.getmembers(source, inspect.isfunction) if not method_prefix or name.startswith(method_prefix) ] # Sort the methods by the line number, to make sure the definitions are diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py index b35b1d2c9d8..639e0dd19a2 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py +++ b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py @@ -31,7 +31,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from typing import Any, Callable, Tuple +import itertools + +from typing import Any, Callable, Dict, Set import gast @@ -39,6 +41,7 @@ from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import cfg from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct.static_analysis import activity from tensorflow.python.autograph.pyct.static_analysis import annos @@ -118,12 +121,20 @@ class Resolver(object): """Resolves the return type of a unary operation.""" raise NotImplementedError('subclasses must implement') - def res_binop(self, ns, types_ns, node, left, right): + def res_unop(self, ns, types_ns, node, opnd): """Resolves the return type of a unary operation.""" raise NotImplementedError('subclasses must implement') + def res_binop(self, ns, types_ns, node, left, right): + """Resolves the return type of a binary operation.""" + raise NotImplementedError('subclasses must implement') -class _SymbolTable(object): + def res_list_literal(self, ns, elt_types): + """Resolves the type of a list literal from its elements.""" + raise NotImplementedError('subclasses must implement') + + +class _TypeMap(object): """Abstraction for the state of the CFG walk for type inference. This is a value type. Only implements the strictly necessary operators. @@ -135,7 +146,7 @@ class _SymbolTable(object): def __init__(self, init_from=None): if init_from: - assert isinstance(init_from, _SymbolTable) + assert isinstance(init_from, _TypeMap) self.types = { s: set(other_types) for s, other_types in init_from.types.items() } @@ -152,8 +163,8 @@ class _SymbolTable(object): return not self.__eq__(other) def __or__(self, other): - assert isinstance(other, _SymbolTable) - result = _SymbolTable(self) + assert isinstance(other, _TypeMap) + result = _TypeMap(self) for s, other_types in other.types.items(): if s not in result.types: self_types = set() @@ -192,13 +203,22 @@ class StmtInferrer(gast.NodeVisitor): print(a) # a = int; side effect of f() accounted for """ - def __init__(self, resolver, scope, namespace, closure_types, types_in): + def __init__(self, + resolver: Resolver, + scope: activity.Scope, + namespace: Dict[qual_names.QN, Any], + closure_types: Dict[qual_names.QN, Set[Any]], + types_in: _TypeMap): self.resolver = resolver self.scope = scope self.namespace = namespace self.closure_types = closure_types self.types_in = types_in self.new_symbols = {} + + # rvalue type. This property is set when encountering an assign operation, + # so that visiting nodes with Store ctx (typically found on left side of + # assignments) can infer the type they should receive. self.rtype = None def visit(self, node): @@ -221,36 +241,36 @@ class StmtInferrer(gast.NodeVisitor): self._check_set(types) return types - def visit_Tuple(self, node): - if isinstance(node.ctx, gast.Load): - for elt in node.elts: - self.visit(elt) - # TODO(mdan): Parameterize it. - return {Tuple} - + def _apply_unpacking(self, node): assert isinstance(node.ctx, gast.Store) - if self.rtype is not None: original_stype = self.rtype # TODO(mdan): Find a better way to express unpacking. i_type = self.resolver.res_value(self.namespace, 0) for i, elt in enumerate(node.elts): - self.rtype = self.resolver.res_subscript( + self.rtype = self.resolver.res_slice( self.namespace, self.types_in.types, i, original_stype, i_type) self.visit(elt) self.rtype = original_stype return original_stype - return None + def visit_Tuple(self, node): + if isinstance(node.ctx, gast.Load): + elt_types = () + for elt in node.elts: + types_ = self.visit(elt) + if types_ is None: + return None + elt_types += (types_,) + return set(itertools.product(*elt_types)) + return self._apply_unpacking(node) + def visit_List(self, node): if isinstance(node.ctx, gast.Load): - el_types = [] - for elt in node.elts: - el_types.append(self.visit(elt)) - return {list} - - raise NotImplementedError('list unpacking') + elt_types = tuple(self.visit(elt) for elt in node.elts) + return self.resolver.res_list_literal(self.namespace, elt_types) + return self._apply_unpacking(node) def visit_Set(self, node): raise NotImplementedError() @@ -442,7 +462,7 @@ class StmtInferrer(gast.NodeVisitor): if val_types is None or slice_types is None: return None - types = self.resolver.res_subscript( + types = self.resolver.res_slice( self.namespace, self.types_in.types, node, val_types, slice_types) if __debug__: @@ -480,6 +500,20 @@ class StmtInferrer(gast.NodeVisitor): return types + def visit_UnaryOp(self, node): + opnd_types = self.visit(node.operand) + + if opnd_types is None: + return None + + types = self.resolver.res_unop( + self.namespace, self.types_in.types, node, opnd_types) + + if __debug__: + self._check_set(types) + + return types + class Analyzer(cfg.GraphVisitor): """CFG visitor that propagates type information across statements.""" @@ -504,13 +538,13 @@ class Analyzer(cfg.GraphVisitor): n: t for n, t in closure_types.items() if n not in scope.bound } if context_types: - self.context_types = _SymbolTable() + self.context_types = _TypeMap() self.context_types.types = context_types else: self.context_types = None def init_state(self, _): - return _SymbolTable() + return _TypeMap() def _update_closure_types(self, ast_node, types): existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None) @@ -528,13 +562,13 @@ class Analyzer(cfg.GraphVisitor): def visit_node(self, node): prev_types_out = self.out[node] - types_in = _SymbolTable() + types_in = _TypeMap() for n in node.prev: types_in |= self.out[n] if (self.context_types is not None) and (node is self.graph.entry): types_in |= self.context_types - types_out = _SymbolTable(types_in) + types_out = _TypeMap(types_in) ast_node = node.ast_node inferrer = StmtInferrer(self.resolver, self.scope, self.namespace, diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py index 5648f8dcb62..861e62b8f35 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from typing import Any, Callable, Tuple +from typing import Any, Callable, List from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import cfg @@ -171,7 +171,7 @@ class TypeInferenceAnalyzerTest(test.TestCase): node, _ = tr.transform(test_fn, None) fn_body = node.body - self.assertTypes(fn_body[0].body[0].value, Tuple) + self.assertTypes(fn_body[0].body[0].value, (('x_type', 'y_type'),)) self.assertTypes(fn_body[0].body[0].value.elts[0], 'x_type') self.assertTypes(fn_body[0].body[0].value.elts[1], 'y_type') @@ -656,7 +656,7 @@ class TypeInferenceAnalyzerTest(test.TestCase): def res_value(self, ns, value): return {int} - def res_subscript(self, ns, types_ns, node, value, slice_): + def res_slice(self, ns, types_ns, node, value, slice_): test_self.assertSetEqual(value, {list}) test_self.assertSetEqual(slice_, {int}) return {str} @@ -683,7 +683,7 @@ class TypeInferenceAnalyzerTest(test.TestCase): def res_value(self, ns, value): return {int} - def res_subscript(self, ns, types_ns, node_or_slice, value, slice_): + def res_slice(self, ns, types_ns, node_or_slice, value, slice_): test_self.assertIn(node_or_slice, (0, 1)) test_self.assertSetEqual(value, {list}) test_self.assertSetEqual(slice_, {int}) @@ -699,7 +699,7 @@ class TypeInferenceAnalyzerTest(test.TestCase): node, _ = TestTranspiler(Resolver).transform(test_fn, None) fn_body = node.body - self.assertTypes(fn_body[1].value, Tuple) + self.assertTypes(fn_body[1].value, ((float, str),)) self.assertTypes(fn_body[1].value.elts[0], float) self.assertTypes(fn_body[1].value.elts[1], str) @@ -751,6 +751,196 @@ class TypeInferenceAnalyzerTest(test.TestCase): self.assertTypes(fn_body[0].value.left, list) self.assertTypes(fn_body[0].value.right, list) + def test_unop(self): + + class Resolver(type_inference.Resolver): + + def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): + return {list} + + def res_unop(self, ns, types_ns, node, opnd): + return {float} + + def test_fn(a): + return -a + + node, _ = TestTranspiler(Resolver).transform(test_fn, None) + fn_body = node.body + + self.assertTypes(fn_body[0].value, float) + self.assertTypes(fn_body[0].value.operand, list) + + def test_tuple_literal(self): + + class Resolver(type_inference.Resolver): + + def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): + return {int} + + def test_fn(a, b): + return a, b + + node, _ = TestTranspiler(Resolver).transform(test_fn, None) + fn_body = node.body + + self.assertTypes(fn_body[0].value, ((int, int),)) + self.assertTypes(fn_body[0].value.elts[0], int) + self.assertTypes(fn_body[0].value.elts[1], int) + + def test_list_literal(self): + + class Resolver(type_inference.Resolver): + + def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): + return {int} + + def res_list_literal(self, ns, elt_types): + all_types = set() + for s in elt_types: + all_types |= s + return {List[t] for t in all_types} + + def test_fn(a, b): + return [a, b] + + node, _ = TestTranspiler(Resolver).transform(test_fn, None) + fn_body = node.body + + self.assertTypes(fn_body[0].value, List[int]) + self.assertTypes(fn_body[0].value.elts[0], int) + self.assertTypes(fn_body[0].value.elts[1], int) + + def test_tuple_unpacking_syntactic(self): + + test_self = self + + class Resolver(type_inference.Resolver): + + def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): + if name == qual_names.QN('a'): + return {int} + else: + return {float} + + def res_value(self, ns, value): + test_self.assertIn(value, (0, 1)) + return int + + def res_slice(self, ns, types_ns, node_or_slice, value, slice_): + test_self.assertIn(node_or_slice, (0, 1)) + test_self.assertSetEqual(value, {(int, float)}) + test_self.assertEqual(slice_, int) + return {t[node_or_slice] for t in value} + + def test_fn(a, b): + c, d = a, b + return c, d + + node, _ = TestTranspiler(Resolver).transform(test_fn, None) + fn_body = node.body + + self.assertTypes(fn_body[1].value, ((int, float),)) + self.assertTypes(fn_body[1].value.elts[0], int) + self.assertTypes(fn_body[1].value.elts[1], float) + + def test_tuple_unpacking_operational(self): + + test_self = self + + class Resolver(type_inference.Resolver): + + def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): + return {(int, float)} + + def res_value(self, ns, value): + test_self.assertIn(value, (0, 1)) + return int + + def res_slice(self, ns, types_ns, node_or_slice, value, slice_): + test_self.assertIn(node_or_slice, (0, 1)) + test_self.assertSetEqual(value, {(int, float)}) + test_self.assertEqual(slice_, int) + return {t[node_or_slice] for t in value} + + def test_fn(a): + c, d = a + return c, d + + node, _ = TestTranspiler(Resolver).transform(test_fn, None) + fn_body = node.body + + self.assertTypes(fn_body[1].value, ((int, float),)) + self.assertTypes(fn_body[1].value.elts[0], int) + self.assertTypes(fn_body[1].value.elts[1], float) + + def test_list_expansion_syntactic(self): + + test_self = self + + class Resolver(type_inference.Resolver): + + def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): + if name == qual_names.QN('a'): + return {int} + else: + return {float} + + def res_value(self, ns, value): + test_self.assertIn(value, (0, 1)) + return int + + def res_slice(self, ns, types_ns, node_or_slice, value, slice_): + test_self.assertIn(node_or_slice, (0, 1)) + test_self.assertSetEqual(value, {(int, float)}) + test_self.assertEqual(slice_, int) + return {t[node_or_slice] for t in value} + + def test_fn(a, b): + [c, d] = a, b + return c, d + + node, _ = TestTranspiler(Resolver).transform(test_fn, None) + fn_body = node.body + + # TODO(mdan): Whether it's List or Tuple might be open for interpretation. + self.assertTypes(fn_body[1].value, ((int, float),)) + self.assertTypes(fn_body[1].value.elts[0], int) + self.assertTypes(fn_body[1].value.elts[1], float) + + def test_list_expansion_operational(self): + + test_self = self + + class Resolver(type_inference.Resolver): + + def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): + if name == qual_names.QN('a'): + return {int} + else: + return {float} + + def res_value(self, ns, value): + test_self.assertIn(value, (0, 1)) + return int + + def res_slice(self, ns, types_ns, node_or_slice, value, slice_): + test_self.assertIn(node_or_slice, (0, 1)) + test_self.assertSetEqual(value, {(int, float)}) + test_self.assertEqual(slice_, int) + return {t[node_or_slice] for t in value} + + def test_fn(a, b): + [c, d] = a, b + return c, d + + node, _ = TestTranspiler(Resolver).transform(test_fn, None) + fn_body = node.body + + # TODO(mdan): Whether it's List or Tuple might be open for interpretation. + self.assertTypes(fn_body[1].value, ((int, float),)) + self.assertTypes(fn_body[1].value.elts[0], int) + self.assertTypes(fn_body[1].value.elts[1], float) + if __name__ == '__main__': test.main()