[XLA/GPU] Remove non-cublas gemm implementation. It's already covered by ElementalIrEmitter.
The overall logic are the same, and benchmarks don't seem to regress. PiperOrigin-RevId: 343013260 Change-Id: I1dcde18399a06a7541842640f9802f96d3c85746
This commit is contained in:
parent
a6b779c17f
commit
8d06472606
@ -562,176 +562,6 @@ std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status IrEmitter::HandleDot(HloInstruction* dot) {
|
||||
auto lhs_instruction = dot->operand(0);
|
||||
auto rhs_instruction = dot->operand(1);
|
||||
const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot);
|
||||
const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot);
|
||||
const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot);
|
||||
|
||||
const Shape& lhs_shape = lhs_instruction->shape();
|
||||
const Shape& rhs_shape = rhs_instruction->shape();
|
||||
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
|
||||
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
|
||||
dnums.rhs_batch_dimensions_size());
|
||||
|
||||
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
|
||||
llvm::Type* index_type = b_.getInt64Ty();
|
||||
llvm_ir::IrArray::Index element_index(index_type);
|
||||
if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) {
|
||||
// If the operands are scalar, don't emit any loops.
|
||||
llvm::Value* lhs_value =
|
||||
lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
|
||||
llvm::Value* rhs_value =
|
||||
rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
|
||||
llvm::Value* result;
|
||||
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
|
||||
auto value = MultiplyComplex(lhs_value, rhs_value, &b_);
|
||||
result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
|
||||
result = InsertValue(result, value.first, {0});
|
||||
result = InsertValue(result, value.second, {1});
|
||||
} else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
|
||||
result = FMul(lhs_value, rhs_value);
|
||||
} else {
|
||||
TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
|
||||
result = Mul(lhs_value, rhs_value);
|
||||
}
|
||||
target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// "Scalar dot non-scalar" or "non-scalar dot scalar" is invalid. See
|
||||
// the semantics of Dot in the XLA documentation for details.
|
||||
TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) &&
|
||||
!ShapeUtil::IsScalar(rhs_shape));
|
||||
|
||||
const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0);
|
||||
const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0);
|
||||
|
||||
// Check that the batch dims don't cover the reduction dimensions.
|
||||
for (int64 batch_dim : dnums.lhs_batch_dimensions()) {
|
||||
CHECK_NE(lhs_reduction_dimension, batch_dim);
|
||||
CHECK_NE(rhs_reduction_dimension, batch_dim);
|
||||
}
|
||||
|
||||
// Verify the reduction dimension in the two operands are the same size.
|
||||
TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
|
||||
rhs_shape.dimensions(rhs_reduction_dimension))
|
||||
<< "lhs_shape.dimensions(" << lhs_reduction_dimension
|
||||
<< ") = " << lhs_shape.dimensions(lhs_reduction_dimension)
|
||||
<< ", and rhs_shape.dimensions(" << rhs_reduction_dimension
|
||||
<< ") = " << rhs_shape.dimensions(rhs_reduction_dimension);
|
||||
|
||||
// Create loop nests which loop through the LHS operand dimensions and the RHS
|
||||
// operand dimensions. The reduction dimension of the LHS and RHS are handled
|
||||
// in a separate innermost loop which performs the sum of products.
|
||||
llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_);
|
||||
std::vector<llvm::Value*> lhs_multi_index =
|
||||
loop_nest.EmitOperandArrayLoopNest(
|
||||
lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
|
||||
std::vector<llvm::Value*> rhs_multi_index =
|
||||
loop_nest.EmitOperandArrayLoopNest(
|
||||
rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
|
||||
|
||||
// We don't have to iterate over the batch dimensions in both arrays, simplify
|
||||
// the loop nest of the rhs.
|
||||
for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
|
||||
DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i));
|
||||
rhs_multi_index[i] = lhs_multi_index[i];
|
||||
}
|
||||
|
||||
// Create the reduction loop which does the sum of products reduction.
|
||||
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
|
||||
/*start_index=*/0,
|
||||
/*end_index=*/lhs_shape.dimensions(lhs_reduction_dimension),
|
||||
/*suffix=*/"reduction");
|
||||
|
||||
// The final entry in the rhs and lhs indexes is the indvar of the reduction
|
||||
// loop.
|
||||
lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
|
||||
rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
|
||||
|
||||
// For computing the sum of products we alloca a single location to store the
|
||||
// dot product result as we accumulate it within the reduction loop. After the
|
||||
// reduction loop we load the result and store into the output array.
|
||||
llvm::Type* accum_type = target_array.GetElementLlvmType();
|
||||
llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
accum_type, // The pointee type of the alloca instruction.
|
||||
"accum_address", // The name of the alloca instruction.
|
||||
&b_);
|
||||
|
||||
// Initialize the accumulator in the preheader to zero.
|
||||
new llvm::StoreInst(
|
||||
llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0
|
||||
accum_address, // The address.
|
||||
reduction_loop->GetPreheaderBasicBlock()
|
||||
->getTerminator()); // The instruction this store is inserted before.
|
||||
|
||||
// Emit the body of the reduction loop:
|
||||
// accum = *accum_address
|
||||
// updated_accum = accum + lhs_element * rhs_element
|
||||
// *accum_address = updated_accum
|
||||
TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty());
|
||||
b_.SetInsertPoint(
|
||||
&*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
|
||||
llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_array.GetShape(),
|
||||
b_.getInt64Ty());
|
||||
llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_);
|
||||
llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_array.GetShape(),
|
||||
b_.getInt64Ty());
|
||||
llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_);
|
||||
llvm::Value* accum = Load(accum_address);
|
||||
llvm::Value* updated_accum;
|
||||
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
|
||||
auto value = MultiplyComplex(lhs_element, rhs_element, &b_);
|
||||
llvm::Value* accum_real = Real(accum, &b_);
|
||||
llvm::Value* real_sum = FAdd(accum_real, value.first);
|
||||
updated_accum = InsertValue(accum, real_sum, {0});
|
||||
llvm::Value* accum_imag = Imag(accum, &b_);
|
||||
llvm::Value* imag_sum = FAdd(accum_imag, value.second);
|
||||
updated_accum = InsertValue(updated_accum, imag_sum, {1});
|
||||
} else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
|
||||
llvm::Value* product = FMul(lhs_element, rhs_element);
|
||||
updated_accum = FAdd(accum, product);
|
||||
} else {
|
||||
TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
|
||||
llvm::Value* product = Mul(lhs_element, rhs_element);
|
||||
updated_accum = Add(accum, product);
|
||||
}
|
||||
Store(updated_accum, accum_address);
|
||||
|
||||
// After the reduction loop exits, store the accumulator into the target
|
||||
// address. The index into the target address is the concatenation of the rhs
|
||||
// and lhs indexes with the reduction dimensions removed. The terms from the
|
||||
// rhs index are the lower dimensions in the index so we add them first.
|
||||
std::vector<llvm::Value*> target_multi_index;
|
||||
for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) {
|
||||
if (dimension != lhs_reduction_dimension) {
|
||||
target_multi_index.push_back(lhs_index[dimension]);
|
||||
}
|
||||
}
|
||||
// Skip over the batch dimensions to not have them in the index twice.
|
||||
for (size_t dimension = dnums.lhs_batch_dimensions_size();
|
||||
dimension < rhs_index.size(); ++dimension) {
|
||||
if (dimension != rhs_reduction_dimension) {
|
||||
target_multi_index.push_back(rhs_index[dimension]);
|
||||
}
|
||||
}
|
||||
SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_);
|
||||
llvm_ir::IrArray::Index target_index(target_multi_index,
|
||||
target_array.GetShape(), index_type);
|
||||
target_array.EmitWriteArrayElement(
|
||||
target_index,
|
||||
Load(accum_address), // The value written to the target array.
|
||||
&b_);
|
||||
|
||||
// Set the IR builder insert point to the exit basic block of the outer most
|
||||
// loop. This ensures later instructions are inserted after this loop nest.
|
||||
b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
||||
if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
|
||||
// Emit no code for an empty output.
|
||||
|
||||
@ -79,7 +79,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
Status HandleConstant(HloInstruction* constant) override;
|
||||
Status HandleBitcast(HloInstruction* bitcast) override;
|
||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
|
||||
Status HandleDot(HloInstruction* dot) override;
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleFft(HloInstruction* fft) override;
|
||||
Status HandleAllReduce(HloInstruction* crs) override;
|
||||
|
||||
@ -580,12 +580,6 @@ Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
|
||||
AddThunkToThunkSequence(
|
||||
BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
|
||||
return IrEmitter::HandleDot(dot);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
|
||||
TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
|
||||
AddThunkToThunkSequence(std::move(thunk));
|
||||
|
||||
@ -168,7 +168,6 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleConditional(HloInstruction* conditional) override;
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||
Status HandleDot(HloInstruction* dot) override;
|
||||
Status HandleFft(HloInstruction* fft) override;
|
||||
Status HandleFusion(HloInstruction* fusion) override;
|
||||
Status EmitLoopFusionFromMlir(MlirEmitterInput input,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user