[XLA/GPU] Add optional BufferAllocation fields to IrEmitterContext, and
replace some of the IrEmitterUnnested BufferAssignment uses with BufferAllocation use. Also move the MLIR dialect registration to where MLIRContext gets created. PiperOrigin-RevId: 349604730 Change-Id: I016fca6fbc20c8a0bc6ce8219b32ac38fe862d64
This commit is contained in:
parent
e5f34fd4cc
commit
c08face2b4
@ -578,6 +578,9 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
"after_optimizations");
|
||||
|
||||
mlir::MLIRContext mlir_context;
|
||||
mlir_context.loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
|
||||
mlir::StandardOpsDialect,
|
||||
mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||
|
||||
IrEmitterContext ir_emitter_context(
|
||||
hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
|
||||
|
@ -50,11 +50,7 @@ class IrEmitterContext {
|
||||
cuda_compute_capability_(cuda_compute_capability),
|
||||
profile_index_map_(profile_index_map),
|
||||
mlir_context_(mlir_context),
|
||||
llvm_module_(llvm_module) {
|
||||
mlir_context_->loadDialect<
|
||||
mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
|
||||
mlir::StandardOpsDialect, mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||
}
|
||||
llvm_module_(llvm_module) {}
|
||||
// Disallow copy and assign.
|
||||
IrEmitterContext(const IrEmitterContext&) = delete;
|
||||
IrEmitterContext& operator=(const IrEmitterContext&) = delete;
|
||||
@ -76,9 +72,22 @@ class IrEmitterContext {
|
||||
|
||||
std::vector<GpuExecutable::ConstantInfo>& constants() { return constants_; }
|
||||
|
||||
absl::Span<const BufferAllocation> allocations() const {
|
||||
if (buffer_assignment_) {
|
||||
return buffer_assignment_->Allocations();
|
||||
}
|
||||
return allocations_;
|
||||
}
|
||||
|
||||
void set_allocations(absl::Span<const BufferAllocation> allocations) {
|
||||
CHECK_EQ(nullptr, buffer_assignment_);
|
||||
allocations_ = allocations;
|
||||
}
|
||||
|
||||
private:
|
||||
const HloModule* hlo_module_;
|
||||
const BufferAssignment* buffer_assignment_;
|
||||
absl::Span<const BufferAllocation> allocations_;
|
||||
std::string platform_name_;
|
||||
GpuDeviceInfo gpu_device_info_;
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability_;
|
||||
|
@ -566,12 +566,7 @@ StatusOr<Shape> GetConsistentInputShapeForRootSlices(
|
||||
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
|
||||
const HloComputation* hlo_computation,
|
||||
IrEmitterContext* ir_emitter_context)
|
||||
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
|
||||
hlo_computation_(hlo_computation),
|
||||
mlir_scratch_module_(mlir::ModuleOp::create(
|
||||
mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc())),
|
||||
lhlo_scratch_emitter_(ir_emitter_context_->buffer_assignment(),
|
||||
*hlo_computation, mlir_scratch_module_.get()) {}
|
||||
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false) {}
|
||||
|
||||
StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
|
||||
const HloModuleConfig& hlo_module_config,
|
||||
@ -579,8 +574,15 @@ StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
|
||||
IrEmitterContext* ir_emitter_context) {
|
||||
auto emitter = std::unique_ptr<IrEmitterUnnested>(new IrEmitterUnnested(
|
||||
hlo_module_config, hlo_computation, ir_emitter_context));
|
||||
TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_.Initialize());
|
||||
TF_RETURN_IF_ERROR(emitter->EmitConstants(*hlo_computation, true));
|
||||
if (hlo_computation) {
|
||||
emitter->mlir_scratch_module_.emplace(mlir::ModuleOp::create(
|
||||
mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc()));
|
||||
emitter->lhlo_scratch_emitter_.emplace(
|
||||
emitter->ir_emitter_context_->buffer_assignment(), *hlo_computation,
|
||||
emitter->mlir_scratch_module_->get());
|
||||
TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_->Initialize());
|
||||
TF_RETURN_IF_ERROR(emitter->EmitConstants(*hlo_computation, true));
|
||||
}
|
||||
return std::move(emitter);
|
||||
}
|
||||
|
||||
@ -656,9 +658,8 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
|
||||
|
||||
StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSliceForMlir(
|
||||
::mlir::Value v) {
|
||||
absl::Span<const BufferAllocation> allocations(
|
||||
ir_emitter_context_->buffer_assignment().Allocations());
|
||||
return xla::gpu::GetAllocationSliceForMlir(v, allocations);
|
||||
return xla::gpu::GetAllocationSliceForMlir(
|
||||
v, ir_emitter_context_->allocations());
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
|
||||
@ -1583,7 +1584,7 @@ static Status ProcessFusionForConversion(mlir::Region* region,
|
||||
StatusOr<MlirEmitterInput> IrEmitterUnnested::GetMlirEmitterInput(
|
||||
HloInstruction* hlo) {
|
||||
MlirEmitterInput input;
|
||||
TF_ASSIGN_OR_RETURN(input.op, lhlo_scratch_emitter_.EmitOp(hlo));
|
||||
TF_ASSIGN_OR_RETURN(input.op, lhlo_scratch_emitter_->EmitOp(hlo));
|
||||
input.thunk_info = GetThunkInfo(hlo);
|
||||
if (hlo->shape().IsTuple()) {
|
||||
const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
|
||||
@ -1751,7 +1752,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const auto dim_numbers,
|
||||
lhlo_scratch_emitter_.GetScatterDimensionNumbers(root));
|
||||
lhlo_scratch_emitter_->GetScatterDimensionNumbers(root));
|
||||
|
||||
ScatterDescriptor desc;
|
||||
desc.name = IrName(root);
|
||||
@ -2647,7 +2648,7 @@ IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
|
||||
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
MlirEmitterInput result;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_.EmitOp(sort));
|
||||
TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_->EmitOp(sort));
|
||||
result.op = sort_op;
|
||||
const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
|
||||
auto& slice = result.extra_slice.emplace();
|
||||
@ -2852,14 +2853,6 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
|
||||
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
mlir_input.thunk_info, std::move(thunks)));
|
||||
if (context.operand_shapes.size() > 1) {
|
||||
// Emit the tuple as part of the last stage of sorting.
|
||||
// We are currently in the block sorted.in_bounds.after.
|
||||
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
|
||||
llvm_ir::EmitTuple(
|
||||
ir_arrays.back(),
|
||||
absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -3300,8 +3293,6 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
|
||||
absl::Span<const BufferSlice* const> slices,
|
||||
std::function<void(const BufferSlice*, llvm::Value*)>
|
||||
bind_slice_to_ir_value) {
|
||||
const auto& buffer_assn = ir_emitter_context_->buffer_assignment();
|
||||
|
||||
// Figure out which buffer allocations need to be passed as arguments to our
|
||||
// kernel. This is simply all of the allocations referenced in slices,
|
||||
// plus the XLA temp buffer (if we have it). We always include the temp
|
||||
@ -3312,7 +3303,7 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
|
||||
buffers_needed.insert(slice->buffer_slice.allocation());
|
||||
}
|
||||
absl::optional<const BufferAllocation*> temp_buffer;
|
||||
for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
|
||||
for (const BufferAllocation& alloc : ir_emitter_context_->allocations()) {
|
||||
if (alloc.IsPreallocatedTempBuffer()) {
|
||||
if (!temp_buffer.has_value()) {
|
||||
// Retrieve the first seen temp buffer.
|
||||
@ -5858,6 +5849,15 @@ Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(
|
||||
return info;
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitOp(MlirEmitterInput mlir_input) {
|
||||
if (mlir::isa<mlir::lmhlo::SortOp>(mlir_input.op)) {
|
||||
return EmitSortFromMlir(mlir_input);
|
||||
}
|
||||
LOG(FATAL)
|
||||
<< "This function is for test only, and the op is not implemented: "
|
||||
<< MlirToString(mlir_input.op);
|
||||
}
|
||||
|
||||
void MlirEmitterContext::SetOperation(mlir::Operation* op) {
|
||||
this->name = mlir::GetNameFromLoc(op->getLoc());
|
||||
|
||||
|
@ -201,6 +201,8 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleReplicaId(HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermute(HloInstruction* hlo) override;
|
||||
|
||||
Status EmitOp(MlirEmitterInput mlir_input);
|
||||
|
||||
Status EmitTargetElementLoop(
|
||||
const HloInstruction& hlo,
|
||||
const llvm_ir::ElementGenerator& body_emitter) override;
|
||||
@ -715,16 +717,18 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// The thunk sequence this IrEmitter generates for the input computation.
|
||||
ThunkSequence thunk_sequence_;
|
||||
|
||||
// The HloComputation that this IrEmitter emits code for.
|
||||
const HloComputation* hlo_computation_;
|
||||
|
||||
mlir::OwningModuleRef mlir_scratch_module_;
|
||||
// Begin optional members for XLA HLO -> LMHLO:
|
||||
// TODO(timshen): Once XLA HLO -> LMHLO converter is complete,
|
||||
// IrEmitterUnnested should take LMHLO only, and won't require a scratch
|
||||
// module.
|
||||
absl::optional<mlir::OwningModuleRef> mlir_scratch_module_;
|
||||
|
||||
// This is for cache-purpose only. It has no significant semantics.
|
||||
mlir::LhloDialectEmitter lhlo_scratch_emitter_;
|
||||
absl::optional<mlir::LhloDialectEmitter> lhlo_scratch_emitter_;
|
||||
|
||||
absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>>
|
||||
scratch_nested_computations_;
|
||||
// End optional members for XLA HLO -> LMHLO.
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -334,12 +334,6 @@ compare {
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK: sort.in_bounds-after:
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = bitcast [2 x [3 x i32]]* [[TMP1]] to i8*
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 0
|
||||
// CHECK-NEXT: store i8* [[TMP13]], i8** [[TMP14]], align 8
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = bitcast [2 x [3 x float]]* [[TMP3]] to i8*
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 1
|
||||
// CHECK-NEXT: store i8* [[TMP15]], i8** [[TMP16]], align 8
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK: sort.in_bounds-true:
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = mul i64 [[TMP10]], 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user