Add complex128 support to RFFT, RFFT2D, RFFT3D, IRFFT, IRFFT2D, and IRFFT3D.

Finishes support requested in:
- 
- 
- https://stackoverflow.com/questions/47214508/the-result-of-fft-in-tensorflow-is-different-from-numpy

PiperOrigin-RevId: 268772775
This commit is contained in:
RJ Skerry-Ryan 2019-09-12 14:55:31 -07:00 committed by TensorFlower Gardener
parent 731984bfd0
commit 4a3cbea5f3
14 changed files with 492 additions and 331 deletions

View File

@ -138,10 +138,20 @@ class RFFTOp : public GenericFftOp {
explicit RFFTOp(OpKernelConstruction* ctx)
: GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {}
};
REGISTER_XLA_OP(Name("RFFT").CompileTimeConstantInput("fft_length"), RFFTOp<1>);
REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstantInput("fft_length"),
REGISTER_XLA_OP(Name("RFFT")
.TypeConstraint("Treal", DT_FLOAT)
.TypeConstraint("Tcomplex", DT_COMPLEX64)
.CompileTimeConstantInput("fft_length"),
RFFTOp<1>);
REGISTER_XLA_OP(Name("RFFT2D")
.TypeConstraint("Treal", DT_FLOAT)
.TypeConstraint("Tcomplex", DT_COMPLEX64)
.CompileTimeConstantInput("fft_length"),
RFFTOp<2>);
REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstantInput("fft_length"),
REGISTER_XLA_OP(Name("RFFT3D")
.TypeConstraint("Treal", DT_FLOAT)
.TypeConstraint("Tcomplex", DT_COMPLEX64)
.CompileTimeConstantInput("fft_length"),
RFFTOp<3>);
template <int FFTRank>
@ -150,11 +160,20 @@ class IRFFTOp : public GenericFftOp {
explicit IRFFTOp(OpKernelConstruction* ctx)
: GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {}
};
REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstantInput("fft_length"),
REGISTER_XLA_OP(Name("IRFFT")
.TypeConstraint("Treal", DT_FLOAT)
.TypeConstraint("Tcomplex", DT_COMPLEX64)
.CompileTimeConstantInput("fft_length"),
IRFFTOp<1>);
REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstantInput("fft_length"),
REGISTER_XLA_OP(Name("IRFFT2D")
.TypeConstraint("Treal", DT_FLOAT)
.TypeConstraint("Tcomplex", DT_COMPLEX64)
.CompileTimeConstantInput("fft_length"),
IRFFTOp<2>);
REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstantInput("fft_length"),
REGISTER_XLA_OP(Name("IRFFT3D")
.TypeConstraint("Treal", DT_FLOAT)
.TypeConstraint("Tcomplex", DT_COMPLEX64)
.CompileTimeConstantInput("fft_length"),
IRFFTOp<3>);
} // namespace

View File

