Add option to GraphPruning pass to allow it to skip main function.

This knob isn't needed for encapsulated TPU subgraph

PiperOrigin-RevId: 286092252
Change-Id: I21d07b2ccc41ab6fac5cc61413e0171588dd9570
This commit is contained in:
Yanan Cao 2019-12-17 17:08:11 -08:00 committed by TensorFlower Gardener
parent d0fe2f87f8
commit a51291d9b4
5 changed files with 6 additions and 54 deletions

View File

@ -167,16 +167,3 @@ func @control_fetch(%arg0 : i32) {
}
return
}
// Check that @main function is pruned.
// CHECK-LABEL: func @main
func @main() {
tf_executor.graph {
// CHECK-NOT: tf_executor.island
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}

View File

@ -1,14 +0,0 @@
// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail
// Check that @main function is skipped by default.
// CHECK-LABEL: func @main
func @main() {
tf_executor.graph {
// CHECKT: tf_executor.island
%0 = tf_executor.island {
tf_executor.yield
}
tf_executor.fetch
}
return
}

View File

@ -86,36 +86,17 @@ namespace {
// This transformation pass prunes a TF graph eliminating dead-nodes.
struct GraphPruning : public FunctionPass<GraphPruning> {
void runOnFunction() override {
FuncOp func = getFunction();
if (func.getName() == "main" && skip_main_func) return;
func.walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
getFunction().walk([](tf_executor::GraphOp graph) { PruneGraph(graph); });
}
struct Options : public PassOptions<Options> {
Option<bool> skip_main_func{
*this, "skip-main-func",
llvm::cl::desc("skip graph pruning for main function"),
llvm::cl::init(false)};
};
explicit GraphPruning(bool skip_main_func)
: FunctionPass<GraphPruning>(), skip_main_func(skip_main_func) {}
explicit GraphPruning(const Options& option)
: GraphPruning(option.skip_main_func) {}
private:
bool skip_main_func;
};
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
bool skip_main_func) {
return std::make_unique<GraphPruning>(skip_main_func);
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass() {
return std::make_unique<GraphPruning>();
}
static PassRegistration<GraphPruning, GraphPruning::Options> pass(
static PassRegistration<GraphPruning> pass(
"tf-executor-graph-pruning",
"Prune unreachable nodes in a TensorFlow Graph.");

View File

@ -79,8 +79,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateSwitchFoldPass();
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorIslandCoarseningPass();
// Create a pass to prune tf_executor.graph from dead nodes.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass(
bool skip_main_func = false);
std::unique_ptr<OpPassBase<FuncOp>> CreateTFExecutorGraphPruningPass();
// Prunes unreachable operations of a tf_executor.graph operation.
void PruneGraph(GraphOp graph);

View File

@ -216,8 +216,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
// and canonicalization opportunities that are necessary for the second
// LegalizeTFPass(allow_partial_conversion=false) invocation.
tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(true));
tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass(
/*skip_main_func=*/true));
tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass());
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addNestedPass<mlir::FuncOp>(
mlir::xla_hlo::createLegalizeTFPass(false));