From 5f1da7ab8c5ff813da9a907a01125b35ada94399 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Fri, 13 Nov 2020 15:02:01 -0800 Subject: [PATCH] [XLA/GPU] Roll-forward previous CL that migrates Copy emitters to take LMHLO. PiperOrigin-RevId: 342345053 Change-Id: I94849d62cbd84e4b43b789bf5e5925159b067a28 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/ir_emitter_unnested.cc | 272 ++++++++++++------ .../xla/service/gpu/ir_emitter_unnested.h | 19 +- .../xla/service/gpu/tests/copy-nested.hlo | 22 +- .../compiler/xla/service/gpu/tests/copy.hlo | 8 +- 5 files changed, 217 insertions(+), 105 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 05535c0dbe9..bf2fbfe1973 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -268,6 +268,7 @@ cc_library( "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/mlir/xla:hlo_module_importer", "//tensorflow/compiler/mlir/xla:hlo_utils", "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 5857fac93d2..cf31f1ca1f2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/utils/name_utils.h" +#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" @@ -546,6 +547,39 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { return IrEmitter::DefaultAction(hlo); } +Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) { + // Replace unnested op with a fused nested op. + // + // TODO(timshen): Ultimately this should be a pass. It's currently not a pass, + // because we don't have a fully functioning LMHLO graph yet. + + mlir::Location loc = input.op->getLoc(); + mlir::lmhlo::FusionOp fusion = nullptr; + Shape output_shape; + if (auto copy = mlir::dyn_cast(input.op)) { + fusion = mlir::OpBuilder(copy).create( + loc, llvm::ArrayRef()); + copy.getOperation()->moveBefore(&fusion.region().front().back()); + mlir::OpBuilder b(copy); + auto operand = b.create(loc, copy.operand()); + HloFunctionImporter::SetLayoutForMlir( + operand, TypeToShape(copy.operand().getType())); + auto fused_copy = b.create(loc, operand); + output_shape = TypeToShape(copy.output().getType()); + HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape); + b.create(loc, fused_copy, copy.output()); + copy.getOperation()->erase(); + } else { + input.op->dump(); + LOG(FATAL) << "Unimplemented default action for mlir op"; + } + input.op = fusion; + auto ret = EmitLoopFusionFromMlir( + input, output_shape, + ComputeMaxUnrollFactor(output_shape, hlo_module_config_)); + return ret; +} + Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { AddThunkToThunkSequence( BuildKernelThunk(dot, /*implements_whole_instruction=*/true)); @@ -1150,7 +1184,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop) << ": " << fusion->ToString(); - if (CheckAndEmitHloWithTile021(fusion)) { + TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion)); + TF_ASSIGN_OR_RETURN(const bool matched_021, + CheckAndEmitHloWithTile021(input)); + if (matched_021) { return Status::OK(); } @@ -1159,35 +1196,46 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { unroll_factor = ComputeMaxUnrollFactor(fusion); } - TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion)); return EmitLoopFusionFromMlir(input, fusion->shape(), unroll_factor); } Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape())); - const BufferAssignment& buffer_assignment = - ir_emitter_context_->buffer_assignment(); - if (LayoutUtil::Equal(copy->operand(0)->shape().layout(), - copy->shape().layout()) && - buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) { + TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(copy)); + return EmitCopyForMlir(input); +} + +Status IrEmitterUnnested::EmitCopyForMlir(MlirEmitterInput input) { + auto copy = mlir::cast(input.op); + auto operand_shape = TypeToShape(copy.operand().getType()); + auto output_shape = TypeToShape(copy.output().getType()); + + CHECK(ShapeUtil::Compatible(operand_shape, output_shape)); + absl::Span allocations( + ir_emitter_context_->buffer_assignment().Allocations()); + + auto maybe_slice = GetAllocationSliceForMlir(copy.operand(), allocations); + if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) && + maybe_slice.ok()) { // Copy the operand into the output if it's not the same buffer already. - auto operand_buffer = GetAllocationSlice(*copy->operand(0)); - auto destination_buffer = GetAllocationSlice(*copy); + auto operand_buffer = *maybe_slice; + auto destination_buffer = + *GetAllocationSliceForMlir(copy.output(), allocations); if (operand_buffer != destination_buffer) { AddThunkToThunkSequence(absl::make_unique( - GetThunkInfo(copy), + input.thunk_info, /*source_address=*/operand_buffer, /*destination_buffer=*/destination_buffer, /*mem_size=*/ - ByteSizeOf(copy->operand(0)->shape()))); + ByteSizeOf(operand_shape))); } return Status::OK(); } - if (CheckAndEmitHloWithTile021(copy)) { + TF_ASSIGN_OR_RETURN(bool matched_021, CheckAndEmitHloWithTile021(input)); + if (matched_021) { return Status::OK(); } - return IrEmitter::HandleCopy(copy); + return DefaultActionForMlir(input); } Status IrEmitterUnnested::EmitExtraOutputsForReduce( @@ -2504,11 +2552,15 @@ static void GetFusionOperandsAndOutputs(mlir::lmhlo::FusionOp fusion, std::vector* operands, std::vector* outputs) { fusion.region().walk([&](mlir::TensorLoadOp load) { - CHECK(load.memref().getParentRegion() != &fusion.region()); + CHECK(load.memref().getParentRegion() != &fusion.region()) + << "TensorLoadOp shows should be only expected for accessing captured " + "memrefs."; operands->push_back(load.memref()); }); fusion.region().walk([&](mlir::TensorStoreOp store) { - CHECK(store.memref().getParentRegion() != &fusion.region()); + CHECK(store.memref().getParentRegion() != &fusion.region()) + << "TensorStoreOp shows should be only expected for accessing captured " + "memrefs."; outputs->push_back(store.memref()); }); } @@ -3118,16 +3170,16 @@ void IrEmitterUnnested::EmitTile( // dimensions that play the same role in the transpose. // mapping_scheme: Kernel mapping scheme specifying the tiling void IrEmitterUnnested::EmitTileElementForCopy( - HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const Shape& output_shape, const llvm_ir::IrArray& output_array, + const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, llvm::Value* x_loc, absl::Span param_shmem_buffers) { // TODO(jlebar): Add AA metadata to this load. llvm::Instruction* load_from_shmem_buffer = Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}), "output_element"); - llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo); Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( - hlo->shape().element_type(), mapping_scheme.GetDimsInElems()); + output_shape.element_type(), mapping_scheme.GetDimsInElems()); // When the output_reduced_shape is a 0-2-1 transpose of the input shape, // the 0-2-1 transpose is achieved through EmitWriteArrayElement. output_array.CastToShape(output_reduced_shape, &b_) @@ -3165,14 +3217,19 @@ static IrArray::Index GetUnnormalizedIndex( // the same role in the transpose. // kernel_info: Other information to support the kernel code generation. void IrEmitterUnnested::EmitTileElementForFusion( - HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + mlir::lmhlo::FusionOp fusion, + absl::Span operand_arrays, + absl::Span output_arrays, + const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, llvm::Value* x_loc, absl::Span param_shmem_buffers) { - std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); + const HloComputation* fused_computation = + *GetOrCreateSubComputationFromRegion(&fusion.region(), + /*is_fusion*/ true); GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); FusedIrEmitter fused_emitter(&elem_emitter); - for (int i = 0; i < hlo->operand_count(); i++) { + for (int i = 0; i < operand_arrays.size(); i++) { llvm_ir::ElementGenerator gen; if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) { gen = [this, param_tile_buffer, x_loc, @@ -3189,23 +3246,24 @@ void IrEmitterUnnested::EmitTileElementForFusion( "tiled_buffer"); }; } else { - const HloInstruction* operand = hlo->operand(i); - gen = [this, operand, hlo](llvm_ir::IrArray::Index index) { - return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_); + auto array = operand_arrays[i]; + gen = [this, array](llvm_ir::IrArray::Index index) { + return array.EmitReadArrayElement(index, &b_); }; } - fused_emitter.BindGenerator(hlo->fused_parameter(i), std::move(gen)); + fused_emitter.BindGenerator(fused_computation->parameter_instruction(i), + std::move(gen)); } IrArray::Index untiled_index = GetUnnormalizedIndex( index, output_arrays[0].GetShape(), &b_, mapping_scheme); llvm_ir::ElementGenerator output_generator = - *fused_emitter.GetGenerator(hlo->fused_expression_root()); + *fused_emitter.GetGenerator(fused_computation->root_instruction()); llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); - if (hlo->IsMultiOutputFusion()) { + if (output_arrays.size() > 1) { DCHECK(output_value->getType()->isStructTy()); DCHECK_EQ(output_value->getType()->getStructNumElements(), - output_arrays.size()); - for (int64 i = 0; i < output_arrays.size(); ++i) { + output_arrays.size() - 1); + for (int64 i = 0; i < output_arrays.size() - 1; ++i) { output_arrays[i].EmitWriteArrayElement( untiled_index, ExtractValue(output_value, i), &b_); } @@ -3809,10 +3867,15 @@ llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() { // TODO(b/33320379): Here each block transposes 1 tile. It may be more // efficient to launch fewer blocks so each transposes many tiles. void IrEmitterUnnested::EmitHlo021Tile( - HloInstruction* hlo, Thunk* kernel_thunk, + mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context, + absl::Span operand_arrays, + absl::Span output_arrays, absl::Span reduced_output_dims, absl::Span tiled_param_ids) { constexpr int kNumRows = 4; + + std::string name = mlir::GetNameFromLoc(op->getLoc()); + KernelMappingScheme mapping_scheme(reduced_output_dims, /*tile_sizes=*/{1, kWarpSize, kWarpSize}, /*num_threads_y=*/kNumRows, @@ -3822,14 +3885,16 @@ void IrEmitterUnnested::EmitHlo021Tile( /*is_row_contiguous=*/false); LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), mapping_scheme.GetThreadsPerBlock()); + llvm::Type* index_type = - GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_); + GetIndexTypeForKernelFromMlir(op, launch_dimensions.launch_bound(), &b_); std::vector param_arrays; // For each tiled parameter, cast its input IrArray to the corresponding // reduced shape and keep the reduced shape live during IR emission. std::vector param_in_reduced_shape_arrays; - std::vector param_shmem_buffers(hlo->operand_count(), nullptr); + std::vector param_shmem_buffers(context.operand_shapes.size(), + nullptr); auto get_shared_memory_buffer = [&](llvm::Type* elem_ty, absl::string_view buffer_name) { @@ -3846,20 +3911,18 @@ void IrEmitterUnnested::EmitHlo021Tile( buffer_type, buffer_name); }; - for (int64 id = 0; id < hlo->operand_count(); id++) { - const HloInstruction* param = hlo->operand(id); - param_arrays.push_back(GetIrArray(*param, *hlo)); + for (int64 id = 0; id < context.operand_shapes.size(); id++) { + const Shape& param_shape = context.operand_shapes[id]; + param_arrays.push_back(operand_arrays[id]); if (absl::c_linear_search(tiled_param_ids, id)) { - param_shmem_buffers[id] = - get_shared_memory_buffer(llvm_ir::PrimitiveTypeToIrType( - param->shape().element_type(), module_), - IrName(hlo, StrCat("tile", id))); + param_shmem_buffers[id] = get_shared_memory_buffer( + llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_), + IrName(name, StrCat("tile", id))); VLOG(3) << "Added shmem buffer for parameter " << id << ": " << llvm_ir::DumpToString(*param_shmem_buffers[id]); Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( - param->shape().element_type(), - Permute({0, 2, 1}, reduced_output_dims)); + param_shape.element_type(), Permute({0, 2, 1}, reduced_output_dims)); param_in_reduced_shape_arrays.push_back( param_arrays[id].CastToShape(reduced_shape, &b_)); } else { @@ -3870,13 +3933,18 @@ void IrEmitterUnnested::EmitHlo021Tile( EmitElementFunction element_generator = [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 x_iter_num) { - if (hlo->opcode() == HloOpcode::kCopy) { - EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc, + if (auto copy = mlir::dyn_cast(op)) { + CHECK_EQ(1, context.output_shapes.size()); + EmitTileElementForCopy(context.output_shapes[0], output_arrays[0], + index, mapping_scheme, y_loc, x_loc, param_shmem_buffers); - } else { - CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); - EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc, + } else if (auto fusion = mlir::dyn_cast(op)) { + EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index, + mapping_scheme, y_loc, x_loc, param_shmem_buffers); + } else { + op->dump(); + LOG(FATAL) << "Unexpected op type"; } }; @@ -3903,9 +3971,9 @@ void IrEmitterUnnested::EmitHlo021Tile( llvm::Value* x_loc, int64 /*x_iter_num*/) { for (int64 id : tiled_param_ids) { IrArray& input_in_logical_shape = - param_in_reduced_shape_arrays[id]; + param_in_reduced_shape_arrays.at(id); - llvm::Value* shmem_buffer = param_shmem_buffers[id]; + llvm::Value* shmem_buffer = param_shmem_buffers.at(id); llvm::Value* zero = llvm::ConstantInt::get(index_type, 0); // TODO(jlebar): Add AA metadata to this store. Tile @@ -3942,10 +4010,11 @@ void IrEmitterUnnested::EmitHlo021Tile( // the hopes of reducing register pressure, since we touch // threadIdx.x and blockIdx.x at the beginning of the kernel // *anyway*. - if (hlo->IsMultiOutputFusion()) { + if (output_arrays.size() > 1) { KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), - ConstructIrArrayForOutputs(*hlo), &b_); + llvm_ir::EmitTuple(output_arrays.back(), + output_arrays.subspan(0, output_arrays.size() - 1), + &b_); }); } @@ -3955,6 +4024,7 @@ void IrEmitterUnnested::EmitHlo021Tile( } namespace { + // A recursive function to inspect the users of a parameter to determine // whether it's safe for a parameter to participate in a shared-memory // transpose. @@ -3993,14 +4063,29 @@ namespace { // a reduce operations. In this case, the above description on "output" apply // to the result of such a use-chain, which provides the input to the reduce // operation. -bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) { - if (hlo->IsElementwise()) { - return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { - return IsInstructionSafeForShmemTranspose(user); - }); +bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) { + if (mlir::isa(op)) { + return true; } - switch (hlo->opcode()) { + HloOpcode opcode; + if (mlir::isa(op)) { + opcode = HloOpcode::kParameter; + } else { + opcode = *mlir::MhloToHloOpcode(op); + } + if (HloInstruction::IsOpElementwise(opcode)) { + for (mlir::Value v : op->getResults()) { + for (mlir::OpOperand use : v.getUsers()) { + if (!IsInstructionSafeForShmemTranspose(use.getOwner())) { + return false; + } + } + } + return true; + } + + switch (opcode) { // Non-elementwise instructions that don't cause the shmem transpose // to be unsafe, including the instructions that don't currently fuse. case HloOpcode::kGetDimensionSize: @@ -4012,9 +4097,14 @@ bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) { case HloOpcode::kParameter: case HloOpcode::kTuple: case HloOpcode::kTupleSelect: - return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { - return IsInstructionSafeForShmemTranspose(user); - }); + for (mlir::Value v : op->getResults()) { + for (mlir::OpOperand use : v.getUsers()) { + if (!IsInstructionSafeForShmemTranspose(use.getOwner())) { + return false; + } + } + } + return true; default: return false; @@ -4033,15 +4123,21 @@ bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) { // preloaded tile. We inspect all the transitive users of the input parameter // up to the fusion root instruction to see if we can find any instruction // that can make preloading the input tile unsafe. -std::vector FilterInputsForShmemTranspose(const HloInstruction* fusion, +std::vector FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion, std::vector input_ids) { + std::vector params; + fusion.region().walk([&](mlir::TensorLoadOp load) { + CHECK(load.memref().getParentRegion() != &fusion.region()) + << "TensorLoadOp shows should be only expected for accessing captured " + "memrefs."; + params.push_back(load); + }); + std::vector filtered_input_ids; - for (int64 i = 0; i < input_ids.size(); ++i) { - const HloInstruction* input = fusion->fused_parameter(input_ids[i]); - if (IsInstructionSafeForShmemTranspose(input)) { - filtered_input_ids.push_back(input_ids[i]); - } else { - VLOG(10) << "Input not safe for shmem transpose " << input->ToString(); + for (int64 input_id : input_ids) { + mlir::Value input = params.at(input_id); + if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) { + filtered_input_ids.push_back(input_id); } } return filtered_input_ids; @@ -4049,24 +4145,23 @@ std::vector FilterInputsForShmemTranspose(const HloInstruction* fusion, } // namespace -bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { - HloOpcode opcode = hlo->opcode(); +StatusOr IrEmitterUnnested::CheckAndEmitHloWithTile021( + MlirEmitterInput input) { + CHECK(mlir::isa(input.op) || + mlir::isa(input.op)); - CHECK(hlo->IsLoopFusion() || opcode == HloOpcode::kCopy); - - const Shape& output_shape = hlo->IsMultiOutputFusion() - ? ShapeUtil::GetSubshape(hlo->shape(), {0}) - : hlo->shape(); + MlirEmitterContext context; + context.SetOperation(input.op); // If the output_shape is reduced to 021 shape, find all the parameters of // the HLO that are in the corresponding 012 shape. std::vector params_012; optional> reduced_dims_021; - for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); + for (int64 operand_idx = 0; operand_idx < context.operand_shapes.size(); ++operand_idx) { - HloInstruction* operand = hlo->mutable_operand(operand_idx); + const Shape& operand_shape = context.operand_shapes[operand_idx]; auto find_transpose_result = - ShapeUtil::FindTranspose021(operand->shape(), output_shape); + ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]); if (!find_transpose_result.has_value()) { continue; } @@ -4091,8 +4186,8 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } - if (opcode == HloOpcode::kFusion) { - params_012 = FilterInputsForShmemTranspose(hlo, params_012); + if (auto fusion_op = mlir::dyn_cast(input.op)) { + params_012 = FilterInputsForShmemTranspose(fusion_op, params_012); if (params_012.empty()) { return false; } @@ -4120,10 +4215,10 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { constexpr int64 kShmemPerCore = 48 * 1024; int64 shmem_used = 0; for (int64 i = 0; i < params_012.size(); ++i) { - const HloInstruction* operand = hlo->operand(params_012[i]); + const Shape& operand_shape = context.operand_shapes[params_012[i]]; shmem_used += 32 * 33 * - ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type()); + ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type()); if (kMinBlocksPerCore * shmem_used > kShmemPerCore) { // Erase this element and everything after it from params_012. @@ -4136,10 +4231,15 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } - VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString(); - std::unique_ptr kernel_thunk = - BuildKernelThunk(hlo, /*implements_whole_instruction=*/true); - EmitHlo021Tile(hlo, kernel_thunk.get(), *reduced_dims_021, params_012); + std::vector ir_arrays; + TF_ASSIGN_OR_RETURN(std::unique_ptr kernel_thunk, + BuildKernelThunkForMlir(input.op, input.thunk_info, + input.extra_slice, &ir_arrays)); + EmitHlo021Tile( + input.op, kernel_thunk.get(), context, + absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()), + absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()), + *reduced_dims_021, params_012); AddThunkToThunkSequence(std::move(kernel_thunk)); return true; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index bb875e90450..a5089111150 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -73,6 +73,7 @@ struct MlirEmitterInput { }; // Convenience struct that contains useful data structures in MLIR emitter. +// Not all fields may be filled. It's entiredly dependent on the uses. struct MlirEmitterContext { void SetOperation(mlir::Operation* op); @@ -156,11 +157,14 @@ class IrEmitterUnnested : public IrEmitter, } Status DefaultAction(HloInstruction* hlo) override; + Status DefaultActionForMlir(MlirEmitterInput input); // IrEmitterUnnested handles the following instructions differently from // IrEmitter. It also mixes in some special handling for custom kernels // via the ThunkEmitter. Status HandleCopy(HloInstruction* copy) override; + Status EmitCopyForMlir(MlirEmitterInput input); + Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleCustomCall(HloInstruction* custom_call) override; @@ -467,12 +471,15 @@ class IrEmitterUnnested : public IrEmitter, // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel // for the hlo instruction. - bool CheckAndEmitHloWithTile021(HloInstruction* hlo); + StatusOr CheckAndEmitHloWithTile021(MlirEmitterInput input); // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and // sets the corresponding launch dimensions. This is a helper to support // the implementation of CheckAndEmitHloWithTile021. - void EmitHlo021Tile(HloInstruction* hlo, Thunk* kernel_thunk, + void EmitHlo021Tile(mlir::Operation* op, Thunk* kernel_thunk, + const MlirEmitterContext& context, + absl::Span operand_arrays, + absl::Span output_arrays, absl::Span reduced_output_dims, absl::Span tiled_param_ids); @@ -527,7 +534,8 @@ class IrEmitterUnnested : public IrEmitter, // y_loc: The y coordinate within a tile. // x_loc: The x coordinate within a tile. void EmitTileElementForCopy( - HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const Shape& output_shape, const llvm_ir::IrArray& ir_array, + const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, llvm::Value* x_loc, absl::Span param_shmem_buffers); @@ -536,7 +544,10 @@ class IrEmitterUnnested : public IrEmitter, // y_loc: The y coordinate within a tile. // x_loc: The x coordinate within a tile. void EmitTileElementForFusion( - HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + mlir::lmhlo::FusionOp fusion, + absl::Span operand_arrays, + absl::Span output_arrays, + const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, llvm::Value* x_loc, absl::Span param_shmem_buffers); diff --git a/tensorflow/compiler/xla/service/gpu/tests/copy-nested.hlo b/tensorflow/compiler/xla/service/gpu/tests/copy-nested.hlo index 4365d4a800a..483175c3a80 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/copy-nested.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/copy-nested.hlo @@ -2,12 +2,12 @@ // CHECK-LABEL: entry: // CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 -// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [200 x [100 x [300 x float]]]* +// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [100 x [200 x [300 x float]]]* // CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0 -// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [100 x [200 x [300 x float]]]* +// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [200 x [100 x [300 x float]]]* // CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 -// CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 128 +// CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 256 // CHECK: %[[VAL_9:.*]] = add nuw nsw i32 %[[VAL_8]], %[[VAL_7]] // CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], {{.*}} // CHECK: call void @llvm.assume(i1 %[[VAL_10]]) @@ -40,24 +40,24 @@ // CHECK: b.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]] // CHECK: ret void // CHECK: b.in_bounds-true: ; preds = %[[VAL_38]] -// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]] +// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]] // CHECK: %[[VAL_40:.*]] = load float, float* %[[VAL_39]], align 4, !invariant.load !4 -// CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float* +// CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* // CHECK: %[[VAL_42:.*]] = getelementptr inbounds float, float* %[[VAL_41]], i32 %[[VAL_11]] // CHECK: store float %[[VAL_40]], float* %[[VAL_42]], align 4 -// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]] +// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]] // CHECK: %[[VAL_44:.*]] = load float, float* %[[VAL_43]], align 4, !invariant.load !4 -// CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float* +// CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* // CHECK: %[[VAL_46:.*]] = getelementptr inbounds float, float* %[[VAL_45]], i32 %[[VAL_17]] // CHECK: store float %[[VAL_44]], float* %[[VAL_46]], align 4 -// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]] +// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]] // CHECK: %[[VAL_48:.*]] = load float, float* %[[VAL_47]], align 4, !invariant.load !4 -// CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float* +// CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* // CHECK: %[[VAL_50:.*]] = getelementptr inbounds float, float* %[[VAL_49]], i32 %[[VAL_23]] // CHECK: store float %[[VAL_48]], float* %[[VAL_50]], align 4 -// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]] +// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]] // CHECK: %[[VAL_52:.*]] = load float, float* %[[VAL_51]], align 4, !invariant.load !4 -// CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float* +// CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* // CHECK: %[[VAL_54:.*]] = getelementptr inbounds float, float* %[[VAL_53]], i32 %[[VAL_29]] // CHECK: store float %[[VAL_52]], float* %[[VAL_54]], align 4 // CHECK: br label %[[VAL_37]] diff --git a/tensorflow/compiler/xla/service/gpu/tests/copy.hlo b/tensorflow/compiler/xla/service/gpu/tests/copy.hlo index 594b9a08d43..128378a9130 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/copy.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/copy.hlo @@ -4,10 +4,10 @@ // CHECK: %[[VAL_0:.*]] = alloca i32, align 4 // CHECK: %[[VAL_1:.*]] = alloca i32, align 4 // CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_3:.*]], i64 0 -// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [200 x [100 x float]]* +// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [100 x [200 x float]]* // CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0 -// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [100 x [200 x float]]* -// CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_7]] to [1 x [100 x [200 x float]]]* +// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [200 x [100 x float]]* +// CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_4]] to [1 x [100 x [200 x float]]]* // CHECK: %[[VAL_9:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 // CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 32 // CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_9]], 32 @@ -88,7 +88,7 @@ // CHECK: output_x_in_tile-true: ; preds = %[[VAL_59]] // CHECK: %[[VAL_72:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i64 0, i32 %[[VAL_65]], i32 %[[VAL_63]] // CHECK: %[[VAL_73:.*]] = load float, float addrspace(3)* %[[VAL_72]], align 4 -// CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_4]] to [1 x [200 x [100 x float]]]* +// CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_7]] to [1 x [200 x [100 x float]]]* // CHECK: %[[VAL_75:.*]] = getelementptr inbounds [1 x [200 x [100 x float]]], [1 x [200 x [100 x float]]]* %[[VAL_74]], i32 0, i32 0, i32 %[[VAL_64]], i32 %[[VAL_66]] // CHECK: store float %[[VAL_73]], float* %[[VAL_75]], align 4 // CHECK: br label %[[VAL_55]]