@ -3,13 +3,13 @@ op {
in_arg {
name: "input"
description: <<END
A complex64 tensor.
A complex tensor.
END
}
out_arg {
name: "output"
description: <<END
A complex64 tensor of the same shape as `input`. The inner-most 3
A complex tensor of the same shape as `input`. The inner-most 3
dimensions of `input` are replaced with their 3D Fourier transform.
@compatibility(numpy)

View File

@ -3,13 +3,13 @@ op {
in_arg {
name: "input"
description: <<END
A complex64 tensor.
A complex tensor.
END
}
out_arg {
name: "output"
description: <<END
A complex64 tensor of the same shape as `input`. The inner-most 3
A complex tensor of the same shape as `input`. The inner-most 3
dimensions of `input` are replaced with their inverse 3D Fourier transform.
@compatibility(numpy)

View File

@ -3,7 +3,7 @@ op {
in_arg {
name: "input"
description: <<END
A complex64 tensor.
A complex tensor.
END
}
in_arg {

View File

@ -3,7 +3,7 @@ op {
in_arg {
name: "input"
description: <<END
A complex64 tensor.
A complex tensor.
END
}
in_arg {

View File

@ -3,7 +3,7 @@ op {
in_arg {
name: "input"
description: <<END
A complex64 tensor.
A complex tensor.
END
}
in_arg {

View File

@ -15,7 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
// See docs in ../ops/spectral_ops.cc.
// See docs in ../ops/fft_ops.cc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
@ -94,7 +95,34 @@ class FFTBase : public OpKernel {
}
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
if (IsReal()) {
if (IsForward()) {
OP_REQUIRES(
ctx,
(in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64) ||
(in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128),
errors::InvalidArgument("Wrong types for forward real FFT: in=",
in.dtype(), " out=", out->dtype()));
} else {
OP_REQUIRES(
ctx,
(in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT) ||
(in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE),
errors::InvalidArgument("Wrong types for backward real FFT: in=",
in.dtype(), " out=", out->dtype()));
}
} else {
OP_REQUIRES(
ctx,
(in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) ||
(in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128),
errors::InvalidArgument("Wrong types for FFT: in=", in.dtype(),
" out=", out->dtype()));
}
if (input_shape.num_elements() == 0) {
DCHECK_EQ(0, output_shape.num_elements());
return;
}
@ -129,126 +157,159 @@ class FFTCPU : public FFTBase {
const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
auto device = ctx->eigen_device<CPUDevice>();
const bool is_complex128 =
in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
if (!IsReal()) {
// Compute the FFT using Eigen.
constexpr auto direction =
Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
if (in.dtype() == DT_COMPLEX64) {
if (is_complex128) {
DCHECK_EQ(in.dtype(), DT_COMPLEX128);
DCHECK_EQ(out->dtype(), DT_COMPLEX128);
auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
output.device(device) =
input.template fft<Eigen::BothParts, direction>(axes);
} else {
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
output.device(device) =
input.template fft<Eigen::BothParts, direction>(axes);
} else {
DCHECK_EQ(DT_COMPLEX128, in.dtype());
DCHECK_EQ(DT_COMPLEX128, out->dtype());
auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
output.device(device) =
input.template fft<Eigen::BothParts, direction>(axes);
}
} else {
if (IsForward()) {
auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>();
const auto input_dims = input.dimensions();
// Slice input to fft_shape on its inner-most dimensions.
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
input_slice_sizes[0] = input_dims[0];
TensorShape temp_shape{input_dims[0]};
for (int i = 1; i <= FFTRank; ++i) {
input_slice_sizes[i] = fft_shape[i - 1];
temp_shape.AddDim(fft_shape[i - 1]);
if (is_complex128) {
DCHECK_EQ(in.dtype(), DT_DOUBLE);
DCHECK_EQ(out->dtype(), DT_COMPLEX128);
DoRealForwardFFT<double, complex128>(ctx, fft_shape, in, out);
} else {
DCHECK_EQ(in.dtype(), DT_FLOAT);
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
DoRealForwardFFT<float, complex64>(ctx, fft_shape, in, out);
}
auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
// Compute the full FFT using a temporary tensor.
Tensor temp;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
temp_shape, &temp));
auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
full_fft.device(device) =
input.slice(zero_start_indices, input_slice_sizes)
.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
// Slice away the negative frequency components.
output.device(device) =
full_fft.slice(zero_start_indices, output.dimensions());
} else {
// Reconstruct the full FFT and take the inverse.
auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
auto output = out->flat_inner_dims<float, FFTRank + 1>();
const auto input_dims = input.dimensions();
// Calculate the shape of the temporary tensor for the full FFT and the
// region we will slice from input given fft_shape. We slice input to
// fft_shape on its inner-most dimensions, except the last (which we
// slice to fft_shape[-1] / 2 + 1).
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
input_slice_sizes[0] = input_dims[0];
TensorShape full_fft_shape;
full_fft_shape.AddDim(input_dims[0]);
for (auto i = 1; i <= FFTRank; i++) {
input_slice_sizes[i] =
i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
full_fft_shape.AddDim(fft_shape[i - 1]);
if (is_complex128) {
DCHECK_EQ(in.dtype(), DT_COMPLEX128);
DCHECK_EQ(out->dtype(), DT_DOUBLE);
DoRealBackwardFFT<complex128, double>(ctx, fft_shape, in, out);
} else {
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
DCHECK_EQ(out->dtype(), DT_FLOAT);
DoRealBackwardFFT<complex64, float>(ctx, fft_shape, in, out);
}
Tensor temp;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
full_fft_shape, &temp));
auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
// Calculate the starting point and range of the source of
// negative frequency part.
auto neg_sizes = input_slice_sizes;
neg_sizes[FFTRank] =
fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
neg_start_indices[FFTRank] = 1;
full_fft.slice(start_indices, input_slice_sizes).device(device) =
input.slice(start_indices, input_slice_sizes);
// First, conduct IFFTs on outer dimensions. We save computation (and
// avoid touching uninitialized memory) by slicing full_fft to the
// subregion we wrote input to.
if (FFTRank > 1) {
const auto outer_axes =
Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
full_fft.slice(start_indices, input_slice_sizes).device(device) =
full_fft.slice(start_indices, input_slice_sizes)
.template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(
outer_axes);
}
// Reconstruct the full FFT by appending reversed and conjugated
// spectrum as the negative frequency part.
Eigen::array<bool, FFTRank + 1> reverse_last_axis;
for (auto i = 0; i <= FFTRank; i++) {
reverse_last_axis[i] = i == FFTRank;
}
if (neg_sizes[FFTRank] != 0) {
full_fft.slice(neg_target_indices, neg_sizes).device(device) =
full_fft.slice(neg_start_indices, neg_sizes)
.reverse(reverse_last_axis)
.conjugate();
}
auto inner_axis = Eigen::array<int, 1>{FFTRank};
output.device(device) =
full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(
inner_axis);
}
}
}
template <typename RealT, typename ComplexT>
void DoRealForwardFFT(OpKernelContext* ctx, uint64* fft_shape,
const Tensor& in, Tensor* out) {
// Create the axes (which are always trailing).
const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
auto device = ctx->eigen_device<CPUDevice>();
auto input = Tensor(in).flat_inner_dims<RealT, FFTRank + 1>();
const auto input_dims = input.dimensions();
// Slice input to fft_shape on its inner-most dimensions.
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
input_slice_sizes[0] = input_dims[0];
TensorShape temp_shape{input_dims[0]};
for (int i = 1; i <= FFTRank; ++i) {
input_slice_sizes[i] = fft_shape[i - 1];
temp_shape.AddDim(fft_shape[i - 1]);
}
auto output = out->flat_inner_dims<ComplexT, FFTRank + 1>();
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
// Compute the full FFT using a temporary tensor.
Tensor temp;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
temp_shape, &temp));
auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
full_fft.device(device) =
input.slice(zero_start_indices, input_slice_sizes)
.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
// Slice away the negative frequency components.
output.device(device) =
full_fft.slice(zero_start_indices, output.dimensions());
}
template <typename ComplexT, typename RealT>
void DoRealBackwardFFT(OpKernelContext* ctx, uint64* fft_shape,
const Tensor& in, Tensor* out) {
auto device = ctx->eigen_device<CPUDevice>();
// Reconstruct the full FFT and take the inverse.
auto input = Tensor(in).flat_inner_dims<ComplexT, FFTRank + 1>();
auto output = out->flat_inner_dims<RealT, FFTRank + 1>();
const auto input_dims = input.dimensions();
// Calculate the shape of the temporary tensor for the full FFT and the
// region we will slice from input given fft_shape. We slice input to
// fft_shape on its inner-most dimensions, except the last (which we
// slice to fft_shape[-1] / 2 + 1).
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
input_slice_sizes[0] = input_dims[0];
TensorShape full_fft_shape;
full_fft_shape.AddDim(input_dims[0]);
for (auto i = 1; i <= FFTRank; i++) {
input_slice_sizes[i] =
i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
full_fft_shape.AddDim(fft_shape[i - 1]);
}
Tensor temp;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
full_fft_shape, &temp));
auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
// Calculate the starting point and range of the source of
// negative frequency part.
auto neg_sizes = input_slice_sizes;
neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
neg_start_indices[FFTRank] = 1;
full_fft.slice(start_indices, input_slice_sizes).device(device) =
input.slice(start_indices, input_slice_sizes);
// First, conduct IFFTs on outer dimensions. We save computation (and
// avoid touching uninitialized memory) by slicing full_fft to the
// subregion we wrote input to.
if (FFTRank > 1) {
const auto outer_axes =
Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
full_fft.slice(start_indices, input_slice_sizes).device(device) =
full_fft.slice(start_indices, input_slice_sizes)
.template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
}
// Reconstruct the full FFT by appending reversed and conjugated
// spectrum as the negative frequency part.
Eigen::array<bool, FFTRank + 1> reverse_last_axis;
for (auto i = 0; i <= FFTRank; i++) {
reverse_last_axis[i] = i == FFTRank;
}
if (neg_sizes[FFTRank] != 0) {
full_fft.slice(neg_target_indices, neg_sizes).device(device) =
full_fft.slice(neg_start_indices, neg_sizes)
.reverse(reverse_last_axis)
.conjugate();
}
auto inner_axis = Eigen::array<int, 1>{FFTRank};
output.device(device) =
full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
}
};
REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU), FFTCPU<true, false, 1>);
@ -390,16 +451,19 @@ class FFTGPUBase : public FFTBase {
}
constexpr bool kInPlaceFft = false;
const bool is_complex128 = in.dtype() == DT_COMPLEX128;
// complex128 real FFT is not supported yet.
DCHECK(!IsReal() || !is_complex128);
const bool is_complex128 =
in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
const auto kFftType =
IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R)
: (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
: se::fft::Type::kC2CForward)
: (is_complex128 ? se::fft::Type::kZ2ZInverse
: se::fft::Type::kC2CInverse));
IsReal()
? (IsForward()
? (is_complex128 ? se::fft::Type::kD2Z : se::fft::Type::kR2C)
: (is_complex128 ? se::fft::Type::kZ2D
: se::fft::Type::kC2R))
: (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
: se::fft::Type::kC2CForward)
: (is_complex128 ? se::fft::Type::kZ2ZInverse
: se::fft::Type::kC2CInverse));
CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
auto plan =
@ -410,67 +474,80 @@ class FFTGPUBase : public FFTBase {
if (IsReal()) {
if (IsForward()) {
auto src = AsDeviceMemory<float>(in.flat<float>().data());
auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
OP_REQUIRES(
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
" in.shape=", input_shape.DebugString()));
if (is_complex128) {
DCHECK_EQ(in.dtype(), DT_DOUBLE);
DCHECK_EQ(out->dtype(), DT_COMPLEX128);
DoFFTInternal<double, complex128>(ctx, stream, plan.get(), kFftType,
output_distance, in, out);
} else {
DCHECK_EQ(in.dtype(), DT_FLOAT);
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
DoFFTInternal<float, complex64>(ctx, stream, plan.get(), kFftType,
output_distance, in, out);
}
} else {
auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
auto dst = AsDeviceMemory<float>(out->flat<float>().data());
OP_REQUIRES(
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
" in.shape=", input_shape.DebugString()));
auto alpha = 1.f / output_distance;
OP_REQUIRES(
ctx,
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
.ok(),
errors::Internal("BlasScal failed : in.shape=",
input_shape.DebugString()));
if (is_complex128) {
DCHECK_EQ(in.dtype(), DT_COMPLEX128);
DCHECK_EQ(out->dtype(), DT_DOUBLE);
DoFFTInternal<complex128, double>(ctx, stream, plan.get(), kFftType,
output_distance, in, out);
} else {
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
DCHECK_EQ(out->dtype(), DT_FLOAT);
DoFFTInternal<complex64, float>(ctx, stream, plan.get(), kFftType,
output_distance, in, out);
}
}
} else {
if (!is_complex128) {
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
OP_REQUIRES(
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
" in.shape=", input_shape.DebugString()));
if (!IsForward()) {
float alpha = 1.f / output_distance;
OP_REQUIRES(
ctx,
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
.ok(),
errors::Internal("BlasScal failed : in.shape=",
input_shape.DebugString()));
}
} else {
if (is_complex128) {
DCHECK_EQ(in.dtype(), DT_COMPLEX128);
DCHECK_EQ(out->dtype(), DT_COMPLEX128);
auto src = AsDeviceMemory<complex128>(in.flat<complex128>().data());
auto dst = AsDeviceMemory<complex128>(out->flat<complex128>().data());
OP_REQUIRES(
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
" in.shape=", input_shape.DebugString()));
if (!IsForward()) {
double alpha = 1.0 / output_distance;
OP_REQUIRES(
ctx,
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
.ok(),
errors::Internal("BlasScal failed : in.shape=",
input_shape.DebugString()));
}
DoFFTInternal<complex128, complex128>(ctx, stream, plan.get(), kFftType,
output_distance, in, out);
} else {
DCHECK_EQ(in.dtype(), DT_COMPLEX64);
DCHECK_EQ(out->dtype(), DT_COMPLEX64);
DoFFTInternal<complex64, complex64>(ctx, stream, plan.get(), kFftType,
output_distance, in, out);
}
}
}
private:
template <typename T>
struct RealTypeFromComplexType {
typedef T RealT;
};
template <typename T>
struct RealTypeFromComplexType<std::complex<T>> {
typedef T RealT;
};
template <typename InT, typename OutT>
void DoFFTInternal(OpKernelContext* ctx, se::Stream* stream,
se::fft::Plan* plan, const se::fft::Type fft_type,
const uint64 output_distance, const Tensor& in,
Tensor* out) {
auto src = AsDeviceMemory<InT>(in.flat<InT>().data());
auto dst = AsDeviceMemory<OutT>(out->flat<OutT>().data());
const TensorShape& input_shape = in.shape();
const TensorShape& output_shape = out->shape();
OP_REQUIRES(
ctx, stream->ThenFft(plan, src, &dst).ok(),
errors::Internal("fft failed : type=", static_cast<int>(fft_type),
" in.shape=", input_shape.DebugString()));
if (!IsForward()) {
typedef typename RealTypeFromComplexType<OutT>::RealT RealT;
RealT alpha = 1.0 / output_distance;
OP_REQUIRES(
ctx,
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
.ok(),
errors::Internal("BlasScal failed : in.shape=",
input_shape.DebugString()));
}
}
};
int64 FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(

View File

@ -109,39 +109,51 @@ Status RFFTShape(InferenceContext* c, const bool forward, const int rank) {
}
REGISTER_OP("RFFT")
.Input("input: float")
.Input("input: Treal")
.Input("fft_length: int32")
.Output("output: complex64")
.Output("output: Tcomplex")
.Attr("Treal: {float32, float64} = DT_FLOAT")
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 1); });
REGISTER_OP("IRFFT")
.Input("input: complex64")
.Input("input: Tcomplex")
.Input("fft_length: int32")
.Output("output: float")
.Output("output: Treal")
.Attr("Treal: {float32, float64} = DT_FLOAT")
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 1); });
REGISTER_OP("RFFT2D")
.Input("input: float")
.Input("input: Treal")
.Input("fft_length: int32")
.Output("output: complex64")
.Output("output: Tcomplex")
.Attr("Treal: {float32, float64} = DT_FLOAT")
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 2); });
REGISTER_OP("IRFFT2D")
.Input("input: complex64")
.Input("input: Tcomplex")
.Input("fft_length: int32")
.Output("output: float")
.Output("output: Treal")
.Attr("Treal: {float32, float64} = DT_FLOAT")
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 2); });
REGISTER_OP("RFFT3D")
.Input("input: float")
.Input("input: Treal")
.Input("fft_length: int32")
.Output("output: complex64")
.Output("output: Tcomplex")
.Attr("Treal: {float32, float64} = DT_FLOAT")
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 3); });
REGISTER_OP("IRFFT3D")
.Input("input: complex64")
.Input("input: Tcomplex")
.Input("fft_length: int32")
.Output("output: float")
.Output("output: Treal")
.Attr("Treal: {float32, float64} = DT_FLOAT")
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 3); });
// Deprecated ops:

