Add option to GraphPruning pass to allow it to skip main function.
PiperOrigin-RevId: 285514262 Change-Id: I62e1bbc4763727d87ecfa88e9a89d7a465cbd939
This commit is contained in:
parent
6c014754ee
commit
c49a5661db
|
@ -167,3 +167,16 @@ func @control_fetch(%arg0 : i32) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check that @main function is pruned.
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() {
|
||||
tf_executor.graph {
|
||||
// CHECK-NOT: tf_executor.island
|
||||
%0 = tf_executor.island {
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail
|
||||
|
||||
// Check that @main function is skipped by default.
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() {
|
||||
tf_executor.graph {
|
||||
// CHECKT: tf_executor.island
|
||||
%0 = tf_executor.island {
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
|
@ -86,17 +86,36 @@ namespace {
|
|||
// This transformation pass prunes a TF graph eliminating dead-nodes.
|
||||
struct GraphPruning : public FunctionPass<GraphPruning> {
|
||||
void runOnFunction() override {
|
||||
getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
|
||||
FuncOp func = getFunction();
|
||||
if (func.getName() == "main" && skip_main_func) return;
|
||||
func.walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
|
||||
}
|
||||
|
||||
struct Options : public PassOptions<Options> {
|
||||
Option<bool> skip_main_func{
|
||||
*this, "skip-main-func",
|
||||
llvm::cl::desc("skip graph pruning for main function"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
explicit GraphPruning(bool skip_main_func)
|
||||
: FunctionPass<GraphPruning>(), skip_main_func(skip_main_func) {}
|
||||
|
||||
explicit GraphPruning(const Options& option)
|
||||
: GraphPruning(option.skip_main_func) {}
|
||||
|
||||
private:
|
||||
bool skip_main_func;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass() {
|
||||
return std::make_unique<GraphPruning>();
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
|
||||
bool skip_main_func) {
|
||||
return std::make_unique<GraphPruning>(skip_main_func);
|
||||
}
|
||||
|
||||
static PassRegistration<GraphPruning> pass(
|
||||
static PassRegistration<GraphPruning, GraphPruning::Options> pass(
|
||||
"tf-executor-graph-pruning",
|
||||
"Prune unreachable nodes in a TensorFlow Graph.");
|
||||
|
||||
|
|
|
@ -79,7 +79,8 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateSwitchFoldPass();
|
|||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
|
||||
|
||||
// Create a pass to prune tf_executor.graph from dead nodes.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass();
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
|
||||
bool skip_main_func = false);
|
||||
|
||||
// Prunes unreachable operations of a tf_executor.graph operation.
|
||||
void PruneGraph(GraphOp graph);
|
||||
|
|
Loading…
Reference in New Issue