[XLA/GPU] Remove uses of BufferAssignment in nested computation lowering. NFC

BufferAssignment doesn't really do intelligent work in nested computations, because all allocations are either constants or local allocas.

Removing it helps with XLA/GPU -> LHLO migration.

PiperOrigin-RevId: 323653069
Change-Id: I50acf0bdf07072102b145e5ffb387d1215f3887c
This commit is contained in:
Tim Shen 2020-07-28 14:24:31 -07:00 committed by TensorFlower Gardener
parent 711e05bd78
commit a78f101f8e
5 changed files with 20 additions and 34 deletions

View File

@ -35,7 +35,7 @@ namespace gpu {
using absl::StrAppend;
using absl::StrCat;
Status HloToIrBindings::EmitBasePointersForHlos(
void HloToIrBindings::EmitBasePointersForHlos(
absl::Span<const HloInstruction* const> io_hlos,
absl::Span<const HloInstruction* const> non_io_hlos) {
CHECK(is_nested_);
@ -77,44 +77,23 @@ Status HloToIrBindings::EmitBasePointersForHlos(
continue;
}
if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) {
continue;
}
auto status = ShapeUtil::ForEachSubshapeWithStatus(
ShapeUtil::ForEachSubshape(
non_io_hlo->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
// A non-IO HLO with a buffer is bound to an alloca if it is
// thread-local.
auto slice_result =
buffer_assignment_->GetUniqueSlice(non_io_hlo, index);
if (!slice_result.ok()) {
return Status::OK();
}
const BufferAllocation::Slice slice =
slice_result.ConsumeValueOrDie();
if (slice.allocation()->is_thread_local()) {
if (non_io_hlo->opcode() == HloOpcode::kConstant) {
llvm::Value* global_for_constant = module_->getGlobalVariable(
llvm_ir::ConstantHloToGlobalName(*non_io_hlo));
BindHloToIrValue(*non_io_hlo, global_for_constant);
} else {
llvm::Type* pointee_type =
llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
BindHloToIrValue(*non_io_hlo,
llvm_ir::EmitAllocaAtFunctionEntry(
pointee_type, /*name=*/"", b_),
index);
} else if (slice.allocation()->is_constant()) {
llvm::Value* global_for_constant = module_->getGlobalVariable(
llvm_ir::ConstantBufferAllocationToGlobalName(
*slice.allocation()));
BindHloToIrValue(*non_io_hlo, global_for_constant);
} else {
return InternalError(
"Nested computation are not expected to take the temporary "
"buffer. All buffers are either constant or thread-local.");
}
return Status::OK();
});
TF_RETURN_IF_ERROR(status);
}
return Status::OK();
}
llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
@ -214,7 +193,8 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
// Therefore if hlo's output buffer is not modified within consumer, and if
// consumer runs hlo only once (so that it doesn't create two different
// outputs), then we can mark ir_array as invariant over the whole program.
if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
if (!is_nested_ &&
BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
VLOG(2) << "Marking " << hlo.name() << " as invariant within "
<< consumer.name();
ir_array.MarkInvariantOverWholeProgram(&module_->getContext());

View File

@ -43,7 +43,7 @@ class HloToIrBindings {
b_(b),
module_(llvm_module) {}
Status EmitBasePointersForHlos(
void EmitBasePointersForHlos(
absl::Span<const HloInstruction* const> io_hlos,
absl::Span<const HloInstruction* const> non_io_hlos);

View File

@ -104,7 +104,7 @@ Status IrEmitterNested::CodegenNestedComputation() {
non_io_hlos.push_back(hlo);
}
}
TF_RETURN_IF_ERROR(bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos));
bindings_.EmitBasePointersForHlos(io_hlos, non_io_hlos);
TF_RETURN_IF_ERROR(nested_computation_.root_instruction()->Accept(this));
b_.SetInsertPoint(ret_instr);

View File

@ -54,9 +54,7 @@ string SanitizeConstantName(const HloInstruction& instr) {
return instr_name;
}
string ConstantBufferAllocationToGlobalName(
const BufferAllocation& allocation) {
const HloInstruction& instr = InstrForConstantBufferAllocation(allocation);
string ConstantHloToGlobalName(const HloInstruction& instr) {
string instr_name = instr.name();
// Check that names are sanitized and stored in the HLO instructions
// before constant buffer allocation.
@ -64,6 +62,11 @@ string ConstantBufferAllocationToGlobalName(
return absl::StrCat("buffer_for_", instr_name);
}
string ConstantBufferAllocationToGlobalName(
const BufferAllocation& allocation) {
return ConstantHloToGlobalName(InstrForConstantBufferAllocation(allocation));
}
const Literal& LiteralForConstantAllocation(
const BufferAllocation& allocation) {
return InstrForConstantBufferAllocation(allocation).literal();

View File

@ -24,6 +24,9 @@ namespace llvm_ir {
// name of the corresponding constant buffer. In particular, it replaces . and
// - with _.
string SanitizeConstantName(const HloInstruction& instr);
string ConstantHloToGlobalName(const HloInstruction& instr);
// In XLA:GPU we map constant buffer allocations to globals in the generated
// LLVM IR. This function gives us the name of the global variable a constant
// buffer is mapped to. Not used on XLA:CPU.