Fix conversion failures in the tf-op-legalization pass
This CL fixes the broken fallback path to flex ops when the given operand types are mismatched. PiperOrigin-RevId: 304313364 Change-Id: Idffa1cc34dae15b0b18a68621d05700230b2a4c2
This commit is contained in:
parent
0e531f62a9
commit
79e3d6dad5
|
@ -409,10 +409,14 @@ static void GenOperandResultVerifier(raw_ostream &os,
|
|||
os << " (void)v;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||
<< " if (failure_on_operand_type_mismatch) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||
valueKind, desc)
|
||||
<< " } else {\n"
|
||||
<< " return ::mlir::LogicalResult::Failure;\n"
|
||||
<< " }\n"
|
||||
<< " }\n" // if
|
||||
<< " ++index;\n"
|
||||
<< " }\n"; // for
|
||||
|
@ -437,7 +441,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
|||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
|
||||
"failure_on_operand_type_mismatch) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
|
|
|
@ -86,7 +86,8 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
|||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
|
||||
"LogicalResult", "VerifyTflRuntimeTypes",
|
||||
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
|
|
@ -829,6 +829,14 @@ func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2x
|
|||
// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @packStringWithFlex(%arg0: tensor<2x!tf.string>, %arg1: tensor<2x!tf.string>) -> tensor<2x2x!tf.string> {
|
||||
%0 = "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
|
||||
return %0 : tensor<2x2x!tf.string>
|
||||
|
||||
// CHECK-LABEL: packStringWithFlex
|
||||
// CHECK: "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
|
||||
}
|
||||
|
||||
func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
|
|
|
@ -745,7 +745,8 @@ void LegalizeTF::runOnFunction() {
|
|||
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
|
||||
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
||||
if (!tfl_op) return false;
|
||||
return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation()));
|
||||
return succeeded(tfl_op.VerifyTflRuntimeTypes(
|
||||
tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
|
||||
}));
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
|
|
|
@ -34,7 +34,9 @@ class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
|
|||
|
||||
void RuntimeTypeVerifyPass::runOnFunction() {
|
||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||
if (failed(op.VerifyTflRuntimeTypes(op.getOperation())))
|
||||
if (failed(op.VerifyTflRuntimeTypes(
|
||||
op.getOperation(),
|
||||
/*failure_on_operand_type_mismatch=*/true)))
|
||||
signalPassFailure();
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue