[XLA/GPU] Remove non-cublas gemm implementation. It's already covered by ElementalIrEmitter.

The overall logic are the same, and benchmarks don't seem to regress.

PiperOrigin-RevId: 343013260
Change-Id: I1dcde18399a06a7541842640f9802f96d3c85746
This commit is contained in:
Tim Shen 2020-11-17 22:09:19 -08:00 committed by TensorFlower Gardener
parent a6b779c17f
commit 8d06472606
4 changed files with 0 additions and 178 deletions

View File

@ -562,176 +562,6 @@ std::pair<llvm::Value*, llvm::Value*> MultiplyComplex(llvm::Value* lhs_value,
}
} // namespace
Status IrEmitter::HandleDot(HloInstruction* dot) {
auto lhs_instruction = dot->operand(0);
auto rhs_instruction = dot->operand(1);
const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot);
const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot);
const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot);
const Shape& lhs_shape = lhs_instruction->shape();
const Shape& rhs_shape = rhs_instruction->shape();
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
dnums.rhs_batch_dimensions_size());
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
llvm::Type* index_type = b_.getInt64Ty();
llvm_ir::IrArray::Index element_index(index_type);
if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) {
// If the operands are scalar, don't emit any loops.
llvm::Value* lhs_value =
lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
llvm::Value* rhs_value =
rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_);
llvm::Value* result;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto value = MultiplyComplex(lhs_value, rhs_value, &b_);
result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
result = InsertValue(result, value.first, {0});
result = InsertValue(result, value.second, {1});
} else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
result = FMul(lhs_value, rhs_value);
} else {
TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
result = Mul(lhs_value, rhs_value);
}
target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_);
return Status::OK();
}
// "Scalar dot non-scalar" or "non-scalar dot scalar" is invalid. See
// the semantics of Dot in the XLA documentation for details.
TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) &&
!ShapeUtil::IsScalar(rhs_shape));
const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0);
const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0);
// Check that the batch dims don't cover the reduction dimensions.
for (int64 batch_dim : dnums.lhs_batch_dimensions()) {
CHECK_NE(lhs_reduction_dimension, batch_dim);
CHECK_NE(rhs_reduction_dimension, batch_dim);
}
// Verify the reduction dimension in the two operands are the same size.
TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
rhs_shape.dimensions(rhs_reduction_dimension))
<< "lhs_shape.dimensions(" << lhs_reduction_dimension
<< ") = " << lhs_shape.dimensions(lhs_reduction_dimension)
<< ", and rhs_shape.dimensions(" << rhs_reduction_dimension
<< ") = " << rhs_shape.dimensions(rhs_reduction_dimension);
// Create loop nests which loop through the LHS operand dimensions and the RHS
// operand dimensions. The reduction dimension of the LHS and RHS are handled
// in a separate innermost loop which performs the sum of products.
llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_);
std::vector<llvm::Value*> lhs_multi_index =
loop_nest.EmitOperandArrayLoopNest(
lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs");
std::vector<llvm::Value*> rhs_multi_index =
loop_nest.EmitOperandArrayLoopNest(
rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
// We don't have to iterate over the batch dimensions in both arrays, simplify
// the loop nest of the rhs.
for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i));
rhs_multi_index[i] = lhs_multi_index[i];
}
// Create the reduction loop which does the sum of products reduction.
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
/*start_index=*/0,
/*end_index=*/lhs_shape.dimensions(lhs_reduction_dimension),
/*suffix=*/"reduction");
// The final entry in the rhs and lhs indexes is the indvar of the reduction
// loop.
lhs_multi_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
rhs_multi_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
// For computing the sum of products we alloca a single location to store the
// dot product result as we accumulate it within the reduction loop. After the
// reduction loop we load the result and store into the output array.
llvm::Type* accum_type = target_array.GetElementLlvmType();
llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry(
accum_type, // The pointee type of the alloca instruction.
"accum_address", // The name of the alloca instruction.
&b_);
// Initialize the accumulator in the preheader to zero.
new llvm::StoreInst(
llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0
accum_address, // The address.
reduction_loop->GetPreheaderBasicBlock()
->getTerminator()); // The instruction this store is inserted before.
// Emit the body of the reduction loop:
// accum = *accum_address
// updated_accum = accum + lhs_element * rhs_element
// *accum_address = updated_accum
TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty());
b_.SetInsertPoint(
&*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
llvm_ir::IrArray::Index lhs_index(lhs_multi_index, lhs_array.GetShape(),
b_.getInt64Ty());
llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_);
llvm_ir::IrArray::Index rhs_index(rhs_multi_index, rhs_array.GetShape(),
b_.getInt64Ty());
llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_);
llvm::Value* accum = Load(accum_address);
llvm::Value* updated_accum;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto value = MultiplyComplex(lhs_element, rhs_element, &b_);
llvm::Value* accum_real = Real(accum, &b_);
llvm::Value* real_sum = FAdd(accum_real, value.first);
updated_accum = InsertValue(accum, real_sum, {0});
llvm::Value* accum_imag = Imag(accum, &b_);
llvm::Value* imag_sum = FAdd(accum_imag, value.second);
updated_accum = InsertValue(updated_accum, imag_sum, {1});
} else if (ShapeUtil::ElementIsFloating(lhs_shape)) {
llvm::Value* product = FMul(lhs_element, rhs_element);
updated_accum = FAdd(accum, product);
} else {
TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape));
llvm::Value* product = Mul(lhs_element, rhs_element);
updated_accum = Add(accum, product);
}
Store(updated_accum, accum_address);
// After the reduction loop exits, store the accumulator into the target
// address. The index into the target address is the concatenation of the rhs
// and lhs indexes with the reduction dimensions removed. The terms from the
// rhs index are the lower dimensions in the index so we add them first.
std::vector<llvm::Value*> target_multi_index;
for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) {
if (dimension != lhs_reduction_dimension) {
target_multi_index.push_back(lhs_index[dimension]);
}
}
// Skip over the batch dimensions to not have them in the index twice.
for (size_t dimension = dnums.lhs_batch_dimensions_size();
dimension < rhs_index.size(); ++dimension) {
if (dimension != rhs_reduction_dimension) {
target_multi_index.push_back(rhs_index[dimension]);
}
}
SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_);
llvm_ir::IrArray::Index target_index(target_multi_index,
target_array.GetShape(), index_type);
target_array.EmitWriteArrayElement(
target_index,
Load(accum_address), // The value written to the target array.
&b_);
// Set the IR builder insert point to the exit basic block of the outer most
// loop. This ensures later instructions are inserted after this loop nest.
b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
return Status::OK();
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
// Emit no code for an empty output.

View File

@ -79,7 +79,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status HandleConstant(HloInstruction* constant) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleDot(HloInstruction* dot) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleAllReduce(HloInstruction* crs) override;

View File

@ -580,12 +580,6 @@ Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
return ret;
}
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
AddThunkToThunkSequence(
BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
return IrEmitter::HandleDot(dot);
}
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
AddThunkToThunkSequence(std::move(thunk));

View File

@ -168,7 +168,6 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleConditional(HloInstruction* conditional) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleDot(HloInstruction* dot) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleFusion(HloInstruction* fusion) override;
Status EmitLoopFusionFromMlir(MlirEmitterInput input,