Add type and shape constraints to TFLite Add op

PiperOrigin-RevId: 305591811
Change-Id: I12a9b832fc9ee385c77dc1a19f24aa9debcca641
This commit is contained in:
A. Unique TensorFlower 2020-04-08 17:53:47 -07:00 committed by TensorFlower Gardener
parent c916d840ed
commit 72600af909
14 changed files with 47 additions and 153 deletions

View File

@ -307,7 +307,7 @@ cc_library(
"transforms/optimize_functional_ops.cc", "transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc", "transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc", "transforms/prepare_tf.cc",
"transforms/runtime_verify.cc", "transforms/runtime_type_verify.cc",
"transforms/split_merged_operands.cc", "transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc", "transforms/trim_functions_tf.cc",
"transforms/while_loop_outline.cc", "transforms/while_loop_outline.cc",

View File

@ -36,8 +36,7 @@ struct PassConfig {
form_clusters(false), form_clusters(false),
unfold_batch_matmul(true), unfold_batch_matmul(true),
legalize_tf_while(true), legalize_tf_while(true),
shape_inference(true), shape_inference(true) {}
runtime_verification(true) {}
// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be
// added, which produces TF Lite ops. // added, which produces TF Lite ops.
@ -66,8 +65,6 @@ struct PassConfig {
bool legalize_tf_while; bool legalize_tf_while;
// Whether to do shape inference. // Whether to do shape inference.
bool shape_inference; bool shape_inference;
// Whether to do TFLite runtime verification.
bool runtime_verification;
}; };
} // namespace TFL } // namespace TFL

View File

@ -441,7 +441,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
mlir::tblgen::FmtContext verify_ctx; mlir::tblgen::FmtContext verify_ctx;
os << "::mlir::LogicalResult " << op.getCppClassName() os << "::mlir::LogicalResult " << op.getCppClassName()
<< "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool " << "::VerifyTflRuntimeTypes(::mlir::Operation *op, bool "
"failure_on_operand_type_mismatch) {\n"; "failure_on_operand_type_mismatch) {\n";
os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n";
verify_ctx.withOp("top"); verify_ctx.withOp("top");
@ -466,25 +466,6 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
"operand"); "operand");
GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(), GenOperandResultVerifier(os, def->getValueAsDag("results")->getArgs(),
"result"); "result");
for (auto &trait : op.getTraits()) {
if (!trait.getDef().isSubClassOf("GenInternalOpTrait")) {
continue;
}
if (trait.getDef().getValueAsString("trait") !=
"OpTrait::TFLRuntimeOpTrait") {
continue;
}
auto *val = trait.getDef().getValue("tflRuntimePredicate");
if (!val) continue;
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
os << tgfmt(
" if (!($0)) {\n "
" return ::mlir::LogicalResult::Failure;\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
}
os << " return top.verify();\n}\n"; os << " return top.verify();\n}\n";
} }

View File

