[XLA:GPU] Allow using ChannelLayoutConstraints for XLA:GPU

PiperOrigin-RevId: 359178668
Change-Id: Iac7816ca263e6f0283f59f02d4f0844a76691f77
This commit is contained in:
George Karpenkov 2021-02-23 18:09:05 -08:00 committed by TensorFlower Gardener
parent 2b2e7b64ca
commit ea6f79938f
2 changed files with 11 additions and 6 deletions

View File

@ -296,9 +296,11 @@ Status GpuCompiler::OptimizeHloModule(
// Layout assignment uses alias analysis, which requires the call graph to
// be flattened.
pipeline.AddPass<FlattenCallGraph>();
ChannelLayoutConstraints layout_constraints;
pipeline.AddPass<GpuLayoutAssignment>(
hlo_module->mutable_entry_computation_layout(),
LayoutAssignment::InstructionCanChangeLayout, stream_exec);
LayoutAssignment::InstructionCanChangeLayout, stream_exec,
&layout_constraints);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}

View File

@ -29,12 +29,15 @@ namespace gpu {
// layout constraints for operands and results of library calls.
class GpuLayoutAssignment : public LayoutAssignment {
public:
explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout,
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func,
se::StreamExecutor* stream_executor)
explicit GpuLayoutAssignment(
ComputationLayout* entry_computation_layout,
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func,
se::StreamExecutor* stream_executor,
ChannelLayoutConstraints* channel_constraints = nullptr)
: LayoutAssignment(entry_computation_layout,
std::move(instruction_can_change_layout_func)),
std::move(instruction_can_change_layout_func),
channel_constraints),
stream_executor_(stream_executor) {}
~GpuLayoutAssignment() override {}