Add switch fold pass to tf_tfl_translate passes
PiperOrigin-RevId: 267004531
This commit is contained in:
parent
30485bb7b9
commit
1479da3cb2
@ -521,6 +521,7 @@ cc_library(
|
||||
":tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
|
@ -41,6 +41,7 @@ bool ShouldRunQuantizePasses(mlir::ModuleOp m) {
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::PassManager* pass_manager) {
|
||||
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
|
||||
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
|
||||
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||
// Ophint extraction will happen after island extraction pass.
|
||||
|
@ -58,7 +58,7 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
class SwitchFold : public mlir::FunctionPass<SwitchFold> {
|
||||
class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
|
||||
public:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
@ -266,7 +266,7 @@ bool HasSendOrReceive(FuncOp function) {
|
||||
.wasInterrupted();
|
||||
}
|
||||
|
||||
void SwitchFold::runOnFunction() {
|
||||
void SwitchFoldPass::runOnFunction() {
|
||||
if (HasSendOrReceive(getFunction())) return;
|
||||
DeadQueue queue;
|
||||
// Initialize dead queue with dead outputs of foldable SwitchOps.
|
||||
@ -277,7 +277,13 @@ void SwitchFold::runOnFunction() {
|
||||
if (failed(FoldMergeNodes(getFunction(), queue))) return signalPassFailure();
|
||||
} // namespace mlir
|
||||
|
||||
static PassRegistration<SwitchFold> pass(
|
||||
namespace tf_executor {
|
||||
std::unique_ptr<FunctionPassBase> CreateSwitchFoldPass() {
|
||||
return std::make_unique<SwitchFoldPass>();
|
||||
}
|
||||
} // namespace tf_executor
|
||||
|
||||
static PassRegistration<SwitchFoldPass> pass(
|
||||
"tf-switch-fold", "Fold switch nodes with constant predicates");
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -41,6 +41,9 @@ std::unique_ptr<FunctionPassBase> CreateRaiseTFControlFlowPass();
|
||||
namespace tf_executor {
|
||||
class GraphOp;
|
||||
|
||||
// Returns a pass that folds switch nodes with constant predicates.
|
||||
std::unique_ptr<FunctionPassBase> CreateSwitchFoldPass();
|
||||
|
||||
// Create a pass to merge IslandOps from TFExecutor dialect.
|
||||
std::unique_ptr<FunctionPassBase> CreateTFExecutorIslandCoarseningPass();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user