Add the decompose resource ops pass to TFLite pass pipeline.

The pattern that TFLite cares about is the decomposition of tf.ResourceGather into
tf.ReadVariableOp and tf.Gather.

It is safe to add this pass by default.

PiperOrigin-RevId: 301303572
Change-Id: Ic53d9f4f84243bc615a80f1602b53520507d91a8
This commit is contained in:
Ashwin Murthy 2020-03-16 21:59:07 -07:00 committed by TensorFlower Gardener
parent 092ae742c3
commit 666f21add8

View File

@ -84,21 +84,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
// This pass does resource analysis of saved model global tensors and marks
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
// This pass marks non-exported functions as symbol visibility 'private'
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
}
// The ophint extractions happen before lots of other passes:
// The assumption of ophint-extraction is each ophinted region is a black-box
// and nodes within this black-box is NOT connected to the nodes OUTSIDE the
@ -114,6 +99,27 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::TFL::CreateLegalizeOphintFuncOpPass());
}
// This decomposes resource ops like ResourceGather into read-variable op
// followed by gather. This is used when the saved model import path is used
// during which resources dont get frozen in the python layer.
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFDevice::CreateDecomposeResourceOpsPass());
// This pass does resource analysis of saved model global tensors and marks
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
// This pass marks non-exported functions as symbol visibility 'private'
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
}
// Legalize while early to allow further constant folding.
// TODO(jpienaar): This may not actually matter as we do canonicalization
// after the legalize below, for now it needs to be below the above passes