diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index bf2fbfe1973..05535c0dbe9 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -268,7 +268,6 @@ cc_library( "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", - "//tensorflow/compiler/mlir/xla:hlo_module_importer", "//tensorflow/compiler/mlir/xla:hlo_utils", "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 71bb6f99061..5857fac93d2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -44,7 +44,6 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/utils/name_utils.h" -#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" @@ -547,39 +546,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { return IrEmitter::DefaultAction(hlo); } -Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) { - // Replace unnested op with a fused nested op. - // - // TODO(timshen): Ultimately this should be a pass. It's currently not a pass, - // because we don't have a fully functioning LMHLO graph yet. - - mlir::Location loc = input.op->getLoc(); - mlir::lmhlo::FusionOp fusion = nullptr; - Shape output_shape; - if (auto copy = mlir::dyn_cast(input.op)) { - fusion = mlir::OpBuilder(copy).create( - loc, llvm::ArrayRef()); - copy.getOperation()->moveBefore(&fusion.region().front().back()); - mlir::OpBuilder b(copy); - auto operand = b.create(loc, copy.operand()); - HloFunctionImporter::SetLayoutForMlir( - operand, TypeToShape(copy.operand().getType())); - auto fused_copy = b.create(loc, operand); - output_shape = TypeToShape(copy.output().getType()); - HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape); - b.create(loc, fused_copy, copy.output()); - copy.getOperation()->erase(); - } else { - input.op->dump(); - LOG(FATAL) << "Unimplemented default action for mlir op"; - } - input.op = fusion; - auto ret = EmitLoopFusionFromMlir( - input, output_shape, - ComputeMaxUnrollFactor(output_shape, hlo_module_config_)); - return ret; -} - Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { AddThunkToThunkSequence( BuildKernelThunk(dot, /*implements_whole_instruction=*/true)); @@ -1184,10 +1150,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop) << ": " << fusion->ToString(); - TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion)); - TF_ASSIGN_OR_RETURN(const bool matched_021, - CheckAndEmitHloWithTile021(input)); - if (matched_021) { + if (CheckAndEmitHloWithTile021(fusion)) { return Status::OK(); } @@ -1196,46 +1159,35 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { unroll_factor = ComputeMaxUnrollFactor(fusion); } + TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion)); return EmitLoopFusionFromMlir(input, fusion->shape(), unroll_factor); } Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(copy)); - return EmitCopyForMlir(input); -} - -Status IrEmitterUnnested::EmitCopyForMlir(MlirEmitterInput input) { - auto copy = mlir::cast(input.op); - auto operand_shape = TypeToShape(copy.operand().getType()); - auto output_shape = TypeToShape(copy.output().getType()); - - CHECK(ShapeUtil::Compatible(operand_shape, output_shape)); - absl::Span allocations( - ir_emitter_context_->buffer_assignment().Allocations()); - - auto maybe_slice = GetAllocationSliceForMlir(copy.operand(), allocations); - if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) && - maybe_slice.ok()) { + CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape())); + const BufferAssignment& buffer_assignment = + ir_emitter_context_->buffer_assignment(); + if (LayoutUtil::Equal(copy->operand(0)->shape().layout(), + copy->shape().layout()) && + buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) { // Copy the operand into the output if it's not the same buffer already. - auto operand_buffer = *maybe_slice; - auto destination_buffer = - *GetAllocationSliceForMlir(copy.output(), allocations); + auto operand_buffer = GetAllocationSlice(*copy->operand(0)); + auto destination_buffer = GetAllocationSlice(*copy); if (operand_buffer != destination_buffer) { AddThunkToThunkSequence(absl::make_unique( - input.thunk_info, + GetThunkInfo(copy), /*source_address=*/operand_buffer, /*destination_buffer=*/destination_buffer, /*mem_size=*/ - ByteSizeOf(operand_shape))); + ByteSizeOf(copy->operand(0)->shape()))); } return Status::OK(); } - TF_ASSIGN_OR_RETURN(bool matched_021, CheckAndEmitHloWithTile021(input)); - if (matched_021) { + if (CheckAndEmitHloWithTile021(copy)) { return Status::OK(); } - return DefaultActionForMlir(input); + return IrEmitter::HandleCopy(copy); } Status IrEmitterUnnested::EmitExtraOutputsForReduce( @@ -2552,15 +2504,11 @@ static void GetFusionOperandsAndOutputs(mlir::lmhlo::FusionOp fusion, std::vector* operands, std::vector* outputs) { fusion.region().walk([&](mlir::TensorLoadOp load) { - CHECK(load.memref().getParentRegion() != &fusion.region()) - << "TensorLoadOp shows should be only expected for accessing captured " - "memrefs."; + CHECK(load.memref().getParentRegion() != &fusion.region()); operands->push_back(load.memref()); }); fusion.region().walk([&](mlir::TensorStoreOp store) { - CHECK(store.memref().getParentRegion() != &fusion.region()) - << "TensorStoreOp shows should be only expected for accessing captured " - "memrefs."; + CHECK(store.memref().getParentRegion() != &fusion.region()); outputs->push_back(store.memref()); }); } @@ -3170,16 +3118,16 @@ void IrEmitterUnnested::EmitTile( // dimensions that play the same role in the transpose. // mapping_scheme: Kernel mapping scheme specifying the tiling void IrEmitterUnnested::EmitTileElementForCopy( - const Shape& output_shape, const llvm_ir::IrArray& output_array, - const llvm_ir::IrArray::Index& index, + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, llvm::Value* x_loc, absl::Span param_shmem_buffers) { // TODO(jlebar): Add AA metadata to this load. llvm::Instruction* load_from_shmem_buffer = Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}), "output_element"); + llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo); Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( - output_shape.element_type(), mapping_scheme.GetDimsInElems()); + hlo->shape().element_type(), mapping_scheme.GetDimsInElems()); // When the output_reduced_shape is a 0-2-1 transpose of the input shape, // the 0-2-1 transpose is achieved through EmitWriteArrayElement. output_array.CastToShape(output_reduced_shape, &b_) @@ -3217,19 +3165,14 @@ static IrArray::Index GetUnnormalizedIndex( // the same role in the transpose. // kernel_info: Other information to support the kernel code generation. void IrEmitterUnnested::EmitTileElementForFusion( - mlir::lmhlo::FusionOp fusion, - absl::Span operand_arrays, - absl::Span output_arrays, - const llvm_ir::IrArray::Index& index, + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, llvm::Value* x_loc, absl::Span param_shmem_buffers) { - const HloComputation* fused_computation = - *GetOrCreateSubComputationFromRegion(&fusion.region(), - /*is_fusion*/ true); + std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); FusedIrEmitter fused_emitter(&elem_emitter); - for (int i = 0; i < operand_arrays.size(); i++) { + for (int i = 0; i < hlo->operand_count(); i++) { llvm_ir::ElementGenerator gen; if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) { gen = [this, param_tile_buffer, x_loc, @@ -3246,24 +3189,23 @@ void IrEmitterUnnested::EmitTileElementForFusion( "tiled_buffer"); }; } else { - auto array = operand_arrays[i]; - gen = [this, array](llvm_ir::IrArray::Index index) { - return array.EmitReadArrayElement(index, &b_); + const HloInstruction* operand = hlo->operand(i); + gen = [this, operand, hlo](llvm_ir::IrArray::Index index) { + return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_); }; } - fused_emitter.BindGenerator(fused_computation->parameter_instruction(i), - std::move(gen)); + fused_emitter.BindGenerator(hlo->fused_parameter(i), std::move(gen)); } IrArray::Index untiled_index = GetUnnormalizedIndex( index, output_arrays[0].GetShape(), &b_, mapping_scheme); llvm_ir::ElementGenerator output_generator = - *fused_emitter.GetGenerator(fused_computation->root_instruction()); + *fused_emitter.GetGenerator(hlo->fused_expression_root()); llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); - if (output_arrays.size() > 1) { + if (hlo->IsMultiOutputFusion()) { DCHECK(output_value->getType()->isStructTy()); DCHECK_EQ(output_value->getType()->getStructNumElements(), output_arrays.size()); - for (int64 i = 0; i < output_arrays.size() - 1; ++i) { + for (int64 i = 0; i < output_arrays.size(); ++i) { output_arrays[i].EmitWriteArrayElement( untiled_index, ExtractValue(output_value, i), &b_); } @@ -3867,15 +3809,10 @@ llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() { // TODO(b/33320379): Here each block transposes 1 tile. It may be more // efficient to launch fewer blocks so each transposes many tiles. void IrEmitterUnnested::EmitHlo021Tile( - mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context, - absl::Span operand_arrays, - absl::Span output_arrays, + HloInstruction* hlo, Thunk* kernel_thunk, absl::Span reduced_output_dims, absl::Span tiled_param_ids) { constexpr int kNumRows = 4; - - std::string name = mlir::GetNameFromLoc(op->getLoc()); - KernelMappingScheme mapping_scheme(reduced_output_dims, /*tile_sizes=*/{1, kWarpSize, kWarpSize}, /*num_threads_y=*/kNumRows, @@ -3885,16 +3822,14 @@ void IrEmitterUnnested::EmitHlo021Tile( /*is_row_contiguous=*/false); LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), mapping_scheme.GetThreadsPerBlock()); - llvm::Type* index_type = - GetIndexTypeForKernelFromMlir(op, launch_dimensions.launch_bound(), &b_); + GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_); std::vector param_arrays; // For each tiled parameter, cast its input IrArray to the corresponding // reduced shape and keep the reduced shape live during IR emission. std::vector param_in_reduced_shape_arrays; - std::vector param_shmem_buffers(context.operand_shapes.size(), - nullptr); + std::vector param_shmem_buffers(hlo->operand_count(), nullptr); auto get_shared_memory_buffer = [&](llvm::Type* elem_ty, absl::string_view buffer_name) { @@ -3911,18 +3846,20 @@ void IrEmitterUnnested::EmitHlo021Tile( buffer_type, buffer_name); }; - for (int64 id = 0; id < context.operand_shapes.size(); id++) { - const Shape& param_shape = context.operand_shapes[id]; - param_arrays.push_back(operand_arrays[id]); + for (int64 id = 0; id < hlo->operand_count(); id++) { + const HloInstruction* param = hlo->operand(id); + param_arrays.push_back(GetIrArray(*param, *hlo)); if (absl::c_linear_search(tiled_param_ids, id)) { - param_shmem_buffers[id] = get_shared_memory_buffer( - llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_), - IrName(name, StrCat("tile", id))); + param_shmem_buffers[id] = + get_shared_memory_buffer(llvm_ir::PrimitiveTypeToIrType( + param->shape().element_type(), module_), + IrName(hlo, StrCat("tile", id))); VLOG(3) << "Added shmem buffer for parameter " << id << ": " << llvm_ir::DumpToString(*param_shmem_buffers[id]); Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( - param_shape.element_type(), Permute({0, 2, 1}, reduced_output_dims)); + param->shape().element_type(), + Permute({0, 2, 1}, reduced_output_dims)); param_in_reduced_shape_arrays.push_back( param_arrays[id].CastToShape(reduced_shape, &b_)); } else { @@ -3933,18 +3870,13 @@ void IrEmitterUnnested::EmitHlo021Tile( EmitElementFunction element_generator = [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 x_iter_num) { - if (auto copy = mlir::dyn_cast(op)) { - CHECK_EQ(1, context.output_shapes.size()); - EmitTileElementForCopy(context.output_shapes[0], output_arrays[0], - index, mapping_scheme, y_loc, x_loc, + if (hlo->opcode() == HloOpcode::kCopy) { + EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc, param_shmem_buffers); - } else if (auto fusion = mlir::dyn_cast(op)) { - EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index, - mapping_scheme, y_loc, x_loc, - param_shmem_buffers); } else { - op->dump(); - LOG(FATAL) << "Unexpected op type"; + CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc, + param_shmem_buffers); } }; @@ -3971,9 +3903,9 @@ void IrEmitterUnnested::EmitHlo021Tile( llvm::Value* x_loc, int64 /*x_iter_num*/) { for (int64 id : tiled_param_ids) { IrArray& input_in_logical_shape = - param_in_reduced_shape_arrays.at(id); + param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers.at(id); + llvm::Value* shmem_buffer = param_shmem_buffers[id]; llvm::Value* zero = llvm::ConstantInt::get(index_type, 0); // TODO(jlebar): Add AA metadata to this store. Tile @@ -4010,11 +3942,10 @@ void IrEmitterUnnested::EmitHlo021Tile( // the hopes of reducing register pressure, since we touch // threadIdx.x and blockIdx.x at the beginning of the kernel // *anyway*. - if (output_arrays.size() > 1) { + if (hlo->IsMultiOutputFusion()) { KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(output_arrays.back(), - output_arrays.subspan(0, output_arrays.size() - 1), - &b_); + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), + ConstructIrArrayForOutputs(*hlo), &b_); }); } @@ -4024,7 +3955,6 @@ void IrEmitterUnnested::EmitHlo021Tile( } namespace { - // A recursive function to inspect the users of a parameter to determine // whether it's safe for a parameter to participate in a shared-memory // transpose. @@ -4063,29 +3993,14 @@ namespace { // a reduce operations. In this case, the above description on "output" apply // to the result of such a use-chain, which provides the input to the reduce // operation. -bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) { - if (mlir::isa(op)) { - return true; +bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) { + if (hlo->IsElementwise()) { + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); } - HloOpcode opcode; - if (mlir::isa(op)) { - opcode = HloOpcode::kParameter; - } else { - opcode = *mlir::MhloToHloOpcode(op); - } - if (HloInstruction::IsOpElementwise(opcode)) { - for (mlir::Value v : op->getResults()) { - for (mlir::OpOperand use : v.getUsers()) { - if (!IsInstructionSafeForShmemTranspose(use.getOwner())) { - return false; - } - } - } - return true; - } - - switch (opcode) { + switch (hlo->opcode()) { // Non-elementwise instructions that don't cause the shmem transpose // to be unsafe, including the instructions that don't currently fuse. case HloOpcode::kGetDimensionSize: @@ -4097,14 +4012,9 @@ bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) { case HloOpcode::kParameter: case HloOpcode::kTuple: case HloOpcode::kTupleSelect: - for (mlir::Value v : op->getResults()) { - for (mlir::OpOperand use : v.getUsers()) { - if (!IsInstructionSafeForShmemTranspose(use.getOwner())) { - return false; - } - } - } - return true; + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); default: return false; @@ -4123,21 +4033,15 @@ bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) { // preloaded tile. We inspect all the transitive users of the input parameter // up to the fusion root instruction to see if we can find any instruction // that can make preloading the input tile unsafe. -std::vector FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion, +std::vector FilterInputsForShmemTranspose(const HloInstruction* fusion, std::vector input_ids) { - std::vector params; - fusion.region().walk([&](mlir::TensorLoadOp load) { - CHECK(load.memref().getParentRegion() != &fusion.region()) - << "TensorLoadOp shows should be only expected for accessing captured " - "memrefs."; - params.push_back(load); - }); - std::vector filtered_input_ids; - for (int64 input_id : input_ids) { - mlir::Value input = params.at(input_id); - if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) { - filtered_input_ids.push_back(input_id); + for (int64 i = 0; i < input_ids.size(); ++i) { + const HloInstruction* input = fusion->fused_parameter(input_ids[i]); + if (IsInstructionSafeForShmemTranspose(input)) { + filtered_input_ids.push_back(input_ids[i]); + } else { + VLOG(10) << "Input not safe for shmem transpose " << input->ToString(); } } return filtered_input_ids; @@ -4145,23 +4049,24 @@ std::vector FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion, } // namespace -StatusOr IrEmitterUnnested::CheckAndEmitHloWithTile021( - MlirEmitterInput input) { - CHECK(mlir::isa(input.op) || - mlir::isa(input.op)); +bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { + HloOpcode opcode = hlo->opcode(); - MlirEmitterContext context; - context.SetOperation(input.op); + CHECK(hlo->IsLoopFusion() || opcode == HloOpcode::kCopy); + + const Shape& output_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); // If the output_shape is reduced to 021 shape, find all the parameters of // the HLO that are in the corresponding 012 shape. std::vector params_012; optional> reduced_dims_021; - for (int64 operand_idx = 0; operand_idx < context.operand_shapes.size(); + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { - const Shape& operand_shape = context.operand_shapes[operand_idx]; + HloInstruction* operand = hlo->mutable_operand(operand_idx); auto find_transpose_result = - ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]); + ShapeUtil::FindTranspose021(operand->shape(), output_shape); if (!find_transpose_result.has_value()) { continue; } @@ -4186,8 +4091,8 @@ StatusOr IrEmitterUnnested::CheckAndEmitHloWithTile021( return false; } - if (auto fusion_op = mlir::dyn_cast(input.op)) { - params_012 = FilterInputsForShmemTranspose(fusion_op, params_012); + if (opcode == HloOpcode::kFusion) { + params_012 = FilterInputsForShmemTranspose(hlo, params_012); if (params_012.empty()) { return false; } @@ -4215,10 +4120,10 @@ StatusOr IrEmitterUnnested::CheckAndEmitHloWithTile021( constexpr int64 kShmemPerCore = 48 * 1024; int64 shmem_used = 0; for (int64 i = 0; i < params_012.size(); ++i) { - const Shape& operand_shape = context.operand_shapes[params_012[i]]; + const HloInstruction* operand = hlo->operand(params_012[i]); shmem_used += 32 * 33 * - ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type()); + ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type()); if (kMinBlocksPerCore * shmem_used > kShmemPerCore) { // Erase this element and everything after it from params_012. @@ -4231,15 +4136,10 @@ StatusOr IrEmitterUnnested::CheckAndEmitHloWithTile021( return false; } - std::vector ir_arrays; - TF_ASSIGN_OR_RETURN(std::unique_ptr kernel_thunk, - BuildKernelThunkForMlir(input.op, input.thunk_info, - input.extra_slice, &ir_arrays)); - EmitHlo021Tile( - input.op, kernel_thunk.get(), context, - absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()), - absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()), - *reduced_dims_021, params_012); + VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString(); + std::unique_ptr kernel_thunk = + BuildKernelThunk(hlo, /*implements_whole_instruction=*/true); + EmitHlo021Tile(hlo, kernel_thunk.get(), *reduced_dims_021, params_012); AddThunkToThunkSequence(std::move(kernel_thunk)); return true; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index a5089111150..bb875e90450 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -73,7 +73,6 @@ struct MlirEmitterInput { }; // Convenience struct that contains useful data structures in MLIR emitter. -// Not all fields may be filled. It's entiredly dependent on the uses. struct MlirEmitterContext { void SetOperation(mlir::Operation* op); @@ -157,14 +156,11 @@ class IrEmitterUnnested : public IrEmitter, } Status DefaultAction(HloInstruction* hlo) override; - Status DefaultActionForMlir(MlirEmitterInput input); // IrEmitterUnnested handles the following instructions differently from // IrEmitter. It also mixes in some special handling for custom kernels // via the ThunkEmitter. Status HandleCopy(HloInstruction* copy) override; - Status EmitCopyForMlir(MlirEmitterInput input); - Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleCustomCall(HloInstruction* custom_call) override; @@ -471,15 +467,12 @@ class IrEmitterUnnested : public IrEmitter, // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel // for the hlo instruction. - StatusOr CheckAndEmitHloWithTile021(MlirEmitterInput input); + bool CheckAndEmitHloWithTile021(HloInstruction* hlo); // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and // sets the corresponding launch dimensions. This is a helper to support // the implementation of CheckAndEmitHloWithTile021. - void EmitHlo021Tile(mlir::Operation* op, Thunk* kernel_thunk, - const MlirEmitterContext& context, - absl::Span operand_arrays, - absl::Span output_arrays, + void EmitHlo021Tile(HloInstruction* hlo, Thunk* kernel_thunk, absl::Span reduced_output_dims, absl::Span tiled_param_ids); @@ -534,8 +527,7 @@ class IrEmitterUnnested : public IrEmitter, // y_loc: The y coordinate within a tile. // x_loc: The x coordinate within a tile. void EmitTileElementForCopy( - const Shape& output_shape, const llvm_ir::IrArray& ir_array, - const llvm_ir::IrArray::Index& index, + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, llvm::Value* x_loc, absl::Span param_shmem_buffers); @@ -544,10 +536,7 @@ class IrEmitterUnnested : public IrEmitter, // y_loc: The y coordinate within a tile. // x_loc: The x coordinate within a tile. void EmitTileElementForFusion( - mlir::lmhlo::FusionOp fusion, - absl::Span operand_arrays, - absl::Span output_arrays, - const llvm_ir::IrArray::Index& index, + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, llvm::Value* x_loc, absl::Span param_shmem_buffers); diff --git a/tensorflow/compiler/xla/service/gpu/tests/copy-nested.hlo b/tensorflow/compiler/xla/service/gpu/tests/copy-nested.hlo index 483175c3a80..4365d4a800a 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/copy-nested.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/copy-nested.hlo @@ -2,12 +2,12 @@ // CHECK-LABEL: entry: // CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 -// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [100 x [200 x [300 x float]]]* +// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [200 x [100 x [300 x float]]]* // CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0 -// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [200 x [100 x [300 x float]]]* +// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [100 x [200 x [300 x float]]]* // CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 -// CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 256 +// CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 128 // CHECK: %[[VAL_9:.*]] = add nuw nsw i32 %[[VAL_8]], %[[VAL_7]] // CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], {{.*}} // CHECK: call void @llvm.assume(i1 %[[VAL_10]]) @@ -40,24 +40,24 @@ // CHECK: b.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]] // CHECK: ret void // CHECK: b.in_bounds-true: ; preds = %[[VAL_38]] -// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]] +// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]] // CHECK: %[[VAL_40:.*]] = load float, float* %[[VAL_39]], align 4, !invariant.load !4 -// CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* +// CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float* // CHECK: %[[VAL_42:.*]] = getelementptr inbounds float, float* %[[VAL_41]], i32 %[[VAL_11]] // CHECK: store float %[[VAL_40]], float* %[[VAL_42]], align 4 -// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]] +// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]] // CHECK: %[[VAL_44:.*]] = load float, float* %[[VAL_43]], align 4, !invariant.load !4 -// CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* +// CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float* // CHECK: %[[VAL_46:.*]] = getelementptr inbounds float, float* %[[VAL_45]], i32 %[[VAL_17]] // CHECK: store float %[[VAL_44]], float* %[[VAL_46]], align 4 -// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]] +// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]] // CHECK: %[[VAL_48:.*]] = load float, float* %[[VAL_47]], align 4, !invariant.load !4 -// CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* +// CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float* // CHECK: %[[VAL_50:.*]] = getelementptr inbounds float, float* %[[VAL_49]], i32 %[[VAL_23]] // CHECK: store float %[[VAL_48]], float* %[[VAL_50]], align 4 -// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]] +// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]] // CHECK: %[[VAL_52:.*]] = load float, float* %[[VAL_51]], align 4, !invariant.load !4 -// CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* +// CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float* // CHECK: %[[VAL_54:.*]] = getelementptr inbounds float, float* %[[VAL_53]], i32 %[[VAL_29]] // CHECK: store float %[[VAL_52]], float* %[[VAL_54]], align 4 // CHECK: br label %[[VAL_37]] diff --git a/tensorflow/compiler/xla/service/gpu/tests/copy.hlo b/tensorflow/compiler/xla/service/gpu/tests/copy.hlo index 128378a9130..594b9a08d43 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/copy.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/copy.hlo @@ -4,10 +4,10 @@ // CHECK: %[[VAL_0:.*]] = alloca i32, align 4 // CHECK: %[[VAL_1:.*]] = alloca i32, align 4 // CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_3:.*]], i64 0 -// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [100 x [200 x float]]* +// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [200 x [100 x float]]* // CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0 -// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [200 x [100 x float]]* -// CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_4]] to [1 x [100 x [200 x float]]]* +// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [100 x [200 x float]]* +// CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_7]] to [1 x [100 x [200 x float]]]* // CHECK: %[[VAL_9:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 // CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 32 // CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_9]], 32 @@ -88,7 +88,7 @@ // CHECK: output_x_in_tile-true: ; preds = %[[VAL_59]] // CHECK: %[[VAL_72:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i64 0, i32 %[[VAL_65]], i32 %[[VAL_63]] // CHECK: %[[VAL_73:.*]] = load float, float addrspace(3)* %[[VAL_72]], align 4 -// CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_7]] to [1 x [200 x [100 x float]]]* +// CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_4]] to [1 x [200 x [100 x float]]]* // CHECK: %[[VAL_75:.*]] = getelementptr inbounds [1 x [200 x [100 x float]]], [1 x [200 x [100 x float]]]* %[[VAL_74]], i32 0, i32 0, i32 %[[VAL_64]], i32 %[[VAL_66]] // CHECK: store float %[[VAL_73]], float* %[[VAL_75]], align 4 // CHECK: br label %[[VAL_55]]