diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index da58e9608b1..0320979102f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -779,6 +780,545 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { return Status::OK(); } +namespace { + +// Straightforward implementation of 1D DFT transform. Uses passed-in start +// index and stride to gather inputs from the data vector into the preallocated +// buffer, computes the result, and writes it back to the same locations in the +// data vector. Runs in O(length^2) time. +// +// Parameters contract_output and expand_input are used to avoid unnecessary +// calculations. When contract_output is set to true, then only (length / 2) + 1 +// output values are computed. When expand_input is set to true, then +// (length / 2) + 1 values from the data set are used to re-create the full set +// of size 'length', on which the transform is then performed. +// +void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse, + bool contract_output, bool expand_input, + absl::Span data, absl::Span buffer) { + CHECK_GT(data.size(), start + (length - 1) * stride); + CHECK_GT(buffer.size(), length - 1); + + // Copy input data to 1D vector. + bool input_is_zero = true; + const int64 ub = expand_input ? length / 2 + 1 : length; + for (int64 k = 0; k < ub; k++) { + complex128 value = data[start + k * stride]; + input_is_zero &= value == complex128(0.0, 0.0); + buffer[k] = value; + if (expand_input) { + // Use conjugates of the values at indices [1 ... (ub - 2)] when the + // length is even and at indices [1 ... (ub - 1)] when the length is odd + // to calculate missing values at indices [(length - 1) ... ub]. + if (k > 0 && k < (length - ub + 1)) { + buffer[length - k] = std::conj(value); + } + } + } + + // Do 1D transformation with double precision. + if (!input_is_zero) { + const int64 ub = contract_output ? length / 2 + 1 : length; + for (int64 k = 0; k < ub; k++) { + complex128 value = complex128(0.0, 0.0); + for (int n = 0; n < length; n++) { + auto coeff = std::exp(complex128(0.0, -2.0 * M_PI * n * k / length)); + value += (inverse ? std::conj(buffer[n]) : buffer[n]) * coeff; + } + data[start + k * stride] = + inverse ? std::conj(value) / complex128(length, 0.0) : value; + } + } +} + +// Helper to reverse the order of dimension lengths in the passed-in literal. +std::vector GetDimensionLengths(const Literal& literal) { + std::vector lengths = literal.shape().dimensions(); + absl::c_reverse(lengths); + return lengths; +} + +// Helper to compute strides for creating linear indices into multidimensional +// data from the dimension lengths and the layout. Returns a new vector of size +// lengths.size() + 1. The last element of the returned vector at index +// [lengths.size()] contains the product of all dimension lengths. +std::vector ComputeStrides(const absl::Span lengths, + const Layout& layout) { + const int64 num_dimensions = lengths.size(); + + // Make sure that the layout length matches the number of dimensions. + CHECK_EQ(num_dimensions, layout.minor_to_major_size()); + + // Calculate strides using layout-specified ordering of the dimensions and + // place the stride for axis 0 at index 0, for axis 1 at index 1, etc. + std::vector strides(num_dimensions + 1); + int64 stride = 1; + for (int64 i = 0; i < num_dimensions; i++) { + // Reverse the ordering of the dimensions in the layout. + const int64 index = (num_dimensions - 1) - layout.minor_to_major(i); + strides[index] = stride; + stride *= lengths[index]; + } + strides[num_dimensions] = stride; + + return strides; +} + +// Compute strides as above using the default layout. +std::vector ComputeStrides(const absl::Span lengths) { + return ComputeStrides(lengths, + LayoutUtil::GetDefaultLayoutForRank(lengths.size())); +} + +// Compute strides as above using the layout from the literal, if available. +std::vector ComputeStrides(const absl::Span lengths, + const Literal& literal) { + return literal.shape().has_layout() + ? ComputeStrides(lengths, literal.shape().layout()) + : ComputeStrides(lengths); +} + +// Make 1D sweeps along each transform axis. +void Sweep(int64 fft_rank, FftType fft_type, + const absl::Span fft_lengths, + const absl::Span fft_strides, + absl::Span data, absl::Span buffer) { + const bool inverse = fft_type == FftType::IFFT || fft_type == FftType::IRFFT; + const bool input_is_truncated = fft_type == FftType::IRFFT; + const bool output_is_truncated = fft_type == FftType::RFFT; + + // Recursively visit each column of the data along the sweep_axis. Calculate + // linearized index of that column's first element and the stride, then invoke + // 1D transform. + // For RFFT, avoid calculating unused output values: first, compute only + // (length_x / 2) + 1 values along the X axis, then limit the X coordinate to + // [0 ... (length / 2)] during the sweeps along other axes. Similarly, for + // IRFFT sweep along higher dimensions first, while keeping the X coordinate + // in the [0 ... (length / 2)] range, then re-create negative frequencies + // omitted in the input and perform the full-length transform along the X axis + // in the last sweep. + std::function sweep = [&](int64 sweep_axis, + int64 axis, + int64 start) { + if (axis < 0) { + // Base case: invoke 1D transform. + const int64 length = fft_lengths[sweep_axis]; + const int64 stride = fft_strides[sweep_axis]; + const bool expand_input = input_is_truncated && sweep_axis == 0; + const bool contract_oputput = output_is_truncated && sweep_axis == 0; + NaiveDft1D(length, start, stride, inverse, contract_oputput, expand_input, + data, buffer); + } else if (axis == sweep_axis) { + // Visit only the elements with coordinate 0 along the sweep axis. + sweep(sweep_axis, axis - 1, start); + } else { + const int64 length = fft_lengths[axis]; + const bool is_truncated = input_is_truncated || output_is_truncated; + const int64 ub = is_truncated && axis == 0 ? (length / 2) + 1 : length; + for (int64 i = 0; i < ub; i++) { + sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]); + } + } + }; + if (input_is_truncated) { + // Sweep along the X axis last for IRFFT. + for (int64 sweep_axis = fft_rank - 1; sweep_axis >= 0; sweep_axis--) { + sweep(sweep_axis, fft_rank - 1, 0); + } + } else { + // Sweep along the X axis first for RFFT. The order does not matter for FFT + // and IFFT types; handle them here as well. + for (int64 sweep_axis = 0; sweep_axis < fft_rank; sweep_axis++) { + sweep(sweep_axis, fft_rank - 1, 0); + } + } +} + +// These templates convert the data from the input data type to the type used in +// calculations and then to the output data type. They are intended to be used +// only within the DFT implementation. One special case is IRFFT, where the +// specialization drops imaginary parts of complex values (which is expected to +// be 0) and returns real numbers. +template +ToType GetAs(FromType value) { + return static_cast(value); +} + +template <> +float GetAs(complex128 value) { + return static_cast(value.real()); +} + +// This template generates two linearized indices, which can be used to access +// multidimensional arrays. It uses a recursive function, which passes the +// indices to the user-supplied callback function. The destination index is +// always within dst_lengths[] bounds. The boolean parameter within_src_bounds +// indicates whether the source index is within src_lengths[] bounds. +// +// The value returned from the callback function controls the recursion depth. +// Returning true indicates that the base case had been hit and the recursion +// stops. Otherwise, the recursion proceeds along the next less-major axis. +// +// For example, the base case when the axis value becomes negative invokes the +// callback function for each possible index within dst_lengths[] bounds. The +// base case when the axis value is equal to zero limits the indices to point +// only to first elements along the minor-most dimension, allowing the callback +// function to handle all values along the X axis. +// +template +void GenerateIndices(const absl::Span dst_lengths, + const absl::Span dst_strides, + const absl::Span src_lengths, + const absl::Span src_strides, int64 fft_rank, + int64 dst_start, int64 src_start, BaseFn&& base) { + CHECK_EQ(dst_lengths.size() + 1, dst_strides.size()); + CHECK_GE(dst_lengths.size(), fft_rank); + CHECK_EQ(src_lengths.size() + 1, src_strides.size()); + CHECK_GE(src_lengths.size(), fft_rank); + + std::function generate = + [&](int64 axis, int64 dst_index, int64 src_index, + bool within_src_bounds) { + if (!base(axis, dst_index, src_index, within_src_bounds)) { + for (int64 i = 0; i < dst_lengths[axis]; i++) { + // Because the loop goes over dst_lengths[], the source index may be + // out of src_lengths[] bounds. In this case, within_src_bounds is + // false. + within_src_bounds &= i < src_lengths[axis]; + generate(axis - 1, dst_index, src_index, within_src_bounds); + dst_index += dst_strides[axis]; + src_index += src_strides[axis]; + } + } + }; + generate(fft_rank - 1, dst_start, src_start, true); +} + +// Copies the input data from a literal to a pre-allocated vector. The sizes of +// the input and the transform do not need to match. For each axis of the +// transform, any extra input values beyond the transform length are ignored. +// Conversely, if the input does not contain enough elements along any axis, the +// data is padded with zeroes. +// +// For IRFFT transforms, we use (length_x / 2) + 1 elements from the input, +// where length_x is the size of the full transform along the X axis. +// +// The input literal may have a rank higher than the rank of the transform. +// Passed-in input_index value points to the first element of the input literal +// to be copied. +// +// Returns true if all values in the work data set are zeroes. +// +template +bool CopyDataFromInput(const Literal& input_literal, int64 input_start, + int64 fft_rank, FftType fft_type, int64 fft_size, + const absl::Span fft_lengths, + const absl::Span fft_strides, + const absl::Span input_lengths, + const absl::Span input_strides, + absl::Span data) { + CHECK_GE(data.size(), fft_size); + + const bool input_is_truncated = fft_type == FftType::IRFFT; + + // Recursively visit each transform dimension to copy input values to the + // working data set. The base case handles inputs along the X axis. + bool input_is_zero = true; + const InputType* input_data = input_literal.data().data(); + auto base_case = [&](int64 axis, int64 dst_index, int64 src_index, + bool within_src_bounds) { + if (axis == 0) { + // For IRFFT, the negavie frequencies are only needed for the sweep along + // the X axis, which is performed last. Leave this part of the working set + // uninitialized until then. + const int64 length = fft_lengths[axis]; + const int64 ub = input_is_truncated ? (length / 2) + 1 : length; + for (int64 i = 0; i < ub; i++) { + complex128 value = InputType(0); + // Read input value only if the index is within bounds. + if (within_src_bounds && i < input_lengths[axis]) { + value = GetAs( + input_data[src_index + i * input_strides[axis]]); + input_is_zero &= value == complex128(0.0, 0.0); + } + data[dst_index + i * fft_strides[axis]] = value; + } + return true; + } + return false; + }; + GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides, + fft_rank, 0, input_start, base_case); + return input_is_zero; +} + +// Copies the result of the transform to the literal output. The sizes of the +// transform and output must match. +// +// For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is +// the size of the full transform along the X axis (the most minor dimension). +// +// The output literal may have a rank higher than the rank of the transform. +// Passed-in output_index value points to the first element of the output +// literal to be filled in. +// +template +void CopyDataToOutput(const absl::Span data, int64 output_start, + int64 fft_rank, FftType fft_type, + const absl::Span fft_lengths, + const absl::Span fft_strides, + const absl::Span output_lengths, + const absl::Span output_strides, + Literal* output_literal) { + const bool output_is_truncated = fft_type == FftType::RFFT; + + // Base case for recursive copy of the results to the output. The code avoids + // making a recursive call for each output element by handling axis 0 in the + // loop (as opposed to making "axis < 0" to be the base case). + OutputType* output_data = output_literal->data().data(); + auto base_case = [&](int64 axis, int64 dst_index, int64 src_index, + bool within_src_bounds) { + if (axis == 0) { + // Drop negative frequencies for RFFT. + const int64 length = fft_lengths[axis]; + const int64 ub = output_is_truncated ? (length / 2) + 1 : length; + for (int64 i = 0; i < output_lengths[axis]; i++) { + OutputType value = OutputType(0); + // Read data only if the index is within bounds. + if (within_src_bounds && i < ub) { + value = GetAs( + data[src_index + i * fft_strides[axis]]); + } + output_data[dst_index + i * output_strides[axis]] = value; + } + return true; + } + return false; + }; + GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides, + fft_rank, output_start, 0, base_case); +} + +// Determine the type to use with the CopyDataFromInput<> template above. +bool CopyDataFromInput(const Literal& input_literal, int64 input_start, + int64 fft_rank, FftType fft_type, int64 fft_size, + const absl::Span fft_lengths, + const absl::Span fft_strides, + const absl::Span input_lengths, + const absl::Span input_strides, + absl::Span data) { + const bool input_is_float = fft_type == FftType::RFFT; + if (input_is_float) { + return CopyDataFromInput( + input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths, + fft_strides, input_lengths, input_strides, data); + } else { + return CopyDataFromInput( + input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths, + fft_strides, input_lengths, input_strides, data); + } +} + +// Determine the type to use with the CopyDataToOutput<> template above. +void CopyDataToOutput(const absl::Span data, int64 output_start, + int64 fft_rank, FftType fft_type, + const absl::Span fft_lengths, + const absl::Span fft_strides, + const absl::Span output_lengths, + const absl::Span output_strides, + Literal* output_literal) { + const bool output_is_float = fft_type == FftType::IRFFT; + if (output_is_float) { + CopyDataToOutput(data, output_start, fft_rank, fft_type, fft_lengths, + fft_strides, output_lengths, output_strides, + output_literal); + } else { + CopyDataToOutput(data, output_start, fft_rank, fft_type, + fft_lengths, fft_strides, output_lengths, + output_strides, output_literal); + } +} + +Status CheckParameters(const Shape& input_shape, const Shape& output_shape, + int64 fft_rank, FftType fft_type, + const absl::Span fft_lengths) { + // Check FFT parameters. + if (fft_rank <= 0) { + return InvalidArgument("Zero or negative FFT rank."); + } + if (*absl::c_min_element(fft_lengths) < 0) { + return InvalidArgument("Negative FFT length."); + } + + // Check input-related values. + TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape)); + if (!input_shape.IsArray()) { + return Unimplemented("Only array input shapes are supported."); + } + auto input_elt_type = input_shape.element_type(); + if (fft_type == FftType::RFFT && input_elt_type != PrimitiveType::F32) { + return InvalidArgument("Invalid input type: %d, must be %d (float).", + input_elt_type, PrimitiveType::F32); + } + if (fft_type != FftType::RFFT && input_elt_type != PrimitiveType::C64) { + return InvalidArgument("Invalid input type: %d, must be %d (complex64).", + input_elt_type, PrimitiveType::C64); + } + const int64 input_rank = input_shape.rank(); + if (input_rank < fft_rank) { + return InvalidArgument("Input shape rank is smaller than FFT rank."); + } + + // Check output-related values. + TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape)); + if (!output_shape.IsArray()) { + return Unimplemented("Only array output shapes are supported."); + } + auto output_elt_type = output_shape.element_type(); + if (fft_type == FftType::IRFFT && output_elt_type != PrimitiveType::F32) { + return InvalidArgument("Invalid output type: %d, must be %d (float).", + output_elt_type, PrimitiveType::F32); + } + if (fft_type != FftType::IRFFT && output_elt_type != PrimitiveType::C64) { + return InvalidArgument("Invalid output type: %d, must be %d (complex64).", + output_elt_type, PrimitiveType::C64); + } + const int64 output_rank = output_shape.rank(); + if (output_rank < fft_rank) { + return InvalidArgument("Output shape rank is smaller than FFT rank."); + } + + // Consistency of input and output parameters. + if (input_rank != output_rank) { + return InvalidArgument( + "Ranks of input shape and output shape do not match."); + } + for (int64 dim = 0; dim < input_rank - fft_rank; dim++) { + if (ShapeUtil::GetDimension(input_shape, dim) != + ShapeUtil::GetDimension(output_shape, dim)) { + return InvalidArgument( + "Higher dimension lengths of input shape and output shape do not " + "match."); + } + } + + return Status::OK(); +} + +} // namespace + +// Flexible but slow implementation of the discrete Fourier transform. All +// transform types (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the +// arbitrary rank and length of each dimension of the transform, and arbitrary +// layouts of the input and output literals. +// +// The input literal in operand 0 provides input data, which must be complex64 +// for FFT, IFFT, IRFFT transforms and float for RFFT. The transform is computed +// over the innermost dimensions of the input, thus the rank of the input data +// must be same as fft_rank or larger. The input is expected to provide Ni +// values along each transform axis with one exception: for IRFFT, only +// (N0 / 2) + 1 values are needed along the X axis (the innermost index). To +// increase flexibility, this implementation can handle mismatches between the +// input size and transform lengths by either dropping extra input values or +// using zeroes in place of missing input values as necessary. If the input data +// has rank higher than the transform, the transform is applied for each valid +// combination of the higher-ranking indices. +// +// The output contains complex64 values for FFT, IFFT, RFFT, and float values +// for IRFFT. The rank of the output as well as the sizes of the dimensions +// above the rank of the transform must match those of the input. Sizes of the +// output's "fft_rank" innermost dimensions are expected to match the length of +// the transform along respective axes with one exception: for RFFT, the output +// is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the +// length(s) mismatch, the FFT output is trimmed to fit into the provided output +// shape, or the output is padded with zero values appropriately. +// +// For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17] +// input array will perform two transforms over the [][15][17] data in the sub +// arrays [0][][] and [1][][], dropping the values along axis X and padding axis +// Y with zeroes to create 16x16 working sets, and generating +// complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to +// complex64[64][16][9] input array will use all input values and will produce +// float[64][16][16] output. +// +// The implementation of the 1D transform is a straightforward loop nest. The +// transforms of higher ranks apply sets of 1D transforms along each axis. For +// example, the 2D transform is computed by applying 1D transforms to each +// column followed by applying 1D transforms to each row. +// +// In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn)) +// time, where Ni is the length of the transform's i-th dimension. It is +// possible to reduce the run time to O(N0*N1*...(log(N0)+log(N1)+...)) by +// plugging in a more efficient 1D implementation. +// +Status HloEvaluator::HandleFft(HloInstruction* fft) { + const FftType fft_type = fft->fft_type(); + std::vector fft_lengths = fft->fft_length(); + const int64 fft_rank = fft_lengths.size(); + const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0)); + const Shape& input_shape = input_literal.shape(); + const Shape& output_shape = fft->shape(); + Literal output_literal = Literal::CreateFromShape(output_shape); + + // Make fft_lengths[0] the minor-most dimension. + absl::c_reverse(fft_lengths); + + TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape, fft_rank, + fft_type, fft_lengths)); + + const auto fft_strides = ComputeStrides(fft_lengths); + + // Working set size. + const int64 fft_size = fft_strides[fft_rank]; + + if (fft_size > 0) { + // Linearized working data set. + std::vector data(fft_size); + + // Temporary buffer allocated once and used in 1D sweeps. + std::vector buffer(*absl::c_max_element(fft_lengths)); + + // Sizes of each axis of input and output literals. + const auto input_lengths = GetDimensionLengths(input_literal); + const auto output_lengths = GetDimensionLengths(output_literal); + + // Strides for generating linearized indices into multidimensional arrays. + const auto input_strides = ComputeStrides(input_lengths, input_literal); + const auto output_strides = ComputeStrides(output_lengths, output_literal); + + // Visit all elements in the dimensions with ranks above the FFT rank. For + // each such element invoke the transform. Use separate indices for the + // input and the output to allow different layouts. + auto base_case = [&](int64 axis, int64 output_index, int64 input_index, + bool within_src_bounds) { + if (axis == fft_rank - 1) { + // Base case: copy the data from the input literal, apply the + // transform, and copy the result to the output literal. + CHECK(within_src_bounds); + bool input_is_zero = + CopyDataFromInput(input_literal, input_index, fft_rank, fft_type, + fft_size, fft_lengths, fft_strides, input_lengths, + input_strides, absl::MakeSpan(data)); + if (!input_is_zero) { + // Make 1D sweeps along each transform axis. + Sweep(fft_rank, fft_type, fft_lengths, fft_strides, + absl::MakeSpan(data), absl::MakeSpan(buffer)); + } + CopyDataToOutput(absl::MakeSpan(data), output_index, fft_rank, fft_type, + fft_lengths, fft_strides, output_lengths, + output_strides, &output_literal); + return true; + } + return false; + }; + GenerateIndices(output_lengths, output_strides, input_lengths, + input_strides, input_shape.rank(), 0, 0, base_case); + } + + evaluated_[fft] = std::move(output_literal); + return Status::OK(); +} + // Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch // dimensions while keeping the rest of the output dimensions clamped to 0. ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 357975a131d..45b6a2754d6 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -204,6 +204,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleGather(HloInstruction* gather) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index eb0ed82eac8..c3ca6f72a39 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -56,7 +56,7 @@ static std::array use_bf16_params{true, false}; // In bf16 mode, all f32 shapes are converted to bf16 before running. class HloEvaluatorTest : public HloTestBase { public: - HloEvaluatorTest() : use_bfloat16_(false) {} + HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); } StatusOr Evaluate( absl::Span arg_literals = {}) { @@ -130,11 +130,24 @@ class HloEvaluatorTest : public HloTestBase { } protected: - explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {} + explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) { + InitializeFftData(); + } + + // Initializes data sets used in FFT tests below. + void InitializeFftData(); + HloEvaluator evaluator_; const bool use_bfloat16_; std::unique_ptr m_ = CreateNewVerifiedModule(); + + // Data sets used in FFT tests below. + ErrorSpec fft_error_ = ErrorSpec(1e-4, 1e-5); + Literal fft_c64x2x4x8_; + Literal fft_c64x2x4x8_1d_; + Literal fft_c64x2x4x8_2d_; + Literal fft_c64x2x4x8_3d_; }; // Lets you write TEST_Ps that run twice, once with and once without bf16. @@ -1423,6 +1436,1015 @@ TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +// Initialization of data sets for FFT tests: + +void HloEvaluatorTest::InitializeFftData() { + // clang-format off + fft_c64x2x4x8_ = LiteralUtil::CreateR3({ + {{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, + {4.0, 0.0}, {5.0, 0.0}, {6.0, 0.0}, {7.0, 0.0}}, + {{0.0, 0.0}, {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0}, + {0.0, 4.0}, {0.0, 5.0}, {0.0, 6.0}, {0.0, 7.0}}, + {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0}, + {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}}, + {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0}, + {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}}, + {{{-4.0, 0.0}, {-3.0, 0.0}, {-2.0, 0.0}, {-1.0, 0.0}, + {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}}, + {{0.0, -4.0}, {0.0, -3.0}, {0.0, -2.0}, {0.0, -1.0}, + {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0}, {0.0, 4.0}}, + {{3.5, 3.5}, {-1.707107, -0.707107}, {-1.0, -0.0}, {-0.707107, 0.292893}, + {-0.5, 0.5}, {-0.292893, 0.707107}, {0.0, 1.0}, {0.707107, 1.707107}}, + {{3.5, 3.5}, {1.707107, 0.707107}, {1.0, 0.0}, {0.707107, -0.292893}, + {0.5, -0.5}, {0.292893, -0.707107}, {-0.0, -1.0}, {-0.707107, -1.707107}}} + }); + fft_c64x2x4x8_1d_ = LiteralUtil::CreateR3({ + {{{28.0, 0.0}, {-4.0, 9.656854}, {-4.0, 4.0}, {-4.0, 1.656854}, + {-4.0, 0.0}, {-4.0, -1.656854}, {-4.0, -4.0}, {-4.0, -9.656854}}, + {{0.0, 28.0}, {-9.656854, -4.0}, {-4.0, -4.0}, {-1.656854, -4.0}, + {0.0, -4.0}, {1.656854, -4.0}, {4.0, -4.0}, {9.656854, -4.0}}, + {{28.0, 28.0}, {5.656854, 13.656854}, {0.0, 8.0}, {-2.343146, 5.656854}, + {-4.0, 4.0}, {-5.656854, 2.343146}, {-8.0, -0.0}, {-13.656854, -5.656854}}, // NOLINT + {{28.0, 28.0}, {-5.656854, -13.656854}, {-0.0, -8.0}, {2.343146, -5.656854}, // NOLINT + {4.0, -4.0}, {5.656854, -2.343146}, {8.0, 0.0}, {13.656854, 5.656854}}}, + {{{0.0, 0.0}, {-5.0, 12.071068}, {-4.0, 4.0}, {-5.0, 2.071068}, + {-4.0, 0.0}, {-5.0, -2.071068}, {-4.0, -4.0}, {-5.0, -12.071068}}, + {{0.0, 0.0}, {-12.071068, -5.0}, {-4.0, -4.0}, {-2.071068, -5.0}, + {0.0, -4.0}, {2.071068, -5.0}, {4.0, -4.0}, {12.071068, -5.0}}, + {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0}, + {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}}, + {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0}, + {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}} + }); + fft_c64x2x4x8_2d_ = LiteralUtil::CreateR3({ + {{{84.0, 84.0}, {-13.656854, 5.656854}, {-8.0, 0.0}, {-5.656854, -2.343146}, + {-4.0, -4.0}, {-2.343146, -5.656854}, {0.0, -8.0}, {5.656854, -13.656854}}, // NOLINT + {{0.0, 0.0}, {0.0, -0.0}, {0.0, 0.0}, {0.0, 0.0}, + {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{28.0, -28.0}, {16.970562, 40.970562}, {0.0, 24.0}, {-7.029438, 16.970562}, // NOLINT + {-12.0, 12.0}, {-16.970562, 7.029438}, {-24.0, 0.0}, {-40.970562, -16.970562}}, // NOLINT + {{0.0, -56.0}, {-19.313708, -8.0}, {-8.0, -8.0}, {-3.313708, -8.0}, + {0.0, -8.0}, {3.313708, -8.0}, {8.0, -8.0}, {19.313708, -8.0}}}, + {{{7.0, 7.0}, {-10.071068, 14.071068}, {-1.0, 7.0}, {-0.071068, 4.071068}, + {3.0, 3.0}, {4.071068, -0.071068}, {7.0, -1.0}, {14.071068, -10.071068}}, + {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136}, + {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}}, + {{-7.0, 7.0}, {2.071068, 22.071068}, {-3.0, 11.0}, {-3.928932, 8.071068}, + {-3.0, 3.0}, {-4.071068, -0.071068}, {-3.0, -5.0}, {-10.071068, -14.071068}}, // NOLINT + {{0.0, -14.0}, {0.0, -12.0}, {0.0, -10.0}, {0.0, -8.0}, + {0.0, -6.0}, {0.0, -4.0}, {0.0, -2.0}, {0.0, 0.0}}} + }); + fft_c64x2x4x8_3d_ = LiteralUtil::CreateR3({ + {{{91.0, 91.0}, {-23.727922, 19.727922}, {-9.0, 7.0}, {-5.727922, 1.727922}, + {-1.0, -1.0}, {1.727922, -5.727922}, {7.0, -9}, {19.727922, -23.727922}}, + {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136}, + {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}}, + {{21.0, -21.0}, {19.041630, 63.041630}, {-3.0, 35.0}, {-10.958370, 25.041630}, // NOLINT + {-15.0, 15.0}, {-21.041630, 6.958370}, {-27.0, -5.0}, {-51.041630, -31.041630}}, // NOLINT + {{0.0, -70.0}, {-19.313708, -20.0}, {-8.0, -18.0}, {-3.313708, -16.0}, + {0.0, -14.0}, {3.313708, -12.0}, {8.0, -10.0}, {19.313708, -8.0}}}, + {{{77.0, 77.0}, {-3.585786, -8.414214}, {-7.0, -7.0}, {-5.585786, -6.414214}, // NOLINT + {-7.0, -7.0}, {-6.414214, -5.585786}, {-7.0, -7.0}, {-8.414214, -3.585786}}, // NOLINT + {{0.0, 0.0}, {12.0, -24.142136}, {12.0, -8.0}, {16.0, -4.142136}, + {16.0, 0.0}, {20.0, 4.142136}, {20.0, 8.0}, {24.0, 24.142136}}, + {{35.0, -35.0}, {14.899494, 18.899494}, {3.0, 13.0}, {-3.100506, 8.899494}, + {-9.0, 9.0}, {-12.899494, 7.100506}, {-21.0, 5.0}, {-30.899494, -2.899494}}, // NOLINT + {{0.0, -42.0}, {-19.313708, 4.0}, {-8.0, 2.0}, {-3.313708, 0.0}, + {0.0, -2.0}, {3.313708, -4.0}, {8.0, -6.0}, {19.313708, -8.0}}} + }); + // clang-format on +} + +// Simple FFT tests: + +TEST_F(HloEvaluatorTest, 1D_FFT_4_on_c64x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[4] parameter(0) + ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4} +} +)"; + auto input = LiteralUtil::CreateR1( + {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}}); + auto expected = LiteralUtil::CreateR1( + {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IFFT_4_on_c64x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[4] parameter(0) + ROOT ifft = c64[4] fft(operand), fft_type=IFFT, fft_length={4} +} +)"; + auto input = LiteralUtil::CreateR1( + {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}}); + auto expected = LiteralUtil::CreateR1( + {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_RFFT_4_on_f32x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[4] parameter(0) + ROOT rfft = c64[3] fft(operand), fft_type=RFFT, fft_length={4} +} +)"; + auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto expected = + LiteralUtil::CreateR1({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IRFFT_4_on_c64x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3] parameter(0) + ROOT irfft = f32[4] fft(operand), fft_type=IRFFT, fft_length={4} +} +)"; + auto input = + LiteralUtil::CreateR1({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}}); + auto expected = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// 1D FFT tests: + +TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IFFT_8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_1d_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_RFFT_8_on_f32x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[8] parameter(0) + ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={8} +} +)"; + auto input = + LiteralUtil::CreateR1({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1}); + auto expected = LiteralUtil::CreateR1({{39.6, 0.0}, + {-3.6, 8.691169}, + {-3.6, 3.6}, + {-3.6, 1.491169}, + {-3.6, 0.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IRFFT_8_on_c64x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[5] parameter(0) + ROOT irfft = f32[8] fft(operand), fft_type=IRFFT, fft_length={8} +} +)"; + auto input = LiteralUtil::CreateR1({{39.6, 0.0}, + {-3.6, 8.691169}, + {-3.6, 3.6}, + {-3.6, 1.491169}, + {-3.6, 0.0}}); + auto expected = + LiteralUtil::CreateR1({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_RFFT_9_on_f32x9) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[9] parameter(0) + ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={9} +} +)"; + auto input = LiteralUtil::CreateR1( + {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9}); + auto expected = LiteralUtil::CreateR1({{49.5, 0.0}, + {-3.360560, 11.705792}, + {-3.893717, 5.712929}, + {-4.5, 3.117691}, + {-4.895723, 1.021942}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 1D_IRFFT_9_on_c64x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[5] parameter(0) + ROOT irfft = f32[9] fft(operand), fft_type=IRFFT, fft_length={9} +} +)"; + auto input = LiteralUtil::CreateR1({{49.5, 0.0}, + {-3.360560, 11.705792}, + {-3.893717, 5.712929}, + {-4.5, 3.117691}, + {-4.895723, 1.021942}}); + auto expected = LiteralUtil::CreateR1( + {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// 2D FFT tests: + +TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={4, 8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_IFFT_4x8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={4, 8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_2d_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_RFFT_3x8_on_f32x3x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 8] parameter(0) + ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 8} +} +)"; + auto input = + LiteralUtil::CreateR2({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1}, + {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8}, + {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}}); + auto expected = LiteralUtil::CreateR2({{{118.8, 0.0}, + {-4.4, 10.622540}, + {-4.4, 4.4}, + {-4.4, 1.822540}, + {-4.4, 0.0}}, + {{0.0, 0.0}, + {-19.926162, 0.797280}, + {-10.128203, -3.728203}, + {-6.069756, -5.602720}, + {-3.2, -6.928203}}, + {{0.0, 0.0}, + {13.526162, 14.653687}, + {3.728203, 10.128203}, + {-0.330244, 8.253687}, + {-3.2, 6.928203}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_IRFFT_3x8_on_c64x3x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 5] parameter(0) + ROOT irfft = f32[3, 8] fft(operand), fft_type=IRFFT, fft_length={3, 8} +} +)"; + auto input = LiteralUtil::CreateR2({{{118.8, 0.0}, + {-4.4, 10.622540}, + {-4.4, 4.4}, + {-4.4, 1.822540}, + {-4.4, 0.0}}, + {{0.0, 0.0}, + {-19.926162, 0.797280}, + {-10.128203, -3.728203}, + {-6.069756, -5.602720}, + {-3.2, -6.928203}}, + {{0.0, 0.0}, + {13.526162, 14.653687}, + {3.728203, 10.128203}, + {-0.330244, 8.253687}, + {-3.2, 6.928203}}}); + auto expected = + LiteralUtil::CreateR2({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1}, + {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8}, + {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_RFFT_3x9_on_f32x3x9) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 9] parameter(0) + ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 9} +} +)"; + auto input = LiteralUtil::CreateR2( + {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1}, + {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9}, + {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}}); + auto expected = LiteralUtil::CreateR2({{{148.5, 0.0}, + {-4.95, 13.600013}, + {-4.95, 5.899180}, + {-4.95, 2.857884}, + {-4.95, 0.872819}}, + {{0.0, 0.0}, + {-25.014467, 2.096690}, + {-12.888800, -3.503916}, + {-8.1, -5.715768}, + {-4.974333, -7.159452}}, + {{0.0, 0.0}, + {17.814467, 17.685147}, + {5.688800, 12.084542}, + {0.9, 9.872690}, + {-2.225667, 8.429006}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_IRFFT_3x9_on_c64x3x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 5] parameter(0) + ROOT irfft = f32[3, 9] fft(operand), fft_type=IRFFT, fft_length={3, 9} +} +)"; + auto input = LiteralUtil::CreateR2({{{148.5, 0.0}, + {-4.95, 13.600013}, + {-4.95, 5.899180}, + {-4.95, 2.857884}, + {-4.95, 0.872819}}, + {{0.0, 0.0}, + {-25.014467, 2.096690}, + {-12.888800, -3.503916}, + {-8.1, -5.715768}, + {-4.974333, -7.159452}}, + {{0.0, 0.0}, + {17.814467, 17.685147}, + {5.688800, 12.084542}, + {0.9, 9.872690}, + {-2.225667, 8.429006}}}); + auto expected = LiteralUtil::CreateR2( + {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1}, + {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9}, + {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// 3D FFT tests: + +TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={2, 4, 8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_IFFT_2x4x8_on_c64x2x4x8) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8] parameter(0) + ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={2, 4, 8} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_3d_})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_f32x3x3x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 3, 4] parameter(0) + ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}}, + {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}}, + {{-1.8, -2.7, -3.6, -4.5}, + {-5.4, -6.3, -7.2, -8.1}, + {1.9, 2.9, 3.9, 4.9}}}); + auto expected = LiteralUtil::CreateR3( + {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}}, + {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}}, + {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}}, + {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}}, + {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}}, + {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}}, + {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}}, + {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}}, + {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_c64x3x3x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 3] parameter(0) + ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}}, + {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}}, + {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}}, + {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}}, + {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}}, + {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}}, + {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}}, + {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}}, + {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}}); + auto expected = LiteralUtil::CreateR3( + {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}}, + {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}}, + {{-1.8, -2.7, -3.6, -4.5}, + {-5.4, -6.3, -7.2, -8.1}, + {1.9, 2.9, 3.9, 4.9}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x5_on_f32x3x3x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 3, 5] parameter(0) + ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 5} +} +)"; + auto input = LiteralUtil::CreateR3({{{1.8, 2.7, 3.6, 4.5, 5.4}, + {8.1, 7.2, 6.3, 5.4, 4.5}, + {1.1, 2.2, 3.3, 4.4, 5.5}}, + {{5.4, 6.3, 7.2, 8.1, 9.0}, + {4.5, 3.6, 2.7, 1.8, 0.9}, + {5.5, 6.6, 7.7, 8.8, 9.9}}, + {{-1.8, -2.7, -3.6, -4.5, -5.4}, + {-5.4, -6.3, -7.2, -8.1, -9.0}, + {1.9, 2.9, 3.9, 4.9, 5.9}}}); + auto expected = LiteralUtil::CreateR3( + {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}}, + {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}}, + {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}}, + {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}}, + {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}}, + {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}}, + {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}}, + {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}}, + {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x5_on_c64x3x3x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 3] parameter(0) + ROOT irfft = f32[3, 3, 5] fft(operand), fft_type=IRFFT, fft_length={3, 3, 5} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}}, + {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}}, + {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}}, + {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}}, + {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}}, + {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}}, + {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}}, + {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}}, + {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}}); + auto expected = LiteralUtil::CreateR3({{{1.8, 2.7, 3.6, 4.5, 5.4}, + {8.1, 7.2, 6.3, 5.4, 4.5}, + {1.1, 2.2, 3.3, 4.4, 5.5}}, + {{5.4, 6.3, 7.2, 8.1, 9.0}, + {4.5, 3.6, 2.7, 1.8, 0.9}, + {5.5, 6.6, 7.7, 8.8, 9.9}}, + {{-1.8, -2.7, -3.6, -4.5, -5.4}, + {-5.4, -6.3, -7.2, -8.1, -9.0}, + {1.9, 2.9, 3.9, 4.9, 5.9}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// FFT tests with non-default data layout: + +TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8_with_layout) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8]{0, 2, 1} parameter(0) + ROOT fft = c64[2, 4, 8]{1, 2, 0} fft(operand), fft_type=FFT, fft_length={8} +} +)"; + auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({0, 2, 1})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8_with_layout) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8]{2, 0, 1} parameter(0) + ROOT fft = c64[2, 4, 8]{1, 0, 2} fft(operand), fft_type=FFT, fft_length={4, 8} +} +)"; + auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({2, 0, 1})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_)); +} + +TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8_with_layout) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[2, 4, 8]{1, 2, 0} parameter(0) + ROOT fft = + c64[2, 4, 8]{0, 2, 1} fft(operand), fft_type=FFT, fft_length={2, 4, 8} +} +)"; + auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({1, 2, 0})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_)); +} + +// FFT tests with unusual parameters: + +// Zero-length transform. +TEST_F(HloEvaluatorTest, 1D_FFT_0_on_c64x1x1x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 1] parameter(0) + ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={0} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}}}); + auto expected = LiteralUtil::CreateR4({{{{{0.0, 0.0}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Zero-length axis. +TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x0) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 0] parameter(0) + ROOT fft = c64[1, 1, 1, 0] fft(operand), fft_type=FFT, fft_length={1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto input, + LiteralUtil::CreateR4({{{{}}}}).Reshape({1, 1, 1, 0})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// Some/all dimensions have length 1. +TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 1] parameter(0) + ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// Zero-length transform. +TEST_F(HloEvaluatorTest, 3D_FFT_1x0x1_on_c64x1x1x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 1] parameter(0) + ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 0, 1} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}}}); + auto expected = LiteralUtil::CreateR4({{{{{0.0, 0.0}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Zero-length axis. +TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x0x1x0x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[0, 1, 0, 1] parameter(0) + ROOT fft = c64[0, 1, 0, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto input, + LiteralUtil::CreateR4({{{{}}}}).Reshape({0, 1, 0, 1})); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// Some/all dimensions have length 1. +TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x1x1x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 1, 1, 1] parameter(0) + ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// Some/all dimensions have length 1. +TEST_F(HloEvaluatorTest, 3D_FFT_3x1x1_on_c64x1x3x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 3, 1, 1] parameter(0) + ROOT fft = c64[1, 3, 1, 1] fft(operand), fft_type=FFT, fft_length={3, 1, 1} +} +)"; + auto input = LiteralUtil::CreateR4( + {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}}); + auto expected = + LiteralUtil::CreateR4({{{{{42.24, 24.42}}}, + {{{84.5367, 97.5818}}}, + {{{-0.0566792, -48.7418}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Some/all dimensions have length 1. +TEST_F(HloEvaluatorTest, 3D_IFFT_3x1x1_on_c64x1x3x1x1) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[1, 3, 1, 1] parameter(0) + ROOT ifft = c64[1, 3, 1, 1] fft(operand), fft_type=IFFT, fft_length={3, 1, 1} +} +)"; + auto input = LiteralUtil::CreateR4({{{{{42.24, 24.42}}}, + {{{84.5367, 97.5818}}}, + {{{-0.0566792, -48.7418}}}}}); + auto expected = LiteralUtil::CreateR4( + {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Odd transform length. +TEST_F(HloEvaluatorTest, 1D_FFT_5_on_c64x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[5] parameter(0) + ROOT fft = c64[5] fft(operand), fft_type=FFT, fft_length={5} +} +)"; + auto input = LiteralUtil::CreateR1( + {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}}); + auto expected = LiteralUtil::CreateR1({{15.0, 15.0}, + {0.940955, 5.94095}, + {-1.6877, 3.3123}, + {-3.3123, 1.6877}, + {-5.94095, -0.940955}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Odd transform length. +TEST_F(HloEvaluatorTest, 1D_IFFT_5_on_c64x5) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[5] parameter(0) + ROOT ifft = c64[5] fft(operand), fft_type=IFFT, fft_length={5} +} +)"; + auto input = LiteralUtil::CreateR1({{15.0, 15.0}, + {0.940955, 5.94095}, + {-1.6877, 3.3123}, + {-3.3123, 1.6877}, + {-5.94095, -0.940955}}); + auto expected = LiteralUtil::CreateR1( + {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 1D_FFT_4_on_zero_c64x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[4] parameter(0) + ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4} +} +)"; + auto input = LiteralUtil::CreateR1( + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 3D_FFT_3x3x4_on_zero_c64x3x3x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 4] parameter(0) + ROOT fft = c64[3, 3, 4] fft(operand), fft_type=FFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 3D_IFFT_3x3x4_on_zero_c64x3x3x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 4] parameter(0) + ROOT ifft = c64[3, 3, 4] fft(operand), fft_type=IFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_zero_f32x3x3x4) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = f32[3, 3, 4] parameter(0) + ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}}); + auto expected = LiteralUtil::CreateR3( + {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// All input values are zero. +TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_zero_c64x3x3x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3, 3] parameter(0) + ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4} +} +)"; + auto input = LiteralUtil::CreateR3( + {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}, + {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}, + {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}}); + auto expected = LiteralUtil::CreateR3( + {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}, + {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + +// Input values, for which IRFFT discards non-zero imaginary parts. +TEST_F(HloEvaluatorTest, 2D_IRFFT_3x4_on_c64x3x3) { + const char* hlo_text = R"( +HloModule Fft + +ENTRY main { + operand = c64[3, 3] parameter(0) + ROOT irfft = f32[3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 4} +} +)"; + auto input = + LiteralUtil::CreateR2({{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}}, + {{3.0, 0.0}, {4.0, 0.0}, {5.0, 0.0}}, + {{6.0, 0.0}, {7.0, 0.0}, {8.0, 0.0}}}); + auto expected = + LiteralUtil::CreateR2({{4.0, -0.5, 0.0, -0.5}, + {-1.5, 0.433013, 0.0, -0.433013}, + {-1.5, -0.433013, 0.0, 0.433013}}); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input})); + EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape())); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); +} + class HloEvaluatorPreciseReduceTest : public HloTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index ab27ac82722..53ecfceb08b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -68,8 +68,8 @@ T ToArithmeticSafeType(T t) { // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated -// Handle* method of a TypedVisitor. There are however a few notable exceptions -// to this rule, notably: +// Handle* method of a TypedVisitor. There are however a few exceptions to this +// rule, notably: // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. // - HandleImag and HandleReal: where the resulting literal type is always float @@ -81,7 +81,7 @@ T ToArithmeticSafeType(T t) { // - ReturnT: The type of input and output of each operation. // - ElementwiseT: The type in which internal computation are done. // -// This a logically a private part of HloEvaluator. It lives in this header +// This is logically a private part of HloEvaluator. It lives in this header // file rather than in hlo_evaluator.cc because we use extern templates and a // bunch of independent cc files to speed up compiling the many instantiations // of this class.