diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc index 362105ce6a0..93aa789a451 100644 --- a/tensorflow/stream_executor/rocm/rocm_fft.cc +++ b/tensorflow/stream_executor/rocm/rocm_fft.cc @@ -230,12 +230,11 @@ port::Status ROCMFftPlan::Initialize( return port::Status{port::error::INTERNAL, "Failed to set auto allocation for rocFFT plan."}; } - size_t size_in_bytes; switch (rank) { case 1: ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0], ROCMFftType(type), /*batch=*/1, - &size_in_bytes); + &scratch_size_bytes_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret; return port::Status{port::error::INTERNAL, @@ -245,7 +244,7 @@ port::Status ROCMFftPlan::Initialize( case 2: ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0], elem_count_[1], ROCMFftType(type), - &size_in_bytes); + &scratch_size_bytes_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret; return port::Status{port::error::INTERNAL, @@ -255,7 +254,7 @@ port::Status ROCMFftPlan::Initialize( case 3: ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0], elem_count_[1], elem_count_[2], - ROCMFftType(type), &size_in_bytes); + ROCMFftType(type), &scratch_size_bytes_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret; return port::Status{port::error::INTERNAL, @@ -269,23 +268,7 @@ port::Status ROCMFftPlan::Initialize( return port::Status{port::error::INVALID_ARGUMENT, "hipfftPlan only takes rank 1, 2, or 3."}; } - // TODO(yangzihao): refactor this code and the one with the same function - // in the batch mode. - if (size_in_bytes != 0) { - auto allocated = scratch_allocator->AllocateBytes(size_in_bytes); - if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) { - LOG(ERROR) << "failed to allocate work area."; - return allocated.status(); - } - } - // Connect work area with allocated space. - ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque()); - if (ret != HIPFFT_SUCCESS) { - LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret; - return port::Status{port::error::INTERNAL, - "Failed to set work area for rocFFT plan."}; - } - return port::Status::OK(); + return UpdateScratchAllocator(stream, scratch_allocator); } } else { // For either multiple batches or rank higher than 3, use hipfftPlanMany(). @@ -315,31 +298,18 @@ port::Status ROCMFftPlan::Initialize( port::error::INTERNAL, "Failed to set auto allocation for rocFFT batched plan."}; } - size_t size_in_bytes; ret = wrap::hipfftMakePlanMany( parent, plan_, rank, elem_count_, input_embed ? input_embed_ : nullptr, input_stride, input_distance, output_embed ? output_embed_ : nullptr, output_stride, - output_distance, ROCMFftType(type), batch_count, &size_in_bytes); + output_distance, ROCMFftType(type), batch_count, + &scratch_size_bytes_); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to make rocFFT batched plan:" << ret; return port::Status{port::error::INTERNAL, "Failed to make rocFFT batched plan."}; } - if (size_in_bytes != 0) { - auto allocated = scratch_allocator->AllocateBytes(size_in_bytes); - if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) { - LOG(ERROR) << "failed to allocate work area."; - return allocated.status(); - } - } - // Connect work area with allocated space. - ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque()); - if (ret != HIPFFT_SUCCESS) { - LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret; - return port::Status{port::error::INTERNAL, - "Failed to set work area for rocFFT batched plan."}; - } + return UpdateScratchAllocator(stream, scratch_allocator); } } return port::Status::OK(); @@ -356,6 +326,25 @@ port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream, /*output_distance=*/0, type, 1, scratch_allocator); } +port::Status ROCMFftPlan::UpdateScratchAllocator( + Stream *stream, ScratchAllocator *scratch_allocator) { + if (scratch_size_bytes_ != 0) { + auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_); + if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) { + LOG(ERROR) << "failed to allocate work area."; + return allocated.status(); + } + } + // Connect work area with allocated space. + auto ret = wrap::hipfftSetWorkArea(parent_, plan_, scratch_.opaque()); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret; + return port::Status(port::error::INTERNAL, + "Failed to set work area for rocFFT plan."); + } + return port::Status::OK(); +} + ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); } int ROCMFftPlan::GetFftDirection() const { @@ -507,7 +496,13 @@ std::unique_ptr ROCMFft::CreateBatchedPlanWithScratchAllocator( void ROCMFft::UpdatePlanWithScratchAllocator( Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) { - LOG(ERROR) << "update plan with scratch allocator not implemented"; + ROCMFftPlan *rocm_fft_plan = dynamic_cast(plan); + port::Status status = + rocm_fft_plan->UpdateScratchAllocator(stream, scratch_allocator); + if (!status.ok()) { + LOG(FATAL) << "failed to update custom allocator for hipfft plan: " + << status.error_message(); + } } template diff --git a/tensorflow/stream_executor/rocm/rocm_fft.h b/tensorflow/stream_executor/rocm/rocm_fft.h index 7086d8a4b12..cf504aa56d0 100644 --- a/tensorflow/stream_executor/rocm/rocm_fft.h +++ b/tensorflow/stream_executor/rocm/rocm_fft.h @@ -49,6 +49,7 @@ class ROCMFftPlan : public fft::Plan { plan_(), fft_type_(fft::Type::kInvalid), scratch_(nullptr), + scratch_size_bytes_(0), is_initialized_(false) {} ~ROCMFftPlan() override; @@ -75,6 +76,9 @@ class ROCMFftPlan : public fft::Plan { uint64 *elem_count, fft::Type type, ScratchAllocator *scratch_allocator); + port::Status UpdateScratchAllocator(Stream *stream, + ScratchAllocator *scratch_allocator); + protected: bool IsInitialized() const { return is_initialized_; } @@ -83,6 +87,7 @@ class ROCMFftPlan : public fft::Plan { hipfftHandle plan_; fft::Type fft_type_; DeviceMemory scratch_; + size_t scratch_size_bytes_; bool is_initialized_; };