[XLA:CPU] [XLA:GPU] Add support for double precision FFTs on CPU and GPU.
PiperOrigin-RevId: 314250560 Change-Id: Ib9b4a7ea2ec2cc480db09e62bc35cfdcaf1c3b9a
This commit is contained in:
parent
ef41a8e100
commit
37aaafb0c1
@ -1217,7 +1217,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
|
||||
auto operand = fft->operand(0);
|
||||
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
||||
/*instruction=*/*fft, /*operands=*/{operand},
|
||||
/*supported_types=*/{F32, C64}));
|
||||
/*supported_types=*/{F32, F64, C64, C128}));
|
||||
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
|
||||
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
|
||||
VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
|
||||
@ -1239,7 +1239,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
|
||||
llvm::FunctionType* fft_type = llvm::FunctionType::get(
|
||||
b_.getVoidTy(),
|
||||
{int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type,
|
||||
int64_type, int64_type, int64_type, int64_type},
|
||||
int32_type, int64_type, int64_type, int64_type, int64_type},
|
||||
/*isVarArg=*/false);
|
||||
|
||||
bool multi_threaded_eigen =
|
||||
@ -1258,6 +1258,8 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
|
||||
{GetExecutableRunOptionsArgument(),
|
||||
BitCast(GetEmittedValueFor(fft), int8_ptr_type),
|
||||
BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
|
||||
b_.getInt32(operand->shape().element_type() == F64 ||
|
||||
operand->shape().element_type() == C128),
|
||||
b_.getInt32(fft_rank), b_.getInt64(input_batch),
|
||||
b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
|
||||
b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),
|
||||
|
@ -28,13 +28,14 @@ using tensorflow::int64;
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft(
|
||||
const void* run_options_ptr, void* out, void* operand, int32 fft_type,
|
||||
int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1,
|
||||
int64 fft_length2) {
|
||||
int32 double_precision, int32 fft_rank, int64 input_batch,
|
||||
int64 fft_length0, int64 fft_length1, int64 fft_length2) {
|
||||
const xla::ExecutableRunOptions* run_options =
|
||||
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
|
||||
XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
|
||||
tensorflow::xla::EigenFftImpl(
|
||||
*run_options->intra_op_thread_pool(), out, operand,
|
||||
static_cast<tensorflow::xla::FftType>(fft_type), fft_rank, input_batch,
|
||||
fft_length0, fft_length1, fft_length2);
|
||||
static_cast<tensorflow::xla::FftType>(fft_type),
|
||||
static_cast<bool>(double_precision), fft_rank, input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
}
|
||||
|
@ -22,7 +22,8 @@ extern "C" {
|
||||
|
||||
extern void __xla_cpu_runtime_EigenFft(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out,
|
||||
void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank,
|
||||
void* operand, tensorflow::int32 fft_type,
|
||||
tensorflow::int32 double_precision, tensorflow::int32 fft_rank,
|
||||
tensorflow::int64 input_batch, tensorflow::int64 fft_length0,
|
||||
tensorflow::int64 fft_length1, tensorflow::int64 fft_length2);
|
||||
|
||||
|
@ -39,8 +39,8 @@ static constexpr int kFftTypeArraySize = 4;
|
||||
namespace internal {
|
||||
|
||||
// Computes either a forward or reverse complex-to-complex FFT.
|
||||
template <bool Forward, int FFTRank, typename EigenDevice>
|
||||
void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
|
||||
template <bool Forward, int FFTRank, typename EigenDevice, typename Complex>
|
||||
void EigenFftC2C(const EigenDevice& device, Complex* out, Complex* operand,
|
||||
int64 input_batch, int64 fft_length0, int64 fft_length1,
|
||||
int64 fft_length2) {
|
||||
// Create the axes (which are always trailing).
|
||||
@ -55,10 +55,10 @@ void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
|
||||
for (int i = 0; i < FFTRank; i++) {
|
||||
dims[i + 1] = fft_shape[i];
|
||||
}
|
||||
const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
|
||||
const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
|
||||
Eigen::Aligned>
|
||||
input(operand, dims);
|
||||
Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
|
||||
Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
|
||||
Eigen::Aligned>
|
||||
output(out, dims);
|
||||
output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
|
||||
@ -66,8 +66,8 @@ void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
|
||||
|
||||
// Computes a forward real->complex FFT, slicing out redundant negative
|
||||
// frequencies from the innermost dimension.
|
||||
template <int FFTRank, typename EigenDevice>
|
||||
void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
|
||||
template <int FFTRank, typename EigenDevice, typename Real, typename Complex>
|
||||
void EigenFftR2C(const EigenDevice& device, Complex* out, Real* operand,
|
||||
int64 input_batch, int64 fft_length0, int64 fft_length1,
|
||||
int64 fft_length2) {
|
||||
const std::array<int64, 3> fft_shape = {
|
||||
@ -81,10 +81,10 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
|
||||
in_dims[i + 1] = fft_shape[i];
|
||||
out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
|
||||
}
|
||||
const Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
|
||||
const Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
|
||||
Eigen::Aligned>
|
||||
input(operand, in_dims);
|
||||
Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
|
||||
Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
|
||||
Eigen::Aligned>
|
||||
output(out, out_dims);
|
||||
|
||||
@ -92,7 +92,7 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
|
||||
const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
|
||||
|
||||
// Compute the full FFT using a temporary tensor.
|
||||
Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
|
||||
Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
|
||||
|
||||
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
|
||||
full_fft.device(device) =
|
||||
@ -105,8 +105,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
|
||||
// Computes a reverse complex->real FFT, reconstructing redundant negative
|
||||
// frequencies using reverse conjugate on innermost dimension after doing IFFT
|
||||
// on outer dimensions.
|
||||
template <int FFTRank, typename EigenDevice>
|
||||
void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
|
||||
template <int FFTRank, typename EigenDevice, typename Complex, typename Real>
|
||||
void EigenFftC2R(const EigenDevice& device, Real* out, Complex* operand,
|
||||
int64 input_batch, int64 fft_length0, int64 fft_length1,
|
||||
int64 fft_length2) {
|
||||
const std::array<int64, 3> fft_shape = {
|
||||
@ -120,10 +120,10 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
|
||||
in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
|
||||
out_dims[i + 1] = fft_shape[i];
|
||||
}
|
||||
const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
|
||||
const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
|
||||
Eigen::Aligned>
|
||||
input(operand, in_dims);
|
||||
Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
|
||||
Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
|
||||
Eigen::Aligned>
|
||||
output(out, out_dims);
|
||||
|
||||
@ -131,7 +131,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
|
||||
// region we will slice from input given fft_shape. We slice input to
|
||||
// fft_shape on its inner-most dimensions, except the last (which we
|
||||
// slice to fft_shape[-1] / 2 + 1).
|
||||
Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
|
||||
Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
|
||||
|
||||
// Calculate the starting point and range of the source of
|
||||
// negative frequency part.
|
||||
@ -178,30 +178,59 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
|
||||
|
||||
template <int FFTRank, typename EigenDevice>
|
||||
void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
|
||||
FftType fft_type, int64 input_batch, int64 fft_length0,
|
||||
int64 fft_length1, int64 fft_length2) {
|
||||
FftType fft_type, bool double_precision,
|
||||
int64 input_batch, int64 fft_length0, int64 fft_length1,
|
||||
int64 fft_length2) {
|
||||
switch (fft_type) {
|
||||
case FftType::FFT:
|
||||
EigenFftC2C<true, FFTRank, EigenDevice>(
|
||||
device, static_cast<complex64*>(out),
|
||||
static_cast<complex64*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
if (double_precision) {
|
||||
EigenFftC2C<true, FFTRank, EigenDevice, complex128>(
|
||||
device, static_cast<complex128*>(out),
|
||||
static_cast<complex128*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
} else {
|
||||
EigenFftC2C<true, FFTRank, EigenDevice, complex64>(
|
||||
device, static_cast<complex64*>(out),
|
||||
static_cast<complex64*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
}
|
||||
break;
|
||||
case FftType::IFFT:
|
||||
EigenFftC2C<false, FFTRank, EigenDevice>(
|
||||
device, static_cast<complex64*>(out),
|
||||
static_cast<complex64*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
if (double_precision) {
|
||||
EigenFftC2C<false, FFTRank, EigenDevice, complex128>(
|
||||
device, static_cast<complex128*>(out),
|
||||
static_cast<complex128*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
} else {
|
||||
EigenFftC2C<false, FFTRank, EigenDevice, complex64>(
|
||||
device, static_cast<complex64*>(out),
|
||||
static_cast<complex64*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
}
|
||||
break;
|
||||
case FftType::RFFT:
|
||||
EigenFftR2C<FFTRank, EigenDevice>(
|
||||
device, static_cast<complex64*>(out), static_cast<float*>(operand),
|
||||
input_batch, fft_length0, fft_length1, fft_length2);
|
||||
if (double_precision) {
|
||||
EigenFftR2C<FFTRank, EigenDevice, double, complex128>(
|
||||
device, static_cast<complex128*>(out),
|
||||
static_cast<double*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
} else {
|
||||
EigenFftR2C<FFTRank, EigenDevice, float, complex64>(
|
||||
device, static_cast<complex64*>(out), static_cast<float*>(operand),
|
||||
input_batch, fft_length0, fft_length1, fft_length2);
|
||||
}
|
||||
break;
|
||||
case FftType::IRFFT:
|
||||
EigenFftC2R<FFTRank, EigenDevice>(
|
||||
device, static_cast<float*>(out), static_cast<complex64*>(operand),
|
||||
input_batch, fft_length0, fft_length1, fft_length2);
|
||||
if (double_precision) {
|
||||
EigenFftC2R<FFTRank, EigenDevice, complex128, double>(
|
||||
device, static_cast<double*>(out),
|
||||
static_cast<complex128*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
} else {
|
||||
EigenFftC2R<FFTRank, EigenDevice, complex64, float>(
|
||||
device, static_cast<float*>(out), static_cast<complex64*>(operand),
|
||||
input_batch, fft_length0, fft_length1, fft_length2);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
// Unsupported FFT type
|
||||
@ -213,22 +242,24 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
|
||||
|
||||
template <typename EigenDevice>
|
||||
void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
|
||||
FftType fft_type, int32 fft_rank, int64 input_batch,
|
||||
int64 fft_length0, int64 fft_length1, int64 fft_length2) {
|
||||
FftType fft_type, bool double_precision, int32 fft_rank,
|
||||
int64 input_batch, int64 fft_length0, int64 fft_length1,
|
||||
int64 fft_length2) {
|
||||
switch (fft_rank) {
|
||||
case 1:
|
||||
internal::EigenFftWithRank<1, EigenDevice>(
|
||||
device, out, operand, fft_type, input_batch, fft_length0, 0, 0);
|
||||
internal::EigenFftWithRank<1, EigenDevice>(device, out, operand, fft_type,
|
||||
double_precision, input_batch,
|
||||
fft_length0, 0, 0);
|
||||
break;
|
||||
case 2:
|
||||
internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
|
||||
input_batch, fft_length0,
|
||||
fft_length1, 0);
|
||||
double_precision, input_batch,
|
||||
fft_length0, fft_length1, 0);
|
||||
break;
|
||||
case 3:
|
||||
internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type,
|
||||
input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
internal::EigenFftWithRank<3, EigenDevice>(
|
||||
device, out, operand, fft_type, double_precision, input_batch,
|
||||
fft_length0, fft_length1, fft_length2);
|
||||
break;
|
||||
default:
|
||||
// Unsupported FFT rank
|
||||
|
@ -24,10 +24,11 @@ using tensorflow::int64;
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft(
|
||||
const void* run_options_ptr, void* out, void* operand, int32 fft_type,
|
||||
int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1,
|
||||
int64 fft_length2) {
|
||||
int32 double_precision, int32 fft_rank, int64 input_batch,
|
||||
int64 fft_length0, int64 fft_length1, int64 fft_length2) {
|
||||
tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand,
|
||||
static_cast<tensorflow::xla::FftType>(fft_type),
|
||||
fft_rank, input_batch, fft_length0, fft_length1,
|
||||
static_cast<bool>(double_precision), fft_rank,
|
||||
input_batch, fft_length0, fft_length1,
|
||||
fft_length2);
|
||||
}
|
||||
|
@ -22,7 +22,8 @@ extern "C" {
|
||||
|
||||
extern void __xla_cpu_runtime_EigenSingleThreadedFft(
|
||||
const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out,
|
||||
void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank,
|
||||
void* operand, tensorflow::int32 fft_type,
|
||||
tensorflow::int32 double_precision, tensorflow::int32 fft_rank,
|
||||
tensorflow::int64 input_batch, tensorflow::int64 fft_length0,
|
||||
tensorflow::int64 fft_length1, tensorflow::int64 fft_length2);
|
||||
|
||||
|
@ -60,16 +60,18 @@ StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
|
||||
|
||||
namespace {
|
||||
|
||||
se::fft::Type FftTypeToSeType(FftType type) {
|
||||
se::fft::Type FftTypeToSeType(FftType type, bool double_precision) {
|
||||
switch (type) {
|
||||
case FftType::FFT:
|
||||
return se::fft::Type::kC2CForward;
|
||||
return double_precision ? se::fft::Type::kZ2ZForward
|
||||
: se::fft::Type::kC2CForward;
|
||||
case FftType::IFFT:
|
||||
return se::fft::Type::kC2CInverse;
|
||||
return double_precision ? se::fft::Type::kZ2ZInverse
|
||||
: se::fft::Type::kC2CInverse;
|
||||
case FftType::IRFFT:
|
||||
return se::fft::Type::kC2R;
|
||||
return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R;
|
||||
case FftType::RFFT:
|
||||
return se::fft::Type::kR2C;
|
||||
return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C;
|
||||
default:
|
||||
LOG(FATAL) << "unsupported fft type";
|
||||
}
|
||||
@ -78,12 +80,16 @@ se::fft::Type FftTypeToSeType(FftType type) {
|
||||
string FftTypeToString(se::fft::Type type) {
|
||||
switch (type) {
|
||||
case se::fft::Type::kC2CForward:
|
||||
case se::fft::Type::kZ2ZForward:
|
||||
return "FFT";
|
||||
case se::fft::Type::kC2CInverse:
|
||||
case se::fft::Type::kZ2ZInverse:
|
||||
return "IFFT";
|
||||
case se::fft::Type::kC2R:
|
||||
case se::fft::Type::kZ2D:
|
||||
return "IRFFT";
|
||||
case se::fft::Type::kR2C:
|
||||
case se::fft::Type::kD2Z:
|
||||
return "RFFT";
|
||||
default:
|
||||
LOG(FATAL) << "unknown fft type";
|
||||
@ -98,7 +104,9 @@ FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
|
||||
const Shape& input_shape, const Shape& output_shape,
|
||||
const HloInstruction* hlo)
|
||||
: Thunk(Kind::kFft, hlo),
|
||||
fft_type_(FftTypeToSeType(fft_type)),
|
||||
fft_type_(
|
||||
FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
|
||||
input_shape.element_type() == C128)),
|
||||
fft_length_(fft_length.begin(), fft_length.end()),
|
||||
scale_factor_(1.0f),
|
||||
input_buffer_(input_buffer),
|
||||
@ -166,6 +174,15 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kZ2ZForward: {
|
||||
se::DeviceMemory<complex128> input_data(
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<complex128> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kC2CInverse: {
|
||||
se::DeviceMemory<complex64> input_data(
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
@ -181,6 +198,22 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kZ2ZInverse: {
|
||||
se::DeviceMemory<complex128> input_data(
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<complex128> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
if (launch_ok) {
|
||||
launch_ok =
|
||||
stream
|
||||
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
|
||||
complex128(scale_factor_), &output_data, 1)
|
||||
.ok();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kR2C: {
|
||||
se::DeviceMemory<float> input_data(
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
@ -190,6 +223,15 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kD2Z: {
|
||||
se::DeviceMemory<double> input_data(
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<complex128> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kC2R: {
|
||||
se::DeviceMemory<complex64> input_data(
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
@ -205,6 +247,21 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case se::fft::Type::kZ2D: {
|
||||
se::DeviceMemory<complex128> input_data(
|
||||
buffer_allocations.GetDeviceAddress(input_buffer_));
|
||||
se::DeviceMemory<double> output_data(
|
||||
buffer_allocations.GetDeviceAddress(output_buffer_));
|
||||
launch_ok =
|
||||
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
|
||||
if (launch_ok) {
|
||||
launch_ok = stream
|
||||
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
|
||||
scale_factor_, &output_data, 1)
|
||||
.ok();
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "unsupported fft type";
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -1856,7 +1857,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
switch (fft_type) {
|
||||
case FFT:
|
||||
case IFFT:
|
||||
if (in.element_type() != C64) {
|
||||
if (!primitive_util::IsComplexType(in.element_type())) {
|
||||
return InvalidArgument("%s requires complex input type, found %s.",
|
||||
FftType_Name(fft_type),
|
||||
PrimitiveType_Name(in.element_type()));
|
||||
@ -1864,8 +1865,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
RET_CHECK_RANK(in);
|
||||
return in;
|
||||
case RFFT: {
|
||||
if (in.element_type() != F32) {
|
||||
return InvalidArgument("RFFT requires F32 input type, found %s.",
|
||||
if (in.element_type() != F32 && in.element_type() != F64) {
|
||||
return InvalidArgument("RFFT requires F32 or F64 input type, found %s.",
|
||||
PrimitiveType_Name(in.element_type()));
|
||||
}
|
||||
RET_CHECK_RANK(in);
|
||||
@ -1880,7 +1881,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
fft_length[i]);
|
||||
}
|
||||
}
|
||||
Shape result = ShapeUtil::ChangeElementType(in, C64);
|
||||
Shape result = ShapeUtil::ChangeElementType(
|
||||
in, in.element_type() == F32 ? C64 : C128);
|
||||
// Preserve the size of zero-sized dimensions.
|
||||
if (fft_length[fft_rank - 1] != 0) {
|
||||
result.set_dimensions(result.dimensions_size() - 1,
|
||||
@ -1889,8 +1891,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
return result;
|
||||
}
|
||||
case IRFFT: {
|
||||
if (in.element_type() != C64) {
|
||||
return InvalidArgument("IRFFT requires C64 input type, found %s.",
|
||||
if (!primitive_util::IsComplexType(in.element_type())) {
|
||||
return InvalidArgument("IRFFT requires complex input type, found %s.",
|
||||
PrimitiveType_Name(in.element_type()));
|
||||
}
|
||||
RET_CHECK_RANK(in);
|
||||
|
@ -615,8 +615,7 @@ namespace fft {
|
||||
static const char* unsupported_rank = "only supports ranks 1-3";
|
||||
static const char* invalid_rank = "requires input of at least same rank";
|
||||
static const char* requires_complex_input = "requires complex input type";
|
||||
static const char* requires_f32_input = "requires F32 input type";
|
||||
static const char* requires_c64_input = "requires C64 input type";
|
||||
static const char* requires_f32_input = "requires F32 or F64 input type";
|
||||
static const char* dimensions_match = "innermost dimensions match fft_length";
|
||||
static const char* innermost_dimension_matches =
|
||||
"innermost dimension matches fft_length/2+1";
|
||||
@ -654,7 +653,7 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) {
|
||||
Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
|
||||
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
|
||||
fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
|
||||
fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input);
|
||||
fft::Pass(shape_c128, type, {16, 8}, shape_c128);
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) {
|
||||
@ -672,7 +671,7 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) {
|
||||
Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
|
||||
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
|
||||
fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
|
||||
fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input);
|
||||
fft::Pass(shape_c128, type, {16, 8}, shape_c128);
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) {
|
||||
@ -747,9 +746,10 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) {
|
||||
TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) {
|
||||
FftType type = FftType::IRFFT;
|
||||
Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
|
||||
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
|
||||
fft::Fail(shape_f32, type, {16, 8}, fft::requires_c64_input);
|
||||
fft::Fail(shape_c128, type, {16, 8}, fft::requires_c64_input);
|
||||
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5});
|
||||
Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8});
|
||||
fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
|
||||
fft::Pass(shape_c128, type, {16, 8}, shape_f64_out);
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user