Add TFL while op with regions
Add TFL while op that uses regions for the conditional and body of the loop. This op has the same semantics as TF while but uses regions instead. This is just the smallest start to using the TFL while op: * the LoopLike interface as the moment is sparse and needs to be filled in; * the legalization from TF while to TFL while is also pending (and so currently supports both forms as legal during translation to flatbuffer); * extra verification is needed for the regions of the while; * missing pretty form; * import from flatbuffer will not generate this form; The same export test is used for tfl.while as for tf.While op (modulo the names). PiperOrigin-RevId: 292557025 Change-Id: I3b1315361a51706273594aba6e0c6fe0fc321869
This commit is contained in:
parent
cbf2288d95
commit
e3130b5bbd
@ -30,6 +30,7 @@ filegroup(
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -211,8 +212,6 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
@ -221,6 +220,10 @@ cc_library(
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -400,6 +403,7 @@ genrule(
|
||||
srcs = [
|
||||
"ir/tfl_ops.td",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
],
|
||||
outs = [
|
||||
|
@ -90,6 +90,7 @@ using mlir::MLIRContext;
|
||||
using mlir::ModuleOp;
|
||||
using mlir::NoneType;
|
||||
using mlir::Operation;
|
||||
using mlir::Region;
|
||||
using mlir::StringAttr;
|
||||
using mlir::TensorType;
|
||||
using mlir::TranslateFromMLIRRegistration;
|
||||
@ -309,7 +310,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
|
||||
static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
|
||||
::mlir::Operation* inst) {
|
||||
// We pass empty string for the original node_def name since Flex runtime
|
||||
// does not care about this being set correctly on node_def. There is no
|
||||
@ -425,6 +426,11 @@ class Translator {
|
||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Build while operator where cond & body are regions.
|
||||
BufferOffset<tflite::Operator> BuildWhileOperator(
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Builds custom operators.
|
||||
// Templated on a) data type of custom_option to be stored into flatbuffer,
|
||||
// and b) TFL custom op type.
|
||||
@ -472,7 +478,10 @@ class Translator {
|
||||
Operation* inst, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
|
||||
// Build a subgraph with a given name out of the region either corresponding
|
||||
// to a function's body or while op.
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
|
||||
const std::string& name, Region* region);
|
||||
|
||||
// Builds Metadata with the given `name` and buffer `content`.
|
||||
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
|
||||
@ -494,6 +503,12 @@ class Translator {
|
||||
// Returns a unique name for `val`.
|
||||
std::string UniqueName(mlir::Value val);
|
||||
|
||||
// Returns the names of the subgraphs corresponding the regions of the op. The
|
||||
// names are supposed to be unique as the op name is unique and the suffix is
|
||||
// not a valid name.
|
||||
std::string GetWhileBodyName(mlir::TFL::WhileOp while_op);
|
||||
std::string GetWhileCondName(mlir::TFL::WhileOp while_op);
|
||||
|
||||
ModuleOp module_;
|
||||
|
||||
tensorflow::OpOrArgNameMapper& name_mapper_;
|
||||
@ -687,6 +702,30 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
std::string Translator::GetWhileBodyName(mlir::TFL::WhileOp while_op) {
|
||||
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$body").str();
|
||||
}
|
||||
|
||||
std::string Translator::GetWhileCondName(mlir::TFL::WhileOp while_op) {
|
||||
return (name_mapper_.GetUniqueName(while_op.getOperation()) + "$cond").str();
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
|
||||
int body_subgraph_index = subgraph_index_map_.at(GetWhileBodyName(op));
|
||||
int cond_subgraph_index = subgraph_index_map_.at(GetWhileCondName(op));
|
||||
auto builtin_options = tflite::CreateWhileOptions(
|
||||
builder_, cond_subgraph_index, body_subgraph_index)
|
||||
.Union();
|
||||
auto inputs = builder_.CreateVector(operands);
|
||||
auto outputs = builder_.CreateVector(results);
|
||||
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
|
||||
tflite::BuiltinOptions_WhileOptions,
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
template <typename CustomOptionType, typename TFLOp>
|
||||
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
|
||||
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||
@ -908,6 +947,10 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
|
||||
results);
|
||||
}
|
||||
if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
|
||||
return BuildWhileOperator(whileOp, operands, results);
|
||||
}
|
||||
|
||||
inst->emitOpError("is not a supported TFLite op");
|
||||
return llvm::None;
|
||||
}
|
||||
@ -944,7 +987,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
// we emit op as flex.
|
||||
// if custom is enabled
|
||||
// we emit the op as custom.
|
||||
auto node_def = getTensorFlowNodeDef(inst);
|
||||
auto node_def = GetTensorFlowNodeDef(inst);
|
||||
if (!node_def) {
|
||||
return llvm::None;
|
||||
}
|
||||
@ -1047,9 +1090,12 @@ bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
|
||||
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
||||
const std::string& name, Region* region) {
|
||||
bool has_input_attr = false;
|
||||
InitializeNamesFromAttribute(fn, &has_input_attr);
|
||||
if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
|
||||
InitializeNamesFromAttribute(fn, &has_input_attr);
|
||||
}
|
||||
std::vector<BufferOffset<tflite::Tensor>> tensors;
|
||||
llvm::DenseMap<Value, int> tensor_index_map;
|
||||
|
||||
@ -1081,7 +1127,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
};
|
||||
|
||||
std::vector<BufferOffset<tflite::Operator>> operators;
|
||||
auto& bb = fn.getBlocks().front();
|
||||
auto& bb = region->front();
|
||||
|
||||
// Main function's arguments are first passed to `input` op so they don't
|
||||
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
|
||||
@ -1141,7 +1187,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
return tflite::CreateSubGraph(
|
||||
builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
|
||||
builder_.CreateVector(outputs), builder_.CreateVector(operators),
|
||||
/*name=*/builder_.CreateString(fn.getName().str()));
|
||||
/*name=*/builder_.CreateString(name));
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
|
||||
@ -1184,35 +1230,45 @@ Optional<std::string> Translator::Translate(
|
||||
}
|
||||
|
||||
Optional<std::string> Translator::TranslateInternal() {
|
||||
// Create a list of functions in the module with main function being the
|
||||
// first function in the list. This is required as the first subgraph in the
|
||||
// model is entry point for the model.
|
||||
std::vector<FuncOp> functions;
|
||||
functions.reserve(std::distance(module_.begin(), module_.end()));
|
||||
// A list of named regions in the module with main function being the first in
|
||||
// the list. The main function is required as the first subgraph in the model
|
||||
// is entry point for the model.
|
||||
std::vector<std::pair<std::string, Region*>> named_regions;
|
||||
named_regions.reserve(std::distance(module_.begin(), module_.end()));
|
||||
|
||||
int subgraph_idx = 0;
|
||||
FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
|
||||
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
|
||||
functions.push_back(main_fn);
|
||||
for (auto fn : module_.getOps<FuncOp>()) {
|
||||
if (fn == main_fn) continue;
|
||||
named_regions.emplace_back("main", &main_fn.getBody());
|
||||
// Walk over the module collection ops with functions and while ops.
|
||||
module_.walk([&](Operation* op) {
|
||||
if (auto fn = dyn_cast<FuncOp>(op)) {
|
||||
if (fn != main_fn) {
|
||||
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
||||
named_regions.emplace_back(fn.getName(), &fn.getBody());
|
||||
}
|
||||
} else if (auto wo = dyn_cast<mlir::TFL::WhileOp>(op)) {
|
||||
std::string name = GetWhileCondName(wo);
|
||||
subgraph_index_map_[name] = subgraph_idx++;
|
||||
named_regions.emplace_back(GetWhileCondName(wo), &wo.cond());
|
||||
name = GetWhileBodyName(wo);
|
||||
subgraph_index_map_[name] = subgraph_idx++;
|
||||
named_regions.emplace_back(name, &wo.body());
|
||||
}
|
||||
});
|
||||
|
||||
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
||||
functions.push_back(fn);
|
||||
}
|
||||
|
||||
// Build subgraph for each of the functions.
|
||||
// Build subgraph for each of the named regions.
|
||||
std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
|
||||
subgraphs.reserve(functions.size());
|
||||
subgraphs.reserve(named_regions.size());
|
||||
int first_failed_func = -1;
|
||||
for (int i = 0; i < functions.size(); ++i) {
|
||||
auto subgraph_or = BuildSubGraph(functions[i]);
|
||||
for (auto it : llvm::enumerate(named_regions)) {
|
||||
auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
|
||||
if (!subgraph_or) {
|
||||
if (first_failed_func == -1)
|
||||
// Record the index of the first function that cannot be converted.
|
||||
// Record the index of the first region that cannot be converted.
|
||||
// Keep looping through all subgraphs in the module to make sure that
|
||||
// we collect the list of missing ops from the entire module.
|
||||
first_failed_func = i;
|
||||
first_failed_func = it.index();
|
||||
} else {
|
||||
subgraphs.push_back(*subgraph_or);
|
||||
}
|
||||
@ -1233,9 +1289,10 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
"-emit-custom-ops flag): " +
|
||||
failed_custom_ops_list;
|
||||
|
||||
return functions[first_failed_func].emitError("failed while converting: '")
|
||||
<< functions[first_failed_func].getName() << "\'\n"
|
||||
<< err,
|
||||
auto& failed_region = named_regions[first_failed_func];
|
||||
return failed_region.second->getParentOp()->emitError()
|
||||
<< "failed while converting: '" << failed_region.first
|
||||
<< "': " << err,
|
||||
llvm::None;
|
||||
}
|
||||
|
||||
|
@ -1736,6 +1736,20 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
Region &WhileOp::getLoopBody() { return body(); }
|
||||
|
||||
bool WhileOp::isDefinedOutsideOfLoop(Value value) {
|
||||
// TODO(jpienaar): This is to overly conservative and disables anything other
|
||||
// than constant hoisting initially.
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *>) {
|
||||
// TODO(jpienaar): Fail any hoisting until post test case and refining
|
||||
// isDefinedOutsideOfLoop.
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#define TFL_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Transforms/LoopLikeInterface.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
||||
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
|
||||
|
||||
@ -3370,4 +3371,46 @@ def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
|
||||
let results = (outs TensorOf<[F32, I32]>:$output);
|
||||
}
|
||||
|
||||
def TFL_YieldOp : Op<TFL_Dialect, "yield", [Terminator]> {
|
||||
let summary = "Yield operation";
|
||||
let description = [{
|
||||
The "yield" operation represents a return operation within the conditional
|
||||
and body of structured control flow (e.g., while). The operation takes
|
||||
variable number of operands and produces no results. The operand number and
|
||||
types must match the signature of the region that contains the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
}
|
||||
|
||||
def TFL_WhileOp : Op<TFL_Dialect, "while", [
|
||||
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
// Make isolated from above to force values through operands to simplify
|
||||
// exporting to subgraphs.
|
||||
IsolatedFromAbove]> {
|
||||
let summary = [{While loop}];
|
||||
|
||||
let description = [{
|
||||
output = input; while (cond(output)) { output = body(output) }
|
||||
|
||||
input: A list of input tensors whose types are T.
|
||||
output: A list of output tensors whose types are T.
|
||||
cond: A region takes 'input' and returns a boolean scalar tensor.
|
||||
body: A region that takes a list of tensors and returns another
|
||||
list of tensors. Both lists have the same types.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyTensor>:$input,
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
// op.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
|
||||
);
|
||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||
|
||||
let results = (outs Variadic<AnyTensor>:$output);
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -0,0 +1,214 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: WHILE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: GREATER,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: SUB,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "WhileOp1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1 ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "WhileOp2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2, 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: WhileOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: cond_subgraph_index: 1,
|
||||
// CHECK-NEXT: body_subgraph_index: 2
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: BOOL,
|
||||
// CHECK-NEXT: buffer: 8,
|
||||
// CHECK-NEXT: name: "tfl.greater",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "WhileOp$cond"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 9,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 10,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 11,
|
||||
// CHECK-NEXT: name: "Const1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 12,
|
||||
// CHECK-NEXT: name: "tfl.sub",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 13,
|
||||
// CHECK-NEXT: name: "tfl.add",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3, 4 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: SubOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 3,
|
||||
// CHECK-NEXT: inputs: [ 1, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: AddOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "WhileOp$body"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 1, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @main(%arg0 : tensor<i32>, %arg1 : tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0:2 = "tfl.while"(%arg0, %arg1) (
|
||||
// cond
|
||||
{
|
||||
^bb0(%condArg0: tensor<*xi32>, %condArg1: tensor<*xf32>):
|
||||
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tfl.greater"(%condArg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
},
|
||||
// body
|
||||
{
|
||||
^bb0(%bodyArg0: tensor<*xi32>, %bodyArg1: tensor<*xf32>):
|
||||
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tfl.sub"(%bodyArg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = tfl.add %bodyArg1, %bodyArg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
"tfl.yield"(%1, %2) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||
}
|
||||
) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
5
third_party/mlir/BUILD
vendored
5
third_party/mlir/BUILD
vendored
@ -2550,6 +2550,9 @@ exports_files(
|
||||
)
|
||||
|
||||
exports_files(
|
||||
["include/mlir/Analysis/InferTypeOpInterface.td"],
|
||||
[
|
||||
"include/mlir/Analysis/InferTypeOpInterface.td",
|
||||
"include/mlir/Transforms/LoopLikeInterface.td",
|
||||
],
|
||||
visibility = ["@llvm-project//mlir:friends"],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user