[StreamExecutor] Workaround the cuFFT bug in CUDA 10.1/10.2/11.
See https://github.com/google/jax/issues/2874 for details. PiperOrigin-RevId: 319143928 Change-Id: I8c4759e90d6e9f6e134e5f2a241cb946d7db99b3
This commit is contained in:
parent
70d1d81d08
commit
1bf8f49335
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace stream_executor {
|
||||
@ -82,6 +83,7 @@ port::Status CUDAFftPlan::Initialize(
|
||||
LOG(FATAL) << "Try to repeatedly initialize.";
|
||||
}
|
||||
is_initialized_ = true;
|
||||
scratch_allocator_ = scratch_allocator;
|
||||
cuda::ScopedActivateExecutorContext sac(parent);
|
||||
int elem_count_[3], input_embed_[3], output_embed_[3];
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
@ -243,6 +245,8 @@ port::Status CUDAFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
|
||||
|
||||
port::Status CUDAFftPlan::UpdateScratchAllocator(
|
||||
Stream *stream, ScratchAllocator *scratch_allocator) {
|
||||
scratch_allocator_ = scratch_allocator;
|
||||
|
||||
if (scratch_size_bytes_ != 0) {
|
||||
auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
|
||||
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||
@ -455,6 +459,9 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
|
||||
const DeviceMemory<InputT> &input,
|
||||
DeviceMemory<OutputT> *output) {
|
||||
CUDAFftPlan *cuda_fft_plan = dynamic_cast<CUDAFftPlan *>(plan);
|
||||
|
||||
DeviceMemory<InputT> input_maybe_copy = input;
|
||||
|
||||
if (cuda_fft_plan == nullptr) {
|
||||
LOG(ERROR) << "the passed-in plan is not a CUDAFftPlan object.";
|
||||
return false;
|
||||
@ -464,10 +471,33 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Workaround a cuFFT bug, which mutates the input buffer when it shouldn't.
|
||||
// See b/155276727 and go/nvbugs/2959622.
|
||||
// TODO(b/155276727): refine the bounding condition.
|
||||
if (input.opaque() != output->opaque() && CUDA_VERSION >= 10010 &&
|
||||
CUDA_VERSION <= 11000 &&
|
||||
std::is_same<InputT, std::complex<float>>::value &&
|
||||
std::is_same<OutputT, float>::value && input.size() > 0) {
|
||||
auto *allocator = cuda_fft_plan->GetScratchAllocator();
|
||||
if (allocator) {
|
||||
auto allocated = allocator->AllocateBytes(input.size());
|
||||
if (allocated.ok()) {
|
||||
if (stream->ThenMemcpy(&allocated.ValueOrDie(), input, input.size())
|
||||
.ok()) {
|
||||
input_maybe_copy = DeviceMemory<InputT>(allocated.ValueOrDie());
|
||||
}
|
||||
}
|
||||
// Keep going even the workaround fails, since we don't have a good
|
||||
// bounding box. We don't want to give up on a potentially correct
|
||||
// execution just because the allocation for the incorrect case fails.
|
||||
}
|
||||
}
|
||||
|
||||
cuda::ScopedActivateExecutorContext sac(parent_);
|
||||
auto ret = cufftExec(cuda_fft_plan->GetPlan(),
|
||||
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
||||
GpuComplex(GpuMemoryMutable(output)));
|
||||
auto ret =
|
||||
cufftExec(cuda_fft_plan->GetPlan(),
|
||||
GpuComplex(const_cast<InputT *>(GpuMemory(input_maybe_copy))),
|
||||
GpuComplex(GpuMemoryMutable(output)));
|
||||
|
||||
if (ret != CUFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to run cuFFT routine: " << ret;
|
||||
|
@ -50,7 +50,8 @@ class CUDAFftPlan : public fft::Plan {
|
||||
fft_type_(fft::Type::kInvalid),
|
||||
scratch_(nullptr),
|
||||
scratch_size_bytes_(0),
|
||||
is_initialized_(false) {}
|
||||
is_initialized_(false),
|
||||
scratch_allocator_(nullptr) {}
|
||||
~CUDAFftPlan() override;
|
||||
|
||||
// Get FFT direction in cuFFT based on FFT type.
|
||||
@ -79,6 +80,8 @@ class CUDAFftPlan : public fft::Plan {
|
||||
port::Status UpdateScratchAllocator(Stream *stream,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
|
||||
ScratchAllocator* GetScratchAllocator() const { return scratch_allocator_; }
|
||||
|
||||
protected:
|
||||
bool IsInitialized() const { return is_initialized_; }
|
||||
|
||||
@ -89,6 +92,7 @@ class CUDAFftPlan : public fft::Plan {
|
||||
DeviceMemory<uint8> scratch_;
|
||||
size_t scratch_size_bytes_;
|
||||
bool is_initialized_;
|
||||
ScratchAllocator* scratch_allocator_;
|
||||
};
|
||||
|
||||
// FFT support for CUDA platform via cuFFT library.
|
||||
|
Loading…
Reference in New Issue
Block a user