Add type and shape constraints to TFLite Add op
PiperOrigin-RevId: 305591811 Change-Id: I12a9b832fc9ee385c77dc1a19f24aa9debcca641
This commit is contained in:
parent
c916d840ed
commit
72600af909
|
@ -307,7 +307,7 @@ cc_library(
|
||||||
"transforms/optimize_functional_ops.cc",
|
"transforms/optimize_functional_ops.cc",
|
||||||
"transforms/prepare_composite_functions_tf.cc",
|
"transforms/prepare_composite_functions_tf.cc",
|
||||||
"transforms/prepare_tf.cc",
|
"transforms/prepare_tf.cc",
|
||||||
"transforms/runtime_verify.cc",
|
"transforms/runtime_type_verify.cc",
|
||||||
"transforms/split_merged_operands.cc",
|
"transforms/split_merged_operands.cc",
|
||||||
"transforms/trim_functions_tf.cc",
|
"transforms/trim_functions_tf.cc",
|
||||||
"transforms/while_loop_outline.cc",
|
"transforms/while_loop_outline.cc",
|
||||||
|
|
|
@ -36,8 +36,7 @@ struct PassConfig {
|
||||||
form_clusters(false),
|
form_clusters(false),
|
||||||
unfold_batch_matmul(true),
|
unfold_batch_matmul(true),
|
||||||
legalize_tf_while(true),
|
legalize_tf_while(true),
|
||||||
shape_inference(true),
|
shape_inference(true) {}
|
||||||
runtime_verification(true) {}
|
|
||||||
|
|
||||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||||
// added, which produces TF Lite ops.
|
// added, which produces TF Lite ops.
|
||||||
|
@ -66,8 +65,6 @@ struct PassConfig {
|
||||||
bool legalize_tf_while;
|
bool legalize_tf_while;
|
||||||
// Whether to do shape inference.
|
// Whether to do shape inference.
|
||||||
bool shape_inference;
|
bool shape_inference;
|
||||||
// Whether to do TFLite runtime verification.
|
|
||||||
bool runtime_verification;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
|
|
@ -441,7 +441,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||||
|
|
||||||
mlir::tblgen::FmtContext verify_ctx;
|
mlir::tblgen::FmtContext verify_ctx;
|
||||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
|
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
|
||||||
"failure_on_operand_type_mismatch) {\n";
|
"failure_on_operand_type_mismatch) {\n";
|
||||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||||
verify_ctx.withOp("top");
|
verify_ctx.withOp("top");
|
||||||
|
@ -466,25 +466,6 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
||||||
"operand");
|
"operand");
|
||||||
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||||
"result");
|
"result");
|
||||||
|
|
||||||
for (auto &trait : op.getTraits()) {
|
|
||||||
if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (trait.getDef().getValueAsString("trait") !=
|
|
||||||
"OpTrait::TFLRuntimeOpTrait") {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto *val = trait.getDef().getValue("tflRuntimePredicate");
|
|
||||||
if (!val) continue;
|
|
||||||
|
|
||||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
|
||||||
os << tgfmt(
|
|
||||||
" if (!($0)) {\n "
|
|
||||||
" return ::mlir::LogicalResult::Failure;\n }\n",
|
|
||||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
|
|
||||||
}
|
|
||||||
os << " return top.verify();\n}\n";
|
os << " return top.verify();\n}\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
||||||
let methods = [
|
let methods = [
|
||||||
StaticInterfaceMethod<
|
StaticInterfaceMethod<
|
||||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||||
"LogicalResult", "VerifyTflRuntimeConstraints",
|
"LogicalResult", "VerifyTflRuntimeTypes",
|
||||||
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||||
>,
|
>,
|
||||||
];
|
];
|
||||||
|
|
|
@ -46,30 +46,6 @@ namespace mlir {
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
||||||
namespace TFL {
|
namespace TFL {
|
||||||
|
|
||||||
// Returns true when the given two types have the same shape or broadcastable
|
|
||||||
// shape within the given rank. If any given shapes are non-static, this method
|
|
||||||
// returns true.
|
|
||||||
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
|
|
||||||
int max_bcast_rank) {
|
|
||||||
// Ignore shape checking on the non-static shapes for model compatibility.
|
|
||||||
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
|
|
||||||
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
|
|
||||||
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
|
|
||||||
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
|
|
||||||
|
|
||||||
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> result_shape;
|
|
||||||
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
|
|
||||||
rhs_shaped_type.getShape(),
|
|
||||||
result_shape)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return lhs_shaped_type.getRank() <= max_bcast_rank &&
|
|
||||||
rhs_shaped_type.getRank() <= max_bcast_rank;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TensorFlowLiteDialect
|
// TensorFlowLiteDialect
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -106,22 +106,6 @@ class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
|
||||||
class DerivedTFLiteTypeAttr<code body> :
|
class DerivedTFLiteTypeAttr<code body> :
|
||||||
DerivedAttr<"tflite::TensorType", body>;
|
DerivedAttr<"tflite::TensorType", body>;
|
||||||
|
|
||||||
// TFL Runtime op trait predicate.
|
|
||||||
class TFL_RuntimePredOpTrait<string desc, Pred pred> :
|
|
||||||
GenInternalOpTrait<"TFLRuntimeOpTrait"> {
|
|
||||||
Pred tflRuntimePredicate = pred;
|
|
||||||
string tflRuntimeDescription = desc;
|
|
||||||
}
|
|
||||||
|
|
||||||
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
|
|
||||||
int i, int j, int max_bcast_rank> :
|
|
||||||
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
|
|
||||||
" have the same shape or broadcastable shapes within the rank " #
|
|
||||||
max_bcast_rank,
|
|
||||||
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
|
|
||||||
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
|
|
||||||
").getType(), " # max_bcast_rank # ")">>;
|
|
||||||
|
|
||||||
// These additional types/type constraints here are used to decouple the ops
|
// These additional types/type constraints here are used to decouple the ops
|
||||||
// from runtime support for the ops. Prefer to use these types when defining
|
// from runtime support for the ops. Prefer to use these types when defining
|
||||||
// new TF_Ops for uniformity.
|
// new TF_Ops for uniformity.
|
||||||
|
@ -376,9 +360,10 @@ an output element, this operation computes \\(y = |x|\\).
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_AddOp : TFL_Op<"add", [
|
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
|
||||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
NoSideEffect,
|
||||||
ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> {
|
Commutative,
|
||||||
|
TFL_GpuTargetOp]> {
|
||||||
let summary = "Addition operator";
|
let summary = "Addition operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -386,11 +371,11 @@ def TFL_AddOp : TFL_Op<"add", [
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
|
ins AnyTensor:$lhs,
|
||||||
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
|
AnyTensor:$rhs,
|
||||||
TFL_AFAttr:$fused_activation_function);
|
TFL_AFAttr:$fused_activation_function);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
|
let results = (outs AnyTensor:$output);
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
|
||||||
|
|
|
@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||||
if (pass_config.legalize_tf_while) {
|
if (pass_config.legalize_tf_while) {
|
||||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||||
}
|
}
|
||||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||||
|
|
||||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||||
|
|
|
@ -9,20 +9,6 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: testAddHighDimsHaveSameShape
|
|
||||||
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
|
|
||||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
|
|
||||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testAddTooHighBroadcastableDims
|
|
||||||
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
|
||||||
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
|
|
||||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
|
||||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||||
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||||
return %2: tensor<1xf32>
|
return %2: tensor<1xf32>
|
||||||
|
|
|
@ -173,8 +173,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||||
pass_manager->addPass(
|
pass_manager->addPass(
|
||||||
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
||||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
pass_manager->addPass(
|
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
|
||||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||||
// This pass operates on TensorFlow ops but is triggered after legalization
|
// This pass operates on TensorFlow ops but is triggered after legalization
|
||||||
// so that it can target constants introduced once TensorFlow Identity ops
|
// so that it can target constants introduced once TensorFlow Identity ops
|
||||||
|
@ -256,8 +255,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||||
// TFLite dialect passes.
|
// TFLite dialect passes.
|
||||||
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
|
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
|
||||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(
|
pm.addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||||
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
|
|
||||||
pm.addPass(mlir::TFL::CreateOptimizePass());
|
pm.addPass(mlir::TFL::CreateOptimizePass());
|
||||||
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||||
|
|
||||||
|
@ -270,7 +268,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||||
|
|
||||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||||
|
|
||||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Registers a pass pipeline for the standard TFL passes.
|
// Registers a pass pipeline for the standard TFL passes.
|
||||||
|
|
|
@ -214,7 +214,7 @@ int main(int argc, char **argv) {
|
||||||
if (pass_config.legalize_tf_while) {
|
if (pass_config.legalize_tf_while) {
|
||||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||||
}
|
}
|
||||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||||
|
|
||||||
std::string result;
|
std::string result;
|
||||||
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
|
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
|
||||||
|
|
|
@ -70,21 +70,8 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
|
||||||
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
||||||
|
|
||||||
// Legalize operations in functions.
|
// Legalize operations in functions.
|
||||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
struct LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||||
public:
|
|
||||||
LegalizeTF() = default;
|
|
||||||
LegalizeTF(const LegalizeTF&) {}
|
|
||||||
explicit LegalizeTF(bool run_tfl_runtime_verification) {
|
|
||||||
run_tfl_runtime_verification_ = run_tfl_runtime_verification;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Performs the lowering to TFLite dialect.
|
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
|
||||||
private:
|
|
||||||
Option<bool> run_tfl_runtime_verification_{
|
|
||||||
*this, "run-tfl-runtime-verification",
|
|
||||||
llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns true if all tensor value in `values` has static shape and same shape.
|
// Returns true if all tensor value in `values` has static shape and same shape.
|
||||||
|
@ -754,19 +741,13 @@ void LegalizeTF::runOnFunction() {
|
||||||
// graph.
|
// graph.
|
||||||
target.addLegalOp<mlir::ConstantOp>();
|
target.addLegalOp<mlir::ConstantOp>();
|
||||||
target.addLegalOp<ConstOp>();
|
target.addLegalOp<ConstOp>();
|
||||||
if (run_tfl_runtime_verification_) {
|
target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
|
||||||
target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
|
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
|
||||||
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
|
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
||||||
[](Operation* op) {
|
if (!tfl_op) return false;
|
||||||
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
return succeeded(tfl_op.VerifyTflRuntimeTypes(
|
||||||
if (!tfl_op) return false;
|
tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
|
||||||
return succeeded(tfl_op.VerifyTflRuntimeConstraints(
|
}));
|
||||||
tfl_op.getOperation(),
|
|
||||||
/*failure_on_operand_type_mismatch=*/false));
|
|
||||||
}));
|
|
||||||
} else {
|
|
||||||
target.addLegalDialect<TensorFlowLiteDialect>();
|
|
||||||
}
|
|
||||||
// Keep trying to convert.
|
// Keep trying to convert.
|
||||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||||
// Look if there is a function that tries until it converge.
|
// Look if there is a function that tries until it converge.
|
||||||
|
@ -782,9 +763,8 @@ void LegalizeTF::runOnFunction() {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
|
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass() {
|
||||||
bool run_tfl_runtime_verification) {
|
return std::make_unique<LegalizeTF>();
|
||||||
return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<LegalizeTF> pass(
|
static PassRegistration<LegalizeTF> pass(
|
||||||
|
|
|
@ -30,11 +30,7 @@ namespace TFL {
|
||||||
class QuantizationSpecs;
|
class QuantizationSpecs;
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||||
// When the given run_tfl_runtime_verification value is true, it will check each
|
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass();
|
||||||
// TFL builtin op towards the TFL runtime capability and the incompatible TF ops
|
|
||||||
// will be left in the graph without getting legalized.
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
|
|
||||||
bool run_tfl_runtime_verification);
|
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
|
||||||
|
@ -95,8 +91,8 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass();
|
||||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
||||||
|
|
||||||
// Verifies runtime constraints.
|
// Verifies runtime supports types used.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass();
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
|
||||||
|
|
|
@ -22,32 +22,34 @@ namespace mlir {
|
||||||
namespace TFL {
|
namespace TFL {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// This pass verifies that the TFL ops meet the TFL runtime constraints.
|
// This pass verifies that the operands and results types are supported by
|
||||||
class RuntimeVerifyPass
|
// TFLite runtime.
|
||||||
: public mlir::PassWrapper<RuntimeVerifyPass, FunctionPass> {
|
class RuntimeTypeVerifyPass
|
||||||
|
: public mlir::PassWrapper<RuntimeTypeVerifyPass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
explicit RuntimeVerifyPass() {}
|
explicit RuntimeTypeVerifyPass() {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
void RuntimeVerifyPass::runOnFunction() {
|
void RuntimeTypeVerifyPass::runOnFunction() {
|
||||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||||
if (failed(op.VerifyTflRuntimeConstraints(
|
if (failed(op.VerifyTflRuntimeTypes(
|
||||||
op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
|
op.getOperation(),
|
||||||
|
/*failure_on_operand_type_mismatch=*/true)))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Verifies TFL runtime constraints.
|
// Verifies runtime supports types used.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass() {
|
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass() {
|
||||||
return std::make_unique<RuntimeVerifyPass>();
|
return std::make_unique<RuntimeTypeVerifyPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<RuntimeVerifyPass> pass("tfl-runtime-verify",
|
static PassRegistration<RuntimeTypeVerifyPass> pass(
|
||||||
"TFLite runtime verification");
|
"tfl-runtime-verify", "TFLite runtime verification");
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
} // namespace mlir
|
} // namespace mlir
|
|
@ -48,17 +48,10 @@ def make_hardswish_tests(options):
|
||||||
"""Make a set of tests to do hardswish."""
|
"""Make a set of tests to do hardswish."""
|
||||||
|
|
||||||
# Chose a set of parameters
|
# Chose a set of parameters
|
||||||
if options.run_with_flex:
|
test_parameters = [{
|
||||||
# Only Flex is able to execute on the data bigger than four dimension.
|
"input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
|
||||||
test_parameters = [{
|
[3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
|
||||||
"input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
|
}]
|
||||||
[3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
|
|
||||||
}]
|
|
||||||
else:
|
|
||||||
test_parameters = [{
|
|
||||||
"input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
|
|
||||||
[3, 15, 14, 3]],
|
|
||||||
}]
|
|
||||||
|
|
||||||
def build_graph(parameters):
|
def build_graph(parameters):
|
||||||
inp = tf.compat.v1.placeholder(
|
inp = tf.compat.v1.placeholder(
|
||||||
|
|
Loading…
Reference in New Issue