FFT support in HloEvaluator.
PiperOrigin-RevId: 247056664
This commit is contained in:
parent
f4cb4eb43f
commit
4259bfb42b
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
@ -779,6 +780,545 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Straightforward implementation of 1D DFT transform. Uses passed-in start
|
||||
// index and stride to gather inputs from the data vector into the preallocated
|
||||
// buffer, computes the result, and writes it back to the same locations in the
|
||||
// data vector. Runs in O(length^2) time.
|
||||
//
|
||||
// Parameters contract_output and expand_input are used to avoid unnecessary
|
||||
// calculations. When contract_output is set to true, then only (length / 2) + 1
|
||||
// output values are computed. When expand_input is set to true, then
|
||||
// (length / 2) + 1 values from the data set are used to re-create the full set
|
||||
// of size 'length', on which the transform is then performed.
|
||||
//
|
||||
void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse,
|
||||
bool contract_output, bool expand_input,
|
||||
absl::Span<complex128> data, absl::Span<complex128> buffer) {
|
||||
CHECK_GT(data.size(), start + (length - 1) * stride);
|
||||
CHECK_GT(buffer.size(), length - 1);
|
||||
|
||||
// Copy input data to 1D vector.
|
||||
bool input_is_zero = true;
|
||||
const int64 ub = expand_input ? length / 2 + 1 : length;
|
||||
for (int64 k = 0; k < ub; k++) {
|
||||
complex128 value = data[start + k * stride];
|
||||
input_is_zero &= value == complex128(0.0, 0.0);
|
||||
buffer[k] = value;
|
||||
if (expand_input) {
|
||||
// Use conjugates of the values at indices [1 ... (ub - 2)] when the
|
||||
// length is even and at indices [1 ... (ub - 1)] when the length is odd
|
||||
// to calculate missing values at indices [(length - 1) ... ub].
|
||||
if (k > 0 && k < (length - ub + 1)) {
|
||||
buffer[length - k] = std::conj(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do 1D transformation with double precision.
|
||||
if (!input_is_zero) {
|
||||
const int64 ub = contract_output ? length / 2 + 1 : length;
|
||||
for (int64 k = 0; k < ub; k++) {
|
||||
complex128 value = complex128(0.0, 0.0);
|
||||
for (int n = 0; n < length; n++) {
|
||||
auto coeff = std::exp(complex128(0.0, -2.0 * M_PI * n * k / length));
|
||||
value += (inverse ? std::conj(buffer[n]) : buffer[n]) * coeff;
|
||||
}
|
||||
data[start + k * stride] =
|
||||
inverse ? std::conj(value) / complex128(length, 0.0) : value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to reverse the order of dimension lengths in the passed-in literal.
|
||||
std::vector<int64> GetDimensionLengths(const Literal& literal) {
|
||||
std::vector<int64> lengths = literal.shape().dimensions();
|
||||
absl::c_reverse(lengths);
|
||||
return lengths;
|
||||
}
|
||||
|
||||
// Helper to compute strides for creating linear indices into multidimensional
|
||||
// data from the dimension lengths and the layout. Returns a new vector of size
|
||||
// lengths.size() + 1. The last element of the returned vector at index
|
||||
// [lengths.size()] contains the product of all dimension lengths.
|
||||
std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
|
||||
const Layout& layout) {
|
||||
const int64 num_dimensions = lengths.size();
|
||||
|
||||
// Make sure that the layout length matches the number of dimensions.
|
||||
CHECK_EQ(num_dimensions, layout.minor_to_major_size());
|
||||
|
||||
// Calculate strides using layout-specified ordering of the dimensions and
|
||||
// place the stride for axis 0 at index 0, for axis 1 at index 1, etc.
|
||||
std::vector<int64> strides(num_dimensions + 1);
|
||||
int64 stride = 1;
|
||||
for (int64 i = 0; i < num_dimensions; i++) {
|
||||
// Reverse the ordering of the dimensions in the layout.
|
||||
const int64 index = (num_dimensions - 1) - layout.minor_to_major(i);
|
||||
strides[index] = stride;
|
||||
stride *= lengths[index];
|
||||
}
|
||||
strides[num_dimensions] = stride;
|
||||
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Compute strides as above using the default layout.
|
||||
std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths) {
|
||||
return ComputeStrides(lengths,
|
||||
LayoutUtil::GetDefaultLayoutForRank(lengths.size()));
|
||||
}
|
||||
|
||||
// Compute strides as above using the layout from the literal, if available.
|
||||
std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
|
||||
const Literal& literal) {
|
||||
return literal.shape().has_layout()
|
||||
? ComputeStrides(lengths, literal.shape().layout())
|
||||
: ComputeStrides(lengths);
|
||||
}
|
||||
|
||||
// Make 1D sweeps along each transform axis.
|
||||
void Sweep(int64 fft_rank, FftType fft_type,
|
||||
const absl::Span<const int64> fft_lengths,
|
||||
const absl::Span<const int64> fft_strides,
|
||||
absl::Span<complex128> data, absl::Span<complex128> buffer) {
|
||||
const bool inverse = fft_type == FftType::IFFT || fft_type == FftType::IRFFT;
|
||||
const bool input_is_truncated = fft_type == FftType::IRFFT;
|
||||
const bool output_is_truncated = fft_type == FftType::RFFT;
|
||||
|
||||
// Recursively visit each column of the data along the sweep_axis. Calculate
|
||||
// linearized index of that column's first element and the stride, then invoke
|
||||
// 1D transform.
|
||||
// For RFFT, avoid calculating unused output values: first, compute only
|
||||
// (length_x / 2) + 1 values along the X axis, then limit the X coordinate to
|
||||
// [0 ... (length / 2)] during the sweeps along other axes. Similarly, for
|
||||
// IRFFT sweep along higher dimensions first, while keeping the X coordinate
|
||||
// in the [0 ... (length / 2)] range, then re-create negative frequencies
|
||||
// omitted in the input and perform the full-length transform along the X axis
|
||||
// in the last sweep.
|
||||
std::function<void(int64, int64, int64)> sweep = [&](int64 sweep_axis,
|
||||
int64 axis,
|
||||
int64 start) {
|
||||
if (axis < 0) {
|
||||
// Base case: invoke 1D transform.
|
||||
const int64 length = fft_lengths[sweep_axis];
|
||||
const int64 stride = fft_strides[sweep_axis];
|
||||
const bool expand_input = input_is_truncated && sweep_axis == 0;
|
||||
const bool contract_oputput = output_is_truncated && sweep_axis == 0;
|
||||
NaiveDft1D(length, start, stride, inverse, contract_oputput, expand_input,
|
||||
data, buffer);
|
||||
} else if (axis == sweep_axis) {
|
||||
// Visit only the elements with coordinate 0 along the sweep axis.
|
||||
sweep(sweep_axis, axis - 1, start);
|
||||
} else {
|
||||
const int64 length = fft_lengths[axis];
|
||||
const bool is_truncated = input_is_truncated || output_is_truncated;
|
||||
const int64 ub = is_truncated && axis == 0 ? (length / 2) + 1 : length;
|
||||
for (int64 i = 0; i < ub; i++) {
|
||||
sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]);
|
||||
}
|
||||
}
|
||||
};
|
||||
if (input_is_truncated) {
|
||||
// Sweep along the X axis last for IRFFT.
|
||||
for (int64 sweep_axis = fft_rank - 1; sweep_axis >= 0; sweep_axis--) {
|
||||
sweep(sweep_axis, fft_rank - 1, 0);
|
||||
}
|
||||
} else {
|
||||
// Sweep along the X axis first for RFFT. The order does not matter for FFT
|
||||
// and IFFT types; handle them here as well.
|
||||
for (int64 sweep_axis = 0; sweep_axis < fft_rank; sweep_axis++) {
|
||||
sweep(sweep_axis, fft_rank - 1, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// These templates convert the data from the input data type to the type used in
|
||||
// calculations and then to the output data type. They are intended to be used
|
||||
// only within the DFT implementation. One special case is IRFFT, where the
|
||||
// specialization drops imaginary parts of complex values (which is expected to
|
||||
// be 0) and returns real numbers.
|
||||
template <typename ToType, typename FromType>
|
||||
ToType GetAs(FromType value) {
|
||||
return static_cast<ToType>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float GetAs<float, complex128>(complex128 value) {
|
||||
return static_cast<float>(value.real());
|
||||
}
|
||||
|
||||
// This template generates two linearized indices, which can be used to access
|
||||
// multidimensional arrays. It uses a recursive function, which passes the
|
||||
// indices to the user-supplied callback function. The destination index is
|
||||
// always within dst_lengths[] bounds. The boolean parameter within_src_bounds
|
||||
// indicates whether the source index is within src_lengths[] bounds.
|
||||
//
|
||||
// The value returned from the callback function controls the recursion depth.
|
||||
// Returning true indicates that the base case had been hit and the recursion
|
||||
// stops. Otherwise, the recursion proceeds along the next less-major axis.
|
||||
//
|
||||
// For example, the base case when the axis value becomes negative invokes the
|
||||
// callback function for each possible index within dst_lengths[] bounds. The
|
||||
// base case when the axis value is equal to zero limits the indices to point
|
||||
// only to first elements along the minor-most dimension, allowing the callback
|
||||
// function to handle all values along the X axis.
|
||||
//
|
||||
template <typename BaseFn>
|
||||
void GenerateIndices(const absl::Span<const int64> dst_lengths,
|
||||
const absl::Span<const int64> dst_strides,
|
||||
const absl::Span<const int64> src_lengths,
|
||||
const absl::Span<const int64> src_strides, int64 fft_rank,
|
||||
int64 dst_start, int64 src_start, BaseFn&& base) {
|
||||
CHECK_EQ(dst_lengths.size() + 1, dst_strides.size());
|
||||
CHECK_GE(dst_lengths.size(), fft_rank);
|
||||
CHECK_EQ(src_lengths.size() + 1, src_strides.size());
|
||||
CHECK_GE(src_lengths.size(), fft_rank);
|
||||
|
||||
std::function<void(int64, int64, int64, bool)> generate =
|
||||
[&](int64 axis, int64 dst_index, int64 src_index,
|
||||
bool within_src_bounds) {
|
||||
if (!base(axis, dst_index, src_index, within_src_bounds)) {
|
||||
for (int64 i = 0; i < dst_lengths[axis]; i++) {
|
||||
// Because the loop goes over dst_lengths[], the source index may be
|
||||
// out of src_lengths[] bounds. In this case, within_src_bounds is
|
||||
// false.
|
||||
within_src_bounds &= i < src_lengths[axis];
|
||||
generate(axis - 1, dst_index, src_index, within_src_bounds);
|
||||
dst_index += dst_strides[axis];
|
||||
src_index += src_strides[axis];
|
||||
}
|
||||
}
|
||||
};
|
||||
generate(fft_rank - 1, dst_start, src_start, true);
|
||||
}
|
||||
|
||||
// Copies the input data from a literal to a pre-allocated vector. The sizes of
|
||||
// the input and the transform do not need to match. For each axis of the
|
||||
// transform, any extra input values beyond the transform length are ignored.
|
||||
// Conversely, if the input does not contain enough elements along any axis, the
|
||||
// data is padded with zeroes.
|
||||
//
|
||||
// For IRFFT transforms, we use (length_x / 2) + 1 elements from the input,
|
||||
// where length_x is the size of the full transform along the X axis.
|
||||
//
|
||||
// The input literal may have a rank higher than the rank of the transform.
|
||||
// Passed-in input_index value points to the first element of the input literal
|
||||
// to be copied.
|
||||
//
|
||||
// Returns true if all values in the work data set are zeroes.
|
||||
//
|
||||
template <typename InputType>
|
||||
bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
|
||||
int64 fft_rank, FftType fft_type, int64 fft_size,
|
||||
const absl::Span<const int64> fft_lengths,
|
||||
const absl::Span<const int64> fft_strides,
|
||||
const absl::Span<const int64> input_lengths,
|
||||
const absl::Span<const int64> input_strides,
|
||||
absl::Span<complex128> data) {
|
||||
CHECK_GE(data.size(), fft_size);
|
||||
|
||||
const bool input_is_truncated = fft_type == FftType::IRFFT;
|
||||
|
||||
// Recursively visit each transform dimension to copy input values to the
|
||||
// working data set. The base case handles inputs along the X axis.
|
||||
bool input_is_zero = true;
|
||||
const InputType* input_data = input_literal.data<InputType>().data();
|
||||
auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
|
||||
bool within_src_bounds) {
|
||||
if (axis == 0) {
|
||||
// For IRFFT, the negavie frequencies are only needed for the sweep along
|
||||
// the X axis, which is performed last. Leave this part of the working set
|
||||
// uninitialized until then.
|
||||
const int64 length = fft_lengths[axis];
|
||||
const int64 ub = input_is_truncated ? (length / 2) + 1 : length;
|
||||
for (int64 i = 0; i < ub; i++) {
|
||||
complex128 value = InputType(0);
|
||||
// Read input value only if the index is within bounds.
|
||||
if (within_src_bounds && i < input_lengths[axis]) {
|
||||
value = GetAs<complex128, InputType>(
|
||||
input_data[src_index + i * input_strides[axis]]);
|
||||
input_is_zero &= value == complex128(0.0, 0.0);
|
||||
}
|
||||
data[dst_index + i * fft_strides[axis]] = value;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides,
|
||||
fft_rank, 0, input_start, base_case);
|
||||
return input_is_zero;
|
||||
}
|
||||
|
||||
// Copies the result of the transform to the literal output. The sizes of the
|
||||
// transform and output must match.
|
||||
//
|
||||
// For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is
|
||||
// the size of the full transform along the X axis (the most minor dimension).
|
||||
//
|
||||
// The output literal may have a rank higher than the rank of the transform.
|
||||
// Passed-in output_index value points to the first element of the output
|
||||
// literal to be filled in.
|
||||
//
|
||||
template <typename OutputType>
|
||||
void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
|
||||
int64 fft_rank, FftType fft_type,
|
||||
const absl::Span<const int64> fft_lengths,
|
||||
const absl::Span<const int64> fft_strides,
|
||||
const absl::Span<const int64> output_lengths,
|
||||
const absl::Span<const int64> output_strides,
|
||||
Literal* output_literal) {
|
||||
const bool output_is_truncated = fft_type == FftType::RFFT;
|
||||
|
||||
// Base case for recursive copy of the results to the output. The code avoids
|
||||
// making a recursive call for each output element by handling axis 0 in the
|
||||
// loop (as opposed to making "axis < 0" to be the base case).
|
||||
OutputType* output_data = output_literal->data<OutputType>().data();
|
||||
auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
|
||||
bool within_src_bounds) {
|
||||
if (axis == 0) {
|
||||
// Drop negative frequencies for RFFT.
|
||||
const int64 length = fft_lengths[axis];
|
||||
const int64 ub = output_is_truncated ? (length / 2) + 1 : length;
|
||||
for (int64 i = 0; i < output_lengths[axis]; i++) {
|
||||
OutputType value = OutputType(0);
|
||||
// Read data only if the index is within bounds.
|
||||
if (within_src_bounds && i < ub) {
|
||||
value = GetAs<OutputType, complex128>(
|
||||
data[src_index + i * fft_strides[axis]]);
|
||||
}
|
||||
output_data[dst_index + i * output_strides[axis]] = value;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides,
|
||||
fft_rank, output_start, 0, base_case);
|
||||
}
|
||||
|
||||
// Determine the type to use with the CopyDataFromInput<> template above.
|
||||
bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
|
||||
int64 fft_rank, FftType fft_type, int64 fft_size,
|
||||
const absl::Span<const int64> fft_lengths,
|
||||
const absl::Span<const int64> fft_strides,
|
||||
const absl::Span<const int64> input_lengths,
|
||||
const absl::Span<const int64> input_strides,
|
||||
absl::Span<complex128> data) {
|
||||
const bool input_is_float = fft_type == FftType::RFFT;
|
||||
if (input_is_float) {
|
||||
return CopyDataFromInput<float>(
|
||||
input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
|
||||
fft_strides, input_lengths, input_strides, data);
|
||||
} else {
|
||||
return CopyDataFromInput<complex64>(
|
||||
input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
|
||||
fft_strides, input_lengths, input_strides, data);
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the type to use with the CopyDataToOutput<> template above.
|
||||
void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
|
||||
int64 fft_rank, FftType fft_type,
|
||||
const absl::Span<const int64> fft_lengths,
|
||||
const absl::Span<const int64> fft_strides,
|
||||
const absl::Span<const int64> output_lengths,
|
||||
const absl::Span<const int64> output_strides,
|
||||
Literal* output_literal) {
|
||||
const bool output_is_float = fft_type == FftType::IRFFT;
|
||||
if (output_is_float) {
|
||||
CopyDataToOutput<float>(data, output_start, fft_rank, fft_type, fft_lengths,
|
||||
fft_strides, output_lengths, output_strides,
|
||||
output_literal);
|
||||
} else {
|
||||
CopyDataToOutput<complex64>(data, output_start, fft_rank, fft_type,
|
||||
fft_lengths, fft_strides, output_lengths,
|
||||
output_strides, output_literal);
|
||||
}
|
||||
}
|
||||
|
||||
Status CheckParameters(const Shape& input_shape, const Shape& output_shape,
|
||||
int64 fft_rank, FftType fft_type,
|
||||
const absl::Span<const int64> fft_lengths) {
|
||||
// Check FFT parameters.
|
||||
if (fft_rank <= 0) {
|
||||
return InvalidArgument("Zero or negative FFT rank.");
|
||||
}
|
||||
if (*absl::c_min_element(fft_lengths) < 0) {
|
||||
return InvalidArgument("Negative FFT length.");
|
||||
}
|
||||
|
||||
// Check input-related values.
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape));
|
||||
if (!input_shape.IsArray()) {
|
||||
return Unimplemented("Only array input shapes are supported.");
|
||||
}
|
||||
auto input_elt_type = input_shape.element_type();
|
||||
if (fft_type == FftType::RFFT && input_elt_type != PrimitiveType::F32) {
|
||||
return InvalidArgument("Invalid input type: %d, must be %d (float).",
|
||||
input_elt_type, PrimitiveType::F32);
|
||||
}
|
||||
if (fft_type != FftType::RFFT && input_elt_type != PrimitiveType::C64) {
|
||||
return InvalidArgument("Invalid input type: %d, must be %d (complex64).",
|
||||
input_elt_type, PrimitiveType::C64);
|
||||
}
|
||||
const int64 input_rank = input_shape.rank();
|
||||
if (input_rank < fft_rank) {
|
||||
return InvalidArgument("Input shape rank is smaller than FFT rank.");
|
||||
}
|
||||
|
||||
// Check output-related values.
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape));
|
||||
if (!output_shape.IsArray()) {
|
||||
return Unimplemented("Only array output shapes are supported.");
|
||||
}
|
||||
auto output_elt_type = output_shape.element_type();
|
||||
if (fft_type == FftType::IRFFT && output_elt_type != PrimitiveType::F32) {
|
||||
return InvalidArgument("Invalid output type: %d, must be %d (float).",
|
||||
output_elt_type, PrimitiveType::F32);
|
||||
}
|
||||
if (fft_type != FftType::IRFFT && output_elt_type != PrimitiveType::C64) {
|
||||
return InvalidArgument("Invalid output type: %d, must be %d (complex64).",
|
||||
output_elt_type, PrimitiveType::C64);
|
||||
}
|
||||
const int64 output_rank = output_shape.rank();
|
||||
if (output_rank < fft_rank) {
|
||||
return InvalidArgument("Output shape rank is smaller than FFT rank.");
|
||||
}
|
||||
|
||||
// Consistency of input and output parameters.
|
||||
if (input_rank != output_rank) {
|
||||
return InvalidArgument(
|
||||
"Ranks of input shape and output shape do not match.");
|
||||
}
|
||||
for (int64 dim = 0; dim < input_rank - fft_rank; dim++) {
|
||||
if (ShapeUtil::GetDimension(input_shape, dim) !=
|
||||
ShapeUtil::GetDimension(output_shape, dim)) {
|
||||
return InvalidArgument(
|
||||
"Higher dimension lengths of input shape and output shape do not "
|
||||
"match.");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Flexible but slow implementation of the discrete Fourier transform. All
|
||||
// transform types (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the
|
||||
// arbitrary rank and length of each dimension of the transform, and arbitrary
|
||||
// layouts of the input and output literals.
|
||||
//
|
||||
// The input literal in operand 0 provides input data, which must be complex64
|
||||
// for FFT, IFFT, IRFFT transforms and float for RFFT. The transform is computed
|
||||
// over the innermost dimensions of the input, thus the rank of the input data
|
||||
// must be same as fft_rank or larger. The input is expected to provide Ni
|
||||
// values along each transform axis with one exception: for IRFFT, only
|
||||
// (N0 / 2) + 1 values are needed along the X axis (the innermost index). To
|
||||
// increase flexibility, this implementation can handle mismatches between the
|
||||
// input size and transform lengths by either dropping extra input values or
|
||||
// using zeroes in place of missing input values as necessary. If the input data
|
||||
// has rank higher than the transform, the transform is applied for each valid
|
||||
// combination of the higher-ranking indices.
|
||||
//
|
||||
// The output contains complex64 values for FFT, IFFT, RFFT, and float values
|
||||
// for IRFFT. The rank of the output as well as the sizes of the dimensions
|
||||
// above the rank of the transform must match those of the input. Sizes of the
|
||||
// output's "fft_rank" innermost dimensions are expected to match the length of
|
||||
// the transform along respective axes with one exception: for RFFT, the output
|
||||
// is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the
|
||||
// length(s) mismatch, the FFT output is trimmed to fit into the provided output
|
||||
// shape, or the output is padded with zero values appropriately.
|
||||
//
|
||||
// For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17]
|
||||
// input array will perform two transforms over the [][15][17] data in the sub
|
||||
// arrays [0][][] and [1][][], dropping the values along axis X and padding axis
|
||||
// Y with zeroes to create 16x16 working sets, and generating
|
||||
// complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to
|
||||
// complex64[64][16][9] input array will use all input values and will produce
|
||||
// float[64][16][16] output.
|
||||
//
|
||||
// The implementation of the 1D transform is a straightforward loop nest. The
|
||||
// transforms of higher ranks apply sets of 1D transforms along each axis. For
|
||||
// example, the 2D transform is computed by applying 1D transforms to each
|
||||
// column followed by applying 1D transforms to each row.
|
||||
//
|
||||
// In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn))
|
||||
// time, where Ni is the length of the transform's i-th dimension. It is
|
||||
// possible to reduce the run time to O(N0*N1*...(log(N0)+log(N1)+...)) by
|
||||
// plugging in a more efficient 1D implementation.
|
||||
//
|
||||
Status HloEvaluator::HandleFft(HloInstruction* fft) {
|
||||
const FftType fft_type = fft->fft_type();
|
||||
std::vector<int64> fft_lengths = fft->fft_length();
|
||||
const int64 fft_rank = fft_lengths.size();
|
||||
const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0));
|
||||
const Shape& input_shape = input_literal.shape();
|
||||
const Shape& output_shape = fft->shape();
|
||||
Literal output_literal = Literal::CreateFromShape(output_shape);
|
||||
|
||||
// Make fft_lengths[0] the minor-most dimension.
|
||||
absl::c_reverse(fft_lengths);
|
||||
|
||||
TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape, fft_rank,
|
||||
fft_type, fft_lengths));
|
||||
|
||||
const auto fft_strides = ComputeStrides(fft_lengths);
|
||||
|
||||
// Working set size.
|
||||
const int64 fft_size = fft_strides[fft_rank];
|
||||
|
||||
if (fft_size > 0) {
|
||||
// Linearized working data set.
|
||||
std::vector<complex128> data(fft_size);
|
||||
|
||||
// Temporary buffer allocated once and used in 1D sweeps.
|
||||
std::vector<complex128> buffer(*absl::c_max_element(fft_lengths));
|
||||
|
||||
// Sizes of each axis of input and output literals.
|
||||
const auto input_lengths = GetDimensionLengths(input_literal);
|
||||
const auto output_lengths = GetDimensionLengths(output_literal);
|
||||
|
||||
// Strides for generating linearized indices into multidimensional arrays.
|
||||
const auto input_strides = ComputeStrides(input_lengths, input_literal);
|
||||
const auto output_strides = ComputeStrides(output_lengths, output_literal);
|
||||
|
||||
// Visit all elements in the dimensions with ranks above the FFT rank. For
|
||||
// each such element invoke the transform. Use separate indices for the
|
||||
// input and the output to allow different layouts.
|
||||
auto base_case = [&](int64 axis, int64 output_index, int64 input_index,
|
||||
bool within_src_bounds) {
|
||||
if (axis == fft_rank - 1) {
|
||||
// Base case: copy the data from the input literal, apply the
|
||||
// transform, and copy the result to the output literal.
|
||||
CHECK(within_src_bounds);
|
||||
bool input_is_zero =
|
||||
CopyDataFromInput(input_literal, input_index, fft_rank, fft_type,
|
||||
fft_size, fft_lengths, fft_strides, input_lengths,
|
||||
input_strides, absl::MakeSpan(data));
|
||||
if (!input_is_zero) {
|
||||
// Make 1D sweeps along each transform axis.
|
||||
Sweep(fft_rank, fft_type, fft_lengths, fft_strides,
|
||||
absl::MakeSpan(data), absl::MakeSpan(buffer));
|
||||
}
|
||||
CopyDataToOutput(absl::MakeSpan(data), output_index, fft_rank, fft_type,
|
||||
fft_lengths, fft_strides, output_lengths,
|
||||
output_strides, &output_literal);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
GenerateIndices(output_lengths, output_strides, input_lengths,
|
||||
input_strides, input_shape.rank(), 0, 0, base_case);
|
||||
}
|
||||
|
||||
evaluated_[fft] = std::move(output_literal);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
|
||||
// dimensions while keeping the rest of the output dimensions clamped to 0.
|
||||
ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(
|
||||
|
@ -204,6 +204,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleTuple(HloInstruction* tuple) override;
|
||||
|
||||
Status HandleFft(HloInstruction* fft) override;
|
||||
|
||||
Status HandleGather(HloInstruction* gather) override;
|
||||
|
||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -68,8 +68,8 @@ T ToArithmeticSafeType(T t) {
|
||||
// Templated DfsHloVisitor for use by HloEvaluator.
|
||||
//
|
||||
// Typically ReturnT here indicates the resulting literal type of each evaluated
|
||||
// Handle* method of a TypedVisitor. There are however a few notable exceptions
|
||||
// to this rule, notably:
|
||||
// Handle* method of a TypedVisitor. There are however a few exceptions to this
|
||||
// rule, notably:
|
||||
// - HandleCompare and HandleIsFinite: where the resulting literal type is
|
||||
// always boolean.
|
||||
// - HandleImag and HandleReal: where the resulting literal type is always float
|
||||
@ -81,7 +81,7 @@ T ToArithmeticSafeType(T t) {
|
||||
// - ReturnT: The type of input and output of each operation.
|
||||
// - ElementwiseT: The type in which internal computation are done.
|
||||
//
|
||||
// This a logically a private part of HloEvaluator. It lives in this header
|
||||
// This is logically a private part of HloEvaluator. It lives in this header
|
||||
// file rather than in hlo_evaluator.cc because we use extern templates and a
|
||||
// bunch of independent cc files to speed up compiling the many instantiations
|
||||
// of this class.
|
||||
|
Loading…
Reference in New Issue
Block a user