[XLA:GPU] Don't share FFT plans across devices.

Concurrent access to the same FFT plan is undefined behavior.

Should fix https://github.com/google/jax/issues/3518 when merged into a jaxlib.

PiperOrigin-RevId: 340953140
Change-Id: I7701ab8f1a0783d9bdec9bb16e20536be454616c
This commit is contained in:
Peter Hawkins 2020-11-05 16:47:49 -08:00 committed by TensorFlower Gardener
parent a26cc18a8b
commit 7d185bb63b
2 changed files with 33 additions and 20 deletions

View File

@ -128,7 +128,21 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
if (fft_plan_ == nullptr) {
FftPlan* fft_plan_ptr;
{
absl::MutexLock lock(&mu_);
std::unique_ptr<FftPlan>& plan =
fft_plans_[buffer_allocations.device_ordinal()];
if (!plan) {
plan = std::make_unique<FftPlan>();
}
fft_plan_ptr = plan.get();
}
// CuFFT thread-safety requires that separate host threads not share plans;
// protect each plan with a mutex.
absl::MutexLock lock(&fft_plan_ptr->mu);
std::unique_ptr<se::fft::Plan>& fft_plan = fft_plan_ptr->plan;
if (fft_plan == nullptr) {
const int64 fft_rank = fft_length_.size();
CHECK_LE(fft_rank, 3);
int batch_size = 1;
@ -153,14 +167,14 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
}
constexpr bool kInPlaceFft = false;
fft_plan_ = stream.parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
fft_plan = stream.parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
&stream, fft_rank, fft_length, input_embed, input_stride,
input_distance, output_embed, output_stride, output_distance, fft_type_,
kInPlaceFft, batch_size, &scratch_allocator);
scale_factor_ = 1.0f / output_distance;
} else {
stream.parent()->AsFft()->UpdatePlanWithScratchAllocator(
&stream, fft_plan_.get(), &scratch_allocator);
&stream, fft_plan.get(), &scratch_allocator);
}
bool launch_ok;
@ -170,8 +184,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<complex64> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
break;
}
case se::fft::Type::kZ2ZForward: {
@ -179,8 +192,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<complex128> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
break;
}
case se::fft::Type::kC2CInverse: {
@ -188,8 +200,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<complex64> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
if (launch_ok) {
launch_ok = stream
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
@ -203,8 +214,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<complex128> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
if (launch_ok) {
launch_ok =
stream
@ -219,8 +229,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<complex64> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
break;
}
case se::fft::Type::kD2Z: {
@ -228,8 +237,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<complex128> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
break;
}
case se::fft::Type::kC2R: {
@ -237,8 +245,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<float> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
if (launch_ok) {
launch_ok = stream
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
@ -252,8 +259,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<double> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
if (launch_ok) {
launch_ok = stream
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
@ -80,7 +81,13 @@ class FftThunk : public Thunk {
float scale_factor_;
std::unique_ptr<se::fft::Plan> fft_plan_;
// One plan per device ordinal.
absl::Mutex mu_;
struct FftPlan {
absl::Mutex mu;
std::unique_ptr<se::fft::Plan> plan;
};
absl::flat_hash_map<int, std::unique_ptr<FftPlan>> fft_plans_ GUARDED_BY(mu_);
const BufferAllocation::Slice input_buffer_;
const BufferAllocation::Slice output_buffer_;