[XLA:GPU] Don't share FFT plans across devices.
Concurrent access to the same FFT plan is undefined behavior. Should fix https://github.com/google/jax/issues/3518 when merged into a jaxlib. PiperOrigin-RevId: 340953140 Change-Id: I7701ab8f1a0783d9bdec9bb16e20536be454616c
This commit is contained in:
parent
a26cc18a8b
commit
7d185bb63b
@ -128,7 +128,21 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
if (fft_plan_ == nullptr) {
|
||||
FftPlan* fft_plan_ptr;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
std::unique_ptr<FftPlan>& plan =
|
||||
fft_plans_[buffer_allocations.device_ordinal()];
|
||||
if (!plan) {
|
||||
plan = std::make_unique<FftPlan>();
|
||||
}
|
||||
fft_plan_ptr = plan.get();
|
||||
}
|
||||
// CuFFT thread-safety requires that separate host threads not share plans;
|
||||
// protect each plan with a mutex.
|
||||
absl::MutexLock lock(&fft_plan_ptr->mu);
|
||||
std::unique_ptr<se::fft::Plan>& fft_plan = fft_plan_ptr->plan;
|
||||
if (fft_plan == nullptr) {
|
||||
const int64 fft_rank = fft_length_.size();
|
||||
CHECK_LE(fft_rank, 3);
|
||||
int batch_size = 1;
|
||||
@ -153,14 +167,14 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
}
|
||||
|
||||
constexpr bool kInPlaceFft = false;
|
||||
fft_plan_ = stream.parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
|
||||
fft_plan = stream.parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
|
||||
&stream, fft_rank, fft_length, input_embed, input_stride,
|
||||
input_distance, output_embed, output_stride, output_distance, fft_type_,
|
||||
kInPlaceFft, batch_size, &scratch_allocator);
|
||||
scale_factor_ = 1.0f / output_distance;
|
||||
} else {
|
||||
stream.parent()->AsFft()->UpdatePlanWithScratchAllocator(
|
||||
&stream, fft_plan_.get(), &scratch_allocator);
|
||||
&stream, fft_plan.get(), &scratch_allocator);
|
||||
}
|
||||
|
||||
bool launch_ok;
|
||||
@ -170,8 +184,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<complex64> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kZ2ZForward: {
|
||||
@ -179,8 +192,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<complex128> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kC2CInverse: {
|
||||
@ -188,8 +200,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<complex64> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
|
||||
if (launch_ok) {
|
||||
launch_ok = stream
|
||||
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
|
||||
@ -203,8 +214,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<complex128> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
|
||||
if (launch_ok) {
|
||||
launch_ok =
|
||||
stream
|
||||
@ -219,8 +229,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<complex64> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kD2Z: {
|
||||
@ -228,8 +237,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<complex128> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kC2R: {
|
||||
@ -237,8 +245,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<float> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
|
||||
if (launch_ok) {
|
||||
launch_ok = stream
|
||||
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
|
||||
@ -252,8 +259,7 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<double> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
|
||||
if (launch_ok) {
|
||||
launch_ok = stream
|
||||
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
@ -80,7 +81,13 @@ class FftThunk : public Thunk {
|
||||
|
||||
float scale_factor_;
|
||||
|
||||
std::unique_ptr<se::fft::Plan> fft_plan_;
|
||||
// One plan per device ordinal.
|
||||
absl::Mutex mu_;
|
||||
struct FftPlan {
|
||||
absl::Mutex mu;
|
||||
std::unique_ptr<se::fft::Plan> plan;
|
||||
};
|
||||
absl::flat_hash_map<int, std::unique_ptr<FftPlan>> fft_plans_ GUARDED_BY(mu_);
|
||||
|
||||
const BufferAllocation::Slice input_buffer_;
|
||||
const BufferAllocation::Slice output_buffer_;
|
||||
|
Loading…
x
Reference in New Issue
Block a user