Fix conversion failures in the tf-op-legalization pass
This CL fixes the broken fallback path to flex ops when the given operand types are mismatched. PiperOrigin-RevId: 304313364 Change-Id: Idffa1cc34dae15b0b18a68621d05700230b2a4c2
This commit is contained in:
parent
0e531f62a9
commit
79e3d6dad5
|
@ -409,10 +409,14 @@ static void GenOperandResultVerifier(raw_ostream &os,
|
||||||
os << " (void)v;\n"
|
os << " (void)v;\n"
|
||||||
<< " if (!("
|
<< " if (!("
|
||||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||||
|
<< " if (failure_on_operand_type_mismatch) {\n"
|
||||||
<< formatv(
|
<< formatv(
|
||||||
" return op->emitOpError(\"{0} #\") << index "
|
" return op->emitOpError(\"{0} #\") << index "
|
||||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||||
valueKind, desc)
|
valueKind, desc)
|
||||||
|
<< " } else {\n"
|
||||||
|
<< " return ::mlir::LogicalResult::Failure;\n"
|
||||||
|
<< " }\n"
|
||||||
<< " }\n" // if
|
<< " }\n" // if
|
||||||
<< " ++index;\n"
|
<< " ++index;\n"
|
||||||
<< " }\n"; // for
|
<< " }\n"; // for
|
||||||
|
@ -437,7 +441,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||||
|
|
||||||
mlir::tblgen::FmtContext verify_ctx;
|
mlir::tblgen::FmtContext verify_ctx;
|
||||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
|
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
|
||||||
|
"failure_on_operand_type_mismatch) {\n";
|
||||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||||
verify_ctx.withOp("top");
|
verify_ctx.withOp("top");
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,8 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||||
let methods = [
|
let methods = [
|
||||||
StaticInterfaceMethod<
|
StaticInterfaceMethod<
|
||||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||||
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
|
"LogicalResult", "VerifyTflRuntimeTypes",
|
||||||
|
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||||
>,
|
>,
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
|
@ -829,6 +829,14 @@ func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2x
|
||||||
// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @packStringWithFlex(%arg0: tensor<2x!tf.string>, %arg1: tensor<2x!tf.string>) -> tensor<2x2x!tf.string> {
|
||||||
|
%0 = "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
|
||||||
|
return %0 : tensor<2x2x!tf.string>
|
||||||
|
|
||||||
|
// CHECK-LABEL: packStringWithFlex
|
||||||
|
// CHECK: "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
|
||||||
|
}
|
||||||
|
|
||||||
func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
|
func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
|
||||||
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
|
||||||
return %0 : tensor<2x3xi32>
|
return %0 : tensor<2x3xi32>
|
||||||
|
|
|
@ -745,7 +745,8 @@ void LegalizeTF::runOnFunction() {
|
||||||
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
|
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
|
||||||
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
||||||
if (!tfl_op) return false;
|
if (!tfl_op) return false;
|
||||||
return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation()));
|
return succeeded(tfl_op.VerifyTflRuntimeTypes(
|
||||||
|
tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
|
||||||
}));
|
}));
|
||||||
// Keep trying to convert.
|
// Keep trying to convert.
|
||||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||||
|
|
|
@ -34,7 +34,9 @@ class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
|
||||||
|
|
||||||
void RuntimeTypeVerifyPass::runOnFunction() {
|
void RuntimeTypeVerifyPass::runOnFunction() {
|
||||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||||
if (failed(op.VerifyTflRuntimeTypes(op.getOperation())))
|
if (failed(op.VerifyTflRuntimeTypes(
|
||||||
|
op.getOperation(),
|
||||||
|
/*failure_on_operand_type_mismatch=*/true)))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue