Enable nested composition functions
This patch allowed the users to define nested composition functions. all the composition functions need to be defined before their uses. PiperOrigin-RevId: 345073100 Change-Id: Ideab901b270e6036b5361feb82c64503a734d57e
This commit is contained in:
parent
d5eb6779f2
commit
f432994964
@ -431,6 +431,15 @@ class TFRTypeResolver(type_inference.Resolver):
|
||||
return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
|
||||
None)
|
||||
|
||||
elif f_type == (types.FunctionType,):
|
||||
# A composition Python function name is used directly.
|
||||
op_name = name.qn[0]
|
||||
op_def, _ = self._op_defs.lookup(op_name)
|
||||
if len(op_def.output_arg) == 1:
|
||||
return {_get_type_from_proto(op_def.output_arg[0])}, None
|
||||
return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)},
|
||||
None)
|
||||
|
||||
elif f_type == (TFRTypes.PY_BUILTIN_FUNC,):
|
||||
assert name.is_simple()
|
||||
if name == QN('range'):
|
||||
@ -809,6 +818,9 @@ class TFRGen(transformer.CodeGenerator):
|
||||
if func_type == TFRTypes.TF_RAW_OP:
|
||||
return self._visit_tf_op(func_name, node.args, node.keywords, node)
|
||||
|
||||
if func_type == types.FunctionType:
|
||||
return self._visit_tf_op(func_name, node.args, node.keywords, node)
|
||||
|
||||
if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC:
|
||||
return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST)
|
||||
|
||||
@ -1184,7 +1196,13 @@ class TFRGen(transformer.CodeGenerator):
|
||||
raise NotImplementedError('If not supported.')
|
||||
|
||||
def visit_Name(self, node):
|
||||
val, lookup_type = self.symbol_table.lookup(node.id)
|
||||
val_and_lookup_type = self.symbol_table.lookup(node.id)
|
||||
if val_and_lookup_type:
|
||||
(val, lookup_type) = val_and_lookup_type
|
||||
else:
|
||||
op_def, _ = self._op_defs.lookup(node.id)
|
||||
val = op_def.name
|
||||
lookup_type = anno.getanno(node, anno.Static.TYPES, types.FunctionType)
|
||||
type_ = self._get_inferred_type(node, lookup_type)
|
||||
return val, type_
|
||||
|
||||
|
@ -211,6 +211,20 @@ def _tfr_shapes(x):
|
||||
return x
|
||||
|
||||
|
||||
#--- test fn for nested functions ---
|
||||
|
||||
|
||||
@composite.Composite('TestIdentityNOp')
|
||||
def _tfr_temp_op(x):
|
||||
return x
|
||||
|
||||
|
||||
@composite.Composite('TestIdentityOp')
|
||||
def _tfr_temp_use_op(x):
|
||||
y = _tfr_temp_op([x])
|
||||
return y[0]
|
||||
|
||||
|
||||
class TFRGenTestBase(test.TestCase):
|
||||
|
||||
def _check_code(self, tfr_code, exp_tfr_code):
|
||||
@ -557,6 +571,17 @@ class TFRGenTensorTest(TFRGenTestBase):
|
||||
"""
|
||||
self._check_code(mlir_code, mlir_code_exp)
|
||||
|
||||
def test_temp_function(self):
|
||||
mlir_code = tfr_gen(sys.modules[__name__], '_tfr_temp', [test_ops])
|
||||
mlir_code_exp = r"""
|
||||
CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list)
|
||||
|
||||
CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) {
|
||||
CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x) : (!tfr.tensor) -> !tfr.tensor_list
|
||||
CHECK-NEXT: %[[call:.*]] = tfr.call @tf__test_identity_n_op(%[[list]]) : (!tfr.tensor_list)
|
||||
"""
|
||||
self._check_code(mlir_code, mlir_code_exp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user