diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index cd32dd9cd9d..1d21f5c79b8 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -301,7 +301,7 @@ StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op, return std::string("tf.If"); } if (builtin_code == tflite::BuiltinOperator_WHILE) { - return std::string("tf.While"); + return std::string("tfl.while"); } llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code)); @@ -590,9 +590,7 @@ llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs( auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx)); return {builder.getNamedAttr("cond", cond_attr), - builder.getNamedAttr("body", body_attr), - // TODO(b/139667752): Analyze statelessness correctly - builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))}; + builder.getNamedAttr("body", body_attr)}; } return {}; } @@ -714,6 +712,14 @@ StatusOr<Operation*> ConvertOp( TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc, builder)); } + if (op_name == "tfl.while") { + // Adds two empty regions for "tfl.while". We will fill the regions after + // creating the callee functions because the "tfl.while" input/output types + // may be different with the callee functions, and the call ops need to sync + // with callee function types. + op_state.addRegion(); + op_state.addRegion(); + } if (op_name == "tfl.unidirectional_sequence_lstm") { TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc, builder)); @@ -755,7 +761,8 @@ StatusOr<Operation*> ConvertOp( } op_state.addAttributes(attrs); - // Handle the conversion from subgraph index to functions for If and While + // Handle the conversion from subgraph index to functions for If and While. We + // will add CallOps in the region to call the functions later for While. auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs( op.builtin_options, func_names, builder); op_state.addAttributes(function_ref_attrs); @@ -1163,6 +1170,35 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) { } return subgraph.name; } + +// Adds a CallOp in `region` to call the `func` and returns the results of +// CallOp. +void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::FuncOp func) { + OpBuilder op_builder{region}; + region.push_back(new mlir::Block()); + region.addArguments(func.getType().getInputs()); + op_builder.setInsertionPointToStart(®ion.front()); + auto call_op = op_builder.create<mlir::CallOp>( + region.getLoc(), func.getType().getResults(), func.sym_name(), + region.getArguments()); + op_builder.create<mlir::TFL::YieldOp>(region.getLoc(), call_op.getResults()); +} + +// TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions +// if we have while ops. +void AddRegionsForTflWhileOp(mlir::ModuleOp module) { + mlir::SymbolTable symbol_table(module); + module.walk([&](mlir::TFL::WhileOp while_op) { + auto cond = symbol_table.lookup<mlir::FuncOp>( + while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue()); + AddCallOpInWhileOpRegion(while_op.cond(), cond); + while_op.removeAttr("cond"); + auto body = symbol_table.lookup<mlir::FuncOp>( + while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue()); + AddCallOpInWhileOpRegion(while_op.body(), body); + while_op.removeAttr("body"); + }); +} } // namespace OwningModuleRef tflite::FlatBufferToMlir( @@ -1219,5 +1255,6 @@ OwningModuleRef tflite::FlatBufferToMlir( } module.push_back(func_or_error.ConsumeValueOrDie()); } + AddRegionsForTflWhileOp(module); return OwningModuleRef(module); } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/unranked_function_output.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/unranked_function_output.mlir index 8a97a83064f..ba9b52c54a2 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/unranked_function_output.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/unranked_function_output.mlir @@ -2,12 +2,18 @@ // This test is to test for unranked function output from input, the output type should be compatible with input type. -// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> -// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32> -// CHECK: return %0 : tensor<*xf32> -// CHECK: func private @cond(%arg0: tensor<*xf32>) -> tensor<*xf32> -// CHECK: func private @body(%arg0: tensor<*xf32>) -> tensor<*xf32> - +// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tf.While"}} { +// CHECK: %0 = "tfl.while"(%arg0) ( { +// CHECK: ^bb0(%arg1: tensor<*xf32>): +// CHECK: %[[RES0:.*]] = call @cond(%arg1) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: "tfl.yield"(%[[RES0]]) : (tensor<*xf32>) -> () +// CHECK: }, { +// CHECK: ^bb0(%arg1: tensor<*xf32>): +// CHECK: %[[RES1:.*]] = call @body(%arg1) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: "tfl.yield"(%[[RES1]]) : (tensor<*xf32>) -> () +// CHECK: }) : (tensor<1xf32>) -> tensor<*xf32> +// CHECK: return %0 : tensor<*xf32> +// CHECK: } func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> { %0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir index d29c9e661ad..4728802b33f 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir @@ -1,8 +1,14 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s -// Check to see if function references in while loops are preserved -// TODO(b/138222071) Expect first output to be a scalar -// CHECK: %{{.*}}:2 = "tf.While"(%{{.*}}, %{{.*}}) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>) +// Check to see if nested regions in while loops are preserved +// CHECK: %{{.*}}:2 = "tfl.while"(%{{.*}}, %{{.*}}) ( { +// CHECK: ^bb0(%{{.*}}: tensor<*xi32>, %{{.*}}: tensor<*xf32>): +// CHECK: "tfl.yield"(%{{.*}}) : (tensor<*xi1>) -> () +// CHECK: }, { +// CHECK: ^bb0(%{{.*}}: tensor<*xi32>, %{{.*}}: tensor<*xf32>): +// CHECK: "tfl.yield"(%{{.*}}, %{{.*}}) : (tensor<*xi32>, tensor<*xf32>) -> () +// CHECK: }) : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>) + func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { // While %arg0 is greater than zero, element wise add %arg1 with itself. %0:2 = "tfl.while"(%arg0, %arg1) ( {