Add option to GraphPruning pass to allow it to skip main function.

PiperOrigin-RevId: 285514262
Change-Id: I62e1bbc4763727d87ecfa88e9a89d7a465cbd939
This commit is contained in:
Yanan Cao 2019-12-13 19:31:07 -08:00 committed by TensorFlower Gardener
parent 6c014754ee
commit c49a5661db
4 changed files with 52 additions and 5 deletions

View File

@ -167,3 +167,16 @@ func @control_fetch(%arg0 : i32) {
} }
return return
} }
// Check that @main function is pruned.
// CHECK-LABEL: func @main
func @main() {
tf_executor.graph {
// CHECK-NOT: tf_executor.island
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}

View File

@ -0,0 +1,14 @@
// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail
// Check that @main function is skipped by default.
// CHECK-LABEL: func @main
func @main() {
tf_executor.graph {
// CHECKT: tf_executor.island
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}

View File

@ -86,17 +86,36 @@ namespace {
// This transformation pass prunes a TF graph eliminating dead-nodes. // This transformation pass prunes a TF graph eliminating dead-nodes.
struct GraphPruning : public FunctionPass<GraphPruning> { struct GraphPruning : public FunctionPass<GraphPruning> {
void runOnFunction() override { void runOnFunction() override {
getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); }); FuncOp func = getFunction();
if (func.getName() == "main" && skip_main_func) return;
func.walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
} }
struct Options : public PassOptions<Options> {
Option<bool> skip_main_func{
*this, "skip-main-func",
llvm::cl::desc("skip graph pruning for main function"),
llvm::cl::init(false)};
};
explicit GraphPruning(bool skip_main_func)
: FunctionPass<GraphPruning>(), skip_main_func(skip_main_func) {}
explicit GraphPruning(const Options& option)
: GraphPruning(option.skip_main_func) {}
private:
bool skip_main_func;
}; };
} // namespace } // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass() { std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
return std::make_unique<GraphPruning>(); bool skip_main_func) {
return std::make_unique<GraphPruning>(skip_main_func);
} }
static PassRegistration<GraphPruning> pass( static PassRegistration<GraphPruning, GraphPruning::Options> pass(
"tf-executor-graph-pruning", "tf-executor-graph-pruning",
"Prune unreachable nodes in a TensorFlow Graph."); "Prune unreachable nodes in a TensorFlow Graph.");

View File

@ -79,7 +79,8 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateSwitchFoldPass();
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass(); std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
// Create a pass to prune tf_executor.graph from dead nodes. // Create a pass to prune tf_executor.graph from dead nodes.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(); std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
bool skip_main_func = false);
// Prunes unreachable operations of a tf_executor.graph operation. // Prunes unreachable operations of a tf_executor.graph operation.
void PruneGraph(GraphOp graph); void PruneGraph(GraphOp graph);