[XLA:CPU] [XLA:GPU] Add support for double precision FFTs on CPU and GPU.

PiperOrigin-RevId: 314250560
Change-Id: Ib9b4a7ea2ec2cc480db09e62bc35cfdcaf1c3b9a
This commit is contained in:
Peter Hawkins 2020-06-01 19:20:02 -07:00 committed by TensorFlower Gardener
parent ef41a8e100
commit 37aaafb0c1
9 changed files with 165 additions and 69 deletions

View File

@ -1217,7 +1217,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
auto operand = fft->operand(0);
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*fft, /*operands=*/{operand},
/*supported_types=*/{F32, C64}));
/*supported_types=*/{F32, F64, C64, C128}));
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout()));
VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape());
@ -1239,7 +1239,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
llvm::FunctionType* fft_type = llvm::FunctionType::get(
b_.getVoidTy(),
{int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type,
int64_type, int64_type, int64_type, int64_type},
int32_type, int64_type, int64_type, int64_type, int64_type},
/*isVarArg=*/false);
bool multi_threaded_eigen =
@ -1258,6 +1258,8 @@ Status IrEmitter::HandleFft(HloInstruction* fft) {
{GetExecutableRunOptionsArgument(),
BitCast(GetEmittedValueFor(fft), int8_ptr_type),
BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()),
b_.getInt32(operand->shape().element_type() == F64 ||
operand->shape().element_type() == C128),
b_.getInt32(fft_rank), b_.getInt64(input_batch),
b_.getInt64(fft_rank > 0 ? fft_length[0] : 0),
b_.getInt64(fft_rank > 1 ? fft_length[1] : 0),

View File

@ -28,13 +28,14 @@ using tensorflow::int64;
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft(
const void* run_options_ptr, void* out, void* operand, int32 fft_type,
int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1,
int64 fft_length2) {
int32 double_precision, int32 fft_rank, int64 input_batch,
int64 fft_length0, int64 fft_length1, int64 fft_length2) {
const xla::ExecutableRunOptions* run_options =
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
tensorflow::xla::EigenFftImpl(
*run_options->intra_op_thread_pool(), out, operand,
static_cast<tensorflow::xla::FftType>(fft_type), fft_rank, input_batch,
fft_length0, fft_length1, fft_length2);
static_cast<tensorflow::xla::FftType>(fft_type),
static_cast<bool>(double_precision), fft_rank, input_batch, fft_length0,
fft_length1, fft_length2);
}

View File

