[XLA:GPU] Allow using ChannelLayoutConstraints for XLA:GPU
PiperOrigin-RevId: 359178668 Change-Id: Iac7816ca263e6f0283f59f02d4f0844a76691f77
This commit is contained in:
parent
2b2e7b64ca
commit
ea6f79938f
@ -296,9 +296,11 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
// Layout assignment uses alias analysis, which requires the call graph to
|
// Layout assignment uses alias analysis, which requires the call graph to
|
||||||
// be flattened.
|
// be flattened.
|
||||||
pipeline.AddPass<FlattenCallGraph>();
|
pipeline.AddPass<FlattenCallGraph>();
|
||||||
|
ChannelLayoutConstraints layout_constraints;
|
||||||
pipeline.AddPass<GpuLayoutAssignment>(
|
pipeline.AddPass<GpuLayoutAssignment>(
|
||||||
hlo_module->mutable_entry_computation_layout(),
|
hlo_module->mutable_entry_computation_layout(),
|
||||||
LayoutAssignment::InstructionCanChangeLayout, stream_exec);
|
LayoutAssignment::InstructionCanChangeLayout, stream_exec,
|
||||||
|
&layout_constraints);
|
||||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -29,12 +29,15 @@ namespace gpu {
|
|||||||
// layout constraints for operands and results of library calls.
|
// layout constraints for operands and results of library calls.
|
||||||
class GpuLayoutAssignment : public LayoutAssignment {
|
class GpuLayoutAssignment : public LayoutAssignment {
|
||||||
public:
|
public:
|
||||||
explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout,
|
explicit GpuLayoutAssignment(
|
||||||
std::function<bool(const HloInstruction*)>
|
ComputationLayout* entry_computation_layout,
|
||||||
instruction_can_change_layout_func,
|
std::function<bool(const HloInstruction*)>
|
||||||
se::StreamExecutor* stream_executor)
|
instruction_can_change_layout_func,
|
||||||
|
se::StreamExecutor* stream_executor,
|
||||||
|
ChannelLayoutConstraints* channel_constraints = nullptr)
|
||||||
: LayoutAssignment(entry_computation_layout,
|
: LayoutAssignment(entry_computation_layout,
|
||||||
std::move(instruction_can_change_layout_func)),
|
std::move(instruction_can_change_layout_func),
|
||||||
|
channel_constraints),
|
||||||
stream_executor_(stream_executor) {}
|
stream_executor_(stream_executor) {}
|
||||||
~GpuLayoutAssignment() override {}
|
~GpuLayoutAssignment() override {}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user