Do not emit op errors on TF -> TFL legalization
Applying legalization patterns will emit unwanted, transient errors when the replaced TFLite ops do not meet the sanity checks. In order to ignore the transient errors, the following lines override a diagnostic handler with an no-op handler only while this pass runs. PiperOrigin-RevId: 316980278 Change-Id: Idef14e13f36ff0ee3c4bb1a401f92ba217042dbb
This commit is contained in:
parent
b186ba0334
commit
19c51afbf1
tensorflow/compiler/mlir/lite
@ -446,7 +446,7 @@ static void GenOperandResultVerifier(raw_ostream &os,
|
||||
auto desc =
|
||||
definit->getDef()->getValueAsString("tflRuntimeTypeDescription");
|
||||
|
||||
// Emit a loop to check all the dynamic values in the pack.
|
||||
// Emit a loop to check all operands.
|
||||
os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n",
|
||||
// Capitalize the first letter to match the function name
|
||||
valueKind.substr(0, 1).upper(), valueKind.substr(1),
|
||||
@ -455,14 +455,10 @@ static void GenOperandResultVerifier(raw_ostream &os,
|
||||
os << " (void)v;\n"
|
||||
<< " if (!("
|
||||
<< tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n"
|
||||
<< " if (failure_on_operand_type_mismatch) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
" return op->emitOpError(\"{0} #\") << index "
|
||||
"<< \" must be {1}, but got \" << v.getType();\n",
|
||||
valueKind, desc)
|
||||
<< " } else {\n"
|
||||
<< " return ::mlir::LogicalResult::Failure;\n"
|
||||
<< " }\n"
|
||||
<< " }\n" // if
|
||||
<< " ++index;\n"
|
||||
<< " }\n"; // for
|
||||
@ -487,8 +483,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
|
||||
"failure_on_operand_type_mismatch) {\n";
|
||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
||||
@ -529,11 +524,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
os << tgfmt(
|
||||
" if (!($0)) {\n "
|
||||
" if (failure_on_operand_type_mismatch) {\n"
|
||||
" return top.emitOpError(\"failed to verify that $1\");\n"
|
||||
" } else {\n"
|
||||
" return ::mlir::LogicalResult::Failure;\n }\n }\n",
|
||||
" if (!($0))\n"
|
||||
" return top.emitOpError(\"failed to verify that $1\");\n",
|
||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
|
||||
}
|
||||
os << " return top.verify();\n}\n";
|
||||
|
@ -94,8 +94,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeConstraints",
|
||||
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||
"LogicalResult", "VerifyTflRuntimeConstraints", (ins "Operation*":$op)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
@ -990,6 +990,13 @@ func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2:
|
||||
// CHECK: "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3x2xi32>) -> tensor<?x3x3x3x4xf32> {
|
||||
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<?x1x1x1x4xf32>, tensor<3xi32>, tensor<3x2xi32>) -> tensor<?x3x3x3x4xf32>
|
||||
return %0 : tensor<?x3x3x3x4xf32>
|
||||
// CHECK-LABEL: batch_to_space_nd_unsupported
|
||||
// CHECK: "tf.BatchToSpaceND"
|
||||
}
|
||||
|
||||
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
@ -28,9 +28,11 @@ limitations under the License.
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Threading.h"
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
@ -767,13 +769,26 @@ void LegalizeTF::runOnFunction() {
|
||||
[](Operation* op) {
|
||||
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
||||
if (!tfl_op) return false;
|
||||
return succeeded(tfl_op.VerifyTflRuntimeConstraints(
|
||||
tfl_op.getOperation(),
|
||||
/*failure_on_operand_type_mismatch=*/false));
|
||||
return succeeded(tfl_op.VerifyTflRuntimeConstraints(op));
|
||||
}));
|
||||
} else {
|
||||
target.addLegalDialect<TensorFlowLiteDialect>();
|
||||
}
|
||||
|
||||
// Ignore transient errors by registering an no-op handler.
|
||||
// Applying legalization patterns will emit unwanted, transient errors when
|
||||
// the replaced TFLite ops do not meet the sanity checks. In order to ignore
|
||||
// the transient errors, the following lines override a diagnostic handler
|
||||
// with an no-op handler only while this pass runs.
|
||||
uint64_t current_thread_id = llvm::get_threadid();
|
||||
ScopedDiagnosticHandler scoped_diag_handler(
|
||||
context, [¤t_thread_id](Diagnostic&) -> LogicalResult {
|
||||
// Consume only errors that are coming from the same thread in order not
|
||||
// to ignore errors from other passes that are running. Things running
|
||||
// in the pass manager can be multi-threaded.
|
||||
return success(current_thread_id == llvm::get_threadid());
|
||||
});
|
||||
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
// Look if there is a function that tries until it converge.
|
||||
|
@ -34,8 +34,7 @@ class RuntimeVerifyPass
|
||||
|
||||
void RuntimeVerifyPass::runOnFunction() {
|
||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||
if (failed(op.VerifyTflRuntimeConstraints(
|
||||
op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
|
||||
if (failed(op.VerifyTflRuntimeConstraints(op.getOperation())))
|
||||
signalPassFailure();
|
||||
});
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user