diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 17fd1820acd..495a801c182 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -562,176 +562,6 @@ std::pair MultiplyComplex(llvm::Value* lhs_value, } } // namespace -Status IrEmitter::HandleDot(HloInstruction* dot) { - auto lhs_instruction = dot->operand(0); - auto rhs_instruction = dot->operand(1); - const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot); - const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot); - const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot); - - const Shape& lhs_shape = lhs_instruction->shape(); - const Shape& rhs_shape = rhs_instruction->shape(); - const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - CHECK_EQ(dnums.lhs_batch_dimensions_size(), - dnums.rhs_batch_dimensions_size()); - - // TODO(b/110211620): Convert to use i32 index_type when it is possible. - llvm::Type* index_type = b_.getInt64Ty(); - llvm_ir::IrArray::Index element_index(index_type); - if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) { - // If the operands are scalar, don't emit any loops. - llvm::Value* lhs_value = - lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_); - llvm::Value* rhs_value = - rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_); - llvm::Value* result; - if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto value = MultiplyComplex(lhs_value, rhs_value, &b_); - result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = InsertValue(result, value.first, {0}); - result = InsertValue(result, value.second, {1}); - } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { - result = FMul(lhs_value, rhs_value); - } else { - TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); - result = Mul(lhs_value, rhs_value); - } - target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); - return Status::OK(); - } - - // "Scalar dot non-scalar" or "non-scalar dot scalar" is invalid. See - // the semantics of Dot in the XLA documentation for details. - TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) && - !ShapeUtil::IsScalar(rhs_shape)); - - const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0); - const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0); - - // Check that the batch dims don't cover the reduction dimensions. - for (int64 batch_dim : dnums.lhs_batch_dimensions()) { - CHECK_NE(lhs_reduction_dimension, batch_dim); - CHECK_NE(rhs_reduction_dimension, batch_dim); - } - - // Verify the reduction dimension in the two operands are the same size. - TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == - rhs_shape.dimensions(rhs_reduction_dimension)) - << "lhs_shape.dimensions(" << lhs_reduction_dimension - << ") = " << lhs_shape.dimensions(lhs_reduction_dimension) - << ", and rhs_shape.dimensions(" << rhs_reduction_dimension - << ") = " << rhs_shape.dimensions(rhs_reduction_dimension); - - // Create loop nests which loop through the LHS operand dimensions and the RHS - // operand dimensions. The reduction dimension of the LHS and RHS are handled - // in a separate innermost loop which performs the sum of products. - llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_); - std::vector lhs_multi_index = - loop_nest.EmitOperandArrayLoopNest( - lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); - std::vector rhs_multi_index = - loop_nest.EmitOperandArrayLoopNest( - rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); - - // We don't have to iterate over the batch dimensions in both arrays, simplify - // the loop nest of the rhs. - for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { - DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i)); - rhs_multi_index[i] = lhs_multi_index[i]; - } - - // Create the reduction loop which does the sum of products reduction. - std::unique_ptr reduction_loop = loop_nest.AddLoop( - /*start_index=*/0, - /*end_index=*/lhs_shape.dimensions(lhs_reduction_dimension), - /*suffix=*/"reduction"); - - // The final entry in the rhs and lhs indexes is the indvar of the reduction - // loop. - lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue(); - rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue(); - - // For computing the sum of products we alloca a single location to store the - // dot product result as we accumulate it within the reduction loop. After the - // reduction loop we load the result and store into the output array. - llvm::Type* accum_type = target_array.GetElementLlvmType(); - llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry( - accum_type, // The pointee type of the alloca instruction. - "accum_address", // The name of the alloca instruction. - &b_); - - // Initialize the accumulator in the preheader to zero. - new llvm::StoreInst( - llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0 - accum_address, // The address. - reduction_loop->GetPreheaderBasicBlock() - ->getTerminator()); // The instruction this store is inserted before. - - // Emit the body of the reduction loop: - // accum = *accum_address - // updated_accum = accum + lhs_element * rhs_element - // *accum_address = updated_accum - TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty()); - b_.SetInsertPoint( - &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); - llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_array.GetShape(), - b_.getInt64Ty()); - llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); - llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_array.GetShape(), - b_.getInt64Ty()); - llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); - llvm::Value* accum = Load(accum_address); - llvm::Value* updated_accum; - if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto value = MultiplyComplex(lhs_element, rhs_element, &b_); - llvm::Value* accum_real = Real(accum, &b_); - llvm::Value* real_sum = FAdd(accum_real, value.first); - updated_accum = InsertValue(accum, real_sum, {0}); - llvm::Value* accum_imag = Imag(accum, &b_); - llvm::Value* imag_sum = FAdd(accum_imag, value.second); - updated_accum = InsertValue(updated_accum, imag_sum, {1}); - } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { - llvm::Value* product = FMul(lhs_element, rhs_element); - updated_accum = FAdd(accum, product); - } else { - TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); - llvm::Value* product = Mul(lhs_element, rhs_element); - updated_accum = Add(accum, product); - } - Store(updated_accum, accum_address); - - // After the reduction loop exits, store the accumulator into the target - // address. The index into the target address is the concatenation of the rhs - // and lhs indexes with the reduction dimensions removed. The terms from the - // rhs index are the lower dimensions in the index so we add them first. - std::vector target_multi_index; - for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { - if (dimension != lhs_reduction_dimension) { - target_multi_index.push_back(lhs_index[dimension]); - } - } - // Skip over the batch dimensions to not have them in the index twice. - for (size_t dimension = dnums.lhs_batch_dimensions_size(); - dimension < rhs_index.size(); ++dimension) { - if (dimension != rhs_reduction_dimension) { - target_multi_index.push_back(rhs_index[dimension]); - } - } - SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); - llvm_ir::IrArray::Index target_index(target_multi_index, - target_array.GetShape(), index_type); - target_array.EmitWriteArrayElement( - target_index, - Load(accum_address), // The value written to the target array. - &b_); - - // Set the IR builder insert point to the exit basic block of the outer most - // loop. This ensures later instructions are inserted after this loop nest. - b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - - return Status::OK(); -} - Status IrEmitter::HandleConvolution(HloInstruction* convolution) { if (ShapeUtil::IsZeroElementArray(convolution->shape())) { // Emit no code for an empty output. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 894f1401e0d..df30395b7f1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -79,7 +79,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleConstant(HloInstruction* constant) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleAllReduce(HloInstruction* crs) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b22d6a0c810..68e29bf68c6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -580,12 +580,6 @@ Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) { return ret; } -Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { - AddThunkToThunkSequence( - BuildKernelThunk(dot, /*implements_whole_instruction=*/true)); - return IrEmitter::HandleDot(dot); -} - Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional)); AddThunkToThunkSequence(std::move(thunk)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index a5089111150..197a25e2cc1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -168,7 +168,6 @@ class IrEmitterUnnested : public IrEmitter, Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleDot(HloInstruction* dot) override; Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; Status EmitLoopFusionFromMlir(MlirEmitterInput input,