Import while from flatbuffer as "tfl.while"

Rather than importing as TF while, use TFL while instead. This is follow up from previously where TFL while didn't exist.

PiperOrigin-RevId: 350682295
Change-Id: I9decd235a80417e6b4b37a2c4ae1a59bcb4b4faa
This commit is contained in:
Chuan He 2021-01-07 19:16:15 -08:00 committed by TensorFlower Gardener
parent 7719705955
commit beff361b55
3 changed files with 63 additions and 14 deletions

View File

@ -301,7 +301,7 @@ StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
return std::string("tf.If");
}
if (builtin_code == tflite::BuiltinOperator_WHILE) {
return std::string("tf.While");
return std::string("tfl.while");
}
llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
@ -590,9 +590,7 @@ llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx));
return {builder.getNamedAttr("cond", cond_attr),
builder.getNamedAttr("body", body_attr),
// TODO(b/139667752): Analyze statelessness correctly
builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
builder.getNamedAttr("body", body_attr)};
}
return {};
}
@ -714,6 +712,14 @@ StatusOr<Operation*> ConvertOp(
TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
builder));
}
if (op_name == "tfl.while") {
// Adds two empty regions for "tfl.while". We will fill the regions after
// creating the callee functions because the "tfl.while" input/output types
// may be different with the callee functions, and the call ops need to sync
// with callee function types.
op_state.addRegion();
op_state.addRegion();
}
if (op_name == "tfl.unidirectional_sequence_lstm") {
TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
builder));
@ -755,7 +761,8 @@ StatusOr<Operation*> ConvertOp(
}
op_state.addAttributes(attrs);
// Handle the conversion from subgraph index to functions for If and While
// Handle the conversion from subgraph index to functions for If and While. We
// will add CallOps in the region to call the functions later for While.
auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs(
op.builtin_options, func_names, builder);
op_state.addAttributes(function_ref_attrs);
@ -1163,6 +1170,35 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
}
return subgraph.name;
}
// Adds a CallOp in `region` to call the `func` and returns the results of
// CallOp.
void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::FuncOp func) {
OpBuilder op_builder{region};
region.push_back(new mlir::Block());
region.addArguments(func.getType().getInputs());
op_builder.setInsertionPointToStart(&region.front());
auto call_op = op_builder.create<mlir::CallOp>(
region.getLoc(), func.getType().getResults(), func.sym_name(),
region.getArguments());
op_builder.create<mlir::TFL::YieldOp>(region.getLoc(), call_op.getResults());
}
// TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions
// if we have while ops.
void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
mlir::SymbolTable symbol_table(module);
module.walk([&](mlir::TFL::WhileOp while_op) {
auto cond = symbol_table.lookup<mlir::FuncOp>(
while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
AddCallOpInWhileOpRegion(while_op.cond(), cond);
while_op.removeAttr("cond");
auto body = symbol_table.lookup<mlir::FuncOp>(
while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
AddCallOpInWhileOpRegion(while_op.body(), body);
while_op.removeAttr("body");
});
}
} // namespace
OwningModuleRef tflite::FlatBufferToMlir(
@ -1219,5 +1255,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
}
module.push_back(func_or_error.ConsumeValueOrDie());
}
AddRegionsForTflWhileOp(module);
return OwningModuleRef(module);
}

View File

@ -2,12 +2,18 @@
// This test is to test for unranked function output from input, the output type should be compatible with input type.
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32>
// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
// CHECK: return %0 : tensor<*xf32>
// CHECK: func private @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func private @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tf.While"}} {
// CHECK: %0 = "tfl.while"(%arg0) ( {
// CHECK: ^bb0(%arg1: tensor<*xf32>):
// CHECK: %[[RES0:.*]] = call @cond(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "tfl.yield"(%[[RES0]]) : (tensor<*xf32>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%arg1: tensor<*xf32>):
// CHECK: %[[RES1:.*]] = call @body(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: "tfl.yield"(%[[RES1]]) : (tensor<*xf32>) -> ()
// CHECK: }) : (tensor<1xf32>) -> tensor<*xf32>
// CHECK: return %0 : tensor<*xf32>
// CHECK: }
func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> {
%0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>

View File

@ -1,8 +1,14 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Check to see if function references in while loops are preserved
// TODO(b/138222071) Expect first output to be a scalar
// CHECK: %{{.*}}:2 = "tf.While"(%{{.*}}, %{{.*}}) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
// Check to see if nested regions in while loops are preserved
// CHECK: %{{.*}}:2 = "tfl.while"(%{{.*}}, %{{.*}}) ( {
// CHECK: ^bb0(%{{.*}}: tensor<*xi32>, %{{.*}}: tensor<*xf32>):
// CHECK: "tfl.yield"(%{{.*}}) : (tensor<*xi1>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%{{.*}}: tensor<*xi32>, %{{.*}}: tensor<*xf32>):
// CHECK: "tfl.yield"(%{{.*}}, %{{.*}}) : (tensor<*xi32>, tensor<*xf32>) -> ()
// CHECK: }) : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// While %arg0 is greater than zero, element wise add %arg1 with itself.
%0:2 = "tfl.while"(%arg0, %arg1) ( {