Merge pull request #45743 from ROCmSoftwarePlatform:google_upstream_rocm_rocfft_se
PiperOrigin-RevId: 348624235 Change-Id: Ib32ac683913a9031a693b06f62e27721d4d6c33c
This commit is contained in:
commit
2fc25c73bd
@ -230,12 +230,11 @@ port::Status ROCMFftPlan::Initialize(
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to set auto allocation for rocFFT plan."};
|
||||
}
|
||||
size_t size_in_bytes;
|
||||
switch (rank) {
|
||||
case 1:
|
||||
ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
|
||||
ROCMFftType(type), /*batch=*/1,
|
||||
&size_in_bytes);
|
||||
&scratch_size_bytes_);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
@ -245,7 +244,7 @@ port::Status ROCMFftPlan::Initialize(
|
||||
case 2:
|
||||
ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
|
||||
elem_count_[1], ROCMFftType(type),
|
||||
&size_in_bytes);
|
||||
&scratch_size_bytes_);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
@ -255,7 +254,7 @@ port::Status ROCMFftPlan::Initialize(
|
||||
case 3:
|
||||
ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
|
||||
elem_count_[1], elem_count_[2],
|
||||
ROCMFftType(type), &size_in_bytes);
|
||||
ROCMFftType(type), &scratch_size_bytes_);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
@ -269,23 +268,7 @@ port::Status ROCMFftPlan::Initialize(
|
||||
return port::Status{port::error::INVALID_ARGUMENT,
|
||||
"hipfftPlan only takes rank 1, 2, or 3."};
|
||||
}
|
||||
// TODO(yangzihao): refactor this code and the one with the same function
|
||||
// in the batch mode.
|
||||
if (size_in_bytes != 0) {
|
||||
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
|
||||
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "failed to allocate work area.";
|
||||
return allocated.status();
|
||||
}
|
||||
}
|
||||
// Connect work area with allocated space.
|
||||
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to set work area for rocFFT plan."};
|
||||
}
|
||||
return port::Status::OK();
|
||||
return UpdateScratchAllocator(stream, scratch_allocator);
|
||||
}
|
||||
} else {
|
||||
// For either multiple batches or rank higher than 3, use hipfftPlanMany().
|
||||
@ -315,31 +298,18 @@ port::Status ROCMFftPlan::Initialize(
|
||||
port::error::INTERNAL,
|
||||
"Failed to set auto allocation for rocFFT batched plan."};
|
||||
}
|
||||
size_t size_in_bytes;
|
||||
ret = wrap::hipfftMakePlanMany(
|
||||
parent, plan_, rank, elem_count_,
|
||||
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
|
||||
output_embed ? output_embed_ : nullptr, output_stride,
|
||||
output_distance, ROCMFftType(type), batch_count, &size_in_bytes);
|
||||
output_distance, ROCMFftType(type), batch_count,
|
||||
&scratch_size_bytes_);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to make rocFFT batched plan."};
|
||||
}
|
||||
if (size_in_bytes != 0) {
|
||||
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
|
||||
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "failed to allocate work area.";
|
||||
return allocated.status();
|
||||
}
|
||||
}
|
||||
// Connect work area with allocated space.
|
||||
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to set work area for rocFFT batched plan."};
|
||||
}
|
||||
return UpdateScratchAllocator(stream, scratch_allocator);
|
||||
}
|
||||
}
|
||||
return port::Status::OK();
|
||||
@ -356,6 +326,25 @@ port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
|
||||
/*output_distance=*/0, type, 1, scratch_allocator);
|
||||
}
|
||||
|
||||
port::Status ROCMFftPlan::UpdateScratchAllocator(
|
||||
Stream *stream, ScratchAllocator *scratch_allocator) {
|
||||
if (scratch_size_bytes_ != 0) {
|
||||
auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
|
||||
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "failed to allocate work area.";
|
||||
return allocated.status();
|
||||
}
|
||||
}
|
||||
// Connect work area with allocated space.
|
||||
auto ret = wrap::hipfftSetWorkArea(parent_, plan_, scratch_.opaque());
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
|
||||
return port::Status(port::error::INTERNAL,
|
||||
"Failed to set work area for rocFFT plan.");
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
|
||||
|
||||
int ROCMFftPlan::GetFftDirection() const {
|
||||
@ -507,7 +496,13 @@ std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
|
||||
|
||||
void ROCMFft::UpdatePlanWithScratchAllocator(
|
||||
Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
|
||||
LOG(ERROR) << "update plan with scratch allocator not implemented";
|
||||
ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
|
||||
port::Status status =
|
||||
rocm_fft_plan->UpdateScratchAllocator(stream, scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << "failed to update custom allocator for hipfft plan: "
|
||||
<< status.error_message();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FuncT, typename InputT, typename OutputT>
|
||||
|
@ -49,6 +49,7 @@ class ROCMFftPlan : public fft::Plan {
|
||||
plan_(),
|
||||
fft_type_(fft::Type::kInvalid),
|
||||
scratch_(nullptr),
|
||||
scratch_size_bytes_(0),
|
||||
is_initialized_(false) {}
|
||||
~ROCMFftPlan() override;
|
||||
|
||||
@ -75,6 +76,9 @@ class ROCMFftPlan : public fft::Plan {
|
||||
uint64 *elem_count, fft::Type type,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
|
||||
port::Status UpdateScratchAllocator(Stream *stream,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
|
||||
protected:
|
||||
bool IsInitialized() const { return is_initialized_; }
|
||||
|
||||
@ -83,6 +87,7 @@ class ROCMFftPlan : public fft::Plan {
|
||||
hipfftHandle plan_;
|
||||
fft::Type fft_type_;
|
||||
DeviceMemory<uint8> scratch_;
|
||||
size_t scratch_size_bytes_;
|
||||
bool is_initialized_;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user