Add switch fold pass to tf_tfl_translate passes

PiperOrigin-RevId: 267004531
This commit is contained in:
Jacques Pienaar 2019-09-03 13:49:23 -07:00 committed by TensorFlower Gardener
parent 30485bb7b9
commit 1479da3cb2
4 changed files with 14 additions and 3 deletions

View File

@ -521,6 +521,7 @@ cc_library(
":tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_fold_switch",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",

View File

@ -41,6 +41,7 @@ bool ShouldRunQuantizePasses(mlir::ModuleOp m) {
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::PassManager* pass_manager) {
pass_manager->addPass(mlir::tf_executor::CreateSwitchFoldPass());
pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion());
pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass());
// Ophint extraction will happen after island extraction pass.

View File

@ -58,7 +58,7 @@ limitations under the License.
namespace mlir {
namespace {
class SwitchFold : public mlir::FunctionPass<SwitchFold> {
class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
public:
void runOnFunction() override;
};
@ -266,7 +266,7 @@ bool HasSendOrReceive(FuncOp function) {
.wasInterrupted();
}
void SwitchFold::runOnFunction() {
void SwitchFoldPass::runOnFunction() {
if (HasSendOrReceive(getFunction())) return;
DeadQueue queue;
// Initialize dead queue with dead outputs of foldable SwitchOps.
@ -277,7 +277,13 @@ void SwitchFold::runOnFunction() {
if (failed(FoldMergeNodes(getFunction(), queue))) return signalPassFailure();
} // namespace mlir
static PassRegistration<SwitchFold> pass(
namespace tf_executor {
std::unique_ptr<FunctionPassBase> CreateSwitchFoldPass() {
return std::make_unique<SwitchFoldPass>();
}
} // namespace tf_executor
static PassRegistration<SwitchFoldPass> pass(
"tf-switch-fold", "Fold switch nodes with constant predicates");
} // namespace mlir

View File

@ -41,6 +41,9 @@ std::unique_ptr<FunctionPassBase> CreateRaiseTFControlFlowPass();
namespace tf_executor {
class GraphOp;
// Returns a pass that folds switch nodes with constant predicates.
std::unique_ptr<FunctionPassBase> CreateSwitchFoldPass();
// Create a pass to merge IslandOps from TFExecutor dialect.
std::unique_ptr<FunctionPassBase> CreateTFExecutorIslandCoarseningPass();