[StreamExecutor] Workaround the cuFFT bug in CUDA 10.1/10.2/11.

See https://github.com/google/jax/issues/2874 for details.

PiperOrigin-RevId: 319143928
Change-Id: I8c4759e90d6e9f6e134e5f2a241cb946d7db99b3
This commit is contained in:
Tim Shen 2020-06-30 19:04:16 -07:00 committed by TensorFlower Gardener
parent 70d1d81d08
commit 1bf8f49335
2 changed files with 38 additions and 4 deletions

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
namespace stream_executor {
@ -82,6 +83,7 @@ port::Status CUDAFftPlan::Initialize(
LOG(FATAL) << "Try to repeatedly initialize.";
}
is_initialized_ = true;
scratch_allocator_ = scratch_allocator;
cuda::ScopedActivateExecutorContext sac(parent);
int elem_count_[3], input_embed_[3], output_embed_[3];
for (int i = 0; i < rank; ++i) {
@ -243,6 +245,8 @@ port::Status CUDAFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
port::Status CUDAFftPlan::UpdateScratchAllocator(
Stream *stream, ScratchAllocator *scratch_allocator) {
scratch_allocator_ = scratch_allocator;
if (scratch_size_bytes_ != 0) {
auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
@ -455,6 +459,9 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
const DeviceMemory<InputT> &input,
DeviceMemory<OutputT> *output) {
CUDAFftPlan *cuda_fft_plan = dynamic_cast<CUDAFftPlan *>(plan);
DeviceMemory<InputT> input_maybe_copy = input;
if (cuda_fft_plan == nullptr) {
LOG(ERROR) << "the passed-in plan is not a CUDAFftPlan object.";
return false;
@ -464,10 +471,33 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
return false;
}
// Workaround a cuFFT bug, which mutates the input buffer when it shouldn't.
// See b/155276727 and go/nvbugs/2959622.
// TODO(b/155276727): refine the bounding condition.
if (input.opaque() != output->opaque() && CUDA_VERSION >= 10010 &&
CUDA_VERSION <= 11000 &&
std::is_same<InputT, std::complex<float>>::value &&
std::is_same<OutputT, float>::value && input.size() > 0) {
auto *allocator = cuda_fft_plan->GetScratchAllocator();
if (allocator) {
auto allocated = allocator->AllocateBytes(input.size());
if (allocated.ok()) {
if (stream->ThenMemcpy(&allocated.ValueOrDie(), input, input.size())
.ok()) {
input_maybe_copy = DeviceMemory<InputT>(allocated.ValueOrDie());
}
}
// Keep going even the workaround fails, since we don't have a good
// bounding box. We don't want to give up on a potentially correct
// execution just because the allocation for the incorrect case fails.
}
}
cuda::ScopedActivateExecutorContext sac(parent_);
auto ret = cufftExec(cuda_fft_plan->GetPlan(),
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
GpuComplex(GpuMemoryMutable(output)));
auto ret =
cufftExec(cuda_fft_plan->GetPlan(),
GpuComplex(const_cast<InputT *>(GpuMemory(input_maybe_copy))),
GpuComplex(GpuMemoryMutable(output)));
if (ret != CUFFT_SUCCESS) {
LOG(ERROR) << "failed to run cuFFT routine: " << ret;

View File

@ -50,7 +50,8 @@ class CUDAFftPlan : public fft::Plan {
fft_type_(fft::Type::kInvalid),
scratch_(nullptr),
scratch_size_bytes_(0),
is_initialized_(false) {}
is_initialized_(false),
scratch_allocator_(nullptr) {}
~CUDAFftPlan() override;
// Get FFT direction in cuFFT based on FFT type.
@ -79,6 +80,8 @@ class CUDAFftPlan : public fft::Plan {
port::Status UpdateScratchAllocator(Stream *stream,
ScratchAllocator *scratch_allocator);
ScratchAllocator* GetScratchAllocator() const { return scratch_allocator_; }
protected:
bool IsInitialized() const { return is_initialized_; }
@ -89,6 +92,7 @@ class CUDAFftPlan : public fft::Plan {
DeviceMemory<uint8> scratch_;
size_t scratch_size_bytes_;
bool is_initialized_;
ScratchAllocator* scratch_allocator_;
};
// FFT support for CUDA platform via cuFFT library.