FFT support in HloEvaluator.

PiperOrigin-RevId: 247056664
This commit is contained in:
A. Unique TensorFlower 2019-05-07 11:23:28 -07:00 committed by TensorFlower Gardener
parent f4cb4eb43f
commit 4259bfb42b
4 changed files with 1569 additions and 5 deletions

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -779,6 +780,545 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
namespace {
// Straightforward implementation of 1D DFT transform. Uses passed-in start
// index and stride to gather inputs from the data vector into the preallocated
// buffer, computes the result, and writes it back to the same locations in the
// data vector. Runs in O(length^2) time.
//
// Parameters contract_output and expand_input are used to avoid unnecessary
// calculations. When contract_output is set to true, then only (length / 2) + 1
// output values are computed. When expand_input is set to true, then
// (length / 2) + 1 values from the data set are used to re-create the full set
// of size 'length', on which the transform is then performed.
//
void NaiveDft1D(int64 length, int64 start, int64 stride, bool inverse,
bool contract_output, bool expand_input,
absl::Span<complex128> data, absl::Span<complex128> buffer) {
CHECK_GT(data.size(), start + (length - 1) * stride);
CHECK_GT(buffer.size(), length - 1);
// Copy input data to 1D vector.
bool input_is_zero = true;
const int64 ub = expand_input ? length / 2 + 1 : length;
for (int64 k = 0; k < ub; k++) {
complex128 value = data[start + k * stride];
input_is_zero &= value == complex128(0.0, 0.0);
buffer[k] = value;
if (expand_input) {
// Use conjugates of the values at indices [1 ... (ub - 2)] when the
// length is even and at indices [1 ... (ub - 1)] when the length is odd
// to calculate missing values at indices [(length - 1) ... ub].
if (k > 0 && k < (length - ub + 1)) {
buffer[length - k] = std::conj(value);
}
}
}
// Do 1D transformation with double precision.
if (!input_is_zero) {
const int64 ub = contract_output ? length / 2 + 1 : length;
for (int64 k = 0; k < ub; k++) {
complex128 value = complex128(0.0, 0.0);
for (int n = 0; n < length; n++) {
auto coeff = std::exp(complex128(0.0, -2.0 * M_PI * n * k / length));
value += (inverse ? std::conj(buffer[n]) : buffer[n]) * coeff;
}
data[start + k * stride] =
inverse ? std::conj(value) / complex128(length, 0.0) : value;
}
}
}
// Helper to reverse the order of dimension lengths in the passed-in literal.
std::vector<int64> GetDimensionLengths(const Literal& literal) {
std::vector<int64> lengths = literal.shape().dimensions();
absl::c_reverse(lengths);
return lengths;
}
// Helper to compute strides for creating linear indices into multidimensional
// data from the dimension lengths and the layout. Returns a new vector of size
// lengths.size() + 1. The last element of the returned vector at index
// [lengths.size()] contains the product of all dimension lengths.
std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
const Layout& layout) {
const int64 num_dimensions = lengths.size();
// Make sure that the layout length matches the number of dimensions.
CHECK_EQ(num_dimensions, layout.minor_to_major_size());
// Calculate strides using layout-specified ordering of the dimensions and
// place the stride for axis 0 at index 0, for axis 1 at index 1, etc.
std::vector<int64> strides(num_dimensions + 1);
int64 stride = 1;
for (int64 i = 0; i < num_dimensions; i++) {
// Reverse the ordering of the dimensions in the layout.
const int64 index = (num_dimensions - 1) - layout.minor_to_major(i);
strides[index] = stride;
stride *= lengths[index];
}
strides[num_dimensions] = stride;
return strides;
}
// Compute strides as above using the default layout.
std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths) {
return ComputeStrides(lengths,
LayoutUtil::GetDefaultLayoutForRank(lengths.size()));
}
// Compute strides as above using the layout from the literal, if available.
std::vector<int64> ComputeStrides(const absl::Span<const int64> lengths,
const Literal& literal) {
return literal.shape().has_layout()
? ComputeStrides(lengths, literal.shape().layout())
: ComputeStrides(lengths);
}
// Make 1D sweeps along each transform axis.
void Sweep(int64 fft_rank, FftType fft_type,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
absl::Span<complex128> data, absl::Span<complex128> buffer) {
const bool inverse = fft_type == FftType::IFFT || fft_type == FftType::IRFFT;
const bool input_is_truncated = fft_type == FftType::IRFFT;
const bool output_is_truncated = fft_type == FftType::RFFT;
// Recursively visit each column of the data along the sweep_axis. Calculate
// linearized index of that column's first element and the stride, then invoke
// 1D transform.
// For RFFT, avoid calculating unused output values: first, compute only
// (length_x / 2) + 1 values along the X axis, then limit the X coordinate to
// [0 ... (length / 2)] during the sweeps along other axes. Similarly, for
// IRFFT sweep along higher dimensions first, while keeping the X coordinate
// in the [0 ... (length / 2)] range, then re-create negative frequencies
// omitted in the input and perform the full-length transform along the X axis
// in the last sweep.
std::function<void(int64, int64, int64)> sweep = [&](int64 sweep_axis,
int64 axis,
int64 start) {
if (axis < 0) {
// Base case: invoke 1D transform.
const int64 length = fft_lengths[sweep_axis];
const int64 stride = fft_strides[sweep_axis];
const bool expand_input = input_is_truncated && sweep_axis == 0;
const bool contract_oputput = output_is_truncated && sweep_axis == 0;
NaiveDft1D(length, start, stride, inverse, contract_oputput, expand_input,
data, buffer);
} else if (axis == sweep_axis) {
// Visit only the elements with coordinate 0 along the sweep axis.
sweep(sweep_axis, axis - 1, start);
} else {
const int64 length = fft_lengths[axis];
const bool is_truncated = input_is_truncated || output_is_truncated;
const int64 ub = is_truncated && axis == 0 ? (length / 2) + 1 : length;
for (int64 i = 0; i < ub; i++) {
sweep(sweep_axis, axis - 1, start + i * fft_strides[axis]);
}
}
};
if (input_is_truncated) {
// Sweep along the X axis last for IRFFT.
for (int64 sweep_axis = fft_rank - 1; sweep_axis >= 0; sweep_axis--) {
sweep(sweep_axis, fft_rank - 1, 0);
}
} else {
// Sweep along the X axis first for RFFT. The order does not matter for FFT
// and IFFT types; handle them here as well.
for (int64 sweep_axis = 0; sweep_axis < fft_rank; sweep_axis++) {
sweep(sweep_axis, fft_rank - 1, 0);
}
}
}
// These templates convert the data from the input data type to the type used in
// calculations and then to the output data type. They are intended to be used
// only within the DFT implementation. One special case is IRFFT, where the
// specialization drops imaginary parts of complex values (which is expected to
// be 0) and returns real numbers.
template <typename ToType, typename FromType>
ToType GetAs(FromType value) {
return static_cast<ToType>(value);
}
template <>
float GetAs<float, complex128>(complex128 value) {
return static_cast<float>(value.real());
}
// This template generates two linearized indices, which can be used to access
// multidimensional arrays. It uses a recursive function, which passes the
// indices to the user-supplied callback function. The destination index is
// always within dst_lengths[] bounds. The boolean parameter within_src_bounds
// indicates whether the source index is within src_lengths[] bounds.
//
// The value returned from the callback function controls the recursion depth.
// Returning true indicates that the base case had been hit and the recursion
// stops. Otherwise, the recursion proceeds along the next less-major axis.
//
// For example, the base case when the axis value becomes negative invokes the
// callback function for each possible index within dst_lengths[] bounds. The
// base case when the axis value is equal to zero limits the indices to point
// only to first elements along the minor-most dimension, allowing the callback
// function to handle all values along the X axis.
//
template <typename BaseFn>
void GenerateIndices(const absl::Span<const int64> dst_lengths,
const absl::Span<const int64> dst_strides,
const absl::Span<const int64> src_lengths,
const absl::Span<const int64> src_strides, int64 fft_rank,
int64 dst_start, int64 src_start, BaseFn&& base) {
CHECK_EQ(dst_lengths.size() + 1, dst_strides.size());
CHECK_GE(dst_lengths.size(), fft_rank);
CHECK_EQ(src_lengths.size() + 1, src_strides.size());
CHECK_GE(src_lengths.size(), fft_rank);
std::function<void(int64, int64, int64, bool)> generate =
[&](int64 axis, int64 dst_index, int64 src_index,
bool within_src_bounds) {
if (!base(axis, dst_index, src_index, within_src_bounds)) {
for (int64 i = 0; i < dst_lengths[axis]; i++) {
// Because the loop goes over dst_lengths[], the source index may be
// out of src_lengths[] bounds. In this case, within_src_bounds is
// false.
within_src_bounds &= i < src_lengths[axis];
generate(axis - 1, dst_index, src_index, within_src_bounds);
dst_index += dst_strides[axis];
src_index += src_strides[axis];
}
}
};
generate(fft_rank - 1, dst_start, src_start, true);
}
// Copies the input data from a literal to a pre-allocated vector. The sizes of
// the input and the transform do not need to match. For each axis of the
// transform, any extra input values beyond the transform length are ignored.
// Conversely, if the input does not contain enough elements along any axis, the
// data is padded with zeroes.
//
// For IRFFT transforms, we use (length_x / 2) + 1 elements from the input,
// where length_x is the size of the full transform along the X axis.
//
// The input literal may have a rank higher than the rank of the transform.
// Passed-in input_index value points to the first element of the input literal
// to be copied.
//
// Returns true if all values in the work data set are zeroes.
//
template <typename InputType>
bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
int64 fft_rank, FftType fft_type, int64 fft_size,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
const absl::Span<const int64> input_lengths,
const absl::Span<const int64> input_strides,
absl::Span<complex128> data) {
CHECK_GE(data.size(), fft_size);
const bool input_is_truncated = fft_type == FftType::IRFFT;
// Recursively visit each transform dimension to copy input values to the
// working data set. The base case handles inputs along the X axis.
bool input_is_zero = true;
const InputType* input_data = input_literal.data<InputType>().data();
auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
bool within_src_bounds) {
if (axis == 0) {
// For IRFFT, the negavie frequencies are only needed for the sweep along
// the X axis, which is performed last. Leave this part of the working set
// uninitialized until then.
const int64 length = fft_lengths[axis];
const int64 ub = input_is_truncated ? (length / 2) + 1 : length;
for (int64 i = 0; i < ub; i++) {
complex128 value = InputType(0);
// Read input value only if the index is within bounds.
if (within_src_bounds && i < input_lengths[axis]) {
value = GetAs<complex128, InputType>(
input_data[src_index + i * input_strides[axis]]);
input_is_zero &= value == complex128(0.0, 0.0);
}
data[dst_index + i * fft_strides[axis]] = value;
}
return true;
}
return false;
};
GenerateIndices(fft_lengths, fft_strides, input_lengths, input_strides,
fft_rank, 0, input_start, base_case);
return input_is_zero;
}
// Copies the result of the transform to the literal output. The sizes of the
// transform and output must match.
//
// For RFFT transforms, we copy (length_x / 2) + 1 elements, where length_x is
// the size of the full transform along the X axis (the most minor dimension).
//
// The output literal may have a rank higher than the rank of the transform.
// Passed-in output_index value points to the first element of the output
// literal to be filled in.
//
template <typename OutputType>
void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
int64 fft_rank, FftType fft_type,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
const absl::Span<const int64> output_lengths,
const absl::Span<const int64> output_strides,
Literal* output_literal) {
const bool output_is_truncated = fft_type == FftType::RFFT;
// Base case for recursive copy of the results to the output. The code avoids
// making a recursive call for each output element by handling axis 0 in the
// loop (as opposed to making "axis < 0" to be the base case).
OutputType* output_data = output_literal->data<OutputType>().data();
auto base_case = [&](int64 axis, int64 dst_index, int64 src_index,
bool within_src_bounds) {
if (axis == 0) {
// Drop negative frequencies for RFFT.
const int64 length = fft_lengths[axis];
const int64 ub = output_is_truncated ? (length / 2) + 1 : length;
for (int64 i = 0; i < output_lengths[axis]; i++) {
OutputType value = OutputType(0);
// Read data only if the index is within bounds.
if (within_src_bounds && i < ub) {
value = GetAs<OutputType, complex128>(
data[src_index + i * fft_strides[axis]]);
}
output_data[dst_index + i * output_strides[axis]] = value;
}
return true;
}
return false;
};
GenerateIndices(output_lengths, output_strides, fft_lengths, fft_strides,
fft_rank, output_start, 0, base_case);
}
// Determine the type to use with the CopyDataFromInput<> template above.
bool CopyDataFromInput(const Literal& input_literal, int64 input_start,
int64 fft_rank, FftType fft_type, int64 fft_size,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
const absl::Span<const int64> input_lengths,
const absl::Span<const int64> input_strides,
absl::Span<complex128> data) {
const bool input_is_float = fft_type == FftType::RFFT;
if (input_is_float) {
return CopyDataFromInput<float>(
input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
fft_strides, input_lengths, input_strides, data);
} else {
return CopyDataFromInput<complex64>(
input_literal, input_start, fft_rank, fft_type, fft_size, fft_lengths,
fft_strides, input_lengths, input_strides, data);
}
}
// Determine the type to use with the CopyDataToOutput<> template above.
void CopyDataToOutput(const absl::Span<complex128> data, int64 output_start,
int64 fft_rank, FftType fft_type,
const absl::Span<const int64> fft_lengths,
const absl::Span<const int64> fft_strides,
const absl::Span<const int64> output_lengths,
const absl::Span<const int64> output_strides,
Literal* output_literal) {
const bool output_is_float = fft_type == FftType::IRFFT;
if (output_is_float) {
CopyDataToOutput<float>(data, output_start, fft_rank, fft_type, fft_lengths,
fft_strides, output_lengths, output_strides,
output_literal);
} else {
CopyDataToOutput<complex64>(data, output_start, fft_rank, fft_type,
fft_lengths, fft_strides, output_lengths,
output_strides, output_literal);
}
}
Status CheckParameters(const Shape& input_shape, const Shape& output_shape,
int64 fft_rank, FftType fft_type,
const absl::Span<const int64> fft_lengths) {
// Check FFT parameters.
if (fft_rank <= 0) {
return InvalidArgument("Zero or negative FFT rank.");
}
if (*absl::c_min_element(fft_lengths) < 0) {
return InvalidArgument("Negative FFT length.");
}
// Check input-related values.
TF_CHECK_OK(ShapeUtil::ValidateShape(input_shape));
if (!input_shape.IsArray()) {
return Unimplemented("Only array input shapes are supported.");
}
auto input_elt_type = input_shape.element_type();
if (fft_type == FftType::RFFT && input_elt_type != PrimitiveType::F32) {
return InvalidArgument("Invalid input type: %d, must be %d (float).",
input_elt_type, PrimitiveType::F32);
}
if (fft_type != FftType::RFFT && input_elt_type != PrimitiveType::C64) {
return InvalidArgument("Invalid input type: %d, must be %d (complex64).",
input_elt_type, PrimitiveType::C64);
}
const int64 input_rank = input_shape.rank();
if (input_rank < fft_rank) {
return InvalidArgument("Input shape rank is smaller than FFT rank.");
}
// Check output-related values.
TF_CHECK_OK(ShapeUtil::ValidateShape(output_shape));
if (!output_shape.IsArray()) {
return Unimplemented("Only array output shapes are supported.");
}
auto output_elt_type = output_shape.element_type();
if (fft_type == FftType::IRFFT && output_elt_type != PrimitiveType::F32) {
return InvalidArgument("Invalid output type: %d, must be %d (float).",
output_elt_type, PrimitiveType::F32);
}
if (fft_type != FftType::IRFFT && output_elt_type != PrimitiveType::C64) {
return InvalidArgument("Invalid output type: %d, must be %d (complex64).",
output_elt_type, PrimitiveType::C64);
}
const int64 output_rank = output_shape.rank();
if (output_rank < fft_rank) {
return InvalidArgument("Output shape rank is smaller than FFT rank.");
}
// Consistency of input and output parameters.
if (input_rank != output_rank) {
return InvalidArgument(
"Ranks of input shape and output shape do not match.");
}
for (int64 dim = 0; dim < input_rank - fft_rank; dim++) {
if (ShapeUtil::GetDimension(input_shape, dim) !=
ShapeUtil::GetDimension(output_shape, dim)) {
return InvalidArgument(
"Higher dimension lengths of input shape and output shape do not "
"match.");
}
}
return Status::OK();
}
} // namespace
// Flexible but slow implementation of the discrete Fourier transform. All
// transform types (FFT, IFFT, RFFT, and IRFFT) are supported, as well as the
// arbitrary rank and length of each dimension of the transform, and arbitrary
// layouts of the input and output literals.
//
// The input literal in operand 0 provides input data, which must be complex64
// for FFT, IFFT, IRFFT transforms and float for RFFT. The transform is computed
// over the innermost dimensions of the input, thus the rank of the input data
// must be same as fft_rank or larger. The input is expected to provide Ni
// values along each transform axis with one exception: for IRFFT, only
// (N0 / 2) + 1 values are needed along the X axis (the innermost index). To
// increase flexibility, this implementation can handle mismatches between the
// input size and transform lengths by either dropping extra input values or
// using zeroes in place of missing input values as necessary. If the input data
// has rank higher than the transform, the transform is applied for each valid
// combination of the higher-ranking indices.
//
// The output contains complex64 values for FFT, IFFT, RFFT, and float values
// for IRFFT. The rank of the output as well as the sizes of the dimensions
// above the rank of the transform must match those of the input. Sizes of the
// output's "fft_rank" innermost dimensions are expected to match the length of
// the transform along respective axes with one exception: for RFFT, the output
// is trimmed along the X axis to have only (N0 / 2) + 1 values. In case the
// length(s) mismatch, the FFT output is trimmed to fit into the provided output
// shape, or the output is padded with zero values appropriately.
//
// For example, 2D FFT transform of size 16x16 applied to complex64[2][15][17]
// input array will perform two transforms over the [][15][17] data in the sub
// arrays [0][][] and [1][][], dropping the values along axis X and padding axis
// Y with zeroes to create 16x16 working sets, and generating
// complex64[2][16][16] output. 3D IRFFT transform of size 64x16x16 applied to
// complex64[64][16][9] input array will use all input values and will produce
// float[64][16][16] output.
//
// The implementation of the 1D transform is a straightforward loop nest. The
// transforms of higher ranks apply sets of 1D transforms along each axis. For
// example, the 2D transform is computed by applying 1D transforms to each
// column followed by applying 1D transforms to each row.
//
// In general, a transform of rank n runs in O(N0*N1*...*Nn*(N0+N1+...+Nn))
// time, where Ni is the length of the transform's i-th dimension. It is
// possible to reduce the run time to O(N0*N1*...(log(N0)+log(N1)+...)) by
// plugging in a more efficient 1D implementation.
//
Status HloEvaluator::HandleFft(HloInstruction* fft) {
const FftType fft_type = fft->fft_type();
std::vector<int64> fft_lengths = fft->fft_length();
const int64 fft_rank = fft_lengths.size();
const Literal& input_literal = GetEvaluatedLiteralFor(fft->operand(0));
const Shape& input_shape = input_literal.shape();
const Shape& output_shape = fft->shape();
Literal output_literal = Literal::CreateFromShape(output_shape);
// Make fft_lengths[0] the minor-most dimension.
absl::c_reverse(fft_lengths);
TF_RETURN_IF_ERROR(CheckParameters(input_shape, output_shape, fft_rank,
fft_type, fft_lengths));
const auto fft_strides = ComputeStrides(fft_lengths);
// Working set size.
const int64 fft_size = fft_strides[fft_rank];
if (fft_size > 0) {
// Linearized working data set.
std::vector<complex128> data(fft_size);
// Temporary buffer allocated once and used in 1D sweeps.
std::vector<complex128> buffer(*absl::c_max_element(fft_lengths));
// Sizes of each axis of input and output literals.
const auto input_lengths = GetDimensionLengths(input_literal);
const auto output_lengths = GetDimensionLengths(output_literal);
// Strides for generating linearized indices into multidimensional arrays.
const auto input_strides = ComputeStrides(input_lengths, input_literal);
const auto output_strides = ComputeStrides(output_lengths, output_literal);
// Visit all elements in the dimensions with ranks above the FFT rank. For
// each such element invoke the transform. Use separate indices for the
// input and the output to allow different layouts.
auto base_case = [&](int64 axis, int64 output_index, int64 input_index,
bool within_src_bounds) {
if (axis == fft_rank - 1) {
// Base case: copy the data from the input literal, apply the
// transform, and copy the result to the output literal.
CHECK(within_src_bounds);
bool input_is_zero =
CopyDataFromInput(input_literal, input_index, fft_rank, fft_type,
fft_size, fft_lengths, fft_strides, input_lengths,
input_strides, absl::MakeSpan(data));
if (!input_is_zero) {
// Make 1D sweeps along each transform axis.
Sweep(fft_rank, fft_type, fft_lengths, fft_strides,
absl::MakeSpan(data), absl::MakeSpan(buffer));
}
CopyDataToOutput(absl::MakeSpan(data), output_index, fft_rank, fft_type,
fft_lengths, fft_strides, output_lengths,
output_strides, &output_literal);
return true;
}
return false;
};
GenerateIndices(output_lengths, output_strides, input_lengths,
input_strides, input_shape.rank(), 0, 0, base_case);
}
evaluated_[fft] = std::move(output_literal);
return Status::OK();
}
// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch
// dimensions while keeping the rest of the output dimensions clamped to 0.
ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices(

View File

@ -204,6 +204,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleTuple(HloInstruction* tuple) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleGather(HloInstruction* gather) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;

File diff suppressed because it is too large Load Diff

View File

@ -68,8 +68,8 @@ T ToArithmeticSafeType(T t) {
// Templated DfsHloVisitor for use by HloEvaluator.
//
// Typically ReturnT here indicates the resulting literal type of each evaluated
// Handle* method of a TypedVisitor. There are however a few notable exceptions
// to this rule, notably:
// Handle* method of a TypedVisitor. There are however a few exceptions to this
// rule, notably:
// - HandleCompare and HandleIsFinite: where the resulting literal type is
// always boolean.
// - HandleImag and HandleReal: where the resulting literal type is always float
@ -81,7 +81,7 @@ T ToArithmeticSafeType(T t) {
// - ReturnT: The type of input and output of each operation.
// - ElementwiseT: The type in which internal computation are done.
//
// This a logically a private part of HloEvaluator. It lives in this header
// This is logically a private part of HloEvaluator. It lives in this header
// file rather than in hlo_evaluator.cc because we use extern templates and a
// bunch of independent cc files to speed up compiling the many instantiations
// of this class.