Add switch fold pass to tf_tfl_translate passes
PiperOrigin-RevId: 267004531
This commit is contained in:
parent
30485bb7b9
commit
1479da3cb2
@ -521,6 +521,7 @@ cc_library(
|
|||||||
":tensorflow_lite_quantize",
|
":tensorflow_lite_quantize",
|
||||||
"//tensorflow/compiler/mlir/tensorflow",
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||||
|
@ -41,6 +41,7 @@ bool ShouldRunQuantizePasses(mlir::ModuleOp m) {
|
|||||||
|
|
||||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||||
mlir::PassManager* pass_manager) {
|
mlir::PassManager* pass_manager) {
|
||||||
|
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
|
||||||
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
|
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
|
||||||
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
|
||||||
// Ophint extraction will happen after island extraction pass.
|
// Ophint extraction will happen after island extraction pass.
|
||||||
|
@ -58,7 +58,7 @@ limitations under the License.
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class SwitchFold : public mlir::FunctionPass<SwitchFold> {
|
class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
@ -266,7 +266,7 @@ bool HasSendOrReceive(FuncOp function) {
|
|||||||
.wasInterrupted();
|
.wasInterrupted();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SwitchFold::runOnFunction() {
|
void SwitchFoldPass::runOnFunction() {
|
||||||
if (HasSendOrReceive(getFunction())) return;
|
if (HasSendOrReceive(getFunction())) return;
|
||||||
DeadQueue queue;
|
DeadQueue queue;
|
||||||
// Initialize dead queue with dead outputs of foldable SwitchOps.
|
// Initialize dead queue with dead outputs of foldable SwitchOps.
|
||||||
@ -277,7 +277,13 @@ void SwitchFold::runOnFunction() {
|
|||||||
if (failed(FoldMergeNodes(getFunction(), queue))) return signalPassFailure();
|
if (failed(FoldMergeNodes(getFunction(), queue))) return signalPassFailure();
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
static PassRegistration<SwitchFold> pass(
|
namespace tf_executor {
|
||||||
|
std::unique_ptr<FunctionPassBase> CreateSwitchFoldPass() {
|
||||||
|
return std::make_unique<SwitchFoldPass>();
|
||||||
|
}
|
||||||
|
} // namespace tf_executor
|
||||||
|
|
||||||
|
static PassRegistration<SwitchFoldPass> pass(
|
||||||
"tf-switch-fold", "Fold switch nodes with constant predicates");
|
"tf-switch-fold", "Fold switch nodes with constant predicates");
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -41,6 +41,9 @@ std::unique_ptr<FunctionPassBase> CreateRaiseTFControlFlowPass();
|
|||||||
namespace tf_executor {
|
namespace tf_executor {
|
||||||
class GraphOp;
|
class GraphOp;
|
||||||
|
|
||||||
|
// Returns a pass that folds switch nodes with constant predicates.
|
||||||
|
std::unique_ptr<FunctionPassBase> CreateSwitchFoldPass();
|
||||||
|
|
||||||
// Create a pass to merge IslandOps from TFExecutor dialect.
|
// Create a pass to merge IslandOps from TFExecutor dialect.
|
||||||
std::unique_ptr<FunctionPassBase> CreateTFExecutorIslandCoarseningPass();
|
std::unique_ptr<FunctionPassBase> CreateTFExecutorIslandCoarseningPass();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user