[StreamExecutor] Workaround the cuFFT bug in CUDA 10.1/10.2/11.
See https://github.com/google/jax/issues/2874 for details. PiperOrigin-RevId: 319143928 Change-Id: I8c4759e90d6e9f6e134e5f2a241cb946d7db99b3
This commit is contained in:
parent
70d1d81d08
commit
1bf8f49335
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/platform/logging.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
|
#include "tensorflow/stream_executor/stream.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
@ -82,6 +83,7 @@ port::Status CUDAFftPlan::Initialize(
|
|||||||
LOG(FATAL) << "Try to repeatedly initialize.";
|
LOG(FATAL) << "Try to repeatedly initialize.";
|
||||||
}
|
}
|
||||||
is_initialized_ = true;
|
is_initialized_ = true;
|
||||||
|
scratch_allocator_ = scratch_allocator;
|
||||||
cuda::ScopedActivateExecutorContext sac(parent);
|
cuda::ScopedActivateExecutorContext sac(parent);
|
||||||
int elem_count_[3], input_embed_[3], output_embed_[3];
|
int elem_count_[3], input_embed_[3], output_embed_[3];
|
||||||
for (int i = 0; i < rank; ++i) {
|
for (int i = 0; i < rank; ++i) {
|
||||||
@ -243,6 +245,8 @@ port::Status CUDAFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
|
|||||||
|
|
||||||
port::Status CUDAFftPlan::UpdateScratchAllocator(
|
port::Status CUDAFftPlan::UpdateScratchAllocator(
|
||||||
Stream *stream, ScratchAllocator *scratch_allocator) {
|
Stream *stream, ScratchAllocator *scratch_allocator) {
|
||||||
|
scratch_allocator_ = scratch_allocator;
|
||||||
|
|
||||||
if (scratch_size_bytes_ != 0) {
|
if (scratch_size_bytes_ != 0) {
|
||||||
auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
|
auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
|
||||||
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||||
@ -455,6 +459,9 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
|
|||||||
const DeviceMemory<InputT> &input,
|
const DeviceMemory<InputT> &input,
|
||||||
DeviceMemory<OutputT> *output) {
|
DeviceMemory<OutputT> *output) {
|
||||||
CUDAFftPlan *cuda_fft_plan = dynamic_cast<CUDAFftPlan *>(plan);
|
CUDAFftPlan *cuda_fft_plan = dynamic_cast<CUDAFftPlan *>(plan);
|
||||||
|
|
||||||
|
DeviceMemory<InputT> input_maybe_copy = input;
|
||||||
|
|
||||||
if (cuda_fft_plan == nullptr) {
|
if (cuda_fft_plan == nullptr) {
|
||||||
LOG(ERROR) << "the passed-in plan is not a CUDAFftPlan object.";
|
LOG(ERROR) << "the passed-in plan is not a CUDAFftPlan object.";
|
||||||
return false;
|
return false;
|
||||||
@ -464,10 +471,33 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Workaround a cuFFT bug, which mutates the input buffer when it shouldn't.
|
||||||
|
// See b/155276727 and go/nvbugs/2959622.
|
||||||
|
// TODO(b/155276727): refine the bounding condition.
|
||||||
|
if (input.opaque() != output->opaque() && CUDA_VERSION >= 10010 &&
|
||||||
|
CUDA_VERSION <= 11000 &&
|
||||||
|
std::is_same<InputT, std::complex<float>>::value &&
|
||||||
|
std::is_same<OutputT, float>::value && input.size() > 0) {
|
||||||
|
auto *allocator = cuda_fft_plan->GetScratchAllocator();
|
||||||
|
if (allocator) {
|
||||||
|
auto allocated = allocator->AllocateBytes(input.size());
|
||||||
|
if (allocated.ok()) {
|
||||||
|
if (stream->ThenMemcpy(&allocated.ValueOrDie(), input, input.size())
|
||||||
|
.ok()) {
|
||||||
|
input_maybe_copy = DeviceMemory<InputT>(allocated.ValueOrDie());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Keep going even the workaround fails, since we don't have a good
|
||||||
|
// bounding box. We don't want to give up on a potentially correct
|
||||||
|
// execution just because the allocation for the incorrect case fails.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cuda::ScopedActivateExecutorContext sac(parent_);
|
cuda::ScopedActivateExecutorContext sac(parent_);
|
||||||
auto ret = cufftExec(cuda_fft_plan->GetPlan(),
|
auto ret =
|
||||||
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
cufftExec(cuda_fft_plan->GetPlan(),
|
||||||
GpuComplex(GpuMemoryMutable(output)));
|
GpuComplex(const_cast<InputT *>(GpuMemory(input_maybe_copy))),
|
||||||
|
GpuComplex(GpuMemoryMutable(output)));
|
||||||
|
|
||||||
if (ret != CUFFT_SUCCESS) {
|
if (ret != CUFFT_SUCCESS) {
|
||||||
LOG(ERROR) << "failed to run cuFFT routine: " << ret;
|
LOG(ERROR) << "failed to run cuFFT routine: " << ret;
|
||||||
|
@ -50,7 +50,8 @@ class CUDAFftPlan : public fft::Plan {
|
|||||||
fft_type_(fft::Type::kInvalid),
|
fft_type_(fft::Type::kInvalid),
|
||||||
scratch_(nullptr),
|
scratch_(nullptr),
|
||||||
scratch_size_bytes_(0),
|
scratch_size_bytes_(0),
|
||||||
is_initialized_(false) {}
|
is_initialized_(false),
|
||||||
|
scratch_allocator_(nullptr) {}
|
||||||
~CUDAFftPlan() override;
|
~CUDAFftPlan() override;
|
||||||
|
|
||||||
// Get FFT direction in cuFFT based on FFT type.
|
// Get FFT direction in cuFFT based on FFT type.
|
||||||
@ -79,6 +80,8 @@ class CUDAFftPlan : public fft::Plan {
|
|||||||
port::Status UpdateScratchAllocator(Stream *stream,
|
port::Status UpdateScratchAllocator(Stream *stream,
|
||||||
ScratchAllocator *scratch_allocator);
|
ScratchAllocator *scratch_allocator);
|
||||||
|
|
||||||
|
ScratchAllocator* GetScratchAllocator() const { return scratch_allocator_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool IsInitialized() const { return is_initialized_; }
|
bool IsInitialized() const { return is_initialized_; }
|
||||||
|
|
||||||
@ -89,6 +92,7 @@ class CUDAFftPlan : public fft::Plan {
|
|||||||
DeviceMemory<uint8> scratch_;
|
DeviceMemory<uint8> scratch_;
|
||||||
size_t scratch_size_bytes_;
|
size_t scratch_size_bytes_;
|
||||||
bool is_initialized_;
|
bool is_initialized_;
|
||||||
|
ScratchAllocator* scratch_allocator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// FFT support for CUDA platform via cuFFT library.
|
// FFT support for CUDA platform via cuFFT library.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user