Add option to GraphPruning pass to allow it to skip main function.
PiperOrigin-RevId: 285514262 Change-Id: I62e1bbc4763727d87ecfa88e9a89d7a465cbd939
This commit is contained in:
parent
6c014754ee
commit
c49a5661db
|
@ -167,3 +167,16 @@ func @control_fetch(%arg0 : i32) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that @main function is pruned.
|
||||||
|
// CHECK-LABEL: func @main
|
||||||
|
func @main() {
|
||||||
|
tf_executor.graph {
|
||||||
|
// CHECK-NOT: tf_executor.island
|
||||||
|
%0 = tf_executor.island {
|
||||||
|
tf_executor.yield
|
||||||
|
}
|
||||||
|
tf_executor.fetch
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// Check that @main function is skipped by default.
|
||||||
|
// CHECK-LABEL: func @main
|
||||||
|
func @main() {
|
||||||
|
tf_executor.graph {
|
||||||
|
// CHECKT: tf_executor.island
|
||||||
|
%0 = tf_executor.island {
|
||||||
|
tf_executor.yield
|
||||||
|
}
|
||||||
|
tf_executor.fetch
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -86,17 +86,36 @@ namespace {
|
||||||
// This transformation pass prunes a TF graph eliminating dead-nodes.
|
// This transformation pass prunes a TF graph eliminating dead-nodes.
|
||||||
struct GraphPruning : public FunctionPass<GraphPruning> {
|
struct GraphPruning : public FunctionPass<GraphPruning> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
|
FuncOp func = getFunction();
|
||||||
|
if (func.getName() == "main" && skip_main_func) return;
|
||||||
|
func.walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Options : public PassOptions<Options> {
|
||||||
|
Option<bool> skip_main_func{
|
||||||
|
*this, "skip-main-func",
|
||||||
|
llvm::cl::desc("skip graph pruning for main function"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit GraphPruning(bool skip_main_func)
|
||||||
|
: FunctionPass<GraphPruning>(), skip_main_func(skip_main_func) {}
|
||||||
|
|
||||||
|
explicit GraphPruning(const Options& option)
|
||||||
|
: GraphPruning(option.skip_main_func) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool skip_main_func;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass() {
|
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
|
||||||
return std::make_unique<GraphPruning>();
|
bool skip_main_func) {
|
||||||
|
return std::make_unique<GraphPruning>(skip_main_func);
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<GraphPruning> pass(
|
static PassRegistration<GraphPruning, GraphPruning::Options> pass(
|
||||||
"tf-executor-graph-pruning",
|
"tf-executor-graph-pruning",
|
||||||
"Prune unreachable nodes in a TensorFlow Graph.");
|
"Prune unreachable nodes in a TensorFlow Graph.");
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,8 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateSwitchFoldPass();
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
|
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
|
||||||
|
|
||||||
// Create a pass to prune tf_executor.graph from dead nodes.
|
// Create a pass to prune tf_executor.graph from dead nodes.
|
||||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass();
|
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
|
||||||
|
bool skip_main_func = false);
|
||||||
|
|
||||||
// Prunes unreachable operations of a tf_executor.graph operation.
|
// Prunes unreachable operations of a tf_executor.graph operation.
|
||||||
void PruneGraph(GraphOp graph);
|
void PruneGraph(GraphOp graph);
|
||||||
|
|
Loading…
Reference in New Issue