Merge pull request #45743 from ROCmSoftwarePlatform:google_upstream_rocm_rocfft_se

PiperOrigin-RevId: 348624235
Change-Id: Ib32ac683913a9031a693b06f62e27721d4d6c33c
This commit is contained in:
TensorFlower Gardener 2020-12-22 06:26:57 -08:00
commit 2fc25c73bd
2 changed files with 38 additions and 38 deletions

View File

@ -230,12 +230,11 @@ port::Status ROCMFftPlan::Initialize(
return port::Status{port::error::INTERNAL,
"Failed to set auto allocation for rocFFT plan."};
}
size_t size_in_bytes;
switch (rank) {
case 1:
ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
ROCMFftType(type), /*batch=*/1,
&size_in_bytes);
&scratch_size_bytes_);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
return port::Status{port::error::INTERNAL,
@ -245,7 +244,7 @@ port::Status ROCMFftPlan::Initialize(
case 2:
ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
elem_count_[1], ROCMFftType(type),
&size_in_bytes);
&scratch_size_bytes_);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
return port::Status{port::error::INTERNAL,
@ -255,7 +254,7 @@ port::Status ROCMFftPlan::Initialize(
case 3:
ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
elem_count_[1], elem_count_[2],
ROCMFftType(type), &size_in_bytes);
ROCMFftType(type), &scratch_size_bytes_);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
return port::Status{port::error::INTERNAL,
@ -269,23 +268,7 @@ port::Status ROCMFftPlan::Initialize(
return port::Status{port::error::INVALID_ARGUMENT,
"hipfftPlan only takes rank 1, 2, or 3."};
}
// TODO(yangzihao): refactor this code and the one with the same function
// in the batch mode.
if (size_in_bytes != 0) {
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "failed to allocate work area.";
return allocated.status();
}
}
// Connect work area with allocated space.
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to set work area for rocFFT plan."};
}
return port::Status::OK();
return UpdateScratchAllocator(stream, scratch_allocator);
}
} else {
// For either multiple batches or rank higher than 3, use hipfftPlanMany().
@ -315,31 +298,18 @@ port::Status ROCMFftPlan::Initialize(
port::error::INTERNAL,
"Failed to set auto allocation for rocFFT batched plan."};
}
size_t size_in_bytes;
ret = wrap::hipfftMakePlanMany(
parent, plan_, rank, elem_count_,
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
output_embed ? output_embed_ : nullptr, output_stride,
output_distance, ROCMFftType(type), batch_count, &size_in_bytes);
output_distance, ROCMFftType(type), batch_count,
&scratch_size_bytes_);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to make rocFFT batched plan."};
}
if (size_in_bytes != 0) {
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "failed to allocate work area.";
return allocated.status();
}
}
// Connect work area with allocated space.
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to set work area for rocFFT batched plan."};
}
return UpdateScratchAllocator(stream, scratch_allocator);
}
}
return port::Status::OK();
@ -356,6 +326,25 @@ port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
/*output_distance=*/0, type, 1, scratch_allocator);
}
port::Status ROCMFftPlan::UpdateScratchAllocator(
Stream *stream, ScratchAllocator *scratch_allocator) {
if (scratch_size_bytes_ != 0) {
auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "failed to allocate work area.";
return allocated.status();
}
}
// Connect work area with allocated space.
auto ret = wrap::hipfftSetWorkArea(parent_, plan_, scratch_.opaque());
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
return port::Status(port::error::INTERNAL,
"Failed to set work area for rocFFT plan.");
}
return port::Status::OK();
}
ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
int ROCMFftPlan::GetFftDirection() const {
@ -507,7 +496,13 @@ std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
void ROCMFft::UpdatePlanWithScratchAllocator(
Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
LOG(ERROR) << "update plan with scratch allocator not implemented";
ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
port::Status status =
rocm_fft_plan->UpdateScratchAllocator(stream, scratch_allocator);
if (!status.ok()) {
LOG(FATAL) << "failed to update custom allocator for hipfft plan: "
<< status.error_message();
}
}
template <typename FuncT, typename InputT, typename OutputT>

View File

@ -49,6 +49,7 @@ class ROCMFftPlan : public fft::Plan {
plan_(),
fft_type_(fft::Type::kInvalid),
scratch_(nullptr),
scratch_size_bytes_(0),
is_initialized_(false) {}
~ROCMFftPlan() override;
@ -75,6 +76,9 @@ class ROCMFftPlan : public fft::Plan {
uint64 *elem_count, fft::Type type,
ScratchAllocator *scratch_allocator);
port::Status UpdateScratchAllocator(Stream *stream,
ScratchAllocator *scratch_allocator);
protected:
bool IsInitialized() const { return is_initialized_; }
@ -83,6 +87,7 @@ class ROCMFftPlan : public fft::Plan {
hipfftHandle plan_;
fft::Type fft_type_;
DeviceMemory<uint8> scratch_;
size_t scratch_size_bytes_;
bool is_initialized_;
};