Change TPUExtractOutsideCompilation pass for a Module pass.

This is needed for getting the devices from the module for assigning host device for outside compilation launch op.

PiperOrigin-RevId: 317169244
Change-Id: I734e7eeef3fdb037045d070ffd736be4ef8edee1
This commit is contained in:
Ken Franko 2020-06-18 13:29:28 -07:00 committed by TensorFlower Gardener
parent 8944a3eeb1
commit 2663edb669
2 changed files with 8 additions and 6 deletions

View File

@ -281,7 +281,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUHostComputationExpansionPass();
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
// ops to a separate parallel_execute region to run on CPU.
std::unique_ptr<OperationPass<FuncOp>> CreateTPUExtractOutsideCompilationPass();
std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass();
// Populates the supplied passmanager with the passes required to run the
void CreateTPUBridgePipeline(OpPassManager& pm);

View File

@ -49,8 +49,9 @@ using OutsideClusterMap =
// TODO(b/154363171): Add example tranformations.
struct TPUExtractOutsideCompilation
: public PassWrapper<TPUExtractOutsideCompilation, FunctionPass> {
void runOnFunction() override;
: public PassWrapper<TPUExtractOutsideCompilation,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
@ -305,9 +306,9 @@ void CreateParallelExecuteFromOutsideClusters(
}
}
void TPUExtractOutsideCompilation::runOnFunction() {
void TPUExtractOutsideCompilation::runOnOperation() {
auto extract_result =
getFunction().walk([&](tf_device::ClusterOp tpu_cluster) {
getOperation().walk([&](tf_device::ClusterOp tpu_cluster) {
OutsideClusterMap clusters;
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
&clusters)))
@ -325,7 +326,7 @@ void TPUExtractOutsideCompilation::runOnFunction() {
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass() {
return std::make_unique<TPUExtractOutsideCompilation>();
}