diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 789d06b8ac9..697b161c16d 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -307,7 +307,7 @@ cc_library( "transforms/optimize_functional_ops.cc", "transforms/prepare_composite_functions_tf.cc", "transforms/prepare_tf.cc", - "transforms/runtime_type_verify.cc", + "transforms/runtime_verify.cc", "transforms/split_merged_operands.cc", "transforms/trim_functions_tf.cc", "transforms/while_loop_outline.cc", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 82d058964cb..2ed63fcc794 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -36,7 +36,8 @@ struct PassConfig { form_clusters(false), unfold_batch_matmul(true), legalize_tf_while(true), - shape_inference(true) {} + shape_inference(true), + runtime_verification(true) {} // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // added, which produces TF Lite ops. @@ -65,6 +66,8 @@ struct PassConfig { bool legalize_tf_while; // Whether to do shape inference. bool shape_inference; + // Whether to do TFLite runtime verification. + bool runtime_verification; }; } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index bc894d36e75..f1ed97cb8e7 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -441,7 +441,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { mlir::tblgen::FmtContext verify_ctx; os << "::mlir::LogicalResult " << op.getCppClassName() - << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool " + << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool " "failure_on_operand_type_mismatch) {\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; verify_ctx.withOp("top"); @@ -466,6 +466,25 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { "operand"); GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(), "result"); + + for (auto &trait : op.getTraits()) { + if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) { + continue; + } + if (trait.getDef().getValueAsString("trait") != + "OpTrait::TFLRuntimeOpTrait") { + continue; + } + + auto *val = trait.getDef().getValue("tflRuntimePredicate"); + if (!val) continue; + + mlir::tblgen::Pred pred(dyn_cast(val->getValue())); + os << tgfmt( + " if (!($0)) {\n " + " return ::mlir::LogicalResult::Failure;\n }\n", + &verify_ctx, tgfmt(pred.getCondition(), &verify_ctx)); + } os << " return top.verify();\n}\n"; } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index b20e81aefa9..ccad3cbb79e 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> { let methods = [ StaticInterfaceMethod< [{Returns whether the op's operands/results are supported by runtime.}], - "LogicalResult", "VerifyTflRuntimeTypes", + "LogicalResult", "VerifyTflRuntimeConstraints", (ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch) >, ]; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 45efe8f72f7..dc47c1efd29 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -46,6 +46,30 @@ namespace mlir { #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc" namespace TFL { +// Returns true when the given two types have the same shape or broadcastable +// shape within the given rank. If any given shapes are non-static, this method +// returns true. +bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs, + int max_bcast_rank) { + // Ignore shape checking on the non-static shapes for model compatibility. + auto lhs_shaped_type = lhs.dyn_cast(); + if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true; + auto rhs_shaped_type = rhs.dyn_cast(); + if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true; + + if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape())) + return true; + + SmallVector result_shape; + if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(), + rhs_shaped_type.getShape(), + result_shape)) { + return false; + } + return lhs_shaped_type.getRank() <= max_bcast_rank && + rhs_shaped_type.getRank() <= max_bcast_rank; +} + //===----------------------------------------------------------------------===// // TensorFlowLiteDialect //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 2bf6ca2ab89..cb1f8c63b4c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -106,6 +106,22 @@ class DerivedShapeAttr : DerivedAttr<"ArrayRef", body>; class DerivedTFLiteTypeAttr : DerivedAttr<"tflite::TensorType", body>; +// TFL Runtime op trait predicate. +class TFL_RuntimePredOpTrait : + GenInternalOpTrait<"TFLRuntimeOpTrait"> { + Pred tflRuntimePredicate = pred; + string tflRuntimeDescription = desc; +} + +class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape< + int i, int j, int max_bcast_rank> : + TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j # + " have the same shape or broadcastable shapes within the rank " # + max_bcast_rank, + CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape(" + "$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j # + ").getType(), " # max_bcast_rank # ")">>; + // These additional types/type constraints here are used to decouple the ops // from runtime support for the ops. Prefer to use these types when defining // new TF_Ops for uniformity. @@ -360,10 +376,9 @@ an output element, this operation computes \\(y = |x|\\). let hasFolder = 1; } -def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, - NoSideEffect, - Commutative, - TFL_GpuTargetOp]> { +def TFL_AddOp : TFL_Op<"add", [ + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, + ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> { let summary = "Addition operator"; let description = [{ @@ -371,11 +386,11 @@ def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs, + ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs, + TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs, TFL_AFAttr:$fused_activation_function); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output); let hasFolder = 1; diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index a17cdda2a39..6dd44e666fb 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, if (pass_config.legalize_tf_while) { pm.addPass(mlir::TFL::CreateWhileOutlinePass()); } - pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); + pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 7e9b1bdb711..239f4537920 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -9,6 +9,20 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { // CHECK: return } +// CHECK-LABEL: testAddHighDimsHaveSameShape +func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> { + // CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> + return %0 : tensor<1x2x3x4x5x6x7x8xi32> +} + +// CHECK-LABEL: testAddTooHighBroadcastableDims +func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { + // expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}} + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> + return %0 : tensor<1x2x3x4x5x6xi32> +} + func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { %2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32> return %2: tensor<1xf32> diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index bff289d004a..2cb269b5a3d 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -173,7 +173,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass( mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul)); pass_manager->addNestedPass(mlir::createCanonicalizerPass()); - pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass()); + pass_manager->addPass( + mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); pass_manager->addPass(mlir::TFL::CreateOptimizePass()); // This pass operates on TensorFlow ops but is triggered after legalization // so that it can target constants introduced once TensorFlow Identity ops @@ -255,7 +256,8 @@ void CreateTFLStandardPipeline(OpPassManager& pm, // TFLite dialect passes. pm.addPass(mlir::TFL::CreatePrepareTFPass(true)); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::TFL::CreateLegalizeTFPass()); + pm.addPass( + mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true)); pm.addPass(mlir::TFL::CreateOptimizePass()); pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); @@ -268,7 +270,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm, pm.addPass(mlir::TFL::CreateWhileOutlinePass()); - pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); + pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); } // Registers a pass pipeline for the standard TFL passes. diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 038adebabef..35f9b24f807 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -214,7 +214,7 @@ int main(int argc, char **argv) { if (pass_config.legalize_tf_while) { pm.addPass(mlir::TFL::CreateWhileOutlinePass()); } - pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); + pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); std::string result; auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 6a50ad4fce0..90549ef8264 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -70,8 +70,21 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn"; constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; // Legalize operations in functions. -struct LegalizeTF : public PassWrapper { +class LegalizeTF : public PassWrapper { + public: + LegalizeTF() = default; + LegalizeTF(const LegalizeTF&) {} + explicit LegalizeTF(bool run_tfl_runtime_verification) { + run_tfl_runtime_verification_ = run_tfl_runtime_verification; + } + + /// Performs the lowering to TFLite dialect. void runOnFunction() override; + + private: + Option run_tfl_runtime_verification_{ + *this, "run-tfl-runtime-verification", + llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)}; }; // Returns true if all tensor value in `values` has static shape and same shape. @@ -741,13 +754,19 @@ void LegalizeTF::runOnFunction() { // graph. target.addLegalOp(); target.addLegalOp(); - target.addDynamicallyLegalDialect( - Optional([](Operation* op) { - auto tfl_op = dyn_cast_or_null(op); - if (!tfl_op) return false; - return succeeded(tfl_op.VerifyTflRuntimeTypes( - tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false)); - })); + if (run_tfl_runtime_verification_) { + target.addDynamicallyLegalDialect( + Optional( + [](Operation* op) { + auto tfl_op = dyn_cast_or_null(op); + if (!tfl_op) return false; + return succeeded(tfl_op.VerifyTflRuntimeConstraints( + tfl_op.getOperation(), + /*failure_on_operand_type_mismatch=*/false)); + })); + } else { + target.addLegalDialect(); + } // Keep trying to convert. // TODO(karimnosseir): This is similar to what apply greedy patterns does. // Look if there is a function that tries until it converge. @@ -763,8 +782,9 @@ void LegalizeTF::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. -std::unique_ptr> CreateLegalizeTFPass() { - return std::make_unique(); +std::unique_ptr> CreateLegalizeTFPass( + bool run_tfl_runtime_verification) { + return std::make_unique(run_tfl_runtime_verification); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index c86ac567661..a744a570929 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -30,7 +30,11 @@ namespace TFL { class QuantizationSpecs; // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. -std::unique_ptr> CreateLegalizeTFPass(); +// When the given run_tfl_runtime_verification value is true, it will check each +// TFL builtin op towards the TFL runtime capability and the incompatible TF ops +// will be left in the graph without getting legalized. +std::unique_ptr> CreateLegalizeTFPass( + bool run_tfl_runtime_verification); // Creates an instance of the TensorFlow Lite dialect Optimize pass. std::unique_ptr> CreateOptimizePass(); @@ -91,8 +95,8 @@ std::unique_ptr> CreateLegalizeTFWhilePass(); // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. std::unique_ptr> CreateWhileOutlinePass(); -// Verifies runtime supports types used. -std::unique_ptr> CreateRuntimeTypeVerifyPass(); +// Verifies runtime constraints. +std::unique_ptr> CreateRuntimeVerifyPass(); } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc similarity index 63% rename from tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc rename to tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc index 3cb26a5a11e..3268329b1c1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc +++ b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc @@ -22,34 +22,32 @@ namespace mlir { namespace TFL { namespace { -// This pass verifies that the operands and results types are supported by -// TFLite runtime. -class RuntimeTypeVerifyPass - : public mlir::PassWrapper { +// This pass verifies that the TFL ops meet the TFL runtime constraints. +class RuntimeVerifyPass + : public mlir::PassWrapper { public: - explicit RuntimeTypeVerifyPass() {} + explicit RuntimeVerifyPass() {} private: void runOnFunction() override; }; -void RuntimeTypeVerifyPass::runOnFunction() { +void RuntimeVerifyPass::runOnFunction() { getFunction().walk([&](TflRuntimeVerifyOpInterface op) { - if (failed(op.VerifyTflRuntimeTypes( - op.getOperation(), - /*failure_on_operand_type_mismatch=*/true))) + if (failed(op.VerifyTflRuntimeConstraints( + op.getOperation(), /*failure_on_operand_type_mismatch=*/true))) signalPassFailure(); }); } } // namespace -// Verifies runtime supports types used. -std::unique_ptr> CreateRuntimeTypeVerifyPass() { - return std::make_unique(); +// Verifies TFL runtime constraints. +std::unique_ptr> CreateRuntimeVerifyPass() { + return std::make_unique(); } -static PassRegistration pass( - "tfl-runtime-verify", "TFLite runtime verification"); +static PassRegistration pass("tfl-runtime-verify", + "TFLite runtime verification"); } // namespace TFL } // namespace mlir diff --git a/tensorflow/lite/testing/op_tests/hardswish.py b/tensorflow/lite/testing/op_tests/hardswish.py index 2816fe5bd64..97dad804f3b 100644 --- a/tensorflow/lite/testing/op_tests/hardswish.py +++ b/tensorflow/lite/testing/op_tests/hardswish.py @@ -48,10 +48,17 @@ def make_hardswish_tests(options): """Make a set of tests to do hardswish.""" # Chose a set of parameters - test_parameters = [{ - "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], - [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], - }] + if options.run_with_flex: + # Only Flex is able to execute on the data bigger than four dimension. + test_parameters = [{ + "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], + [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + else: + test_parameters = [{ + "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], + [3, 15, 14, 3]], + }] def build_graph(parameters): inp = tf.compat.v1.placeholder(