@ -22,7 +22,8 @@ extern "C" {
extern void __xla_cpu_runtime_EigenFft(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out,
void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank,
void* operand, tensorflow::int32 fft_type,
tensorflow::int32 double_precision, tensorflow::int32 fft_rank,
tensorflow::int64 input_batch, tensorflow::int64 fft_length0,
tensorflow::int64 fft_length1, tensorflow::int64 fft_length2);

View File

@ -39,8 +39,8 @@ static constexpr int kFftTypeArraySize = 4;
namespace internal {
// Computes either a forward or reverse complex-to-complex FFT.
template <bool Forward, int FFTRank, typename EigenDevice>
void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
template <bool Forward, int FFTRank, typename EigenDevice, typename Complex>
void EigenFftC2C(const EigenDevice& device, Complex* out, Complex* operand,
int64 input_batch, int64 fft_length0, int64 fft_length1,
int64 fft_length2) {
// Create the axes (which are always trailing).
@ -55,10 +55,10 @@ void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
for (int i = 0; i < FFTRank; i++) {
dims[i + 1] = fft_shape[i];
}
const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
Eigen::Aligned>
input(operand, dims);
Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
Eigen::Aligned>
output(out, dims);
output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
@ -66,8 +66,8 @@ void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
// Computes a forward real->complex FFT, slicing out redundant negative
// frequencies from the innermost dimension.
template <int FFTRank, typename EigenDevice>
void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
template <int FFTRank, typename EigenDevice, typename Real, typename Complex>
void EigenFftR2C(const EigenDevice& device, Complex* out, Real* operand,
int64 input_batch, int64 fft_length0, int64 fft_length1,
int64 fft_length2) {
const std::array<int64, 3> fft_shape = {
@ -81,10 +81,10 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
in_dims[i + 1] = fft_shape[i];
out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
}
const Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
const Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
Eigen::Aligned>
input(operand, in_dims);
Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
Eigen::Aligned>
output(out, out_dims);
@ -92,7 +92,7 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
// Compute the full FFT using a temporary tensor.
Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
full_fft.device(device) =
@ -105,8 +105,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
// Computes a reverse complex->real FFT, reconstructing redundant negative
// frequencies using reverse conjugate on innermost dimension after doing IFFT
// on outer dimensions.
template <int FFTRank, typename EigenDevice>
void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
template <int FFTRank, typename EigenDevice, typename Complex, typename Real>
void EigenFftC2R(const EigenDevice& device, Real* out, Complex* operand,
int64 input_batch, int64 fft_length0, int64 fft_length1,
int64 fft_length2) {
const std::array<int64, 3> fft_shape = {
@ -120,10 +120,10 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
out_dims[i + 1] = fft_shape[i];
}
const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
Eigen::Aligned>
input(operand, in_dims);
Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
Eigen::Aligned>
output(out, out_dims);
@ -131,7 +131,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
// region we will slice from input given fft_shape. We slice input to
// fft_shape on its inner-most dimensions, except the last (which we
// slice to fft_shape[-1] / 2 + 1).
Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
// Calculate the starting point and range of the source of
// negative frequency part.
@ -178,30 +178,59 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
template <int FFTRank, typename EigenDevice>
void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
FftType fft_type, int64 input_batch, int64 fft_length0,
int64 fft_length1, int64 fft_length2) {
FftType fft_type, bool double_precision,
int64 input_batch, int64 fft_length0, int64 fft_length1,
int64 fft_length2) {
switch (fft_type) {
case FftType::FFT:
EigenFftC2C<true, FFTRank, EigenDevice>(
device, static_cast<complex64*>(out),
static_cast<complex64*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
if (double_precision) {
EigenFftC2C<true, FFTRank, EigenDevice, complex128>(
device, static_cast<complex128*>(out),
static_cast<complex128*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
} else {
EigenFftC2C<true, FFTRank, EigenDevice, complex64>(
device, static_cast<complex64*>(out),
static_cast<complex64*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
}
break;
case FftType::IFFT:
EigenFftC2C<false, FFTRank, EigenDevice>(
device, static_cast<complex64*>(out),
static_cast<complex64*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
if (double_precision) {
EigenFftC2C<false, FFTRank, EigenDevice, complex128>(
device, static_cast<complex128*>(out),
static_cast<complex128*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
} else {
EigenFftC2C<false, FFTRank, EigenDevice, complex64>(
device, static_cast<complex64*>(out),
static_cast<complex64*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
}
break;
case FftType::RFFT:
EigenFftR2C<FFTRank, EigenDevice>(
device, static_cast<complex64*>(out), static_cast<float*>(operand),
input_batch, fft_length0, fft_length1, fft_length2);
if (double_precision) {
EigenFftR2C<FFTRank, EigenDevice, double, complex128>(
device, static_cast<complex128*>(out),
static_cast<double*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
} else {
EigenFftR2C<FFTRank, EigenDevice, float, complex64>(
device, static_cast<complex64*>(out), static_cast<float*>(operand),
input_batch, fft_length0, fft_length1, fft_length2);
}
break;
case FftType::IRFFT:
EigenFftC2R<FFTRank, EigenDevice>(
device, static_cast<float*>(out), static_cast<complex64*>(operand),
input_batch, fft_length0, fft_length1, fft_length2);
if (double_precision) {
EigenFftC2R<FFTRank, EigenDevice, complex128, double>(
device, static_cast<double*>(out),
static_cast<complex128*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
} else {
EigenFftC2R<FFTRank, EigenDevice, complex64, float>(
device, static_cast<float*>(out), static_cast<complex64*>(operand),
input_batch, fft_length0, fft_length1, fft_length2);
}
break;
default:
// Unsupported FFT type
@ -213,22 +242,24 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
template <typename EigenDevice>
void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
FftType fft_type, int32 fft_rank, int64 input_batch,
int64 fft_length0, int64 fft_length1, int64 fft_length2) {
FftType fft_type, bool double_precision, int32 fft_rank,
int64 input_batch, int64 fft_length0, int64 fft_length1,
int64 fft_length2) {
switch (fft_rank) {
case 1:
internal::EigenFftWithRank<1, EigenDevice>(
device, out, operand, fft_type, input_batch, fft_length0, 0, 0);
internal::EigenFftWithRank<1, EigenDevice>(device, out, operand, fft_type,
double_precision, input_batch,
fft_length0, 0, 0);
break;
case 2:
internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
input_batch, fft_length0,
fft_length1, 0);
double_precision, input_batch,
fft_length0, fft_length1, 0);
break;
case 3:
internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type,
input_batch, fft_length0,
fft_length1, fft_length2);
internal::EigenFftWithRank<3, EigenDevice>(
device, out, operand, fft_type, double_precision, input_batch,
fft_length0, fft_length1, fft_length2);
break;
default:
// Unsupported FFT rank

View File

@ -24,10 +24,11 @@ using tensorflow::int64;
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft(
const void* run_options_ptr, void* out, void* operand, int32 fft_type,
int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1,
int64 fft_length2) {
int32 double_precision, int32 fft_rank, int64 input_batch,
int64 fft_length0, int64 fft_length1, int64 fft_length2) {
tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand,
static_cast<tensorflow::xla::FftType>(fft_type),
fft_rank, input_batch, fft_length0, fft_length1,
static_cast<bool>(double_precision), fft_rank,
input_batch, fft_length0, fft_length1,
fft_length2);
}

View File

@ -22,7 +22,8 @@ extern "C" {
extern void __xla_cpu_runtime_EigenSingleThreadedFft(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out,
void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank,
void* operand, tensorflow::int32 fft_type,
tensorflow::int32 double_precision, tensorflow::int32 fft_rank,
tensorflow::int64 input_batch, tensorflow::int64 fft_length0,
tensorflow::int64 fft_length1, tensorflow::int64 fft_length2);

View File

@ -60,16 +60,18 @@ StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
namespace {
se::fft::Type FftTypeToSeType(FftType type) {
se::fft::Type FftTypeToSeType(FftType type, bool double_precision) {
switch (type) {
case FftType::FFT:
return se::fft::Type::kC2CForward;
return double_precision ? se::fft::Type::kZ2ZForward
: se::fft::Type::kC2CForward;
case FftType::IFFT:
return se::fft::Type::kC2CInverse;
return double_precision ? se::fft::Type::kZ2ZInverse
: se::fft::Type::kC2CInverse;
case FftType::IRFFT:
return se::fft::Type::kC2R;
return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R;
case FftType::RFFT:
return se::fft::Type::kR2C;
return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C;
default:
LOG(FATAL) << "unsupported fft type";
}
@ -78,12 +80,16 @@ se::fft::Type FftTypeToSeType(FftType type) {
string FftTypeToString(se::fft::Type type) {
switch (type) {
case se::fft::Type::kC2CForward:
case se::fft::Type::kZ2ZForward:
return "FFT";
case se::fft::Type::kC2CInverse:
case se::fft::Type::kZ2ZInverse:
return "IFFT";
case se::fft::Type::kC2R:
case se::fft::Type::kZ2D:
return "IRFFT";
case se::fft::Type::kR2C:
case se::fft::Type::kD2Z:
return "RFFT";
default:
LOG(FATAL) << "unknown fft type";
@ -98,7 +104,9 @@ FftThunk::FftThunk(FftType fft_type, absl::Span<const int64> fft_length,
const Shape& input_shape, const Shape& output_shape,
const HloInstruction* hlo)
: Thunk(Kind::kFft, hlo),
fft_type_(FftTypeToSeType(fft_type)),
fft_type_(
FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
input_shape.element_type() == C128)),
fft_length_(fft_length.begin(), fft_length.end()),
scale_factor_(1.0f),
input_buffer_(input_buffer),
@ -166,6 +174,15 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
break;
}
case se::fft::Type::kZ2ZForward: {
se::DeviceMemory<complex128> input_data(
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<complex128> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
break;
}
case se::fft::Type::kC2CInverse: {
se::DeviceMemory<complex64> input_data(
buffer_allocations.GetDeviceAddress(input_buffer_));
@ -181,6 +198,22 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
}
break;
}
case se::fft::Type::kZ2ZInverse: {
se::DeviceMemory<complex128> input_data(
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<complex128> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
if (launch_ok) {
launch_ok =
stream
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
complex128(scale_factor_), &output_data, 1)
.ok();
}
break;
}
case se::fft::Type::kR2C: {
se::DeviceMemory<float> input_data(
buffer_allocations.GetDeviceAddress(input_buffer_));
@ -190,6 +223,15 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
break;
}
case se::fft::Type::kD2Z: {
se::DeviceMemory<double> input_data(
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<complex128> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
break;
}
case se::fft::Type::kC2R: {
se::DeviceMemory<complex64> input_data(
buffer_allocations.GetDeviceAddress(input_buffer_));
@ -205,6 +247,21 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
}
break;
}
case se::fft::Type::kZ2D: {
se::DeviceMemory<complex128> input_data(
buffer_allocations.GetDeviceAddress(input_buffer_));
se::DeviceMemory<double> output_data(
buffer_allocations.GetDeviceAddress(output_buffer_));
launch_ok =
stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok();
if (launch_ok) {
launch_ok = stream
.ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
scale_factor_, &output_data, 1)
.ok();
}
break;
}
default:
LOG(FATAL) << "unsupported fft type";
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@ -1856,7 +1857,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
switch (fft_type) {
case FFT:
case IFFT:
if (in.element_type() != C64) {
if (!primitive_util::IsComplexType(in.element_type())) {
return InvalidArgument("%s requires complex input type, found %s.",
FftType_Name(fft_type),
PrimitiveType_Name(in.element_type()));
@ -1864,8 +1865,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
RET_CHECK_RANK(in);
return in;
case RFFT: {
if (in.element_type() != F32) {
return InvalidArgument("RFFT requires F32 input type, found %s.",
if (in.element_type() != F32 && in.element_type() != F64) {
return InvalidArgument("RFFT requires F32 or F64 input type, found %s.",
PrimitiveType_Name(in.element_type()));
}
RET_CHECK_RANK(in);
@ -1880,7 +1881,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
fft_length[i]);
}
}
Shape result = ShapeUtil::ChangeElementType(in, C64);
Shape result = ShapeUtil::ChangeElementType(
in, in.element_type() == F32 ? C64 : C128);
// Preserve the size of zero-sized dimensions.
if (fft_length[fft_rank - 1] != 0) {
result.set_dimensions(result.dimensions_size() - 1,
@ -1889,8 +1891,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return result;
}
case IRFFT: {
if (in.element_type() != C64) {
return InvalidArgument("IRFFT requires C64 input type, found %s.",
if (!primitive_util::IsComplexType(in.element_type())) {
return InvalidArgument("IRFFT requires complex input type, found %s.",
PrimitiveType_Name(in.element_type()));
}
RET_CHECK_RANK(in);

View File

@ -615,8 +615,7 @@ namespace fft {
static const char* unsupported_rank = "only supports ranks 1-3";
static const char* invalid_rank = "requires input of at least same rank";
static const char* requires_complex_input = "requires complex input type";
static const char* requires_f32_input = "requires F32 input type";
static const char* requires_c64_input = "requires C64 input type";
static const char* requires_f32_input = "requires F32 or F64 input type";
static const char* dimensions_match = "innermost dimensions match fft_length";
static const char* innermost_dimension_matches =
"innermost dimension matches fft_length/2+1";
@ -654,7 +653,7 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) {
Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input);
fft::Pass(shape_c128, type, {16, 8}, shape_c128);
}
TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) {
@ -672,7 +671,7 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) {
Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input);
fft::Pass(shape_c128, type, {16, 8}, shape_c128);
}
TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) {
@ -747,9 +746,10 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) {
TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) {
FftType type = FftType::IRFFT;
Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
fft::Fail(shape_f32, type, {16, 8}, fft::requires_c64_input);
fft::Fail(shape_c128, type, {16, 8}, fft::requires_c64_input);
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5});
Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8});
fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
fft::Pass(shape_c128, type, {16, 8}, shape_f64_out);
}
TEST_F(ShapeInferenceTest, MapThatChangesElementType) {