Merge pull request #45743 from ROCmSoftwarePlatform:google_upstream_rocm_rocfft_se
PiperOrigin-RevId: 348624235 Change-Id: Ib32ac683913a9031a693b06f62e27721d4d6c33c
This commit is contained in:
commit
2fc25c73bd
@ -230,12 +230,11 @@ port::Status ROCMFftPlan::Initialize(
|
|||||||
return port::Status{port::error::INTERNAL,
|
return port::Status{port::error::INTERNAL,
|
||||||
"Failed to set auto allocation for rocFFT plan."};
|
"Failed to set auto allocation for rocFFT plan."};
|
||||||
}
|
}
|
||||||
size_t size_in_bytes;
|
|
||||||
switch (rank) {
|
switch (rank) {
|
||||||
case 1:
|
case 1:
|
||||||
ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
|
ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
|
||||||
ROCMFftType(type), /*batch=*/1,
|
ROCMFftType(type), /*batch=*/1,
|
||||||
&size_in_bytes);
|
&scratch_size_bytes_);
|
||||||
if (ret != HIPFFT_SUCCESS) {
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
|
LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
|
||||||
return port::Status{port::error::INTERNAL,
|
return port::Status{port::error::INTERNAL,
|
||||||
@ -245,7 +244,7 @@ port::Status ROCMFftPlan::Initialize(
|
|||||||
case 2:
|
case 2:
|
||||||
ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
|
ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
|
||||||
elem_count_[1], ROCMFftType(type),
|
elem_count_[1], ROCMFftType(type),
|
||||||
&size_in_bytes);
|
&scratch_size_bytes_);
|
||||||
if (ret != HIPFFT_SUCCESS) {
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
|
LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
|
||||||
return port::Status{port::error::INTERNAL,
|
return port::Status{port::error::INTERNAL,
|
||||||
@ -255,7 +254,7 @@ port::Status ROCMFftPlan::Initialize(
|
|||||||
case 3:
|
case 3:
|
||||||
ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
|
ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
|
||||||
elem_count_[1], elem_count_[2],
|
elem_count_[1], elem_count_[2],
|
||||||
ROCMFftType(type), &size_in_bytes);
|
ROCMFftType(type), &scratch_size_bytes_);
|
||||||
if (ret != HIPFFT_SUCCESS) {
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
|
LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
|
||||||
return port::Status{port::error::INTERNAL,
|
return port::Status{port::error::INTERNAL,
|
||||||
@ -269,23 +268,7 @@ port::Status ROCMFftPlan::Initialize(
|
|||||||
return port::Status{port::error::INVALID_ARGUMENT,
|
return port::Status{port::error::INVALID_ARGUMENT,
|
||||||
"hipfftPlan only takes rank 1, 2, or 3."};
|
"hipfftPlan only takes rank 1, 2, or 3."};
|
||||||
}
|
}
|
||||||
// TODO(yangzihao): refactor this code and the one with the same function
|
return UpdateScratchAllocator(stream, scratch_allocator);
|
||||||
// in the batch mode.
|
|
||||||
if (size_in_bytes != 0) {
|
|
||||||
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
|
|
||||||
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
|
||||||
LOG(ERROR) << "failed to allocate work area.";
|
|
||||||
return allocated.status();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Connect work area with allocated space.
|
|
||||||
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
|
|
||||||
if (ret != HIPFFT_SUCCESS) {
|
|
||||||
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
|
|
||||||
return port::Status{port::error::INTERNAL,
|
|
||||||
"Failed to set work area for rocFFT plan."};
|
|
||||||
}
|
|
||||||
return port::Status::OK();
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For either multiple batches or rank higher than 3, use hipfftPlanMany().
|
// For either multiple batches or rank higher than 3, use hipfftPlanMany().
|
||||||
@ -315,31 +298,18 @@ port::Status ROCMFftPlan::Initialize(
|
|||||||
port::error::INTERNAL,
|
port::error::INTERNAL,
|
||||||
"Failed to set auto allocation for rocFFT batched plan."};
|
"Failed to set auto allocation for rocFFT batched plan."};
|
||||||
}
|
}
|
||||||
size_t size_in_bytes;
|
|
||||||
ret = wrap::hipfftMakePlanMany(
|
ret = wrap::hipfftMakePlanMany(
|
||||||
parent, plan_, rank, elem_count_,
|
parent, plan_, rank, elem_count_,
|
||||||
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
|
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
|
||||||
output_embed ? output_embed_ : nullptr, output_stride,
|
output_embed ? output_embed_ : nullptr, output_stride,
|
||||||
output_distance, ROCMFftType(type), batch_count, &size_in_bytes);
|
output_distance, ROCMFftType(type), batch_count,
|
||||||
|
&scratch_size_bytes_);
|
||||||
if (ret != HIPFFT_SUCCESS) {
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
|
LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
|
||||||
return port::Status{port::error::INTERNAL,
|
return port::Status{port::error::INTERNAL,
|
||||||
"Failed to make rocFFT batched plan."};
|
"Failed to make rocFFT batched plan."};
|
||||||
}
|
}
|
||||||
if (size_in_bytes != 0) {
|
return UpdateScratchAllocator(stream, scratch_allocator);
|
||||||
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
|
|
||||||
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
|
||||||
LOG(ERROR) << "failed to allocate work area.";
|
|
||||||
return allocated.status();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Connect work area with allocated space.
|
|
||||||
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
|
|
||||||
if (ret != HIPFFT_SUCCESS) {
|
|
||||||
LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret;
|
|
||||||
return port::Status{port::error::INTERNAL,
|
|
||||||
"Failed to set work area for rocFFT batched plan."};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
@ -356,6 +326,25 @@ port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
|
|||||||
/*output_distance=*/0, type, 1, scratch_allocator);
|
/*output_distance=*/0, type, 1, scratch_allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port::Status ROCMFftPlan::UpdateScratchAllocator(
|
||||||
|
Stream *stream, ScratchAllocator *scratch_allocator) {
|
||||||
|
if (scratch_size_bytes_ != 0) {
|
||||||
|
auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
|
||||||
|
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||||
|
LOG(ERROR) << "failed to allocate work area.";
|
||||||
|
return allocated.status();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Connect work area with allocated space.
|
||||||
|
auto ret = wrap::hipfftSetWorkArea(parent_, plan_, scratch_.opaque());
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
|
||||||
|
return port::Status(port::error::INTERNAL,
|
||||||
|
"Failed to set work area for rocFFT plan.");
|
||||||
|
}
|
||||||
|
return port::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
|
ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
|
||||||
|
|
||||||
int ROCMFftPlan::GetFftDirection() const {
|
int ROCMFftPlan::GetFftDirection() const {
|
||||||
@ -507,7 +496,13 @@ std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
|
|||||||
|
|
||||||
void ROCMFft::UpdatePlanWithScratchAllocator(
|
void ROCMFft::UpdatePlanWithScratchAllocator(
|
||||||
Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
|
Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
|
||||||
LOG(ERROR) << "update plan with scratch allocator not implemented";
|
ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
|
||||||
|
port::Status status =
|
||||||
|
rocm_fft_plan->UpdateScratchAllocator(stream, scratch_allocator);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL) << "failed to update custom allocator for hipfft plan: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename FuncT, typename InputT, typename OutputT>
|
template <typename FuncT, typename InputT, typename OutputT>
|
||||||
|
@ -49,6 +49,7 @@ class ROCMFftPlan : public fft::Plan {
|
|||||||
plan_(),
|
plan_(),
|
||||||
fft_type_(fft::Type::kInvalid),
|
fft_type_(fft::Type::kInvalid),
|
||||||
scratch_(nullptr),
|
scratch_(nullptr),
|
||||||
|
scratch_size_bytes_(0),
|
||||||
is_initialized_(false) {}
|
is_initialized_(false) {}
|
||||||
~ROCMFftPlan() override;
|
~ROCMFftPlan() override;
|
||||||
|
|
||||||
@ -75,6 +76,9 @@ class ROCMFftPlan : public fft::Plan {
|
|||||||
uint64 *elem_count, fft::Type type,
|
uint64 *elem_count, fft::Type type,
|
||||||
ScratchAllocator *scratch_allocator);
|
ScratchAllocator *scratch_allocator);
|
||||||
|
|
||||||
|
port::Status UpdateScratchAllocator(Stream *stream,
|
||||||
|
ScratchAllocator *scratch_allocator);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool IsInitialized() const { return is_initialized_; }
|
bool IsInitialized() const { return is_initialized_; }
|
||||||
|
|
||||||
@ -83,6 +87,7 @@ class ROCMFftPlan : public fft::Plan {
|
|||||||
hipfftHandle plan_;
|
hipfftHandle plan_;
|
||||||
fft::Type fft_type_;
|
fft::Type fft_type_;
|
||||||
DeviceMemory<uint8> scratch_;
|
DeviceMemory<uint8> scratch_;
|
||||||
|
size_t scratch_size_bytes_;
|
||||||
bool is_initialized_;
|
bool is_initialized_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user