Add type and shape constraints to TFLite Add op
PiperOrigin-RevId: 305591811 Change-Id: I12a9b832fc9ee385c77dc1a19f24aa9debcca641
This commit is contained in:
parent
c916d840ed
commit
72600af909
|
@ -307,7 +307,7 @@ cc_library(
|
|||
"transforms/optimize_functional_ops.cc",
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/runtime_verify.cc",
|
||||
"transforms/runtime_type_verify.cc",
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/while_loop_outline.cc",
|
||||
|
|
|
@ -36,8 +36,7 @@ struct PassConfig {
|
|||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
legalize_tf_while(true),
|
||||
shape_inference(true),
|
||||
runtime_verification(true) {}
|
||||
shape_inference(true) {}
|
||||
|
||||
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
|
||||
// added, which produces TF Lite ops.
|
||||
|
@ -66,8 +65,6 @@ struct PassConfig {
|
|||
bool legalize_tf_while;
|
||||
// Whether to do shape inference.
|
||||
bool shape_inference;
|
||||
// Whether to do TFLite runtime verification.
|
||||
bool runtime_verification;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
|
|
|
@ -441,7 +441,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
|||
|
||||
mlir::tblgen::FmtContext verify_ctx;
|
||||
os << "::mlir::LogicalResult " << op.getCppClassName()
|
||||
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool "
|
||||
<< "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
|
||||
"failure_on_operand_type_mismatch) {\n";
|
||||
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
|
||||
verify_ctx.withOp("top");
|
||||
|
@ -466,25 +466,6 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
|
|||
"operand");
|
||||
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
|
||||
"result");
|
||||
|
||||
for (auto &trait : op.getTraits()) {
|
||||
if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
|
||||
continue;
|
||||
}
|
||||
if (trait.getDef().getValueAsString("trait") !=
|
||||
"OpTrait::TFLRuntimeOpTrait") {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *val = trait.getDef().getValue("tflRuntimePredicate");
|
||||
if (!val) continue;
|
||||
|
||||
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
|
||||
os << tgfmt(
|
||||
" if (!($0)) {\n "
|
||||
" return ::mlir::LogicalResult::Failure;\n }\n",
|
||||
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
|
||||
}
|
||||
os << " return top.verify();\n}\n";
|
||||
}
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
|
|||
let methods = [
|
||||
StaticInterfaceMethod<
|
||||
[{Returns whether the op's operands/results are supported by runtime.}],
|
||||
"LogicalResult", "VerifyTflRuntimeConstraints",
|
||||
"LogicalResult", "VerifyTflRuntimeTypes",
|
||||
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
|
||||
>,
|
||||
];
|
||||
|
|
|
@ -46,30 +46,6 @@ namespace mlir {
|
|||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
||||
namespace TFL {
|
||||
|
||||
// Returns true when the given two types have the same shape or broadcastable
|
||||
// shape within the given rank. If any given shapes are non-static, this method
|
||||
// returns true.
|
||||
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
|
||||
int max_bcast_rank) {
|
||||
// Ignore shape checking on the non-static shapes for model compatibility.
|
||||
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
|
||||
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
|
||||
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
|
||||
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
|
||||
|
||||
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
|
||||
return true;
|
||||
|
||||
SmallVector<int64_t, 4> result_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
|
||||
rhs_shaped_type.getShape(),
|
||||
result_shape)) {
|
||||
return false;
|
||||
}
|
||||
return lhs_shaped_type.getRank() <= max_bcast_rank &&
|
||||
rhs_shaped_type.getRank() <= max_bcast_rank;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlowLiteDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -106,22 +106,6 @@ class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
|
|||
class DerivedTFLiteTypeAttr<code body> :
|
||||
DerivedAttr<"tflite::TensorType", body>;
|
||||
|
||||
// TFL Runtime op trait predicate.
|
||||
class TFL_RuntimePredOpTrait<string desc, Pred pred> :
|
||||
GenInternalOpTrait<"TFLRuntimeOpTrait"> {
|
||||
Pred tflRuntimePredicate = pred;
|
||||
string tflRuntimeDescription = desc;
|
||||
}
|
||||
|
||||
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
|
||||
int i, int j, int max_bcast_rank> :
|
||||
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
|
||||
" have the same shape or broadcastable shapes within the rank " #
|
||||
max_bcast_rank,
|
||||
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
|
||||
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
|
||||
").getType(), " # max_bcast_rank # ")">>;
|
||||
|
||||
// These additional types/type constraints here are used to decouple the ops
|
||||
// from runtime support for the ops. Prefer to use these types when defining
|
||||
// new TF_Ops for uniformity.
|
||||
|
@ -376,9 +360,10 @@ an output element, this operation computes \\(y = |x|\\).
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_AddOp : TFL_Op<"add", [
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>,
|
||||
ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> {
|
||||
def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
|
||||
NoSideEffect,
|
||||
Commutative,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Addition operator";
|
||||
|
||||
let description = [{
|
||||
|
@ -386,11 +371,11 @@ def TFL_AddOp : TFL_Op<"add", [
|
|||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
|
||||
ins AnyTensor:$lhs,
|
||||
AnyTensor:$rhs,
|
||||
TFL_AFAttr:$fused_activation_function);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
|
|
|
@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
|||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
|
|
|
@ -9,20 +9,6 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
|||
// CHECK: return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddHighDimsHaveSameShape
|
||||
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
|
||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddTooHighBroadcastableDims
|
||||
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
|
||||
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
|
||||
return %0 : tensor<1x2x3x4x5x6xi32>
|
||||
}
|
||||
|
||||
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
|
||||
return %2: tensor<1xf32>
|
||||
|
|
|
@ -173,8 +173,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
|||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||
pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
// This pass operates on TensorFlow ops but is triggered after legalization
|
||||
// so that it can target constants introduced once TensorFlow Identity ops
|
||||
|
@ -256,8 +255,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
|||
// TFLite dialect passes.
|
||||
pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm.addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
|
||||
pm.addPass(mlir::TFL::CreateLegalizeTFPass());
|
||||
pm.addPass(mlir::TFL::CreateOptimizePass());
|
||||
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||
|
||||
|
@ -270,7 +268,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
|||
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
|
||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
}
|
||||
|
||||
// Registers a pass pipeline for the standard TFL passes.
|
||||
|
|
|
@ -214,7 +214,7 @@ int main(int argc, char **argv) {
|
|||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
}
|
||||
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass());
|
||||
pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
|
||||
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
|
|
|
@ -70,21 +70,8 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
|
|||
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
||||
|
||||
// Legalize operations in functions.
|
||||
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
public:
|
||||
LegalizeTF() = default;
|
||||
LegalizeTF(const LegalizeTF&) {}
|
||||
explicit LegalizeTF(bool run_tfl_runtime_verification) {
|
||||
run_tfl_runtime_verification_ = run_tfl_runtime_verification;
|
||||
}
|
||||
|
||||
/// Performs the lowering to TFLite dialect.
|
||||
struct LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
Option<bool> run_tfl_runtime_verification_{
|
||||
*this, "run-tfl-runtime-verification",
|
||||
llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
|
||||
};
|
||||
|
||||
// Returns true if all tensor value in `values` has static shape and same shape.
|
||||
|
@ -754,19 +741,13 @@ void LegalizeTF::runOnFunction() {
|
|||
// graph.
|
||||
target.addLegalOp<mlir::ConstantOp>();
|
||||
target.addLegalOp<ConstOp>();
|
||||
if (run_tfl_runtime_verification_) {
|
||||
target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
|
||||
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
|
||||
[](Operation* op) {
|
||||
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
|
||||
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
|
||||
if (!tfl_op) return false;
|
||||
return succeeded(tfl_op.VerifyTflRuntimeConstraints(
|
||||
tfl_op.getOperation(),
|
||||
/*failure_on_operand_type_mismatch=*/false));
|
||||
return succeeded(tfl_op.VerifyTflRuntimeTypes(
|
||||
tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
|
||||
}));
|
||||
} else {
|
||||
target.addLegalDialect<TensorFlowLiteDialect>();
|
||||
}
|
||||
// Keep trying to convert.
|
||||
// TODO(karimnosseir): This is similar to what apply greedy patterns does.
|
||||
// Look if there is a function that tries until it converge.
|
||||
|
@ -782,9 +763,8 @@ void LegalizeTF::runOnFunction() {
|
|||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
|
||||
bool run_tfl_runtime_verification) {
|
||||
return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass() {
|
||||
return std::make_unique<LegalizeTF>();
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTF> pass(
|
||||
|
|
|
@ -30,11 +30,7 @@ namespace TFL {
|
|||
class QuantizationSpecs;
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
|
||||
// When the given run_tfl_runtime_verification value is true, it will check each
|
||||
// TFL builtin op towards the TFL runtime capability and the incompatible TF ops
|
||||
// will be left in the graph without getting legalized.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
|
||||
bool run_tfl_runtime_verification);
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
|
||||
|
@ -95,8 +91,8 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass();
|
|||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
||||
|
||||
// Verifies runtime constraints.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||
// Verifies runtime supports types used.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass();
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
|
|
|
@ -22,32 +22,34 @@ namespace mlir {
|
|||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
// This pass verifies that the TFL ops meet the TFL runtime constraints.
|
||||
class RuntimeVerifyPass
|
||||
: public mlir::PassWrapper<RuntimeVerifyPass, FunctionPass> {
|
||||
// This pass verifies that the operands and results types are supported by
|
||||
// TFLite runtime.
|
||||
class RuntimeTypeVerifyPass
|
||||
: public mlir::PassWrapper<RuntimeTypeVerifyPass, FunctionPass> {
|
||||
public:
|
||||
explicit RuntimeVerifyPass() {}
|
||||
explicit RuntimeTypeVerifyPass() {}
|
||||
|
||||
private:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void RuntimeVerifyPass::runOnFunction() {
|
||||
void RuntimeTypeVerifyPass::runOnFunction() {
|
||||
getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
|
||||
if (failed(op.VerifyTflRuntimeConstraints(
|
||||
op.getOperation(), /*failure_on_operand_type_mismatch=*/true)))
|
||||
if (failed(op.VerifyTflRuntimeTypes(
|
||||
op.getOperation(),
|
||||
/*failure_on_operand_type_mismatch=*/true)))
|
||||
signalPassFailure();
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Verifies TFL runtime constraints.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass() {
|
||||
return std::make_unique<RuntimeVerifyPass>();
|
||||
// Verifies runtime supports types used.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass() {
|
||||
return std::make_unique<RuntimeTypeVerifyPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<RuntimeVerifyPass> pass("tfl-runtime-verify",
|
||||
"TFLite runtime verification");
|
||||
static PassRegistration<RuntimeTypeVerifyPass> pass(
|
||||
"tfl-runtime-verify", "TFLite runtime verification");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
|
@ -48,17 +48,10 @@ def make_hardswish_tests(options):
|
|||
"""Make a set of tests to do hardswish."""
|
||||
|
||||
# Chose a set of parameters
|
||||
if options.run_with_flex:
|
||||
# Only Flex is able to execute on the data bigger than four dimension.
|
||||
test_parameters = [{
|
||||
"input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
|
||||
[3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
|
||||
}]
|
||||
else:
|
||||
test_parameters = [{
|
||||
"input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
|
||||
[3, 15, 14, 3]],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
inp = tf.compat.v1.placeholder(
|
||||
|
|
Loading…
Reference in New Issue