From 834cadde2ee5f7e935e05f8e836acde3337b3f98 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Mon, 14 Dec 2020 09:10:50 -0800 Subject: [PATCH] [XLA:GPU] Eliminate support for Convolution/GEMM/Choleksy from ThunkEmitter - Since these have been migrated to MLIR, we don't need to support them in the ThunkEmitter PiperOrigin-RevId: 347397940 Change-Id: I31c088008b1f4fdad82ead96313fb92324c206fe --- .../compiler/xla/service/gpu/thunk_emitter.cc | 76 ------------------- 1 file changed, 76 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index 058aad76777..215fdb56ce8 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -298,83 +298,7 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { return Status::OK(); } - if (IsCustomCallToDnnConvolution(*custom_call)) { - std::vector operand_slices; - operand_slices.reserve(custom_call->operand_count()); - for (const auto* operand : custom_call->operands()) { - operand_slices.push_back(GetAllocationSlice(*operand)); - } - auto conv_result_slice = GetAllocationSlice(*custom_call, {0}); - auto scratch_slice = GetAllocationSlice(*custom_call, {1}); - - // Assert that the tuple slice is not used by anyone directly. That is, all - // users of the tuple output are get-tuple-element. Also assert that the - // second element of the tuple (the scratch buffer) is not used by anyone. - for (const HloInstruction* user : custom_call->users()) { - TF_RET_CHECK(user->opcode() == HloOpcode::kGetTupleElement && - user->tuple_index() == 0); - } - - TF_ASSIGN_OR_RETURN( - GpuConvConfig config, - GetGpuConvConfig(Cast(custom_call))); - AddThunkToThunkSequence(absl::make_unique( - context_->GetThunkInfo(custom_call), std::move(config), - std::move(operand_slices), conv_result_slice, scratch_slice)); - return Status::OK(); - } - - if (IsCublasGemm(*custom_call)) { - AddThunkToThunkSequence(BuildGemmThunk(custom_call)); - return Status::OK(); - } - #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) - if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) { - TF_ASSIGN_OR_RETURN(CholeskyOptions options, - custom_call->backend_config()); - - const Shape& shape = custom_call->operand(0)->shape(); - int ndim = shape.dimensions_size(); - CHECK_GE(ndim, 2); - int64 n = shape.dimensions(ndim - 1); - - const auto& dims = shape.dimensions(); - int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1}, - [](int64 a, int64 b) { return a * b; }); - - auto operand_buffer = GetAllocationSlice(*custom_call->operand(0)); - - auto a_buffer = GetAllocationSlice(*custom_call, {0}); - auto workspace_buffer = GetAllocationSlice(*custom_call, {1}); - auto info_buffer = GetAllocationSlice(*custom_call, {2}); - - std::vector> thunks; - - if (operand_buffer != a_buffer) { - thunks.push_back(absl::make_unique( - context_->GetThunkInfo(custom_call), - /*source_address=*/operand_buffer, - /*destination_buffer=*/a_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape))); - } - - thunks.push_back(absl::make_unique( - context_->GetThunkInfo(custom_call), options, a_buffer, - workspace_buffer, info_buffer, - custom_call->operand(0)->shape().element_type(), batch_size, n)); - - // Elide the sequential thunk if there's no copy. - if (thunks.size() == 1) { - AddThunkToThunkSequence(std::move(thunks[0])); - } else { - AddThunkToThunkSequence(absl::make_unique( - context_->GetThunkInfo(custom_call), std::move(thunks))); - } - - return Status::OK(); - } - if (void* call_target = CustomCallTargetRegistry::Global()->Lookup( custom_call->custom_call_target(), std::string(platform_name()))) { auto get_slices_for_instr = [&](const HloInstruction* instr) {