Fix conversion failures in the tf-op-legalization pass

This CL fixes the broken fallback path to flex ops when the given operand types are mismatched.

PiperOrigin-RevId: 304313364
Change-Id: Idffa1cc34dae15b0b18a68621d05700230b2a4c2
This commit is contained in:
Jaesung Chung 2020-04-01 19:57:12 -07:00 committed by TensorFlower Gardener
parent 0e531f62a9
commit 79e3d6dad5
5 changed files with 21 additions and 4 deletions

View File

@ -409,10 +409,14 @@ static void GenOperandResultVerifier(raw_ostream &os,
os << " (void)v;\n"
<< " if (!("
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
<< " if (failure_on_operand_type_mismatch) {\n"
<< formatv(
" return op->emitOpError(\"{0} #\") << index "
"<< \" must be {1}, but got \" << v.getType();\n",
valueKind, desc)
<< " } else {\n"
<< " return ::mlir::LogicalResult::Failure;\n"
<< " }\n"
<< " }\n" // if
<< " ++index;\n"
<< " }\n"; // for
@ -437,7 +441,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n";
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
"failure_on_operand_type_mismatch) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top");

View File

@ -86,7 +86,8 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let methods = [
StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op)
"LogicalResult", "VerifyTflRuntimeTypes",
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
>,
];
}

View File

@ -829,6 +829,14 @@ func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2x
// CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
}
func @packStringWithFlex(%arg0: tensor<2x!tf.string>, %arg1: tensor<2x!tf.string>) -> tensor<2x2x!tf.string> {
%0 = "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
return %0 : tensor<2x2x!tf.string>
// CHECK-LABEL: packStringWithFlex
// CHECK: "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string>
}
func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> {
%0 = "tf.Pack"(%arg0, %arg1, %arg2) {axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>

View File

@ -745,7 +745,8 @@ void LegalizeTF::runOnFunction() {
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
if (!tfl_op) return false;
return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation()));
return succeeded(tfl_op.VerifyTflRuntimeTypes(
tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
}));
// Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does.

View File

@ -34,7 +34,9 @@ class RuntimeTypeVerifyPass : public mlir::FunctionPass<RuntimeTypeVerifyPass> {
void RuntimeTypeVerifyPass::runOnFunction() {
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
if (failed(op.VerifyTflRuntimeTypes(op.getOperation())))
if (failed(op.VerifyTflRuntimeTypes(
op.getOperation(),
/*failure_on_operand_type_mismatch=*/true)))
signalPassFailure();
});
}