Enable nested composition functions

This patch allowed the users to define nested composition functions. all the
composition functions need to be defined before their uses.

PiperOrigin-RevId: 345073100
Change-Id: Ideab901b270e6036b5361feb82c64503a734d57e
This commit is contained in:
Feng Liu 2020-12-01 12:14:09 -08:00 committed by TensorFlower Gardener
parent d5eb6779f2
commit f432994964
2 changed files with 44 additions and 1 deletions

View File

@ -431,6 +431,15 @@ class TFRTypeResolver(type_inference.Resolver):
return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
None)
elif f_type == (types.FunctionType,):
# A composition Python function name is used directly.
op_name = name.qn[0]
op_def, _ = self._op_defs.lookup(op_name)
if len(op_def.output_arg) == 1:
return {_get_type_from_proto(op_def.output_arg[0])}, None
return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
None)
elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
assert name.is_simple()
if name == QN('range'):
@ -809,6 +818,9 @@ class TFRGen(transformer.CodeGenerator):
if func_type == TFRTypes.TF_RAW_OP:
return self._visit_tf_op(func_name, node.args, node.keywords, node)
if func_type == types.FunctionType:
return self._visit_tf_op(func_name, node.args, node.keywords, node)
if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC:
return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST)
@ -1184,7 +1196,13 @@ class TFRGen(transformer.CodeGenerator):
raise NotImplementedError('If not supported.')
def visit_Name(self, node):
val, lookup_type = self.symbol_table.lookup(node.id)
val_and_lookup_type = self.symbol_table.lookup(node.id)
if val_and_lookup_type:
(val, lookup_type) = val_and_lookup_type
else:
op_def, _ = self._op_defs.lookup(node.id)
val = op_def.name
lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType)
type_ = self._get_inferred_type(node, lookup_type)
return val, type_

View File

@ -211,6 +211,20 @@ def _tfr_shapes(x):
return x
#--- test fn for nested functions ---
@composite.Composite('TestIdentityNOp')
def _tfr_temp_op(x):
return x
@composite.Composite('TestIdentityOp')
def _tfr_temp_use_op(x):
y = _tfr_temp_op([x])
return y[0]
class TFRGenTestBase(test.TestCase):
def _check_code(self, tfr_code, exp_tfr_code):
@ -557,6 +571,17 @@ class TFRGenTensorTest(TFRGenTestBase):
"""
self._check_code(mlir_code, mlir_code_exp)
def test_temp_function(self):
mlir_code = tfr_gen(sys.modules[__name__], '_tfr_temp', [test_ops])
mlir_code_exp = r"""
CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list)
CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) {
CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x) : (!tfr.tensor) -> !tfr.tensor_list
CHECK-NEXT: %[[call:.*]] = tfr.call @tf__test_identity_n_op(%[[list]]) : (!tfr.tensor_list)
"""
self._check_code(mlir_code, mlir_code_exp)
if __name__ == '__main__':
test.main()