From 37aaafb0c1baa7acd0607748326cc12faf556277 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 1 Jun 2020 19:20:02 -0700 Subject: [PATCH] [XLA:CPU] [XLA:GPU] Add support for double precision FFTs on CPU and GPU. PiperOrigin-RevId: 314250560 Change-Id: Ib9b4a7ea2ec2cc480db09e62bc35cfdcaf1c3b9a --- .../compiler/xla/service/cpu/ir_emitter.cc | 6 +- .../compiler/xla/service/cpu/runtime_fft.cc | 9 +- .../compiler/xla/service/cpu/runtime_fft.h | 3 +- .../xla/service/cpu/runtime_fft_impl.h | 109 +++++++++++------- .../cpu/runtime_single_threaded_fft.cc | 7 +- .../service/cpu/runtime_single_threaded_fft.h | 3 +- .../compiler/xla/service/gpu/fft_thunk.cc | 69 ++++++++++- .../compiler/xla/service/shape_inference.cc | 14 ++- .../xla/service/shape_inference_test.cc | 14 +-- 9 files changed, 165 insertions(+), 69 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 1e204afb001..998b9db132c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1217,7 +1217,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { auto operand = fft->operand(0); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*fft, /*operands=*/{operand}, - /*supported_types=*/{F32, C64})); + /*supported_types=*/{F32, F64, C64, C128})); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape()); @@ -1239,7 +1239,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { llvm::FunctionType* fft_type = llvm::FunctionType::get( b_.getVoidTy(), {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, - int64_type, int64_type, int64_type, int64_type}, + int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); bool multi_threaded_eigen = @@ -1258,6 +1258,8 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {GetExecutableRunOptionsArgument(), BitCast(GetEmittedValueFor(fft), int8_ptr_type), BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), + b_.getInt32(operand->shape().element_type() == F64 || + operand->shape().element_type() == C128), b_.getInt32(fft_rank), b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc index 051120be324..0c1e9dae751 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc @@ -28,13 +28,14 @@ using tensorflow::int64; TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft( const void* run_options_ptr, void* out, void* operand, int32 fft_type, - int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, - int64 fft_length2) { + int32 double_precision, int32 fft_rank, int64 input_batch, + int64 fft_length0, int64 fft_length1, int64 fft_length2) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); tensorflow::xla::EigenFftImpl( *run_options->intra_op_thread_pool(), out, operand, - static_cast(fft_type), fft_rank, input_batch, - fft_length0, fft_length1, fft_length2); + static_cast(fft_type), + static_cast(double_precision), fft_rank, input_batch, fft_length0, + fft_length1, fft_length2); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_fft.h index f20c5aa0aa2..d95da172116 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.h @@ -22,7 +22,8 @@ extern "C" { extern void __xla_cpu_runtime_EigenFft( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, - void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + void* operand, tensorflow::int32 fft_type, + tensorflow::int32 double_precision, tensorflow::int32 fft_rank, tensorflow::int64 input_batch, tensorflow::int64 fft_length0, tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 04dea120a8d..124e7d589a0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -39,8 +39,8 @@ static constexpr int kFftTypeArraySize = 4; namespace internal { // Computes either a forward or reverse complex-to-complex FFT. -template -void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand, +template +void EigenFftC2C(const EigenDevice& device, Complex* out, Complex* operand, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { // Create the axes (which are always trailing). @@ -55,10 +55,10 @@ void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand, for (int i = 0; i < FFTRank; i++) { dims[i + 1] = fft_shape[i]; } - const Eigen::TensorMap, + const Eigen::TensorMap, Eigen::Aligned> input(operand, dims); - Eigen::TensorMap, + Eigen::TensorMap, Eigen::Aligned> output(out, dims); output.device(device) = input.template fft(axes); @@ -66,8 +66,8 @@ void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand, // Computes a forward real->complex FFT, slicing out redundant negative // frequencies from the innermost dimension. -template -void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, +template +void EigenFftR2C(const EigenDevice& device, Complex* out, Real* operand, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { const std::array fft_shape = { @@ -81,10 +81,10 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; } - const Eigen::TensorMap, + const Eigen::TensorMap, Eigen::Aligned> input(operand, in_dims); - Eigen::TensorMap, + Eigen::TensorMap, Eigen::Aligned> output(out, out_dims); @@ -92,7 +92,7 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Eigen::Tensor full_fft(in_dims); + Eigen::Tensor full_fft(in_dims); const Eigen::DSizes zero_start_indices; full_fft.device(device) = @@ -105,8 +105,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, // Computes a reverse complex->real FFT, reconstructing redundant negative // frequencies using reverse conjugate on innermost dimension after doing IFFT // on outer dimensions. -template -void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, +template +void EigenFftC2R(const EigenDevice& device, Real* out, Complex* operand, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { const std::array fft_shape = { @@ -120,10 +120,10 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; } - const Eigen::TensorMap, + const Eigen::TensorMap, Eigen::Aligned> input(operand, in_dims); - Eigen::TensorMap, + Eigen::TensorMap, Eigen::Aligned> output(out, out_dims); @@ -131,7 +131,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Eigen::Tensor full_fft(out_dims); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -178,30 +178,59 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, - FftType fft_type, int64 input_batch, int64 fft_length0, - int64 fft_length1, int64 fft_length2) { + FftType fft_type, bool double_precision, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { switch (fft_type) { case FftType::FFT: - EigenFftC2C( - device, static_cast(out), - static_cast(operand), input_batch, fft_length0, - fft_length1, fft_length2); + if (double_precision) { + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } else { + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } break; case FftType::IFFT: - EigenFftC2C( - device, static_cast(out), - static_cast(operand), input_batch, fft_length0, - fft_length1, fft_length2); + if (double_precision) { + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } else { + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } break; case FftType::RFFT: - EigenFftR2C( - device, static_cast(out), static_cast(operand), - input_batch, fft_length0, fft_length1, fft_length2); + if (double_precision) { + EigenFftR2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } else { + EigenFftR2C( + device, static_cast(out), static_cast(operand), + input_batch, fft_length0, fft_length1, fft_length2); + } break; case FftType::IRFFT: - EigenFftC2R( - device, static_cast(out), static_cast(operand), - input_batch, fft_length0, fft_length1, fft_length2); + if (double_precision) { + EigenFftC2R( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } else { + EigenFftC2R( + device, static_cast(out), static_cast(operand), + input_batch, fft_length0, fft_length1, fft_length2); + } break; default: // Unsupported FFT type @@ -213,22 +242,24 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, template void EigenFftImpl(const EigenDevice& device, void* out, void* operand, - FftType fft_type, int32 fft_rank, int64 input_batch, - int64 fft_length0, int64 fft_length1, int64 fft_length2) { + FftType fft_type, bool double_precision, int32 fft_rank, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { switch (fft_rank) { case 1: - internal::EigenFftWithRank<1, EigenDevice>( - device, out, operand, fft_type, input_batch, fft_length0, 0, 0); + internal::EigenFftWithRank<1, EigenDevice>(device, out, operand, fft_type, + double_precision, input_batch, + fft_length0, 0, 0); break; case 2: internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type, - input_batch, fft_length0, - fft_length1, 0); + double_precision, input_batch, + fft_length0, fft_length1, 0); break; case 3: - internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type, - input_batch, fft_length0, - fft_length1, fft_length2); + internal::EigenFftWithRank<3, EigenDevice>( + device, out, operand, fft_type, double_precision, input_batch, + fft_length0, fft_length1, fft_length2); break; default: // Unsupported FFT rank diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc index d2780dd694e..9476dce5ced 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -24,10 +24,11 @@ using tensorflow::int64; TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( const void* run_options_ptr, void* out, void* operand, int32 fft_type, - int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, - int64 fft_length2) { + int32 double_precision, int32 fft_rank, int64 input_batch, + int64 fft_length0, int64 fft_length1, int64 fft_length2) { tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, static_cast(fft_type), - fft_rank, input_batch, fft_length0, fft_length1, + static_cast(double_precision), fft_rank, + input_batch, fft_length0, fft_length1, fft_length2); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h index dcd133d012c..2f0ccda2d10 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -22,7 +22,8 @@ extern "C" { extern void __xla_cpu_runtime_EigenSingleThreadedFft( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, - void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + void* operand, tensorflow::int32 fft_type, + tensorflow::int32 double_precision, tensorflow::int32 fft_rank, tensorflow::int64 input_batch, tensorflow::int64 fft_length0, tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 991a463f2a0..9d6be3c78ea 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -60,16 +60,18 @@ StatusOr> FftScratchAllocator::AllocateBytes( namespace { -se::fft::Type FftTypeToSeType(FftType type) { +se::fft::Type FftTypeToSeType(FftType type, bool double_precision) { switch (type) { case FftType::FFT: - return se::fft::Type::kC2CForward; + return double_precision ? se::fft::Type::kZ2ZForward + : se::fft::Type::kC2CForward; case FftType::IFFT: - return se::fft::Type::kC2CInverse; + return double_precision ? se::fft::Type::kZ2ZInverse + : se::fft::Type::kC2CInverse; case FftType::IRFFT: - return se::fft::Type::kC2R; + return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R; case FftType::RFFT: - return se::fft::Type::kR2C; + return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C; default: LOG(FATAL) << "unsupported fft type"; } @@ -78,12 +80,16 @@ se::fft::Type FftTypeToSeType(FftType type) { string FftTypeToString(se::fft::Type type) { switch (type) { case se::fft::Type::kC2CForward: + case se::fft::Type::kZ2ZForward: return "FFT"; case se::fft::Type::kC2CInverse: + case se::fft::Type::kZ2ZInverse: return "IFFT"; case se::fft::Type::kC2R: + case se::fft::Type::kZ2D: return "IRFFT"; case se::fft::Type::kR2C: + case se::fft::Type::kD2Z: return "RFFT"; default: LOG(FATAL) << "unknown fft type"; @@ -98,7 +104,9 @@ FftThunk::FftThunk(FftType fft_type, absl::Span fft_length, const Shape& input_shape, const Shape& output_shape, const HloInstruction* hlo) : Thunk(Kind::kFft, hlo), - fft_type_(FftTypeToSeType(fft_type)), + fft_type_( + FftTypeToSeType(fft_type, input_shape.element_type() == F64 || + input_shape.element_type() == C128)), fft_length_(fft_length.begin(), fft_length.end()), scale_factor_(1.0f), input_buffer_(input_buffer), @@ -166,6 +174,15 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); break; } + case se::fft::Type::kZ2ZForward: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } case se::fft::Type::kC2CInverse: { se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); @@ -181,6 +198,22 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { } break; } + case se::fft::Type::kZ2ZInverse: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = + stream + .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + complex128(scale_factor_), &output_data, 1) + .ok(); + } + break; + } case se::fft::Type::kR2C: { se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); @@ -190,6 +223,15 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); break; } + case se::fft::Type::kD2Z: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } case se::fft::Type::kC2R: { se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); @@ -205,6 +247,21 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { } break; } + case se::fft::Type::kZ2D: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = stream + .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + scale_factor_, &output_data, 1) + .ok(); + } + break; + } default: LOG(FATAL) << "unsupported fft type"; } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 0ea7912c95c..75a80747c1d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -1856,7 +1857,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (fft_type) { case FFT: case IFFT: - if (in.element_type() != C64) { + if (!primitive_util::IsComplexType(in.element_type())) { return InvalidArgument("%s requires complex input type, found %s.", FftType_Name(fft_type), PrimitiveType_Name(in.element_type())); @@ -1864,8 +1865,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, RET_CHECK_RANK(in); return in; case RFFT: { - if (in.element_type() != F32) { - return InvalidArgument("RFFT requires F32 input type, found %s.", + if (in.element_type() != F32 && in.element_type() != F64) { + return InvalidArgument("RFFT requires F32 or F64 input type, found %s.", PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); @@ -1880,7 +1881,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]); } } - Shape result = ShapeUtil::ChangeElementType(in, C64); + Shape result = ShapeUtil::ChangeElementType( + in, in.element_type() == F32 ? C64 : C128); // Preserve the size of zero-sized dimensions. if (fft_length[fft_rank - 1] != 0) { result.set_dimensions(result.dimensions_size() - 1, @@ -1889,8 +1891,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } case IRFFT: { - if (in.element_type() != C64) { - return InvalidArgument("IRFFT requires C64 input type, found %s.", + if (!primitive_util::IsComplexType(in.element_type())) { + return InvalidArgument("IRFFT requires complex input type, found %s.", PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 448f5119546..b5ecf6e583e 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -615,8 +615,7 @@ namespace fft { static const char* unsupported_rank = "only supports ranks 1-3"; static const char* invalid_rank = "requires input of at least same rank"; static const char* requires_complex_input = "requires complex input type"; -static const char* requires_f32_input = "requires F32 input type"; -static const char* requires_c64_input = "requires C64 input type"; +static const char* requires_f32_input = "requires F32 or F64 input type"; static const char* dimensions_match = "innermost dimensions match fft_length"; static const char* innermost_dimension_matches = "innermost dimension matches fft_length/2+1"; @@ -654,7 +653,7 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) { Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); - fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input); + fft::Pass(shape_c128, type, {16, 8}, shape_c128); } TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) { @@ -672,7 +671,7 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) { Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); - fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input); + fft::Pass(shape_c128, type, {16, 8}, shape_c128); } TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) { @@ -747,9 +746,10 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) { TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) { FftType type = FftType::IRFFT; Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); - Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); - fft::Fail(shape_f32, type, {16, 8}, fft::requires_c64_input); - fft::Fail(shape_c128, type, {16, 8}, fft::requires_c64_input); + Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5}); + Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8}); + fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); + fft::Pass(shape_c128, type, {16, 8}, shape_f64_out); } TEST_F(ShapeInferenceTest, MapThatChangesElementType) {