diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 83c95c03c8b..bc894d36e75 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -409,10 +409,14 @@ static void GenOperandResultVerifier(raw_ostream &os, os << " (void)v;\n" << " if (!(" << tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n" + << " if (failure_on_operand_type_mismatch) {\n" << formatv( " return op->emitOpError(\"{0} #\") << index " "<< \" must be {1}, but got \" << v.getType();\n", valueKind, desc) + << " } else {\n" + << " return ::mlir::LogicalResult::Failure;\n" + << " }\n" << " }\n" // if << " ++index;\n" << " }\n"; // for @@ -437,7 +441,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { mlir::tblgen::FmtContext verify_ctx; os << "::mlir::LogicalResult " << op.getCppClassName() - << "::VerifyTflRuntimeTypes(::mlir::Operation *op) {\n"; + << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool " + "failure_on_operand_type_mismatch) {\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; verify_ctx.withOp("top"); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index db0bef39358..b20e81aefa9 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -86,7 +86,8 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> { let methods = [ StaticInterfaceMethod< [{Returns whether the op's operands/results are supported by runtime.}], - "LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op) + "LogicalResult", "VerifyTflRuntimeTypes", + (ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch) >, ]; } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 7db46f778fa..7e9b1bdb711 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -829,6 +829,14 @@ func @pack3Tensors(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2x // CHECK: "tfl.pack"(%arg0, %arg1, %arg2) {axis = 1 : i32, values_count = 3 : i32} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> } +func @packStringWithFlex(%arg0: tensor<2x!tf.string>, %arg1: tensor<2x!tf.string>) -> tensor<2x2x!tf.string> { + %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string> + return %0 : tensor<2x2x!tf.string> + +// CHECK-LABEL: packStringWithFlex +// CHECK: "tf.Pack"(%arg0, %arg1) : (tensor<2x!tf.string>, tensor<2x!tf.string>) -> tensor<2x2x!tf.string> +} + func @packNegAxis(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2 : tensor<2xi32>) -> tensor<2x3xi32> { %0 = "tf.Pack"(%arg0, %arg1, %arg2) {axis = -1 : i64} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 4d40eec7a1b..98501aaa803 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -745,7 +745,8 @@ void LegalizeTF::runOnFunction() { Optional([](Operation* op) { auto tfl_op = dyn_cast_or_null(op); if (!tfl_op) return false; - return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation())); + return succeeded(tfl_op.VerifyTflRuntimeTypes( + tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false)); })); // Keep trying to convert. // TODO(karimnosseir): This is similar to what apply greedy patterns does. diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc index 92eb7023438..d103209ffd9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc +++ b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc @@ -34,7 +34,9 @@ class RuntimeTypeVerifyPass : public mlir::FunctionPass { void RuntimeTypeVerifyPass::runOnFunction() { getFunction().walk([&](TflRuntimeVerifyOpInterface op) { - if (failed(op.VerifyTflRuntimeTypes(op.getOperation()))) + if (failed(op.VerifyTflRuntimeTypes( + op.getOperation(), + /*failure_on_operand_type_mismatch=*/true))) signalPassFailure(); }); }