Change TPUExtractOutsideCompilation pass for a Module pass.

This is needed for getting the devices from the module for assigning host device for outside compilation launch op.

PiperOrigin-RevId: 317169244
Change-Id: I734e7eeef3fdb037045d070ffd736be4ef8edee1
This commit is contained in:
Ken Franko 2020-06-18 13:29:28 -07:00 committed by TensorFlower Gardener
parent 8944a3eeb1
commit 2663edb669
2 changed files with 8 additions and 6 deletions

View File

@ -281,7 +281,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUHostComputationExpansionPass();
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster) // Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
// ops to a separate parallel_execute region to run on CPU. // ops to a separate parallel_execute region to run on CPU.
std::unique_ptr<OperationPass<FuncOp>> CreateTPUExtractOutsideCompilationPass(); std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass();
// Populates the supplied passmanager with the passes required to run the // Populates the supplied passmanager with the passes required to run the
void CreateTPUBridgePipeline(OpPassManager& pm); void CreateTPUBridgePipeline(OpPassManager& pm);

View File

@ -49,8 +49,9 @@ using OutsideClusterMap =
// TODO(b/154363171): Add example tranformations. // TODO(b/154363171): Add example tranformations.
struct TPUExtractOutsideCompilation struct TPUExtractOutsideCompilation
: public PassWrapper<TPUExtractOutsideCompilation, FunctionPass> { : public PassWrapper<TPUExtractOutsideCompilation,
void runOnFunction() override; OperationPass<ModuleOp>> {
void runOnOperation() override;
}; };
// Collects and clusters ops in `block` with the same `_xla_outside_compilation` // Collects and clusters ops in `block` with the same `_xla_outside_compilation`
@ -305,9 +306,9 @@ void CreateParallelExecuteFromOutsideClusters(
} }
} }
void TPUExtractOutsideCompilation::runOnFunction() { void TPUExtractOutsideCompilation::runOnOperation() {
auto extract_result = auto extract_result =
getFunction().walk([&](tf_device::ClusterOp tpu_cluster) { getOperation().walk([&](tf_device::ClusterOp tpu_cluster) {
OutsideClusterMap clusters; OutsideClusterMap clusters;
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(), if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
&clusters))) &clusters)))
@ -325,7 +326,7 @@ void TPUExtractOutsideCompilation::runOnFunction() {
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass() { CreateTPUExtractOutsideCompilationPass() {
return std::make_unique<TPUExtractOutsideCompilation>(); return std::make_unique<TPUExtractOutsideCompilation>();
} }