Add complex128 support to RFFT, RFFT2D, RFFT3D, IRFFT, IRFFT2D, and IRFFT3D.
Finishes support requested in: - #10749 - #17332 - https://stackoverflow.com/questions/47214508/the-result-of-fft-in-tensorflow-is-different-from-numpy PiperOrigin-RevId: 268772775
This commit is contained in:
parent
731984bfd0
commit
4a3cbea5f3
tensorflow
compiler/tf2xla/kernels
core
api_def/base_api
api_def_FFT3D.pbtxtapi_def_IFFT3D.pbtxtapi_def_IRFFT.pbtxtapi_def_IRFFT2D.pbtxtapi_def_IRFFT3D.pbtxt
kernels
ops
python
tools/api/golden
@ -138,10 +138,20 @@ class RFFTOp : public GenericFftOp {
|
||||
explicit RFFTOp(OpKernelConstruction* ctx)
|
||||
: GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("RFFT").CompileTimeConstantInput("fft_length"), RFFTOp<1>);
|
||||
REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstantInput("fft_length"),
|
||||
REGISTER_XLA_OP(Name("RFFT")
|
||||
.TypeConstraint("Treal", DT_FLOAT)
|
||||
.TypeConstraint("Tcomplex", DT_COMPLEX64)
|
||||
.CompileTimeConstantInput("fft_length"),
|
||||
RFFTOp<1>);
|
||||
REGISTER_XLA_OP(Name("RFFT2D")
|
||||
.TypeConstraint("Treal", DT_FLOAT)
|
||||
.TypeConstraint("Tcomplex", DT_COMPLEX64)
|
||||
.CompileTimeConstantInput("fft_length"),
|
||||
RFFTOp<2>);
|
||||
REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstantInput("fft_length"),
|
||||
REGISTER_XLA_OP(Name("RFFT3D")
|
||||
.TypeConstraint("Treal", DT_FLOAT)
|
||||
.TypeConstraint("Tcomplex", DT_COMPLEX64)
|
||||
.CompileTimeConstantInput("fft_length"),
|
||||
RFFTOp<3>);
|
||||
|
||||
template <int FFTRank>
|
||||
@ -150,11 +160,20 @@ class IRFFTOp : public GenericFftOp {
|
||||
explicit IRFFTOp(OpKernelConstruction* ctx)
|
||||
: GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstantInput("fft_length"),
|
||||
REGISTER_XLA_OP(Name("IRFFT")
|
||||
.TypeConstraint("Treal", DT_FLOAT)
|
||||
.TypeConstraint("Tcomplex", DT_COMPLEX64)
|
||||
.CompileTimeConstantInput("fft_length"),
|
||||
IRFFTOp<1>);
|
||||
REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstantInput("fft_length"),
|
||||
REGISTER_XLA_OP(Name("IRFFT2D")
|
||||
.TypeConstraint("Treal", DT_FLOAT)
|
||||
.TypeConstraint("Tcomplex", DT_COMPLEX64)
|
||||
.CompileTimeConstantInput("fft_length"),
|
||||
IRFFTOp<2>);
|
||||
REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstantInput("fft_length"),
|
||||
REGISTER_XLA_OP(Name("IRFFT3D")
|
||||
.TypeConstraint("Treal", DT_FLOAT)
|
||||
.TypeConstraint("Tcomplex", DT_COMPLEX64)
|
||||
.CompileTimeConstantInput("fft_length"),
|
||||
IRFFTOp<3>);
|
||||
|
||||
} // namespace
|
||||
|
@ -3,13 +3,13 @@ op {
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
A complex64 tensor.
|
||||
A complex tensor.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A complex64 tensor of the same shape as `input`. The inner-most 3
|
||||
A complex tensor of the same shape as `input`. The inner-most 3
|
||||
dimensions of `input` are replaced with their 3D Fourier transform.
|
||||
|
||||
@compatibility(numpy)
|
||||
|
@ -3,13 +3,13 @@ op {
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
A complex64 tensor.
|
||||
A complex tensor.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A complex64 tensor of the same shape as `input`. The inner-most 3
|
||||
A complex tensor of the same shape as `input`. The inner-most 3
|
||||
dimensions of `input` are replaced with their inverse 3D Fourier transform.
|
||||
|
||||
@compatibility(numpy)
|
||||
|
@ -3,7 +3,7 @@ op {
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
A complex64 tensor.
|
||||
A complex tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
|
@ -3,7 +3,7 @@ op {
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
A complex64 tensor.
|
||||
A complex tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
|
@ -3,7 +3,7 @@ op {
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
A complex64 tensor.
|
||||
A complex tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
// See docs in ../ops/spectral_ops.cc.
|
||||
// See docs in ../ops/fft_ops.cc.
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
@ -94,7 +95,34 @@ class FFTBase : public OpKernel {
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
|
||||
|
||||
if (IsReal()) {
|
||||
if (IsForward()) {
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
(in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64) ||
|
||||
(in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128),
|
||||
errors::InvalidArgument("Wrong types for forward real FFT: in=",
|
||||
in.dtype(), " out=", out->dtype()));
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
(in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT) ||
|
||||
(in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE),
|
||||
errors::InvalidArgument("Wrong types for backward real FFT: in=",
|
||||
in.dtype(), " out=", out->dtype()));
|
||||
}
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
(in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) ||
|
||||
(in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128),
|
||||
errors::InvalidArgument("Wrong types for FFT: in=", in.dtype(),
|
||||
" out=", out->dtype()));
|
||||
}
|
||||
|
||||
if (input_shape.num_elements() == 0) {
|
||||
DCHECK_EQ(0, output_shape.num_elements());
|
||||
return;
|
||||
}
|
||||
|
||||
@ -129,126 +157,159 @@ class FFTCPU : public FFTBase {
|
||||
const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
|
||||
auto device = ctx->eigen_device<CPUDevice>();
|
||||
|
||||
const bool is_complex128 =
|
||||
in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
|
||||
|
||||
if (!IsReal()) {
|
||||
// Compute the FFT using Eigen.
|
||||
constexpr auto direction =
|
||||
Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
|
||||
if (in.dtype() == DT_COMPLEX64) {
|
||||
if (is_complex128) {
|
||||
DCHECK_EQ(in.dtype(), DT_COMPLEX128);
|
||||
DCHECK_EQ(out->dtype(), DT_COMPLEX128);
|
||||
auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
|
||||
auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
|
||||
output.device(device) =
|
||||
input.template fft<Eigen::BothParts, direction>(axes);
|
||||
} else {
|
||||
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
|
||||
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
|
||||
auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
|
||||
auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
|
||||
output.device(device) =
|
||||
input.template fft<Eigen::BothParts, direction>(axes);
|
||||
} else {
|
||||
DCHECK_EQ(DT_COMPLEX128, in.dtype());
|
||||
DCHECK_EQ(DT_COMPLEX128, out->dtype());
|
||||
auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
|
||||
auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
|
||||
output.device(device) =
|
||||
input.template fft<Eigen::BothParts, direction>(axes);
|
||||
}
|
||||
} else {
|
||||
if (IsForward()) {
|
||||
auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>();
|
||||
const auto input_dims = input.dimensions();
|
||||
|
||||
// Slice input to fft_shape on its inner-most dimensions.
|
||||
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
|
||||
input_slice_sizes[0] = input_dims[0];
|
||||
TensorShape temp_shape{input_dims[0]};
|
||||
for (int i = 1; i <= FFTRank; ++i) {
|
||||
input_slice_sizes[i] = fft_shape[i - 1];
|
||||
temp_shape.AddDim(fft_shape[i - 1]);
|
||||
if (is_complex128) {
|
||||
DCHECK_EQ(in.dtype(), DT_DOUBLE);
|
||||
DCHECK_EQ(out->dtype(), DT_COMPLEX128);
|
||||
DoRealForwardFFT<double, complex128>(ctx, fft_shape, in, out);
|
||||
} else {
|
||||
DCHECK_EQ(in.dtype(), DT_FLOAT);
|
||||
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
|
||||
DoRealForwardFFT<float, complex64>(ctx, fft_shape, in, out);
|
||||
}
|
||||
|
||||
auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
|
||||
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
|
||||
|
||||
// Compute the full FFT using a temporary tensor.
|
||||
Tensor temp;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
|
||||
temp_shape, &temp));
|
||||
auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
|
||||
full_fft.device(device) =
|
||||
input.slice(zero_start_indices, input_slice_sizes)
|
||||
.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
|
||||
|
||||
// Slice away the negative frequency components.
|
||||
output.device(device) =
|
||||
full_fft.slice(zero_start_indices, output.dimensions());
|
||||
} else {
|
||||
// Reconstruct the full FFT and take the inverse.
|
||||
auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
|
||||
auto output = out->flat_inner_dims<float, FFTRank + 1>();
|
||||
const auto input_dims = input.dimensions();
|
||||
|
||||
// Calculate the shape of the temporary tensor for the full FFT and the
|
||||
// region we will slice from input given fft_shape. We slice input to
|
||||
// fft_shape on its inner-most dimensions, except the last (which we
|
||||
// slice to fft_shape[-1] / 2 + 1).
|
||||
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
|
||||
input_slice_sizes[0] = input_dims[0];
|
||||
TensorShape full_fft_shape;
|
||||
full_fft_shape.AddDim(input_dims[0]);
|
||||
for (auto i = 1; i <= FFTRank; i++) {
|
||||
input_slice_sizes[i] =
|
||||
i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
|
||||
full_fft_shape.AddDim(fft_shape[i - 1]);
|
||||
if (is_complex128) {
|
||||
DCHECK_EQ(in.dtype(), DT_COMPLEX128);
|
||||
DCHECK_EQ(out->dtype(), DT_DOUBLE);
|
||||
DoRealBackwardFFT<complex128, double>(ctx, fft_shape, in, out);
|
||||
} else {
|
||||
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
|
||||
DCHECK_EQ(out->dtype(), DT_FLOAT);
|
||||
DoRealBackwardFFT<complex64, float>(ctx, fft_shape, in, out);
|
||||
}
|
||||
|
||||
Tensor temp;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
|
||||
full_fft_shape, &temp));
|
||||
auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
|
||||
|
||||
// Calculate the starting point and range of the source of
|
||||
// negative frequency part.
|
||||
auto neg_sizes = input_slice_sizes;
|
||||
neg_sizes[FFTRank] =
|
||||
fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
|
||||
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
|
||||
neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
|
||||
|
||||
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
|
||||
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
|
||||
neg_start_indices[FFTRank] = 1;
|
||||
|
||||
full_fft.slice(start_indices, input_slice_sizes).device(device) =
|
||||
input.slice(start_indices, input_slice_sizes);
|
||||
|
||||
// First, conduct IFFTs on outer dimensions. We save computation (and
|
||||
// avoid touching uninitialized memory) by slicing full_fft to the
|
||||
// subregion we wrote input to.
|
||||
if (FFTRank > 1) {
|
||||
const auto outer_axes =
|
||||
Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
|
||||
full_fft.slice(start_indices, input_slice_sizes).device(device) =
|
||||
full_fft.slice(start_indices, input_slice_sizes)
|
||||
.template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(
|
||||
outer_axes);
|
||||
}
|
||||
|
||||
// Reconstruct the full FFT by appending reversed and conjugated
|
||||
// spectrum as the negative frequency part.
|
||||
Eigen::array<bool, FFTRank + 1> reverse_last_axis;
|
||||
for (auto i = 0; i <= FFTRank; i++) {
|
||||
reverse_last_axis[i] = i == FFTRank;
|
||||
}
|
||||
|
||||
if (neg_sizes[FFTRank] != 0) {
|
||||
full_fft.slice(neg_target_indices, neg_sizes).device(device) =
|
||||
full_fft.slice(neg_start_indices, neg_sizes)
|
||||
.reverse(reverse_last_axis)
|
||||
.conjugate();
|
||||
}
|
||||
|
||||
auto inner_axis = Eigen::array<int, 1>{FFTRank};
|
||||
output.device(device) =
|
||||
full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(
|
||||
inner_axis);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RealT, typename ComplexT>
|
||||
void DoRealForwardFFT(OpKernelContext* ctx, uint64* fft_shape,
|
||||
const Tensor& in, Tensor* out) {
|
||||
// Create the axes (which are always trailing).
|
||||
const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
|
||||
auto device = ctx->eigen_device<CPUDevice>();
|
||||
auto input = Tensor(in).flat_inner_dims<RealT, FFTRank + 1>();
|
||||
const auto input_dims = input.dimensions();
|
||||
|
||||
// Slice input to fft_shape on its inner-most dimensions.
|
||||
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
|
||||
input_slice_sizes[0] = input_dims[0];
|
||||
TensorShape temp_shape{input_dims[0]};
|
||||
for (int i = 1; i <= FFTRank; ++i) {
|
||||
input_slice_sizes[i] = fft_shape[i - 1];
|
||||
temp_shape.AddDim(fft_shape[i - 1]);
|
||||
}
|
||||
|
||||
auto output = out->flat_inner_dims<ComplexT, FFTRank + 1>();
|
||||
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
|
||||
|
||||
// Compute the full FFT using a temporary tensor.
|
||||
Tensor temp;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
|
||||
temp_shape, &temp));
|
||||
auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
|
||||
full_fft.device(device) =
|
||||
input.slice(zero_start_indices, input_slice_sizes)
|
||||
.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
|
||||
|
||||
// Slice away the negative frequency components.
|
||||
output.device(device) =
|
||||
full_fft.slice(zero_start_indices, output.dimensions());
|
||||
}
|
||||
|
||||
template <typename ComplexT, typename RealT>
|
||||
void DoRealBackwardFFT(OpKernelContext* ctx, uint64* fft_shape,
|
||||
const Tensor& in, Tensor* out) {
|
||||
auto device = ctx->eigen_device<CPUDevice>();
|
||||
// Reconstruct the full FFT and take the inverse.
|
||||
auto input = Tensor(in).flat_inner_dims<ComplexT, FFTRank + 1>();
|
||||
auto output = out->flat_inner_dims<RealT, FFTRank + 1>();
|
||||
const auto input_dims = input.dimensions();
|
||||
|
||||
// Calculate the shape of the temporary tensor for the full FFT and the
|
||||
// region we will slice from input given fft_shape. We slice input to
|
||||
// fft_shape on its inner-most dimensions, except the last (which we
|
||||
// slice to fft_shape[-1] / 2 + 1).
|
||||
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
|
||||
input_slice_sizes[0] = input_dims[0];
|
||||
TensorShape full_fft_shape;
|
||||
full_fft_shape.AddDim(input_dims[0]);
|
||||
for (auto i = 1; i <= FFTRank; i++) {
|
||||
input_slice_sizes[i] =
|
||||
i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
|
||||
full_fft_shape.AddDim(fft_shape[i - 1]);
|
||||
}
|
||||
|
||||
Tensor temp;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
|
||||
full_fft_shape, &temp));
|
||||
auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
|
||||
|
||||
// Calculate the starting point and range of the source of
|
||||
// negative frequency part.
|
||||
auto neg_sizes = input_slice_sizes;
|
||||
neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
|
||||
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
|
||||
neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
|
||||
|
||||
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
|
||||
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
|
||||
neg_start_indices[FFTRank] = 1;
|
||||
|
||||
full_fft.slice(start_indices, input_slice_sizes).device(device) =
|
||||
input.slice(start_indices, input_slice_sizes);
|
||||
|
||||
// First, conduct IFFTs on outer dimensions. We save computation (and
|
||||
// avoid touching uninitialized memory) by slicing full_fft to the
|
||||
// subregion we wrote input to.
|
||||
if (FFTRank > 1) {
|
||||
const auto outer_axes =
|
||||
Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
|
||||
full_fft.slice(start_indices, input_slice_sizes).device(device) =
|
||||
full_fft.slice(start_indices, input_slice_sizes)
|
||||
.template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
|
||||
}
|
||||
|
||||
// Reconstruct the full FFT by appending reversed and conjugated
|
||||
// spectrum as the negative frequency part.
|
||||
Eigen::array<bool, FFTRank + 1> reverse_last_axis;
|
||||
for (auto i = 0; i <= FFTRank; i++) {
|
||||
reverse_last_axis[i] = i == FFTRank;
|
||||
}
|
||||
|
||||
if (neg_sizes[FFTRank] != 0) {
|
||||
full_fft.slice(neg_target_indices, neg_sizes).device(device) =
|
||||
full_fft.slice(neg_start_indices, neg_sizes)
|
||||
.reverse(reverse_last_axis)
|
||||
.conjugate();
|
||||
}
|
||||
|
||||
auto inner_axis = Eigen::array<int, 1>{FFTRank};
|
||||
output.device(device) =
|
||||
full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU), FFTCPU<true, false, 1>);
|
||||
@ -390,16 +451,19 @@ class FFTGPUBase : public FFTBase {
|
||||
}
|
||||
|
||||
constexpr bool kInPlaceFft = false;
|
||||
const bool is_complex128 = in.dtype() == DT_COMPLEX128;
|
||||
// complex128 real FFT is not supported yet.
|
||||
DCHECK(!IsReal() || !is_complex128);
|
||||
const bool is_complex128 =
|
||||
in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
|
||||
|
||||
const auto kFftType =
|
||||
IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R)
|
||||
: (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
|
||||
: se::fft::Type::kC2CForward)
|
||||
: (is_complex128 ? se::fft::Type::kZ2ZInverse
|
||||
: se::fft::Type::kC2CInverse));
|
||||
IsReal()
|
||||
? (IsForward()
|
||||
? (is_complex128 ? se::fft::Type::kD2Z : se::fft::Type::kR2C)
|
||||
: (is_complex128 ? se::fft::Type::kZ2D
|
||||
: se::fft::Type::kC2R))
|
||||
: (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
|
||||
: se::fft::Type::kC2CForward)
|
||||
: (is_complex128 ? se::fft::Type::kZ2ZInverse
|
||||
: se::fft::Type::kC2CInverse));
|
||||
|
||||
CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
|
||||
auto plan =
|
||||
@ -410,67 +474,80 @@ class FFTGPUBase : public FFTBase {
|
||||
|
||||
if (IsReal()) {
|
||||
if (IsForward()) {
|
||||
auto src = AsDeviceMemory<float>(in.flat<float>().data());
|
||||
auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
|
||||
OP_REQUIRES(
|
||||
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
|
||||
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
|
||||
" in.shape=", input_shape.DebugString()));
|
||||
if (is_complex128) {
|
||||
DCHECK_EQ(in.dtype(), DT_DOUBLE);
|
||||
DCHECK_EQ(out->dtype(), DT_COMPLEX128);
|
||||
DoFFTInternal<double, complex128>(ctx, stream, plan.get(), kFftType,
|
||||
output_distance, in, out);
|
||||
} else {
|
||||
DCHECK_EQ(in.dtype(), DT_FLOAT);
|
||||
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
|
||||
DoFFTInternal<float, complex64>(ctx, stream, plan.get(), kFftType,
|
||||
output_distance, in, out);
|
||||
}
|
||||
} else {
|
||||
auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
|
||||
auto dst = AsDeviceMemory<float>(out->flat<float>().data());
|
||||
OP_REQUIRES(
|
||||
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
|
||||
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
|
||||
" in.shape=", input_shape.DebugString()));
|
||||
auto alpha = 1.f / output_distance;
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
|
||||
.ok(),
|
||||
errors::Internal("BlasScal failed : in.shape=",
|
||||
input_shape.DebugString()));
|
||||
if (is_complex128) {
|
||||
DCHECK_EQ(in.dtype(), DT_COMPLEX128);
|
||||
DCHECK_EQ(out->dtype(), DT_DOUBLE);
|
||||
DoFFTInternal<complex128, double>(ctx, stream, plan.get(), kFftType,
|
||||
output_distance, in, out);
|
||||
} else {
|
||||
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
|
||||
DCHECK_EQ(out->dtype(), DT_FLOAT);
|
||||
DoFFTInternal<complex64, float>(ctx, stream, plan.get(), kFftType,
|
||||
output_distance, in, out);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!is_complex128) {
|
||||
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
|
||||
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
|
||||
auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
|
||||
auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
|
||||
OP_REQUIRES(
|
||||
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
|
||||
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
|
||||
" in.shape=", input_shape.DebugString()));
|
||||
if (!IsForward()) {
|
||||
float alpha = 1.f / output_distance;
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
|
||||
.ok(),
|
||||
errors::Internal("BlasScal failed : in.shape=",
|
||||
input_shape.DebugString()));
|
||||
}
|
||||
} else {
|
||||
if (is_complex128) {
|
||||
DCHECK_EQ(in.dtype(), DT_COMPLEX128);
|
||||
DCHECK_EQ(out->dtype(), DT_COMPLEX128);
|
||||
auto src = AsDeviceMemory<complex128>(in.flat<complex128>().data());
|
||||
auto dst = AsDeviceMemory<complex128>(out->flat<complex128>().data());
|
||||
OP_REQUIRES(
|
||||
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
|
||||
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
|
||||
" in.shape=", input_shape.DebugString()));
|
||||
if (!IsForward()) {
|
||||
double alpha = 1.0 / output_distance;
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
|
||||
.ok(),
|
||||
errors::Internal("BlasScal failed : in.shape=",
|
||||
input_shape.DebugString()));
|
||||
}
|
||||
DoFFTInternal<complex128, complex128>(ctx, stream, plan.get(), kFftType,
|
||||
output_distance, in, out);
|
||||
} else {
|
||||
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
|
||||
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
|
||||
DoFFTInternal<complex64, complex64>(ctx, stream, plan.get(), kFftType,
|
||||
output_distance, in, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
struct RealTypeFromComplexType {
|
||||
typedef T RealT;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct RealTypeFromComplexType<std::complex<T>> {
|
||||
typedef T RealT;
|
||||
};
|
||||
|
||||
template <typename InT, typename OutT>
|
||||
void DoFFTInternal(OpKernelContext* ctx, se::Stream* stream,
|
||||
se::fft::Plan* plan, const se::fft::Type fft_type,
|
||||
const uint64 output_distance, const Tensor& in,
|
||||
Tensor* out) {
|
||||
auto src = AsDeviceMemory<InT>(in.flat<InT>().data());
|
||||
auto dst = AsDeviceMemory<OutT>(out->flat<OutT>().data());
|
||||
const TensorShape& input_shape = in.shape();
|
||||
const TensorShape& output_shape = out->shape();
|
||||
OP_REQUIRES(
|
||||
ctx, stream->ThenFft(plan, src, &dst).ok(),
|
||||
errors::Internal("fft failed : type=", static_cast<int>(fft_type),
|
||||
" in.shape=", input_shape.DebugString()));
|
||||
if (!IsForward()) {
|
||||
typedef typename RealTypeFromComplexType<OutT>::RealT RealT;
|
||||
RealT alpha = 1.0 / output_distance;
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
|
||||
.ok(),
|
||||
errors::Internal("BlasScal failed : in.shape=",
|
||||
input_shape.DebugString()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int64 FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(
|
||||
|
@ -109,39 +109,51 @@ Status RFFTShape(InferenceContext* c, const bool forward, const int rank) {
|
||||
}
|
||||
|
||||
REGISTER_OP("RFFT")
|
||||
.Input("input: float")
|
||||
.Input("input: Treal")
|
||||
.Input("fft_length: int32")
|
||||
.Output("output: complex64")
|
||||
.Output("output: Tcomplex")
|
||||
.Attr("Treal: {float32, float64} = DT_FLOAT")
|
||||
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
|
||||
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 1); });
|
||||
|
||||
REGISTER_OP("IRFFT")
|
||||
.Input("input: complex64")
|
||||
.Input("input: Tcomplex")
|
||||
.Input("fft_length: int32")
|
||||
.Output("output: float")
|
||||
.Output("output: Treal")
|
||||
.Attr("Treal: {float32, float64} = DT_FLOAT")
|
||||
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
|
||||
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 1); });
|
||||
|
||||
REGISTER_OP("RFFT2D")
|
||||
.Input("input: float")
|
||||
.Input("input: Treal")
|
||||
.Input("fft_length: int32")
|
||||
.Output("output: complex64")
|
||||
.Output("output: Tcomplex")
|
||||
.Attr("Treal: {float32, float64} = DT_FLOAT")
|
||||
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
|
||||
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 2); });
|
||||
|
||||
REGISTER_OP("IRFFT2D")
|
||||
.Input("input: complex64")
|
||||
.Input("input: Tcomplex")
|
||||
.Input("fft_length: int32")
|
||||
.Output("output: float")
|
||||
.Output("output: Treal")
|
||||
.Attr("Treal: {float32, float64} = DT_FLOAT")
|
||||
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
|
||||
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 2); });
|
||||
|
||||
REGISTER_OP("RFFT3D")
|
||||
.Input("input: float")
|
||||
.Input("input: Treal")
|
||||
.Input("fft_length: int32")
|
||||
.Output("output: complex64")
|
||||
.Output("output: Tcomplex")
|
||||
.Attr("Treal: {float32, float64} = DT_FLOAT")
|
||||
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
|
||||
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 3); });
|
||||
|
||||
REGISTER_OP("IRFFT3D")
|
||||
.Input("input: complex64")
|
||||
.Input("input: Tcomplex")
|
||||
.Input("fft_length: int32")
|
||||
.Output("output: float")
|
||||
.Output("output: Treal")
|
||||
.Attr("Treal: {float32, float64} = DT_FLOAT")
|
||||
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
|
||||
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 3); });
|
||||
|
||||
// Deprecated ops:
|
||||
|
@ -48,7 +48,6 @@ from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.signal import fft_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import training
|
||||
from tensorflow.python.util import nest
|
||||
@ -1579,23 +1578,6 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'ndarray'):
|
||||
g.watch(np.array(1.))
|
||||
|
||||
def testOpWithNoAttrs(self):
|
||||
|
||||
@function.defun(autograph=False)
|
||||
def f():
|
||||
with backprop.GradientTape() as tape:
|
||||
xs = random_ops.random_normal([10, 32])
|
||||
tape.watch(xs)
|
||||
# The `rfft()` op has no defined attrs, which exercises a different
|
||||
# branch in the Python op wrapper code generator for recording
|
||||
# gradients.
|
||||
ys = fft_ops.rfft(xs)
|
||||
self.assertEmpty(ys.op.node_def.attr)
|
||||
gs = tape.gradient(ys, xs)
|
||||
self.assertIsNotNone(gs)
|
||||
|
||||
f.get_concrete_function()
|
||||
|
||||
|
||||
class JacobianTest(test.TestCase):
|
||||
|
||||
|
@ -37,6 +37,7 @@ cuda_py_tests(
|
||||
|
||||
cuda_py_tests(
|
||||
name = "fft_ops_test",
|
||||
# TODO(rjryan): Parameterize the test to reduce the time it takes.
|
||||
size = "medium",
|
||||
srcs = ["fft_ops_test.py"],
|
||||
additional_deps = [
|
||||
@ -46,7 +47,7 @@ cuda_py_tests(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/ops/signal",
|
||||
],
|
||||
shard_count = 4,
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
|
@ -18,10 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -36,6 +39,18 @@ from tensorflow.python.platform import test
|
||||
VALID_FFT_RANKS = (1, 2, 3)
|
||||
|
||||
|
||||
def _forward_compat_context(np_dtype):
|
||||
@contextlib.contextmanager
|
||||
def null_context():
|
||||
yield
|
||||
if np_dtype in (np.float64, np.complex128):
|
||||
return compat.forward_compatibility_horizon(2019, 10, 13)
|
||||
else:
|
||||
return null_context()
|
||||
|
||||
|
||||
# TODO(rjryan): Investigate precision issues. We should be able to achieve
|
||||
# better tolerances, at least for the complex128 tests.
|
||||
class BaseFFTOpsTest(test.TestCase):
|
||||
|
||||
def _compare(self, x, rank, fft_length=None, use_placeholder=False,
|
||||
@ -84,8 +99,9 @@ class BaseFFTOpsTest(test.TestCase):
|
||||
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
|
||||
return loss
|
||||
|
||||
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
|
||||
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
|
||||
with _forward_compat_context(x.dtype):
|
||||
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
|
||||
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
|
||||
|
||||
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol)
|
||||
@ -99,8 +115,9 @@ class BaseFFTOpsTest(test.TestCase):
|
||||
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
|
||||
return loss
|
||||
|
||||
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
|
||||
f, [x], delta=1e-2)
|
||||
with _forward_compat_context(x.dtype):
|
||||
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
|
||||
f, [x], delta=1e-2)
|
||||
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@ -174,13 +191,12 @@ class FFTOpsTest(BaseFFTOpsTest):
|
||||
(4,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
|
||||
|
||||
def test_large_batch(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
rank = 1
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-5)):
|
||||
self._compare(
|
||||
np.mod(np.arange(np.power(128, dims)), 10).reshape(
|
||||
(128,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
|
||||
rank = 1
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 5e-5)):
|
||||
self._compare(
|
||||
np.mod(np.arange(np.power(128, dims)), 10).reshape(
|
||||
(128,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
|
||||
|
||||
# TODO(yangzihao): Disable before we can figure out a way to
|
||||
# properly test memory fail for large batch fft.
|
||||
@ -277,17 +293,15 @@ class FFTOpsTest(BaseFFTOpsTest):
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RFFTOpsTest(BaseFFTOpsTest):
|
||||
|
||||
def _compare_backward(self, x, rank, fft_length=None, use_placeholder=False):
|
||||
super(RFFTOpsTest, self)._compare_backward(x, rank, fft_length,
|
||||
use_placeholder)
|
||||
|
||||
def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
with _forward_compat_context(x.dtype), self.cached_session(
|
||||
use_gpu=True) as sess:
|
||||
return sess.run(
|
||||
self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
|
||||
|
||||
def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
with _forward_compat_context(x.dtype), self.cached_session(
|
||||
use_gpu=True) as sess:
|
||||
return sess.run(
|
||||
self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
|
||||
|
||||
@ -332,61 +346,74 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
||||
raise ValueError("invalid rank")
|
||||
|
||||
def test_empty(self):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
x = np.zeros((0,) * dims).astype(np.float32)
|
||||
self.assertEqual(x.shape, self._tf_fft(x, rank).shape)
|
||||
x = np.zeros((0,) * dims).astype(np.complex64)
|
||||
self.assertEqual(x.shape, self._tf_ifft(x, rank).shape)
|
||||
for np_rtype, np_ctype in ((np.float32, np.complex64),
|
||||
(np.float64, np.complex128)):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
x = np.zeros((0,) * dims).astype(np_rtype)
|
||||
self.assertEqual(x.shape, self._tf_fft(x, rank).shape)
|
||||
x = np.zeros((0,) * dims).astype(np_ctype)
|
||||
self.assertEqual(x.shape, self._tf_ifft(x, rank).shape)
|
||||
|
||||
def test_basic(self):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (5, 6):
|
||||
inner_dim = size // 2 + 1
|
||||
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
|
||||
(size,) * dims)
|
||||
self._compare_forward(r2c.astype(np.float32), rank, (size,) * rank)
|
||||
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
|
||||
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
||||
self._compare_backward(
|
||||
c2r.astype(np.complex64), rank, (size,) * rank)
|
||||
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
|
||||
(np.float64, np.complex128, 5e-5)):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (5, 6):
|
||||
inner_dim = size // 2 + 1
|
||||
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
|
||||
(size,) * dims)
|
||||
self._compare_forward(r2c.astype(np_rtype), rank, (size,) * rank,
|
||||
rtol=tol, atol=tol)
|
||||
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
|
||||
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
||||
self._compare_backward(
|
||||
c2r.astype(np_ctype), rank, (size,) * rank, rtol=tol, atol=tol)
|
||||
|
||||
def test_large_batch(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
rank = 1
|
||||
rank = 1
|
||||
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
|
||||
(np.float64, np.complex128, 1e-5)):
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (64, 128):
|
||||
inner_dim = size // 2 + 1
|
||||
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
|
||||
(size,) * dims)
|
||||
self._compare_forward(r2c.astype(np.float32), rank, (size,) * rank)
|
||||
self._compare_forward(r2c.astype(np_rtype), rank, (size,) * rank,
|
||||
rtol=tol, atol=tol)
|
||||
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
|
||||
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
||||
self._compare_backward(c2r.astype(np.complex64), rank, (size,) * rank)
|
||||
self._compare_backward(c2r.astype(np_ctype), rank, (size,) * rank,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
def test_placeholder(self):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (5, 6):
|
||||
inner_dim = size // 2 + 1
|
||||
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
|
||||
(size,) * dims)
|
||||
self._compare_forward(
|
||||
r2c.astype(np.float32),
|
||||
rank, (size,) * rank,
|
||||
use_placeholder=True)
|
||||
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
|
||||
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
||||
self._compare_backward(
|
||||
c2r.astype(np.complex64),
|
||||
rank, (size,) * rank,
|
||||
use_placeholder=True)
|
||||
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
|
||||
(np.float64, np.complex128, 1e-8)):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (5, 6):
|
||||
inner_dim = size // 2 + 1
|
||||
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
|
||||
(size,) * dims)
|
||||
self._compare_forward(
|
||||
r2c.astype(np_rtype),
|
||||
rank, (size,) * rank,
|
||||
use_placeholder=True,
|
||||
rtol=tol, atol=tol)
|
||||
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
|
||||
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
||||
self._compare_backward(
|
||||
c2r.astype(np_ctype),
|
||||
rank, (size,) * rank,
|
||||
use_placeholder=True,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
def test_fft_lenth(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
def test_fft_length(self):
|
||||
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
|
||||
(np.float64, np.complex128, 8e-5)):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (5, 6):
|
||||
@ -397,36 +424,44 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
||||
10).reshape((size,) * (dims - 1) + (inner_dim,))
|
||||
# Test truncation (FFT size < dimensions).
|
||||
fft_length = (size - 2,) * rank
|
||||
self._compare_forward(r2c.astype(np.float32), rank, fft_length)
|
||||
self._compare_backward(c2r.astype(np.complex64), rank, fft_length)
|
||||
self._compare_forward(r2c.astype(np_rtype), rank, fft_length,
|
||||
rtol=tol, atol=tol)
|
||||
self._compare_backward(c2r.astype(np_ctype), rank, fft_length,
|
||||
rtol=tol, atol=tol)
|
||||
# Confirm it works with unknown shapes as well.
|
||||
if not context.executing_eagerly():
|
||||
self._compare_forward(
|
||||
r2c.astype(np.float32),
|
||||
r2c.astype(np_rtype),
|
||||
rank,
|
||||
fft_length,
|
||||
use_placeholder=True)
|
||||
use_placeholder=True,
|
||||
rtol=tol, atol=tol)
|
||||
self._compare_backward(
|
||||
c2r.astype(np.complex64),
|
||||
c2r.astype(np_ctype),
|
||||
rank,
|
||||
fft_length,
|
||||
use_placeholder=True)
|
||||
use_placeholder=True,
|
||||
rtol=tol, atol=tol)
|
||||
# Test padding (FFT size > dimensions).
|
||||
fft_length = (size + 2,) * rank
|
||||
self._compare_forward(r2c.astype(np.float32), rank, fft_length)
|
||||
self._compare_backward(c2r.astype(np.complex64), rank, fft_length)
|
||||
self._compare_forward(r2c.astype(np_rtype), rank, fft_length,
|
||||
rtol=tol, atol=tol)
|
||||
self._compare_backward(c2r.astype(np_ctype), rank, fft_length,
|
||||
rtol=tol, atol=tol)
|
||||
# Confirm it works with unknown shapes as well.
|
||||
if not context.executing_eagerly():
|
||||
self._compare_forward(
|
||||
r2c.astype(np.float32),
|
||||
r2c.astype(np_rtype),
|
||||
rank,
|
||||
fft_length,
|
||||
use_placeholder=True)
|
||||
use_placeholder=True,
|
||||
rtol=tol, atol=tol)
|
||||
self._compare_backward(
|
||||
c2r.astype(np.complex64),
|
||||
c2r.astype(np_ctype),
|
||||
rank,
|
||||
fft_length,
|
||||
use_placeholder=True)
|
||||
use_placeholder=True,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
def test_random(self):
|
||||
def gen_real(shape):
|
||||
@ -442,14 +477,20 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
||||
ret = (re + im * 1j).reshape(shape)
|
||||
return ret
|
||||
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (5, 6):
|
||||
inner_dim = size // 2 + 1
|
||||
self._compare_forward(gen_real((size,) * dims), rank, (size,) * rank)
|
||||
complex_dims = (size,) * (dims - 1) + (inner_dim,)
|
||||
self._compare_backward(
|
||||
gen_complex(complex_dims), rank, (size,) * rank)
|
||||
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
|
||||
(np.float64, np.complex128, 1e-5)):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
for dims in xrange(rank, rank + 3):
|
||||
for size in (5, 6):
|
||||
inner_dim = size // 2 + 1
|
||||
self._compare_forward(gen_real((size,) * dims).astype(np_rtype),
|
||||
rank, (size,) * rank,
|
||||
rtol=tol, atol=tol)
|
||||
complex_dims = (size,) * (dims - 1) + (inner_dim,)
|
||||
self._compare_backward(
|
||||
gen_complex(complex_dims).astype(np_ctype),
|
||||
rank, (size,) * rank,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
def test_error(self):
|
||||
# TODO(rjryan): Fix this test under Eager.
|
||||
@ -510,30 +551,36 @@ class RFFTOpsTest(BaseFFTOpsTest):
|
||||
self.evaluate(irfft_fn(x, fft_length))
|
||||
|
||||
def test_grad_simple(self):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
# rfft3d/irfft3d do not have gradients yet.
|
||||
if rank == 3:
|
||||
continue
|
||||
for dims in xrange(rank, rank + 2):
|
||||
for size in (5, 6):
|
||||
re = np.ones(shape=(size,) * dims, dtype=np.float32)
|
||||
im = -np.ones(shape=(size,) * dims, dtype=np.float32)
|
||||
self._check_grad_real(self._tf_fft_for_rank(rank), re)
|
||||
self._check_grad_complex(
|
||||
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False)
|
||||
for np_rtype, tol in ((np.float32, 1e-3), (np.float64, 1e-10)):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
# rfft3d/irfft3d do not have gradients yet.
|
||||
if rank == 3:
|
||||
continue
|
||||
for dims in xrange(rank, rank + 2):
|
||||
for size in (5, 6):
|
||||
re = np.ones(shape=(size,) * dims, dtype=np_rtype)
|
||||
im = -np.ones(shape=(size,) * dims, dtype=np_rtype)
|
||||
self._check_grad_real(self._tf_fft_for_rank(rank), re,
|
||||
rtol=tol, atol=tol)
|
||||
self._check_grad_complex(
|
||||
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
def test_grad_random(self):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
# rfft3d/irfft3d do not have gradients yet.
|
||||
if rank == 3:
|
||||
continue
|
||||
for dims in xrange(rank, rank + 2):
|
||||
for size in (5, 6):
|
||||
re = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
|
||||
im = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
|
||||
self._check_grad_real(self._tf_fft_for_rank(rank), re)
|
||||
self._check_grad_complex(
|
||||
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False)
|
||||
for np_rtype, tol in ((np.float32, 1e-2), (np.float64, 1e-10)):
|
||||
for rank in VALID_FFT_RANKS:
|
||||
# rfft3d/irfft3d do not have gradients yet.
|
||||
if rank == 3:
|
||||
continue
|
||||
for dims in xrange(rank, rank + 2):
|
||||
for size in (5, 6):
|
||||
re = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1
|
||||
im = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1
|
||||
self._check_grad_real(self._tf_fft_for_rank(rank), re,
|
||||
rtol=tol, atol=tol)
|
||||
self._check_grad_complex(
|
||||
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import dtypes as _dtypes
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework import tensor_util as _tensor_util
|
||||
@ -115,14 +116,23 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
|
||||
"""Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
|
||||
with _ops.name_scope(name, default_name,
|
||||
[input_tensor, fft_length]) as name:
|
||||
input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32)
|
||||
input_tensor = _ops.convert_to_tensor(input_tensor,
|
||||
preferred_dtype=_dtypes.float32)
|
||||
real_dtype = input_tensor.dtype
|
||||
if real_dtype == _dtypes.float32:
|
||||
complex_dtype = _dtypes.complex64
|
||||
elif real_dtype == _dtypes.float64:
|
||||
complex_dtype = _dtypes.complex128
|
||||
input_tensor.shape.with_rank_at_least(fft_rank)
|
||||
if fft_length is None:
|
||||
fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank)
|
||||
else:
|
||||
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
|
||||
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
|
||||
return fft_fn(input_tensor, fft_length, name)
|
||||
|
||||
if not compat.forward_compatible(2019, 10, 12):
|
||||
return fft_fn(input_tensor, fft_length, name=name)
|
||||
return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
|
||||
_rfft.__doc__ = fft_fn.__doc__
|
||||
return _rfft
|
||||
|
||||
@ -134,15 +144,20 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
|
||||
"""Wrapper irfft* that infers fft_length argument."""
|
||||
with _ops.name_scope(name, default_name,
|
||||
[input_tensor, fft_length]) as name:
|
||||
input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64)
|
||||
input_tensor = _ops.convert_to_tensor(input_tensor,
|
||||
preferred_dtype=_dtypes.complex64)
|
||||
input_tensor.shape.with_rank_at_least(fft_rank)
|
||||
complex_dtype = input_tensor.dtype
|
||||
real_dtype = complex_dtype.real_dtype
|
||||
if fft_length is None:
|
||||
fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank)
|
||||
else:
|
||||
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
|
||||
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
|
||||
is_reverse=True)
|
||||
return ifft_fn(input_tensor, fft_length, name)
|
||||
if not compat.forward_compatible(2019, 10, 12):
|
||||
return ifft_fn(input_tensor, fft_length, name=name)
|
||||
return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
|
||||
_irfft.__doc__ = ifft_fn.__doc__
|
||||
return _irfft
|
||||
|
||||
@ -223,8 +238,10 @@ def _rfft_grad_helper(rank, irfft_fn):
|
||||
def _grad(op, grad):
|
||||
"""A gradient function for RFFT with the provided `rank` and `irfft_fn`."""
|
||||
fft_length = op.inputs[1]
|
||||
complex_dtype = grad.dtype
|
||||
real_dtype = complex_dtype.real_dtype
|
||||
input_shape = _array_ops.shape(op.inputs[0])
|
||||
is_even = _math_ops.cast(1 - (fft_length[-1] % 2), _dtypes.complex64)
|
||||
is_even = _math_ops.cast(1 - (fft_length[-1] % 2), complex_dtype)
|
||||
|
||||
def _tile_for_broadcasting(matrix, t):
|
||||
expanded = _array_ops.reshape(
|
||||
@ -248,13 +265,13 @@ def _rfft_grad_helper(rank, irfft_fn):
|
||||
_array_ops.expand_dims(_math_ops.range(length), 0), (length, 1))
|
||||
b = _array_ops.transpose(a, [1, 0])
|
||||
return _math_ops.exp(
|
||||
-2j * np.pi * _math_ops.cast(a * b, _dtypes.complex64) /
|
||||
_math_ops.cast(length, _dtypes.complex64))
|
||||
-2j * np.pi * _math_ops.cast(a * b, complex_dtype) /
|
||||
_math_ops.cast(length, complex_dtype))
|
||||
|
||||
def _ymask(length):
|
||||
"""A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`."""
|
||||
return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2),
|
||||
_dtypes.complex64)
|
||||
complex_dtype)
|
||||
|
||||
y0 = grad[..., 0:1]
|
||||
if rank == 1:
|
||||
@ -288,7 +305,7 @@ def _rfft_grad_helper(rank, irfft_fn):
|
||||
# factor, plus some additional terms to make up for the components dropped
|
||||
# due to Hermitian symmetry.
|
||||
input_size = _math_ops.cast(
|
||||
_fft_size_for_grad(op.inputs[0], rank), _dtypes.float32)
|
||||
_fft_size_for_grad(op.inputs[0], rank), real_dtype)
|
||||
the_irfft = irfft_fn(grad, fft_length)
|
||||
return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None
|
||||
|
||||
@ -307,21 +324,27 @@ def _irfft_grad_helper(rank, rfft_fn):
|
||||
# graph we special-case the situation where the FFT length and last
|
||||
# dimension of the input are known at graph construction time.
|
||||
fft_length = op.inputs[1]
|
||||
real_dtype = grad.dtype
|
||||
if real_dtype == _dtypes.float32:
|
||||
complex_dtype = _dtypes.complex64
|
||||
elif real_dtype == _dtypes.float64:
|
||||
complex_dtype = _dtypes.complex128
|
||||
is_odd = _math_ops.mod(fft_length[-1], 2)
|
||||
input_last_dimension = _array_ops.shape(op.inputs[0])[-1]
|
||||
mask = _array_ops.concat(
|
||||
[[1.0], 2.0 * _array_ops.ones([input_last_dimension - 2 + is_odd]),
|
||||
_array_ops.ones([1 - is_odd])], 0)
|
||||
[[1.0], 2.0 * _array_ops.ones(
|
||||
[input_last_dimension - 2 + is_odd], real_dtype),
|
||||
_array_ops.ones([1 - is_odd], real_dtype)], 0)
|
||||
|
||||
rsize = _math_ops.reciprocal(_math_ops.cast(
|
||||
_fft_size_for_grad(grad, rank), _dtypes.float32))
|
||||
_fft_size_for_grad(grad, rank), real_dtype))
|
||||
|
||||
# The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
|
||||
# factor and a mask. The mask scales the gradient for the Hermitian
|
||||
# symmetric components of the RFFT by a factor of two, since these
|
||||
# components are de-duplicated in the RFFT.
|
||||
the_rfft = rfft_fn(grad, fft_length)
|
||||
return the_rfft * _math_ops.cast(rsize * mask, _dtypes.complex64), None
|
||||
return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None
|
||||
|
||||
return _grad
|
||||
|
||||
|
@ -1626,15 +1626,15 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "IRFFT"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "IRFFT2D"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "IRFFT3D"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Identity"
|
||||
@ -2858,15 +2858,15 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RFFT"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RFFT2D"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RFFT3D"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RGBToHSV"
|
||||
|
@ -1626,15 +1626,15 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "IRFFT"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "IRFFT2D"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "IRFFT3D"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Identity"
|
||||
@ -2858,15 +2858,15 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RFFT"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RFFT2D"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RFFT3D"
|
||||
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RGBToHSV"
|
||||
|
Loading…
Reference in New Issue
Block a user