Change TPUExtractOutsideCompilation pass for a Module pass.
This is needed for getting the devices from the module for assigning host device for outside compilation launch op. PiperOrigin-RevId: 317169244 Change-Id: I734e7eeef3fdb037045d070ffd736be4ef8edee1
This commit is contained in:
parent
8944a3eeb1
commit
2663edb669
@ -281,7 +281,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUHostComputationExpansionPass();
|
|||||||
|
|
||||||
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
|
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
|
||||||
// ops to a separate parallel_execute region to run on CPU.
|
// ops to a separate parallel_execute region to run on CPU.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateTPUExtractOutsideCompilationPass();
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
CreateTPUExtractOutsideCompilationPass();
|
||||||
|
|
||||||
// Populates the supplied passmanager with the passes required to run the
|
// Populates the supplied passmanager with the passes required to run the
|
||||||
void CreateTPUBridgePipeline(OpPassManager& pm);
|
void CreateTPUBridgePipeline(OpPassManager& pm);
|
||||||
|
@ -49,8 +49,9 @@ using OutsideClusterMap =
|
|||||||
// TODO(b/154363171): Add example tranformations.
|
// TODO(b/154363171): Add example tranformations.
|
||||||
|
|
||||||
struct TPUExtractOutsideCompilation
|
struct TPUExtractOutsideCompilation
|
||||||
: public PassWrapper<TPUExtractOutsideCompilation, FunctionPass> {
|
: public PassWrapper<TPUExtractOutsideCompilation,
|
||||||
void runOnFunction() override;
|
OperationPass<ModuleOp>> {
|
||||||
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
|
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
|
||||||
@ -305,9 +306,9 @@ void CreateParallelExecuteFromOutsideClusters(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TPUExtractOutsideCompilation::runOnFunction() {
|
void TPUExtractOutsideCompilation::runOnOperation() {
|
||||||
auto extract_result =
|
auto extract_result =
|
||||||
getFunction().walk([&](tf_device::ClusterOp tpu_cluster) {
|
getOperation().walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||||
OutsideClusterMap clusters;
|
OutsideClusterMap clusters;
|
||||||
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
|
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
|
||||||
&clusters)))
|
&clusters)))
|
||||||
@ -325,7 +326,7 @@ void TPUExtractOutsideCompilation::runOnFunction() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
CreateTPUExtractOutsideCompilationPass() {
|
CreateTPUExtractOutsideCompilationPass() {
|
||||||
return std::make_unique<TPUExtractOutsideCompilation>();
|
return std::make_unique<TPUExtractOutsideCompilation>();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user