[XLA/GPU] Remove uses of BufferAssignment in nested computation lowering. NFC
BufferAssignment doesn't really do intelligent work in nested computations, because all allocations are either constants or local allocas. Removing it helps with XLA/GPU -> LHLO migration. PiperOrigin-RevId: 323653069 Change-Id: I50acf0bdf07072102b145e5ffb387d1215f3887c
This commit is contained in:
parent
711e05bd78
commit
a78f101f8e
@ -35,7 +35,7 @@ namespace gpu {
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
|
||||
Status HloToIrBindings::EmitBasePointersForHlos(
|
||||
void HloToIrBindings::EmitBasePointersForHlos(
|
||||
absl::Span<const HloInstruction* const> io_hlos,
|
||||
absl::Span<const HloInstruction* const> non_io_hlos) {
|
||||
CHECK(is_nested_);
|
||||
@ -77,44 +77,23 @@ Status HloToIrBindings::EmitBasePointersForHlos(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto status = ShapeUtil::ForEachSubshapeWithStatus(
|
||||
ShapeUtil::ForEachSubshape(
|
||||
non_io_hlo->shape(),
|
||||
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
// A non-IO HLO with a buffer is bound to an alloca if it is
|
||||
// thread-local.
|
||||
auto slice_result =
|
||||
buffer_assignment_->GetUniqueSlice(non_io_hlo, index);
|
||||
if (!slice_result.ok()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const BufferAllocation::Slice slice =
|
||||
slice_result.ConsumeValueOrDie();
|
||||
if (slice.allocation()->is_thread_local()) {
|
||||
if (non_io_hlo->opcode() == HloOpcode::kConstant) {
|
||||
llvm::Value* global_for_constant = module_->getGlobalVariable(
|
||||
llvm_ir::ConstantHloToGlobalName(*non_io_hlo));
|
||||
BindHloToIrValue(*non_io_hlo, global_for_constant);
|
||||
} else {
|
||||
llvm::Type* pointee_type =
|
||||
llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
|
||||
BindHloToIrValue(*non_io_hlo,
|
||||
llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
pointee_type, /*name=*/"", b_),
|
||||
index);
|
||||
} else if (slice.allocation()->is_constant()) {
|
||||
llvm::Value* global_for_constant = module_->getGlobalVariable(
|
||||
llvm_ir::ConstantBufferAllocationToGlobalName(
|
||||
*slice.allocation()));
|
||||
BindHloToIrValue(*non_io_hlo, global_for_constant);
|
||||
} else {
|
||||
return InternalError(
|
||||
"Nested computation are not expected to take the temporary "
|
||||
"buffer. All buffers are either constant or thread-local.");
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
|
||||
@ -214,7 +193,8 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
|
||||
// Therefore if hlo's output buffer is not modified within consumer, and if
|
||||
// consumer runs hlo only once (so that it doesn't create two different
|
||||
// outputs), then we can mark ir_array as invariant over the whole program.
|
||||
if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
|
||||
if (!is_nested_ &&
|
||||
BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
|
||||
VLOG(2) << "Marking " << hlo.name() << " as invariant within "
|
||||
<< consumer.name();
|
||||
ir_array.MarkInvariantOverWholeProgram(&module_->getContext());
|
||||
|
||||
@ -43,7 +43,7 @@ class HloToIrBindings {
|
||||
b_(b),
|
||||
module_(llvm_module) {}
|
||||
|
||||
Status EmitBasePointersForHlos(
|
||||
void EmitBasePointersForHlos(
|
||||
absl::Span<const HloInstruction* const> io_hlos,
|
||||
absl::Span<const HloInstruction* const> non_io_hlos);
|
||||
|
||||
|
||||
@ -104,7 +104,7 @@ Status IrEmitterNested::CodegenNestedComputation() {
|
||||
non_io_hlos.push_back(hlo);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos));
|
||||
bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos);
|
||||
|
||||
TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this));
|
||||
b_.SetInsertPoint(ret_instr);
|
||||
|
||||
@ -54,9 +54,7 @@ string SanitizeConstantName(const HloInstruction& instr) {
|
||||
return instr_name;
|
||||
}
|
||||
|
||||
string ConstantBufferAllocationToGlobalName(
|
||||
const BufferAllocation& allocation) {
|
||||
const HloInstruction& instr = InstrForConstantBufferAllocation(allocation);
|
||||
string ConstantHloToGlobalName(const HloInstruction& instr) {
|
||||
string instr_name = instr.name();
|
||||
// Check that names are sanitized and stored in the HLO instructions
|
||||
// before constant buffer allocation.
|
||||
@ -64,6 +62,11 @@ string ConstantBufferAllocationToGlobalName(
|
||||
return absl::StrCat("buffer_for_", instr_name);
|
||||
}
|
||||
|
||||
string ConstantBufferAllocationToGlobalName(
|
||||
const BufferAllocation& allocation) {
|
||||
return ConstantHloToGlobalName(InstrForConstantBufferAllocation(allocation));
|
||||
}
|
||||
|
||||
const Literal& LiteralForConstantAllocation(
|
||||
const BufferAllocation& allocation) {
|
||||
return InstrForConstantBufferAllocation(allocation).literal();
|
||||
|
||||
@ -24,6 +24,9 @@ namespace llvm_ir {
|
||||
// name of the corresponding constant buffer. In particular, it replaces . and
|
||||
// - with _.
|
||||
string SanitizeConstantName(const HloInstruction& instr);
|
||||
|
||||
string ConstantHloToGlobalName(const HloInstruction& instr);
|
||||
|
||||
// In XLA:GPU we map constant buffer allocations to globals in the generated
|
||||
// LLVM IR. This function gives us the name of the global variable a constant
|
||||
// buffer is mapped to. Not used on XLA:CPU.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user