Merge pull request #45743 from ROCmSoftwarePlatform:google_upstream_rocm_rocfft_se

PiperOrigin-RevId: 348624235
Change-Id: Ib32ac683913a9031a693b06f62e27721d4d6c33c
This commit is contained in:
TensorFlower Gardener 2020-12-22 06:26:57 -08:00
commit 2fc25c73bd
2 changed files with 38 additions and 38 deletions

View File

@ -230,12 +230,11 @@ port::Status ROCMFftPlan::Initialize(
return port::Status{port::error::INTERNAL, return port::Status{port::error::INTERNAL,
"Failed to set auto allocation for rocFFT plan."}; "Failed to set auto allocation for rocFFT plan."};
} }
size_t size_in_bytes;
switch (rank) { switch (rank) {
case 1: case 1:
ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0], ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
ROCMFftType(type), /*batch=*/1, ROCMFftType(type), /*batch=*/1,
&size_in_bytes); &scratch_size_bytes_);
if (ret != HIPFFT_SUCCESS) { if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret; LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
return port::Status{port::error::INTERNAL, return port::Status{port::error::INTERNAL,
@ -245,7 +244,7 @@ port::Status ROCMFftPlan::Initialize(
case 2: case 2:
ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0], ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
elem_count_[1], ROCMFftType(type), elem_count_[1], ROCMFftType(type),
&size_in_bytes); &scratch_size_bytes_);
if (ret != HIPFFT_SUCCESS) { if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret; LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
return port::Status{port::error::INTERNAL, return port::Status{port::error::INTERNAL,
@ -255,7 +254,7 @@ port::Status ROCMFftPlan::Initialize(
case 3: case 3:
ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0], ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
elem_count_[1], elem_count_[2], elem_count_[1], elem_count_[2],
ROCMFftType(type), &size_in_bytes); ROCMFftType(type), &scratch_size_bytes_);
if (ret != HIPFFT_SUCCESS) { if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret; LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
return port::Status{port::error::INTERNAL, return port::Status{port::error::INTERNAL,
@ -269,23 +268,7 @@ port::Status ROCMFftPlan::Initialize(
return port::Status{port::error::INVALID_ARGUMENT, return port::Status{port::error::INVALID_ARGUMENT,
"hipfftPlan only takes rank 1, 2, or 3."}; "hipfftPlan only takes rank 1, 2, or 3."};
} }
// TODO(yangzihao): refactor this code and the one with the same function return UpdateScratchAllocator(stream, scratch_allocator);
// in the batch mode.
if (size_in_bytes != 0) {
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "failed to allocate work area.";
return allocated.status();
}
}
// Connect work area with allocated space.
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to set work area for rocFFT plan."};
}
return port::Status::OK();
} }
} else { } else {
// For either multiple batches or rank higher than 3, use hipfftPlanMany(). // For either multiple batches or rank higher than 3, use hipfftPlanMany().
@ -315,31 +298,18 @@ port::Status ROCMFftPlan::Initialize(
port::error::INTERNAL, port::error::INTERNAL,
"Failed to set auto allocation for rocFFT batched plan."}; "Failed to set auto allocation for rocFFT batched plan."};
} }
size_t size_in_bytes;
ret = wrap::hipfftMakePlanMany( ret = wrap::hipfftMakePlanMany(
parent, plan_, rank, elem_count_, parent, plan_, rank, elem_count_,
input_embed ? input_embed_ : nullptr, input_stride, input_distance, input_embed ? input_embed_ : nullptr, input_stride, input_distance,
output_embed ? output_embed_ : nullptr, output_stride, output_embed ? output_embed_ : nullptr, output_stride,
output_distance, ROCMFftType(type), batch_count, &size_in_bytes); output_distance, ROCMFftType(type), batch_count,
&scratch_size_bytes_);
if (ret != HIPFFT_SUCCESS) { if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT batched plan:" << ret; LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
return port::Status{port::error::INTERNAL, return port::Status{port::error::INTERNAL,
"Failed to make rocFFT batched plan."}; "Failed to make rocFFT batched plan."};
} }
if (size_in_bytes != 0) { return UpdateScratchAllocator(stream, scratch_allocator);
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "failed to allocate work area.";
return allocated.status();
}
}
// Connect work area with allocated space.
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to set work area for rocFFT batched plan."};
}
} }
} }
return port::Status::OK(); return port::Status::OK();
@ -356,6 +326,25 @@ port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
/*output_distance=*/0, type, 1, scratch_allocator); /*output_distance=*/0, type, 1, scratch_allocator);
} }
port::Status ROCMFftPlan::UpdateScratchAllocator(
Stream *stream, ScratchAllocator *scratch_allocator) {
if (scratch_size_bytes_ != 0) {
auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "failed to allocate work area.";
return allocated.status();
}
}
// Connect work area with allocated space.
auto ret = wrap::hipfftSetWorkArea(parent_, plan_, scratch_.opaque());
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
return port::Status(port::error::INTERNAL,
"Failed to set work area for rocFFT plan.");
}
return port::Status::OK();
}
ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); } ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
int ROCMFftPlan::GetFftDirection() const { int ROCMFftPlan::GetFftDirection() const {
@ -507,7 +496,13 @@ std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
void ROCMFft::UpdatePlanWithScratchAllocator( void ROCMFft::UpdatePlanWithScratchAllocator(
Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) { Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
LOG(ERROR) << "update plan with scratch allocator not implemented"; ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
port::Status status =
rocm_fft_plan->UpdateScratchAllocator(stream, scratch_allocator);
if (!status.ok()) {
LOG(FATAL) << "failed to update custom allocator for hipfft plan: "
<< status.error_message();
}
} }
template <typename FuncT, typename InputT, typename OutputT> template <typename FuncT, typename InputT, typename OutputT>

View File

@ -49,6 +49,7 @@ class ROCMFftPlan : public fft::Plan {
plan_(), plan_(),
fft_type_(fft::Type::kInvalid), fft_type_(fft::Type::kInvalid),
scratch_(nullptr), scratch_(nullptr),
scratch_size_bytes_(0),
is_initialized_(false) {} is_initialized_(false) {}
~ROCMFftPlan() override; ~ROCMFftPlan() override;
@ -75,6 +76,9 @@ class ROCMFftPlan : public fft::Plan {
uint64 *elem_count, fft::Type type, uint64 *elem_count, fft::Type type,
ScratchAllocator *scratch_allocator); ScratchAllocator *scratch_allocator);
port::Status UpdateScratchAllocator(Stream *stream,
ScratchAllocator *scratch_allocator);
protected: protected:
bool IsInitialized() const { return is_initialized_; } bool IsInitialized() const { return is_initialized_; }
@ -83,6 +87,7 @@ class ROCMFftPlan : public fft::Plan {
hipfftHandle plan_; hipfftHandle plan_;
fft::Type fft_type_; fft::Type fft_type_;
DeviceMemory<uint8> scratch_; DeviceMemory<uint8> scratch_;
size_t scratch_size_bytes_;
bool is_initialized_; bool is_initialized_;
}; };