View File

@ -48,7 +48,6 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import training
from tensorflow.python.util import nest
@ -1579,23 +1578,6 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, 'ndarray'):
g.watch(np.array(1.))
def testOpWithNoAttrs(self):
@function.defun(autograph=False)
def f():
with backprop.GradientTape() as tape:
xs = random_ops.random_normal([10, 32])
tape.watch(xs)
# The `rfft()` op has no defined attrs, which exercises a different
# branch in the Python op wrapper code generator for recording
# gradients.
ys = fft_ops.rfft(xs)
self.assertEmpty(ys.op.node_def.attr)
gs = tape.gradient(ys, xs)
self.assertIsNotNone(gs)
f.get_concrete_function()
class JacobianTest(test.TestCase):

View File

@ -37,6 +37,7 @@ cuda_py_tests(
cuda_py_tests(
name = "fft_ops_test",
# TODO(rjryan): Parameterize the test to reduce the time it takes.
size = "medium",
srcs = ["fft_ops_test.py"],
additional_deps = [
@ -46,7 +47,7 @@ cuda_py_tests(
"//tensorflow/python:math_ops",
"//tensorflow/python/ops/signal",
],
shard_count = 4,
shard_count = 8,
tags = [
"no_rocm",
"optonly",

View File

@ -18,10 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -36,6 +39,18 @@ from tensorflow.python.platform import test
VALID_FFT_RANKS = (1, 2, 3)
def _forward_compat_context(np_dtype):
@contextlib.contextmanager
def null_context():
yield
if np_dtype in (np.float64, np.complex128):
return compat.forward_compatibility_horizon(2019, 10, 13)
else:
return null_context()
# TODO(rjryan): Investigate precision issues. We should be able to achieve
# better tolerances, at least for the complex128 tests.
class BaseFFTOpsTest(test.TestCase):
def _compare(self, x, rank, fft_length=None, use_placeholder=False,
@ -84,8 +99,9 @@ class BaseFFTOpsTest(test.TestCase):
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
return loss
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
with _forward_compat_context(x.dtype):
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol)
@ -99,8 +115,9 @@ class BaseFFTOpsTest(test.TestCase):
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
return loss
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
f, [x], delta=1e-2)
with _forward_compat_context(x.dtype):
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
f, [x], delta=1e-2)
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
@ -174,13 +191,12 @@ class FFTOpsTest(BaseFFTOpsTest):
(4,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
def test_large_batch(self):
if test.is_gpu_available(cuda_only=True):
rank = 1
for dims in xrange(rank, rank + 3):
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-5)):
self._compare(
np.mod(np.arange(np.power(128, dims)), 10).reshape(
(128,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
rank = 1
for dims in xrange(rank, rank + 3):
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 5e-5)):
self._compare(
np.mod(np.arange(np.power(128, dims)), 10).reshape(
(128,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
# TODO(yangzihao): Disable before we can figure out a way to
# properly test memory fail for large batch fft.
@ -277,17 +293,15 @@ class FFTOpsTest(BaseFFTOpsTest):
@test_util.run_all_in_graph_and_eager_modes
class RFFTOpsTest(BaseFFTOpsTest):
def _compare_backward(self, x, rank, fft_length=None, use_placeholder=False):
super(RFFTOpsTest, self)._compare_backward(x, rank, fft_length,
use_placeholder)
def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
with self.cached_session(use_gpu=True) as sess:
with _forward_compat_context(x.dtype), self.cached_session(
use_gpu=True) as sess:
return sess.run(
self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
with self.cached_session(use_gpu=True) as sess:
with _forward_compat_context(x.dtype), self.cached_session(
use_gpu=True) as sess:
return sess.run(
self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
@ -332,61 +346,74 @@ class RFFTOpsTest(BaseFFTOpsTest):
raise ValueError("invalid rank")
def test_empty(self):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
x = np.zeros((0,) * dims).astype(np.float32)
self.assertEqual(x.shape, self._tf_fft(x, rank).shape)
x = np.zeros((0,) * dims).astype(np.complex64)
self.assertEqual(x.shape, self._tf_ifft(x, rank).shape)
for np_rtype, np_ctype in ((np.float32, np.complex64),
(np.float64, np.complex128)):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
x = np.zeros((0,) * dims).astype(np_rtype)
self.assertEqual(x.shape, self._tf_fft(x, rank).shape)
x = np.zeros((0,) * dims).astype(np_ctype)
self.assertEqual(x.shape, self._tf_ifft(x, rank).shape)
def test_basic(self):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
self._compare_forward(r2c.astype(np.float32), rank, (size,) * rank)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
self._compare_backward(
c2r.astype(np.complex64), rank, (size,) * rank)
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
(np.float64, np.complex128, 5e-5)):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
self._compare_forward(r2c.astype(np_rtype), rank, (size,) * rank,
rtol=tol, atol=tol)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
self._compare_backward(
c2r.astype(np_ctype), rank, (size,) * rank, rtol=tol, atol=tol)
def test_large_batch(self):
if test.is_gpu_available(cuda_only=True):
rank = 1
rank = 1
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
(np.float64, np.complex128, 1e-5)):
for dims in xrange(rank, rank + 3):
for size in (64, 128):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
self._compare_forward(r2c.astype(np.float32), rank, (size,) * rank)
self._compare_forward(r2c.astype(np_rtype), rank, (size,) * rank,
rtol=tol, atol=tol)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
self._compare_backward(c2r.astype(np.complex64), rank, (size,) * rank)
self._compare_backward(c2r.astype(np_ctype), rank, (size,) * rank,
rtol=tol, atol=tol)
def test_placeholder(self):
if context.executing_eagerly():
return
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
self._compare_forward(
r2c.astype(np.float32),
rank, (size,) * rank,
use_placeholder=True)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
self._compare_backward(
c2r.astype(np.complex64),
rank, (size,) * rank,
use_placeholder=True)
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
(np.float64, np.complex128, 1e-8)):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
(size,) * dims)
self._compare_forward(
r2c.astype(np_rtype),
rank, (size,) * rank,
use_placeholder=True,
rtol=tol, atol=tol)
c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
10).reshape((size,) * (dims - 1) + (inner_dim,))
self._compare_backward(
c2r.astype(np_ctype),
rank, (size,) * rank,
use_placeholder=True,
rtol=tol, atol=tol)
def test_fft_lenth(self):
if test.is_gpu_available(cuda_only=True):
def test_fft_length(self):
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
(np.float64, np.complex128, 8e-5)):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
@ -397,36 +424,44 @@ class RFFTOpsTest(BaseFFTOpsTest):
10).reshape((size,) * (dims - 1) + (inner_dim,))
# Test truncation (FFT size < dimensions).
fft_length = (size - 2,) * rank
self._compare_forward(r2c.astype(np.float32), rank, fft_length)
self._compare_backward(c2r.astype(np.complex64), rank, fft_length)
self._compare_forward(r2c.astype(np_rtype), rank, fft_length,
rtol=tol, atol=tol)
self._compare_backward(c2r.astype(np_ctype), rank, fft_length,
rtol=tol, atol=tol)
# Confirm it works with unknown shapes as well.
if not context.executing_eagerly():
self._compare_forward(
r2c.astype(np.float32),
r2c.astype(np_rtype),
rank,
fft_length,
use_placeholder=True)
use_placeholder=True,
rtol=tol, atol=tol)
self._compare_backward(
c2r.astype(np.complex64),
c2r.astype(np_ctype),
rank,
fft_length,
use_placeholder=True)
use_placeholder=True,
rtol=tol, atol=tol)
# Test padding (FFT size > dimensions).
fft_length = (size + 2,) * rank
self._compare_forward(r2c.astype(np.float32), rank, fft_length)
self._compare_backward(c2r.astype(np.complex64), rank, fft_length)
self._compare_forward(r2c.astype(np_rtype), rank, fft_length,
rtol=tol, atol=tol)
self._compare_backward(c2r.astype(np_ctype), rank, fft_length,
rtol=tol, atol=tol)
# Confirm it works with unknown shapes as well.
if not context.executing_eagerly():
self._compare_forward(
r2c.astype(np.float32),
r2c.astype(np_rtype),
rank,
fft_length,
use_placeholder=True)
use_placeholder=True,
rtol=tol, atol=tol)
self._compare_backward(
c2r.astype(np.complex64),
c2r.astype(np_ctype),
rank,
fft_length,
use_placeholder=True)
use_placeholder=True,
rtol=tol, atol=tol)
def test_random(self):
def gen_real(shape):
@ -442,14 +477,20 @@ class RFFTOpsTest(BaseFFTOpsTest):
ret = (re + im * 1j).reshape(shape)
return ret
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
self._compare_forward(gen_real((size,) * dims), rank, (size,) * rank)
complex_dims = (size,) * (dims - 1) + (inner_dim,)
self._compare_backward(
gen_complex(complex_dims), rank, (size,) * rank)
for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
(np.float64, np.complex128, 1e-5)):
for rank in VALID_FFT_RANKS:
for dims in xrange(rank, rank + 3):
for size in (5, 6):
inner_dim = size // 2 + 1
self._compare_forward(gen_real((size,) * dims).astype(np_rtype),
rank, (size,) * rank,
rtol=tol, atol=tol)
complex_dims = (size,) * (dims - 1) + (inner_dim,)
self._compare_backward(
gen_complex(complex_dims).astype(np_ctype),
rank, (size,) * rank,
rtol=tol, atol=tol)
def test_error(self):
# TODO(rjryan): Fix this test under Eager.
@ -510,30 +551,36 @@ class RFFTOpsTest(BaseFFTOpsTest):
self.evaluate(irfft_fn(x, fft_length))
def test_grad_simple(self):
for rank in VALID_FFT_RANKS:
# rfft3d/irfft3d do not have gradients yet.
if rank == 3:
continue
for dims in xrange(rank, rank + 2):
for size in (5, 6):
re = np.ones(shape=(size,) * dims, dtype=np.float32)
im = -np.ones(shape=(size,) * dims, dtype=np.float32)
self._check_grad_real(self._tf_fft_for_rank(rank), re)
self._check_grad_complex(
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False)
for np_rtype, tol in ((np.float32, 1e-3), (np.float64, 1e-10)):
for rank in VALID_FFT_RANKS:
# rfft3d/irfft3d do not have gradients yet.
if rank == 3:
continue
for dims in xrange(rank, rank + 2):
for size in (5, 6):
re = np.ones(shape=(size,) * dims, dtype=np_rtype)
im = -np.ones(shape=(size,) * dims, dtype=np_rtype)
self._check_grad_real(self._tf_fft_for_rank(rank), re,
rtol=tol, atol=tol)
self._check_grad_complex(
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
rtol=tol, atol=tol)
def test_grad_random(self):
for rank in VALID_FFT_RANKS:
# rfft3d/irfft3d do not have gradients yet.
if rank == 3:
continue
for dims in xrange(rank, rank + 2):
for size in (5, 6):
re = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
im = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
self._check_grad_real(self._tf_fft_for_rank(rank), re)
self._check_grad_complex(
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False)
for np_rtype, tol in ((np.float32, 1e-2), (np.float64, 1e-10)):
for rank in VALID_FFT_RANKS:
# rfft3d/irfft3d do not have gradients yet.
if rank == 3:
continue
for dims in xrange(rank, rank + 2):
for size in (5, 6):
re = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1
im = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1
self._check_grad_real(self._tf_fft_for_rank(rank), re,
rtol=tol, atol=tol)
self._check_grad_complex(
self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
rtol=tol, atol=tol)
@test_util.run_all_in_graph_and_eager_modes

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import tensor_util as _tensor_util
@ -115,14 +116,23 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
"""Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
with _ops.name_scope(name, default_name,
[input_tensor, fft_length]) as name:
input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32)
input_tensor = _ops.convert_to_tensor(input_tensor,
preferred_dtype=_dtypes.float32)
real_dtype = input_tensor.dtype
if real_dtype == _dtypes.float32:
complex_dtype = _dtypes.complex64
elif real_dtype == _dtypes.float64:
complex_dtype = _dtypes.complex128
input_tensor.shape.with_rank_at_least(fft_rank)
if fft_length is None:
fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank)
else:
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
return fft_fn(input_tensor, fft_length, name)
if not compat.forward_compatible(2019, 10, 12):
return fft_fn(input_tensor, fft_length, name=name)
return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
_rfft.__doc__ = fft_fn.__doc__
return _rfft
@ -134,15 +144,20 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
"""Wrapper irfft* that infers fft_length argument."""
with _ops.name_scope(name, default_name,
[input_tensor, fft_length]) as name:
input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64)
input_tensor = _ops.convert_to_tensor(input_tensor,
preferred_dtype=_dtypes.complex64)
input_tensor.shape.with_rank_at_least(fft_rank)
complex_dtype = input_tensor.dtype
real_dtype = complex_dtype.real_dtype
if fft_length is None:
fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank)
else:
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
is_reverse=True)
return ifft_fn(input_tensor, fft_length, name)
if not compat.forward_compatible(2019, 10, 12):
return ifft_fn(input_tensor, fft_length, name=name)
return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
_irfft.__doc__ = ifft_fn.__doc__
return _irfft
@ -223,8 +238,10 @@ def _rfft_grad_helper(rank, irfft_fn):
def _grad(op, grad):
"""A gradient function for RFFT with the provided `rank` and `irfft_fn`."""
fft_length = op.inputs[1]
complex_dtype = grad.dtype
real_dtype = complex_dtype.real_dtype
input_shape = _array_ops.shape(op.inputs[0])
is_even = _math_ops.cast(1 - (fft_length[-1] % 2), _dtypes.complex64)
is_even = _math_ops.cast(1 - (fft_length[-1] % 2), complex_dtype)
def _tile_for_broadcasting(matrix, t):
expanded = _array_ops.reshape(
@ -248,13 +265,13 @@ def _rfft_grad_helper(rank, irfft_fn):
_array_ops.expand_dims(_math_ops.range(length), 0), (length, 1))
b = _array_ops.transpose(a, [1, 0])
return _math_ops.exp(
-2j * np.pi * _math_ops.cast(a * b, _dtypes.complex64) /
_math_ops.cast(length, _dtypes.complex64))
-2j * np.pi * _math_ops.cast(a * b, complex_dtype) /
_math_ops.cast(length, complex_dtype))
def _ymask(length):
"""A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`."""
return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2),
_dtypes.complex64)
complex_dtype)
y0 = grad[..., 0:1]
if rank == 1:
@ -288,7 +305,7 @@ def _rfft_grad_helper(rank, irfft_fn):
# factor, plus some additional terms to make up for the components dropped
# due to Hermitian symmetry.
input_size = _math_ops.cast(
_fft_size_for_grad(op.inputs[0], rank), _dtypes.float32)
_fft_size_for_grad(op.inputs[0], rank), real_dtype)
the_irfft = irfft_fn(grad, fft_length)
return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None
@ -307,21 +324,27 @@ def _irfft_grad_helper(rank, rfft_fn):
# graph we special-case the situation where the FFT length and last
# dimension of the input are known at graph construction time.
fft_length = op.inputs[1]
real_dtype = grad.dtype
if real_dtype == _dtypes.float32:
complex_dtype = _dtypes.complex64
elif real_dtype == _dtypes.float64:
complex_dtype = _dtypes.complex128
is_odd = _math_ops.mod(fft_length[-1], 2)
input_last_dimension = _array_ops.shape(op.inputs[0])[-1]
mask = _array_ops.concat(
[[1.0], 2.0 * _array_ops.ones([input_last_dimension - 2 + is_odd]),
_array_ops.ones([1 - is_odd])], 0)
[[1.0], 2.0 * _array_ops.ones(
[input_last_dimension - 2 + is_odd], real_dtype),
_array_ops.ones([1 - is_odd], real_dtype)], 0)
rsize = _math_ops.reciprocal(_math_ops.cast(
_fft_size_for_grad(grad, rank), _dtypes.float32))
_fft_size_for_grad(grad, rank), real_dtype))
# The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
# factor and a mask. The mask scales the gradient for the Hermitian
# symmetric components of the RFFT by a factor of two, since these
# components are de-duplicated in the RFFT.
the_rfft = rfft_fn(grad, fft_length)
return the_rfft * _math_ops.cast(rsize * mask, _dtypes.complex64), None
return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None
return _grad

View File

@ -1626,15 +1626,15 @@ tf_module {
}
member_method {
name: "IRFFT"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "IRFFT2D"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "IRFFT3D"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "Identity"
@ -2858,15 +2858,15 @@ tf_module {
}
member_method {
name: "RFFT"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
}
member_method {
name: "RFFT2D"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
}
member_method {
name: "RFFT3D"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
}
member_method {
name: "RGBToHSV"

View File

@ -1626,15 +1626,15 @@ tf_module {
}
member_method {
name: "IRFFT"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "IRFFT2D"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "IRFFT3D"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "Identity"
@ -2858,15 +2858,15 @@ tf_module {
}
member_method {
name: "RFFT"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
}
member_method {
name: "RFFT2D"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
}
member_method {
name: "RFFT3D"
argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
}
member_method {
name: "RGBToHSV"