From 7d185bb63ba346ac955b787540518f732e2ffb53 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 5 Nov 2020 16:47:49 -0800 Subject: [PATCH] [XLA:GPU] Don't share FFT plans across devices. Concurrent access to the same FFT plan is undefined behavior. Should fix https://github.com/google/jax/issues/3518 when merged into a jaxlib. PiperOrigin-RevId: 340953140 Change-Id: I7701ab8f1a0783d9bdec9bb16e20536be454616c --- .../compiler/xla/service/gpu/fft_thunk.cc | 44 +++++++++++-------- .../compiler/xla/service/gpu/fft_thunk.h | 9 +++- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index d3800c7e6b4..46226c5e8b3 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -128,7 +128,21 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - if (fft_plan_ == nullptr) { + FftPlan* fft_plan_ptr; + { + absl::MutexLock lock(&mu_); + std::unique_ptr& plan = + fft_plans_[buffer_allocations.device_ordinal()]; + if (!plan) { + plan = std::make_unique(); + } + fft_plan_ptr = plan.get(); + } + // CuFFT thread-safety requires that separate host threads not share plans; + // protect each plan with a mutex. + absl::MutexLock lock(&fft_plan_ptr->mu); + std::unique_ptr& fft_plan = fft_plan_ptr->plan; + if (fft_plan == nullptr) { const int64 fft_rank = fft_length_.size(); CHECK_LE(fft_rank, 3); int batch_size = 1; @@ -153,14 +167,14 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { } constexpr bool kInPlaceFft = false; - fft_plan_ = stream.parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( + fft_plan = stream.parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( &stream, fft_rank, fft_length, input_embed, input_stride, input_distance, output_embed, output_stride, output_distance, fft_type_, kInPlaceFft, batch_size, &scratch_allocator); scale_factor_ = 1.0f / output_distance; } else { stream.parent()->AsFft()->UpdatePlanWithScratchAllocator( - &stream, fft_plan_.get(), &scratch_allocator); + &stream, fft_plan.get(), &scratch_allocator); } bool launch_ok; @@ -170,8 +184,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - launch_ok = - stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok(); break; } case se::fft::Type::kZ2ZForward: { @@ -179,8 +192,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - launch_ok = - stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok(); break; } case se::fft::Type::kC2CInverse: { @@ -188,8 +200,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - launch_ok = - stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok(); if (launch_ok) { launch_ok = stream .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), @@ -203,8 +214,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - launch_ok = - stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok(); if (launch_ok) { launch_ok = stream @@ -219,8 +229,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - launch_ok = - stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok(); break; } case se::fft::Type::kD2Z: { @@ -228,8 +237,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - launch_ok = - stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok(); break; } case se::fft::Type::kC2R: { @@ -237,8 +245,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - launch_ok = - stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok(); if (launch_ok) { launch_ok = stream .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), @@ -252,8 +259,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - launch_ok = - stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok(); if (launch_ok) { launch_ok = stream .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index bde271216b5..05b052c83b7 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" @@ -80,7 +81,13 @@ class FftThunk : public Thunk { float scale_factor_; - std::unique_ptr fft_plan_; + // One plan per device ordinal. + absl::Mutex mu_; + struct FftPlan { + absl::Mutex mu; + std::unique_ptr plan; + }; + absl::flat_hash_map> fft_plans_ GUARDED_BY(mu_); const BufferAllocation::Slice input_buffer_; const BufferAllocation::Slice output_buffer_;