Fix ophint-extraction passes order & also add an e2e test to prevent this kind of breakage.

PiperOrigin-RevId: 270754087
This commit is contained in:
Renjie Liu 2019-09-23 13:58:43 -07:00 committed by TensorFlower Gardener
parent 84ec37fb58
commit ad1644033a
2 changed files with 7838 additions and 4 deletions

File diff suppressed because it is too large Load Diff

View File

@ -54,6 +54,21 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
// The ophint extractions happen before lots of other passes:
// The assumption of ophint-extraction is each ophinted region is a black-box
// and nodes within this black-box is NOT connected to the nodes OUTSIDE the
// black-box.
// Some passes may merge nodes together (such as const nodes), however, this
// will break the ophint-extraction assumption. (The nodes within the black
// box is not isolated anymore).
// So ophint extraction and legalization needs to happen before
// the canonicalization pass.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreateExtractOphintPass());
// Convert composite op pass will happen after ophint extraction pass.
pass_manager->addPass(mlir::TFL::CreateLegalizeOphintFuncOpPass());
}
// TODO(jpienaar): Revise post dialect constants.
pass_manager->addPass(mlir::TF::CreateDecodeConstantPass());
// Canonicalization includes const folding, which is utilized here to optimize
@ -64,10 +79,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// The below passes only make sense if Builtin TFLite ops are enabled
// for emission.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreateExtractOphintPass());
// Convert composite op pass will happen after ophint extraction pass.
pass_manager->addPass(mlir::TFL::CreateLegalizeOphintFuncOpPass());
// Prepare for TFLite dialect, rerun canonicalization, and then legalize to
// the TFLite dialect.
pass_manager->addPass(mlir::TFL::CreatePrepareTFPass());