Run shape inference before any new non-tf ops are introduced
PiperOrigin-RevId: 310263391 Change-Id: Iff15999065f5ffa1c331646ae53c2f093c05b994
This commit is contained in:
parent
154044d0f2
commit
fac30b7a87
|
@ -73,16 +73,17 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||||
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (pass_config.shape_inference) {
|
||||||
|
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||||
|
}
|
||||||
|
// Keep this pass after the shape inference pass, which couldn't do shape
|
||||||
|
// inference for non-tf ops.
|
||||||
if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
|
if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
|
||||||
pass_manager->addPass(
|
pass_manager->addPass(
|
||||||
mlir::quant::CreateImportQuantStatsPassForTFControlDialect(
|
mlir::quant::CreateImportQuantStatsPassForTFControlDialect(
|
||||||
pass_config.quant_specs.serialized_quant_stats));
|
pass_config.quant_specs.serialized_quant_stats));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pass_config.shape_inference) {
|
|
||||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
|
||||||
}
|
|
||||||
|
|
||||||
// The conversion pipeline has to follow the following orders:
|
// The conversion pipeline has to follow the following orders:
|
||||||
// 1) Saved model related optimization like decompose resource ops
|
// 1) Saved model related optimization like decompose resource ops
|
||||||
// 2) Convert composite functions like lstm/rnns, along with proper function
|
// 2) Convert composite functions like lstm/rnns, along with proper function
|
||||||
|
|
Loading…
Reference in New Issue