@ -86,7 +86,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> {
let methods = [ let methods = [
StaticInterfaceMethod< StaticInterfaceMethod<
[{Returns whether the op's operands/results are supported by runtime.}], [{Returns whether the op's operands/results are supported by runtime.}],
"LogicalResult", "VerifyTflRuntimeConstraints", "LogicalResult", "VerifyTflRuntimeTypes",
(ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch) (ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch)
>, >,
]; ];

View File

@ -46,30 +46,6 @@ namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc" #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL { namespace TFL {
// Returns true when the given two types have the same shape or broadcastable
// shape within the given rank. If any given shapes are non-static, this method
// returns true.
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
int max_bcast_rank) {
// Ignore shape checking on the non-static shapes for model compatibility.
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
return true;
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
rhs_shaped_type.getShape(),
result_shape)) {
return false;
}
return lhs_shaped_type.getRank() <= max_bcast_rank &&
rhs_shaped_type.getRank() <= max_bcast_rank;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TensorFlowLiteDialect // TensorFlowLiteDialect
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -106,22 +106,6 @@ class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
class DerivedTFLiteTypeAttr<code body> : class DerivedTFLiteTypeAttr<code body> :
DerivedAttr<"tflite::TensorType", body>; DerivedAttr<"tflite::TensorType", body>;
// TFL Runtime op trait predicate.
class TFL_RuntimePredOpTrait<string desc, Pred pred> :
GenInternalOpTrait<"TFLRuntimeOpTrait"> {
Pred tflRuntimePredicate = pred;
string tflRuntimeDescription = desc;
}
class TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<
int i, int j, int max_bcast_rank> :
TFL_RuntimePredOpTrait<"operand #" # i # " and operand #" # j #
" have the same shape or broadcastable shapes within the rank " #
max_bcast_rank,
CPred<"TFL::IsBinaryOperandsHaveSameShapesOrBroadcastableShape("
"$_op.getOperand(" # i # ").getType(), $_op.getOperand(" # j #
").getType(), " # max_bcast_rank # ")">>;
// These additional types/type constraints here are used to decouple the ops // These additional types/type constraints here are used to decouple the ops
// from runtime support for the ops. Prefer to use these types when defining // from runtime support for the ops. Prefer to use these types when defining
// new TF_Ops for uniformity. // new TF_Ops for uniformity.
@ -376,9 +360,10 @@ an output element, this operation computes \\(y = |x|\\).
let hasFolder = 1; let hasFolder = 1;
} }
def TFL_AddOp : TFL_Op<"add", [ def TFL_AddOp : TFL_Op<"add", [ResultsBroadcastableShape,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, NoSideEffect,
ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> { Commutative,
TFL_GpuTargetOp]> {
let summary = "Addition operator"; let summary = "Addition operator";
let description = [{ let description = [{
@ -386,11 +371,11 @@ def TFL_AddOp : TFL_Op<"add", [
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs, ins AnyTensor:$lhs,
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs, AnyTensor:$rhs,
TFL_AFAttr:$fused_activation_function); TFL_AFAttr:$fused_activation_function);
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output); let results = (outs AnyTensor:$output);
let hasFolder = 1; let hasFolder = 1;

View File

@ -285,7 +285,7 @@ Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
if (pass_config.legalize_tf_while) { if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass()); pm.addPass(mlir::TFL::CreateWhileOutlinePass());
} }
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
auto status = ConvertTFExecutorToTFLOrFlatbuffer( auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,

View File

@ -9,20 +9,6 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: return // CHECK: return
} }
// CHECK-LABEL: testAddHighDimsHaveSameShape
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6x7x8xi32>, tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32>
return %0 : tensor<1x2x3x4x5x6x7x8xi32>
}
// CHECK-LABEL: testAddTooHighBroadcastableDims
func @testAddTooHighBroadcastableDims(%arg0: tensor<1x2x3x4x5x6xi32>, %arg1: tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32> {
// expected-error @+1 {{'tfl.add' op failed to verify that operand #0 and operand #1 have the same shape or broadcastable shapes within the rank 4}}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xi32>, tensor<1x2x3x4x5x1xi32>) -> tensor<1x2x3x4x5x6xi32>
return %0 : tensor<1x2x3x4x5x6xi32>
}
func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> { func @LeakyRelu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
%2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32> %2 = "tf.LeakyRelu"(%arg0) {alpha = 0.1 : f32} : (tensor<1xf32>) -> tensor<1xf32>
return %2: tensor<1xf32> return %2: tensor<1xf32>

View File

@ -173,8 +173,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass( pass_manager->addPass(
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul)); mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addPass( pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass());
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
pass_manager->addPass(mlir::TFL::CreateOptimizePass()); pass_manager->addPass(mlir::TFL::CreateOptimizePass());
// This pass operates on TensorFlow ops but is triggered after legalization // This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops // so that it can target constants introduced once TensorFlow Identity ops
@ -256,8 +255,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
// TFLite dialect passes. // TFLite dialect passes.
pm.addPass(mlir::TFL::CreatePrepareTFPass(true)); pm.addPass(mlir::TFL::CreatePrepareTFPass(true));
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm.addPass( pm.addPass(mlir::TFL::CreateLegalizeTFPass());
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
pm.addPass(mlir::TFL::CreateOptimizePass()); pm.addPass(mlir::TFL::CreateOptimizePass());
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
@ -270,7 +268,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
pm.addPass(mlir::TFL::CreateWhileOutlinePass()); pm.addPass(mlir::TFL::CreateWhileOutlinePass());
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
} }
// Registers a pass pipeline for the standard TFL passes. // Registers a pass pipeline for the standard TFL passes.

View File

@ -214,7 +214,7 @@ int main(int argc, char **argv) {
if (pass_config.legalize_tf_while) { if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass()); pm.addPass(mlir::TFL::CreateWhileOutlinePass());
} }
pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); pm.addPass(mlir::TFL::CreateRuntimeTypeVerifyPass());
std::string result; std::string result;
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(

View File

@ -70,21 +70,8 @@ constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices"; constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
// Legalize operations in functions. // Legalize operations in functions.
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> { struct LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;
LegalizeTF(const LegalizeTF&) {}
explicit LegalizeTF(bool run_tfl_runtime_verification) {
run_tfl_runtime_verification_ = run_tfl_runtime_verification;
}
/// Performs the lowering to TFLite dialect.
void runOnFunction() override; void runOnFunction() override;
private:
Option<bool> run_tfl_runtime_verification_{
*this, "run-tfl-runtime-verification",
llvm::cl::desc("Allow tfl runtime verification."), llvm::cl::init(true)};
}; };
// Returns true if all tensor value in `values` has static shape and same shape. // Returns true if all tensor value in `values` has static shape and same shape.
@ -754,19 +741,13 @@ void LegalizeTF::runOnFunction() {
// graph. // graph.
target.addLegalOp<mlir::ConstantOp>(); target.addLegalOp<mlir::ConstantOp>();
target.addLegalOp<ConstOp>(); target.addLegalOp<ConstOp>();
if (run_tfl_runtime_verification_) { target.addDynamicallyLegalDialect<TensorFlowLiteDialect>(
target.addDynamicallyLegalDialect<TensorFlowLiteDialect>( Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation* op) {
Optional<ConversionTarget::DynamicLegalityCallbackFn>( auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op);
[](Operation* op) { if (!tfl_op) return false;
auto tfl_op = dyn_cast_or_null<TflRuntimeVerifyOpInterface>(op); return succeeded(tfl_op.VerifyTflRuntimeTypes(
if (!tfl_op) return false; tfl_op.getOperation(), /*failure_on_operand_type_mismatch=*/false));
return succeeded(tfl_op.VerifyTflRuntimeConstraints( }));
tfl_op.getOperation(),
/*failure_on_operand_type_mismatch=*/false));
}));
} else {
target.addLegalDialect<TensorFlowLiteDialect>();
}
// Keep trying to convert. // Keep trying to convert.
// TODO(karimnosseir): This is similar to what apply greedy patterns does. // TODO(karimnosseir): This is similar to what apply greedy patterns does.
// Look if there is a function that tries until it converge. // Look if there is a function that tries until it converge.
@ -782,9 +763,8 @@ void LegalizeTF::runOnFunction() {
} // namespace } // namespace
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass( std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass() {
bool run_tfl_runtime_verification) { return std::make_unique<LegalizeTF>();
return std::make_unique<LegalizeTF>(run_tfl_runtime_verification);
} }
static PassRegistration<LegalizeTF> pass( static PassRegistration<LegalizeTF> pass(

View File

@ -30,11 +30,7 @@ namespace TFL {
class QuantizationSpecs; class QuantizationSpecs;
// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass.
// When the given run_tfl_runtime_verification value is true, it will check each std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass();
// TFL builtin op towards the TFL runtime capability and the incompatible TF ops
// will be left in the graph without getting legalized.
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeTFPass(
bool run_tfl_runtime_verification);
// Creates an instance of the TensorFlow Lite dialect Optimize pass. // Creates an instance of the TensorFlow Lite dialect Optimize pass.
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass(); std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
@ -95,8 +91,8 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFWhilePass();
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. // Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass(); std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime constraints. // Verifies runtime supports types used.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass(); std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass();
} // namespace TFL } // namespace TFL

View File

@ -22,32 +22,34 @@ namespace mlir {
namespace TFL { namespace TFL {
namespace { namespace {
// This pass verifies that the TFL ops meet the TFL runtime constraints. // This pass verifies that the operands and results types are supported by
class RuntimeVerifyPass // TFLite runtime.
: public mlir::PassWrapper<RuntimeVerifyPass, FunctionPass> { class RuntimeTypeVerifyPass
: public mlir::PassWrapper<RuntimeTypeVerifyPass, FunctionPass> {
public: public:
explicit RuntimeVerifyPass() {} explicit RuntimeTypeVerifyPass() {}
private: private:
void runOnFunction() override; void runOnFunction() override;
}; };
void RuntimeVerifyPass::runOnFunction() { void RuntimeTypeVerifyPass::runOnFunction() {
getFunction().walk([&](TflRuntimeVerifyOpInterface op) { getFunction().walk([&](TflRuntimeVerifyOpInterface op) {
if (failed(op.VerifyTflRuntimeConstraints( if (failed(op.VerifyTflRuntimeTypes(
op.getOperation(), /*failure_on_operand_type_mismatch=*/true))) op.getOperation(),
/*failure_on_operand_type_mismatch=*/true)))
signalPassFailure(); signalPassFailure();
}); });
} }
} // namespace } // namespace
// Verifies TFL runtime constraints. // Verifies runtime supports types used.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass() { std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeTypeVerifyPass() {
return std::make_unique<RuntimeVerifyPass>(); return std::make_unique<RuntimeTypeVerifyPass>();
} }
static PassRegistration<RuntimeVerifyPass> pass("tfl-runtime-verify", static PassRegistration<RuntimeTypeVerifyPass> pass(
"TFLite runtime verification"); "tfl-runtime-verify", "TFLite runtime verification");
} // namespace TFL } // namespace TFL
} // namespace mlir } // namespace mlir

View File

@ -48,17 +48,10 @@ def make_hardswish_tests(options):
"""Make a set of tests to do hardswish.""" """Make a set of tests to do hardswish."""
# Chose a set of parameters # Chose a set of parameters
if options.run_with_flex: test_parameters = [{
# Only Flex is able to execute on the data bigger than four dimension. "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
test_parameters = [{ [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
"input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], }]
[3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
}]
else:
test_parameters = [{
"input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
[3, 15, 14, 3]],
}]
def build_graph(parameters): def build_graph(parameters):
inp = tf.compat.v1.placeholder( inp = tf.compat.v1.placeholder(