Change TPUExtractOutsideCompilation pass for a Module pass.
This is needed for getting the devices from the module for assigning host device for outside compilation launch op. PiperOrigin-RevId: 317169244 Change-Id: I734e7eeef3fdb037045d070ffd736be4ef8edee1
This commit is contained in:
parent
8944a3eeb1
commit
2663edb669
@ -281,7 +281,8 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTPUHostComputationExpansionPass();
|
||||
|
||||
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
|
||||
// ops to a separate parallel_execute region to run on CPU.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTPUExtractOutsideCompilationPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateTPUExtractOutsideCompilationPass();
|
||||
|
||||
// Populates the supplied passmanager with the passes required to run the
|
||||
void CreateTPUBridgePipeline(OpPassManager& pm);
|
||||
|
@ -49,8 +49,9 @@ using OutsideClusterMap =
|
||||
// TODO(b/154363171): Add example tranformations.
|
||||
|
||||
struct TPUExtractOutsideCompilation
|
||||
: public PassWrapper<TPUExtractOutsideCompilation, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
: public PassWrapper<TPUExtractOutsideCompilation,
|
||||
OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
|
||||
@ -305,9 +306,9 @@ void CreateParallelExecuteFromOutsideClusters(
|
||||
}
|
||||
}
|
||||
|
||||
void TPUExtractOutsideCompilation::runOnFunction() {
|
||||
void TPUExtractOutsideCompilation::runOnOperation() {
|
||||
auto extract_result =
|
||||
getFunction().walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||
getOperation().walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||
OutsideClusterMap clusters;
|
||||
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
|
||||
&clusters)))
|
||||
@ -325,7 +326,7 @@ void TPUExtractOutsideCompilation::runOnFunction() {
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateTPUExtractOutsideCompilationPass() {
|
||||
return std::make_unique<TPUExtractOutsideCompilation>();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user