[XLA/GPU] Add optional BufferAllocation fields to IrEmitterContext, and

replace some of the IrEmitterUnnested BufferAssignment uses with BufferAllocation use.

Also move the MLIR dialect registration to where MLIRContext gets created.

PiperOrigin-RevId: 349604730
Change-Id: I016fca6fbc20c8a0bc6ce8219b32ac38fe862d64
This commit is contained in:
Tim Shen 2020-12-30 14:48:22 -08:00 committed by TensorFlower Gardener
parent e5f34fd4cc
commit c08face2b4
5 changed files with 51 additions and 41 deletions

View File

@ -578,6 +578,9 @@ static Status CompileModuleToLlvmIrImpl(
"after_optimizations");
mlir::MLIRContext mlir_context;
mlir_context.loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
mlir::StandardOpsDialect,
mlir::lmhlo_gpu::LmhloGpuDialect>();
IrEmitterContext ir_emitter_context(
hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,

View File

@ -50,11 +50,7 @@ class IrEmitterContext {
cuda_compute_capability_(cuda_compute_capability),
profile_index_map_(profile_index_map),
mlir_context_(mlir_context),
llvm_module_(llvm_module) {
mlir_context_->loadDialect<
mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
mlir::StandardOpsDialect, mlir::lmhlo_gpu::LmhloGpuDialect>();
}
llvm_module_(llvm_module) {}
// Disallow copy and assign.
IrEmitterContext(const IrEmitterContext&) = delete;
IrEmitterContext& operator=(const IrEmitterContext&) = delete;
@ -76,9 +72,22 @@ class IrEmitterContext {
std::vector<GpuExecutable::ConstantInfo>& constants() { return constants_; }
absl::Span<const BufferAllocation> allocations() const {
if (buffer_assignment_) {
return buffer_assignment_->Allocations();
}
return allocations_;
}
void set_allocations(absl::Span<const BufferAllocation> allocations) {
CHECK_EQ(nullptr, buffer_assignment_);
allocations_ = allocations;
}
private:
const HloModule* hlo_module_;
const BufferAssignment* buffer_assignment_;
absl::Span<const BufferAllocation> allocations_;
std::string platform_name_;
GpuDeviceInfo gpu_device_info_;
absl::optional<CudaComputeCapability> cuda_compute_capability_;

View File

@ -566,12 +566,7 @@ StatusOr<Shape> GetConsistentInputShapeForRootSlices(
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
const HloComputation* hlo_computation,
IrEmitterContext* ir_emitter_context)
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
hlo_computation_(hlo_computation),
mlir_scratch_module_(mlir::ModuleOp::create(
mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc())),
lhlo_scratch_emitter_(ir_emitter_context_->buffer_assignment(),
*hlo_computation, mlir_scratch_module_.get()) {}
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false) {}
StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
const HloModuleConfig& hlo_module_config,
@ -579,8 +574,15 @@ StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
IrEmitterContext* ir_emitter_context) {
auto emitter = std::unique_ptr<IrEmitterUnnested>(new IrEmitterUnnested(
hlo_module_config, hlo_computation, ir_emitter_context));
TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_.Initialize());
TF_RETURN_IF_ERROR(emitter->EmitConstants(*hlo_computation, true));
if (hlo_computation) {
emitter->mlir_scratch_module_.emplace(mlir::ModuleOp::create(
mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc()));
emitter->lhlo_scratch_emitter_.emplace(
emitter->ir_emitter_context_->buffer_assignment(), *hlo_computation,
emitter->mlir_scratch_module_->get());
TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_->Initialize());
TF_RETURN_IF_ERROR(emitter->EmitConstants(*hlo_computation, true));
}
return std::move(emitter);
}
@ -656,9 +658,8 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSliceForMlir(
::mlir::Value v) {
absl::Span<const BufferAllocation> allocations(
ir_emitter_context_->buffer_assignment().Allocations());
return xla::gpu::GetAllocationSliceForMlir(v, allocations);
return xla::gpu::GetAllocationSliceForMlir(
v, ir_emitter_context_->allocations());
}
Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
@ -1583,7 +1584,7 @@ static Status ProcessFusionForConversion(mlir::Region* region,
StatusOr<MlirEmitterInput> IrEmitterUnnested::GetMlirEmitterInput(
HloInstruction* hlo) {
MlirEmitterInput input;
TF_ASSIGN_OR_RETURN(input.op, lhlo_scratch_emitter_.EmitOp(hlo));
TF_ASSIGN_OR_RETURN(input.op, lhlo_scratch_emitter_->EmitOp(hlo));
input.thunk_info = GetThunkInfo(hlo);
if (hlo->shape().IsTuple()) {
const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
@ -1751,7 +1752,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
TF_ASSIGN_OR_RETURN(
const auto dim_numbers,
lhlo_scratch_emitter_.GetScatterDimensionNumbers(root));
lhlo_scratch_emitter_->GetScatterDimensionNumbers(root));
ScatterDescriptor desc;
desc.name = IrName(root);
@ -2647,7 +2648,7 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
MlirEmitterInput result;
TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_.EmitOp(sort));
TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_->EmitOp(sort));
result.op = sort_op;
const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
auto& slice = result.extra_slice.emplace();
@ -2852,14 +2853,6 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
mlir_input.thunk_info, std::move(thunks)));
if (context.operand_shapes.size() > 1) {
// Emit the tuple as part of the last stage of sorting.
// We are currently in the block sorted.in_bounds.after.
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
llvm_ir::EmitTuple(
ir_arrays.back(),
absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_);
}
return Status::OK();
}
@ -3300,8 +3293,6 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
absl::Span<const BufferSlice* const> slices,
std::function<void(const BufferSlice*, llvm::Value*)>
bind_slice_to_ir_value) {
const auto& buffer_assn = ir_emitter_context_->buffer_assignment();
// Figure out which buffer allocations need to be passed as arguments to our
// kernel. This is simply all of the allocations referenced in slices,
// plus the XLA temp buffer (if we have it). We always include the temp
@ -3312,7 +3303,7 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
buffers_needed.insert(slice->buffer_slice.allocation());
}
absl::optional<const BufferAllocation*> temp_buffer;
for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
for (const BufferAllocation& alloc : ir_emitter_context_->allocations()) {
if (alloc.IsPreallocatedTempBuffer()) {
if (!temp_buffer.has_value()) {
// Retrieve the first seen temp buffer.
@ -5858,6 +5849,15 @@ Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(
return info;
}
Status IrEmitterUnnested::EmitOp(MlirEmitterInput mlir_input) {
if (mlir::isa<mlir::lmhlo::SortOp>(mlir_input.op)) {
return EmitSortFromMlir(mlir_input);
}
LOG(FATAL)
<< "This function is for test only, and the op is not implemented: "
<< MlirToString(mlir_input.op);
}
void MlirEmitterContext::SetOperation(mlir::Operation* op) {
this->name = mlir::GetNameFromLoc(op->getLoc());

View File

@ -201,6 +201,8 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleReplicaId(HloInstruction* hlo) override;
Status HandleCollectivePermute(HloInstruction* hlo) override;
Status EmitOp(MlirEmitterInput mlir_input);
Status EmitTargetElementLoop(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& body_emitter) override;
@ -715,16 +717,18 @@ class IrEmitterUnnested : public IrEmitter,
// The thunk sequence this IrEmitter generates for the input computation.
ThunkSequence thunk_sequence_;
// The HloComputation that this IrEmitter emits code for.
const HloComputation* hlo_computation_;
mlir::OwningModuleRef mlir_scratch_module_;
// Begin optional members for XLA HLO -> LMHLO:
// TODO(timshen): Once XLA HLO -> LMHLO converter is complete,
// IrEmitterUnnested should take LMHLO only, and won't require a scratch
// module.
absl::optional<mlir::OwningModuleRef> mlir_scratch_module_;
// This is for cache-purpose only. It has no significant semantics.
mlir::LhloDialectEmitter lhlo_scratch_emitter_;
absl::optional<mlir::LhloDialectEmitter> lhlo_scratch_emitter_;
absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>>
scratch_nested_computations_;
// End optional members for XLA HLO -> LMHLO.
};
} // namespace gpu

View File

@ -334,12 +334,6 @@ compare {
// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK: sort.in_bounds-after:
// CHECK-NEXT: [[TMP13:%.*]] = bitcast [2 x [3 x i32]]* [[TMP1]] to i8*
// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 0
// CHECK-NEXT: store i8* [[TMP13]], i8** [[TMP14]], align 8
// CHECK-NEXT: [[TMP15:%.*]] = bitcast [2 x [3 x float]]* [[TMP3]] to i8*
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 1
// CHECK-NEXT: store i8* [[TMP15]], i8** [[TMP16]], align 8
// CHECK-NEXT: ret void
// CHECK: sort.in_bounds-true:
// CHECK-NEXT: [[TMP17:%.*]] = mul i64 [[TMP10]], 2