[XLA:GPU] Eliminate support for Convolution/GEMM/Choleksy from ThunkEmitter
- Since these have been migrated to MLIR, we don't need to support them in the ThunkEmitter PiperOrigin-RevId: 348525890 Change-Id: If57be7969a29263c87d31bb7a2dd095cb39a382a
This commit is contained in:
parent
4e7e6df7d7
commit
40f207fdb9
@ -302,83 +302,6 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (IsCustomCallToDnnConvolution(*custom_call)) {
|
||||
std::vector<BufferAllocation::Slice> operand_slices;
|
||||
operand_slices.reserve(custom_call->operand_count());
|
||||
for (const auto* operand : custom_call->operands()) {
|
||||
operand_slices.push_back(GetAllocationSlice(*operand));
|
||||
}
|
||||
auto conv_result_slice = GetAllocationSlice(*custom_call, {0});
|
||||
auto scratch_slice = GetAllocationSlice(*custom_call, {1});
|
||||
|
||||
// Assert that the tuple slice is not used by anyone directly. That is, all
|
||||
// users of the tuple output are get-tuple-element. Also assert that the
|
||||
// second element of the tuple (the scratch buffer) is not used by anyone.
|
||||
for (const HloInstruction* user : custom_call->users()) {
|
||||
TF_RET_CHECK(user->opcode() == HloOpcode::kGetTupleElement &&
|
||||
user->tuple_index() == 0);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
GpuConvConfig config,
|
||||
GetGpuConvConfig(Cast<HloCustomCallInstruction>(custom_call)));
|
||||
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
|
||||
context_->GetThunkInfo(custom_call), std::move(config),
|
||||
std::move(operand_slices), conv_result_slice, scratch_slice));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (IsCublasGemm(*custom_call)) {
|
||||
AddThunkToThunkSequence(BuildGemmThunk(custom_call));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||
if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) {
|
||||
TF_ASSIGN_OR_RETURN(CholeskyOptions options,
|
||||
custom_call->backend_config<CholeskyOptions>());
|
||||
|
||||
const Shape& shape = custom_call->operand(0)->shape();
|
||||
int ndim = shape.dimensions_size();
|
||||
CHECK_GE(ndim, 2);
|
||||
int64 n = shape.dimensions(ndim - 1);
|
||||
|
||||
const auto& dims = shape.dimensions();
|
||||
int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1},
|
||||
[](int64 a, int64 b) { return a * b; });
|
||||
|
||||
auto operand_buffer = GetAllocationSlice(*custom_call->operand(0));
|
||||
|
||||
auto a_buffer = GetAllocationSlice(*custom_call, {0});
|
||||
auto workspace_buffer = GetAllocationSlice(*custom_call, {1});
|
||||
auto info_buffer = GetAllocationSlice(*custom_call, {2});
|
||||
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
||||
if (operand_buffer != a_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
context_->GetThunkInfo(custom_call),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/a_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
|
||||
}
|
||||
|
||||
thunks.push_back(absl::make_unique<CholeskyThunk>(
|
||||
context_->GetThunkInfo(custom_call), options, a_buffer,
|
||||
workspace_buffer, info_buffer,
|
||||
custom_call->operand(0)->shape().element_type(), batch_size, n));
|
||||
|
||||
// Elide the sequential thunk if there's no copy.
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
context_->GetThunkInfo(custom_call), std::move(thunks)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
|
Loading…
Reference in New Issue
Block a user