Import while from flatbuffer as "tfl.while"
Rather than importing as TF while, use TFL while instead. This is follow up from previously where TFL while didn't exist. PiperOrigin-RevId: 350682295 Change-Id: I9decd235a80417e6b4b37a2c4ae1a59bcb4b4faa
This commit is contained in:
parent
7719705955
commit
beff361b55
@ -301,7 +301,7 @@ StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
|
||||
return std::string("tf.If");
|
||||
}
|
||||
if (builtin_code == tflite::BuiltinOperator_WHILE) {
|
||||
return std::string("tf.While");
|
||||
return std::string("tfl.while");
|
||||
}
|
||||
|
||||
llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
|
||||
@ -590,9 +590,7 @@ llvm::SmallVector<mlir::NamedAttribute, 4> ConvertSubgraphIdxsToFunctionAttrs(
|
||||
auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx));
|
||||
|
||||
return {builder.getNamedAttr("cond", cond_attr),
|
||||
builder.getNamedAttr("body", body_attr),
|
||||
// TODO(b/139667752): Analyze statelessness correctly
|
||||
builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))};
|
||||
builder.getNamedAttr("body", body_attr)};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
@ -714,6 +712,14 @@ StatusOr<Operation*> ConvertOp(
|
||||
TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
|
||||
builder));
|
||||
}
|
||||
if (op_name == "tfl.while") {
|
||||
// Adds two empty regions for "tfl.while". We will fill the regions after
|
||||
// creating the callee functions because the "tfl.while" input/output types
|
||||
// may be different with the callee functions, and the call ops need to sync
|
||||
// with callee function types.
|
||||
op_state.addRegion();
|
||||
op_state.addRegion();
|
||||
}
|
||||
if (op_name == "tfl.unidirectional_sequence_lstm") {
|
||||
TF_CHECK_OK(AddOpIntermediatesForLstm(op, intermediate_types, op_state, loc,
|
||||
builder));
|
||||
@ -755,7 +761,8 @@ StatusOr<Operation*> ConvertOp(
|
||||
}
|
||||
op_state.addAttributes(attrs);
|
||||
|
||||
// Handle the conversion from subgraph index to functions for If and While
|
||||
// Handle the conversion from subgraph index to functions for If and While. We
|
||||
// will add CallOps in the region to call the functions later for While.
|
||||
auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs(
|
||||
op.builtin_options, func_names, builder);
|
||||
op_state.addAttributes(function_ref_attrs);
|
||||
@ -1163,6 +1170,35 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
||||
}
|
||||
return subgraph.name;
|
||||
}
|
||||
|
||||
// Adds a CallOp in `region` to call the `func` and returns the results of
|
||||
// CallOp.
|
||||
void AddCallOpInWhileOpRegion(mlir::Region& region, mlir::FuncOp func) {
|
||||
OpBuilder op_builder{region};
|
||||
region.push_back(new mlir::Block());
|
||||
region.addArguments(func.getType().getInputs());
|
||||
op_builder.setInsertionPointToStart(®ion.front());
|
||||
auto call_op = op_builder.create<mlir::CallOp>(
|
||||
region.getLoc(), func.getType().getResults(), func.sym_name(),
|
||||
region.getArguments());
|
||||
op_builder.create<mlir::TFL::YieldOp>(region.getLoc(), call_op.getResults());
|
||||
}
|
||||
|
||||
// TFL::WhileOp has regions, so we add CallOp to call the FuncOp in the regions
|
||||
// if we have while ops.
|
||||
void AddRegionsForTflWhileOp(mlir::ModuleOp module) {
|
||||
mlir::SymbolTable symbol_table(module);
|
||||
module.walk([&](mlir::TFL::WhileOp while_op) {
|
||||
auto cond = symbol_table.lookup<mlir::FuncOp>(
|
||||
while_op->getAttr("cond").cast<mlir::FlatSymbolRefAttr>().getValue());
|
||||
AddCallOpInWhileOpRegion(while_op.cond(), cond);
|
||||
while_op.removeAttr("cond");
|
||||
auto body = symbol_table.lookup<mlir::FuncOp>(
|
||||
while_op->getAttr("body").cast<mlir::FlatSymbolRefAttr>().getValue());
|
||||
AddCallOpInWhileOpRegion(while_op.body(), body);
|
||||
while_op.removeAttr("body");
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OwningModuleRef tflite::FlatBufferToMlir(
|
||||
@ -1219,5 +1255,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
}
|
||||
module.push_back(func_or_error.ConsumeValueOrDie());
|
||||
}
|
||||
AddRegionsForTflWhileOp(module);
|
||||
return OwningModuleRef(module);
|
||||
}
|
||||
|
@ -2,12 +2,18 @@
|
||||
|
||||
// This test is to test for unranked function output from input, the output type should be compatible with input type.
|
||||
|
||||
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %0 : tensor<*xf32>
|
||||
// CHECK: func private @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func private @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tf.While"}} {
|
||||
// CHECK: %0 = "tfl.while"(%arg0) ( {
|
||||
// CHECK: ^bb0(%arg1: tensor<*xf32>):
|
||||
// CHECK: %[[RES0:.*]] = call @cond(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "tfl.yield"(%[[RES0]]) : (tensor<*xf32>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%arg1: tensor<*xf32>):
|
||||
// CHECK: %[[RES1:.*]] = call @body(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: "tfl.yield"(%[[RES1]]) : (tensor<*xf32>) -> ()
|
||||
// CHECK: }) : (tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %0 : tensor<*xf32>
|
||||
// CHECK: }
|
||||
func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
|
@ -1,8 +1,14 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
|
||||
|
||||
// Check to see if function references in while loops are preserved
|
||||
// TODO(b/138222071) Expect first output to be a scalar
|
||||
// CHECK: %{{.*}}:2 = "tf.While"(%{{.*}}, %{{.*}}) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
|
||||
// Check to see if nested regions in while loops are preserved
|
||||
// CHECK: %{{.*}}:2 = "tfl.while"(%{{.*}}, %{{.*}}) ( {
|
||||
// CHECK: ^bb0(%{{.*}}: tensor<*xi32>, %{{.*}}: tensor<*xf32>):
|
||||
// CHECK: "tfl.yield"(%{{.*}}) : (tensor<*xi1>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%{{.*}}: tensor<*xi32>, %{{.*}}: tensor<*xf32>):
|
||||
// CHECK: "tfl.yield"(%{{.*}}, %{{.*}}) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||
// CHECK: }) : (tensor<i32>, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>)
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// While %arg0 is greater than zero, element wise add %arg1 with itself.
|
||||
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
||||
|
Loading…
Reference in New Issue
Block a user