diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 697b161c16d..789d06b8ac9 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -307,7 +307,7 @@ cc_library( "transforms/optimize_functional_ops.cc", "transforms/prepare_composite_functions_tf.cc", "transforms/prepare_tf.cc", - "transforms/runtime_verify.cc", + "transforms/runtime_type_verify.cc", "transforms/split_merged_operands.cc", "transforms/trim_functions_tf.cc", "transforms/while_loop_outline.cc", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 2ed63fcc794..82d058964cb 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -36,8 +36,7 @@ struct PassConfig { form_clusters(false), unfold_batch_matmul(true), legalize_tf_while(true), - shape_inference(true), - runtime_verification(true) {} + shape_inference(true) {} // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // added, which produces TF Lite ops. @@ -66,8 +65,6 @@ struct PassConfig { bool legalize_tf_while; // Whether to do shape inference. bool shape_inference; - // Whether to do TFLite runtime verification. - bool runtime_verification; }; } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index f1ed97cb8e7..bc894d36e75 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -441,7 +441,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { mlir::tblgen::FmtContext verify_ctx; os << "::mlir::LogicalResult " << op.getCppClassName() - << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool " + << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool " "failure_on_operand_type_mismatch) {\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; verify_ctx.withOp("top"); @@ -466,25 +466,6 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { "operand"); GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(), "result"); - - for (auto &trait : op.getTraits()) { - if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) { - continue; - } - if (trait.getDef().getValueAsString("trait") != - "OpTrait::TFLRuntimeOpTrait") { - continue; - } - - auto *val = trait.getDef().getValue("tflRuntimePredicate"); - if (!val) continue; - - mlir::tblgen::Pred pred(dyn_cast(val->getValue())); - os << tgfmt( - " if (!($0)) {\n " - " return ::mlir::LogicalResult::Failure;\n }\n", - &verify_ctx, tgfmt(pred.getCondition(), &verify_ctx)); - } os << " return top.verify();\n}\n"; } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index ccad3cbb79e..b20e81aefa9 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> { let methods = [ StaticInterfaceMethod< [{Returns whether the op's operands/results are supported by runtime.}], - "LogicalResult", "VerifyTflRuntimeConstraints", + "LogicalResult", "VerifyTflRuntimeTypes", (ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch) >, ]; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index dc47c1efd29..45efe8f72f7 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -46,30 +46,6 @@ namespace mlir { #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc" namespace TFL { -// Returns true when the given two types have the same shape or broadcastable -// shape within the given rank. If any given shapes are non-static, this method -// returns true. -bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs, - int max_bcast_rank) { - // Ignore shape checking on the non-static shapes for model compatibility. - auto lhs_shaped_type = lhs.dyn_cast(); - if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true; - auto rhs_shaped_type = rhs.dyn_cast(); - if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true; - - if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape())) - return true; - - SmallVector result_shape; - if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(), - rhs_shaped_type.getShape(), - result_shape)) { - return false; - } - return lhs_shaped_type.getRank() <= max_bcast_rank && - rhs_shaped_type.getRank() <= max_bcast_rank; -} - //===----------------------------------------------------------------------===// // TensorFlowLiteDialect //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index d3726225f2a..2bf6ca2ab89 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -106,22 +106,6 @@ class DerivedShapeAttr : DerivedAttr<"ArrayRef", body>; class DerivedTFLiteTypeAttr : DerivedAttr<"tflite::TensorType", body>; -// TFL Runtime op trait predicate. -class TFL_RuntimePredOpTrait : - GenInternalOpTrait<"TFLRuntimeOpTrait"> { - Pred tflRuntimePredicate = pred; - string tflRuntimeDescription = desc; -} - -class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape< - int i, int j, int max_bcast_rank> : - TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j # - " have the same shape or broadcastable shapes within the rank " # - max_bcast_rank, - CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape(" - "$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j # - ").getType(), " # max_bcast_rank # ")">>; - // These additional types/type constraints here are used to decouple the ops // from runtime support for the ops. Prefer to use these types when defining // new TF_Ops for uniformity. @@ -376,9 +360,10 @@ an output element, this operation computes \\(y = |x|\\). let hasFolder = 1; } -def TFL_AddOp : TFL_Op<"add", [ - TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, - ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> { +def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape, + NoSideEffect, + Commutative, + TFL_GpuTargetOp]> { let summary = "Addition operator"; let description = [{ @@ -386,11 +371,11 @@ def TFL_AddOp : TFL_Op<"add", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs, - TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs, + ins AnyTensor:$lhs, + AnyTensor:$rhs, TFL_AFAttr:$fused_activation_function); - let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output); + let results = (outs AnyTensor:$output); let hasFolder = 1; diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 6dd44e666fb..a17cdda2a39 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags, if (pass_config.legalize_tf_while) { pm.addPass(mlir::TFL::CreateWhileOutlinePass()); } - pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); + pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); auto status = ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 239f4537920..7e9b1bdb711 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -9,20 +9,6 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { // CHECK: return } -// CHECK-LABEL: testAddHighDimsHaveSameShape -func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> { - // CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> - return %0 : tensor<1x2x3x4x5x6x7x8xi32> -} - -// CHECK-LABEL: testAddTooHighBroadcastableDims -func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> { - // expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}} - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> - return %0 : tensor<1x2x3x4x5x6xi32> -} - func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { %2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32> return %2: tensor<1xf32> diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 2cb269b5a3d..bff289d004a 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -173,8 +173,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass( mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul)); pass_manager->addNestedPass(mlir::createCanonicalizerPass()); - pass_manager->addPass( - mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); + pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass()); pass_manager->addPass(mlir::TFL::CreateOptimizePass()); // This pass operates on TensorFlow ops but is triggered after legalization // so that it can target constants introduced once TensorFlow Identity ops @@ -256,8 +255,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm, // TFLite dialect passes. pm.addPass(mlir::TFL::CreatePrepareTFPass(true)); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addPass( - mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true)); + pm.addPass(mlir::TFL::CreateLegalizeTFPass()); pm.addPass(mlir::TFL::CreateOptimizePass()); pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); @@ -270,7 +268,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm, pm.addPass(mlir::TFL::CreateWhileOutlinePass()); - pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); + pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); } // Registers a pass pipeline for the standard TFL passes. diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 35f9b24f807..038adebabef 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -214,7 +214,7 @@ int main(int argc, char **argv) { if (pass_config.legalize_tf_while) { pm.addPass(mlir::TFL::CreateWhileOutlinePass()); } - pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); + pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass()); std::string result; auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 90549ef8264..6a50ad4fce0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -70,21 +70,8 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn"; constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; // Legalize operations in functions. -class LegalizeTF : public PassWrapper { - public: - LegalizeTF() = default; - LegalizeTF(const LegalizeTF&) {} - explicit LegalizeTF(bool run_tfl_runtime_verification) { - run_tfl_runtime_verification_ = run_tfl_runtime_verification; - } - - /// Performs the lowering to TFLite dialect. +struct LegalizeTF : public PassWrapper { void runOnFunction() override; - - private: - Option run_tfl_runtime_verification_{ - *this, "run-tfl-runtime-verification", - llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)}; }; // Returns true if all tensor value in `values` has static shape and same shape. @@ -754,19 +741,13 @@ void LegalizeTF::runOnFunction() { // graph. target.addLegalOp(); target.addLegalOp(); - if (run_tfl_runtime_verification_) { - target.addDynamicallyLegalDialect( - Optional( - [](Operation* op) { - auto tfl_op = dyn_cast_or_null(op); - if (!tfl_op) return false; - return succeeded(tfl_op.VerifyTflRuntimeConstraints( - tfl_op.getOperation(), - /*failure_on_operand_type_mismatch=*/false)); - })); - } else { - target.addLegalDialect(); - } + target.addDynamicallyLegalDialect( + Optional([](Operation* op) { + auto tfl_op = dyn_cast_or_null(op); + if (!tfl_op) return false; + return succeeded(tfl_op.VerifyTflRuntimeTypes( + tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false)); + })); // Keep trying to convert. // TODO(karimnosseir): This is similar to what apply greedy patterns does. // Look if there is a function that tries until it converge. @@ -782,9 +763,8 @@ void LegalizeTF::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. -std::unique_ptr> CreateLegalizeTFPass( - bool run_tfl_runtime_verification) { - return std::make_unique(run_tfl_runtime_verification); +std::unique_ptr> CreateLegalizeTFPass() { + return std::make_unique(); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index a744a570929..c86ac567661 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -30,11 +30,7 @@ namespace TFL { class QuantizationSpecs; // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. -// When the given run_tfl_runtime_verification value is true, it will check each -// TFL builtin op towards the TFL runtime capability and the incompatible TF ops -// will be left in the graph without getting legalized. -std::unique_ptr> CreateLegalizeTFPass( - bool run_tfl_runtime_verification); +std::unique_ptr> CreateLegalizeTFPass(); // Creates an instance of the TensorFlow Lite dialect Optimize pass. std::unique_ptr> CreateOptimizePass(); @@ -95,8 +91,8 @@ std::unique_ptr> CreateLegalizeTFWhilePass(); // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. std::unique_ptr> CreateWhileOutlinePass(); -// Verifies runtime constraints. -std::unique_ptr> CreateRuntimeVerifyPass(); +// Verifies runtime supports types used. +std::unique_ptr> CreateRuntimeTypeVerifyPass(); } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc similarity index 63% rename from tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc rename to tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc index 3268329b1c1..3cb26a5a11e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc +++ b/tensorflow/compiler/mlir/lite/transforms/runtime_type_verify.cc @@ -22,32 +22,34 @@ namespace mlir { namespace TFL { namespace { -// This pass verifies that the TFL ops meet the TFL runtime constraints. -class RuntimeVerifyPass - : public mlir::PassWrapper { +// This pass verifies that the operands and results types are supported by +// TFLite runtime. +class RuntimeTypeVerifyPass + : public mlir::PassWrapper { public: - explicit RuntimeVerifyPass() {} + explicit RuntimeTypeVerifyPass() {} private: void runOnFunction() override; }; -void RuntimeVerifyPass::runOnFunction() { +void RuntimeTypeVerifyPass::runOnFunction() { getFunction().walk([&](TflRuntimeVerifyOpInterface op) { - if (failed(op.VerifyTflRuntimeConstraints( - op.getOperation(), /*failure_on_operand_type_mismatch=*/true))) + if (failed(op.VerifyTflRuntimeTypes( + op.getOperation(), + /*failure_on_operand_type_mismatch=*/true))) signalPassFailure(); }); } } // namespace -// Verifies TFL runtime constraints. -std::unique_ptr> CreateRuntimeVerifyPass() { - return std::make_unique(); +// Verifies runtime supports types used. +std::unique_ptr> CreateRuntimeTypeVerifyPass() { + return std::make_unique(); } -static PassRegistration pass("tfl-runtime-verify", - "TFLite runtime verification"); +static PassRegistration pass( + "tfl-runtime-verify", "TFLite runtime verification"); } // namespace TFL } // namespace mlir diff --git a/tensorflow/lite/testing/op_tests/hardswish.py b/tensorflow/lite/testing/op_tests/hardswish.py index 97dad804f3b..2816fe5bd64 100644 --- a/tensorflow/lite/testing/op_tests/hardswish.py +++ b/tensorflow/lite/testing/op_tests/hardswish.py @@ -48,17 +48,10 @@ def make_hardswish_tests(options): """Make a set of tests to do hardswish.""" # Chose a set of parameters - if options.run_with_flex: - # Only Flex is able to execute on the data bigger than four dimension. - test_parameters = [{ - "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], - [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], - }] - else: - test_parameters = [{ - "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], - [3, 15, 14, 3]], - }] + test_parameters = [{ + "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], + [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] def build_graph(parameters): inp = tf.compat.v1.placeholder(