From 4a3cbea5f3d1f79eb2ff6bb2c64875e951ca3ce2 Mon Sep 17 00:00:00 2001
From: RJ Skerry-Ryan <rjryan@google.com>
Date: Thu, 12 Sep 2019 14:55:31 -0700
Subject: [PATCH] Add complex128 support to RFFT, RFFT2D, RFFT3D, IRFFT,
 IRFFT2D, and IRFFT3D.

Finishes support requested in:
- #10749
- #17332
- https://stackoverflow.com/questions/47214508/the-result-of-fft-in-tensorflow-is-different-from-numpy

PiperOrigin-RevId: 268772775
---
 tensorflow/compiler/tf2xla/kernels/fft_ops.cc |  31 +-
 .../core/api_def/base_api/api_def_FFT3D.pbtxt |   4 +-
 .../api_def/base_api/api_def_IFFT3D.pbtxt     |   4 +-
 .../core/api_def/base_api/api_def_IRFFT.pbtxt |   2 +-
 .../api_def/base_api/api_def_IRFFT2D.pbtxt    |   2 +-
 .../api_def/base_api/api_def_IRFFT3D.pbtxt    |   2 +-
 tensorflow/core/kernels/fft_ops.cc            | 405 +++++++++++-------
 tensorflow/core/ops/spectral_ops.cc           |  36 +-
 tensorflow/python/eager/backprop_test.py      |  18 -
 tensorflow/python/kernel_tests/signal/BUILD   |   3 +-
 .../kernel_tests/signal/fft_ops_test.py       | 243 ++++++-----
 tensorflow/python/ops/signal/fft_ops.py       |  49 ++-
 .../api/golden/v1/tensorflow.raw_ops.pbtxt    |  12 +-
 .../api/golden/v2/tensorflow.raw_ops.pbtxt    |  12 +-
 14 files changed, 492 insertions(+), 331 deletions(-)

diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
index 5ac288d8a34..e5e4e797cc5 100644
--- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
@@ -138,10 +138,20 @@ class RFFTOp : public GenericFftOp {
   explicit RFFTOp(OpKernelConstruction* ctx)
       : GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {}
 };
-REGISTER_XLA_OP(Name("RFFT").CompileTimeConstantInput("fft_length"), RFFTOp<1>);
-REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstantInput("fft_length"),
+REGISTER_XLA_OP(Name("RFFT")
+                    .TypeConstraint("Treal", DT_FLOAT)
+                    .TypeConstraint("Tcomplex", DT_COMPLEX64)
+                    .CompileTimeConstantInput("fft_length"),
+                RFFTOp<1>);
+REGISTER_XLA_OP(Name("RFFT2D")
+                    .TypeConstraint("Treal", DT_FLOAT)
+                    .TypeConstraint("Tcomplex", DT_COMPLEX64)
+                    .CompileTimeConstantInput("fft_length"),
                 RFFTOp<2>);
-REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstantInput("fft_length"),
+REGISTER_XLA_OP(Name("RFFT3D")
+                    .TypeConstraint("Treal", DT_FLOAT)
+                    .TypeConstraint("Tcomplex", DT_COMPLEX64)
+                    .CompileTimeConstantInput("fft_length"),
                 RFFTOp<3>);
 
 template <int FFTRank>
@@ -150,11 +160,20 @@ class IRFFTOp : public GenericFftOp {
   explicit IRFFTOp(OpKernelConstruction* ctx)
       : GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {}
 };
-REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstantInput("fft_length"),
+REGISTER_XLA_OP(Name("IRFFT")
+                    .TypeConstraint("Treal", DT_FLOAT)
+                    .TypeConstraint("Tcomplex", DT_COMPLEX64)
+                    .CompileTimeConstantInput("fft_length"),
                 IRFFTOp<1>);
-REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstantInput("fft_length"),
+REGISTER_XLA_OP(Name("IRFFT2D")
+                    .TypeConstraint("Treal", DT_FLOAT)
+                    .TypeConstraint("Tcomplex", DT_COMPLEX64)
+                    .CompileTimeConstantInput("fft_length"),
                 IRFFTOp<2>);
-REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstantInput("fft_length"),
+REGISTER_XLA_OP(Name("IRFFT3D")
+                    .TypeConstraint("Treal", DT_FLOAT)
+                    .TypeConstraint("Tcomplex", DT_COMPLEX64)
+                    .CompileTimeConstantInput("fft_length"),
                 IRFFTOp<3>);
 
 }  // namespace
diff --git a/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt b/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt
index abd2e67bceb..33de5f424c9 100644
--- a/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt
@@ -3,13 +3,13 @@ op {
   in_arg {
     name: "input"
     description: <<END
-A complex64 tensor.
+A complex tensor.
 END
   }
   out_arg {
     name: "output"
     description: <<END
-A complex64 tensor of the same shape as `input`. The inner-most 3
+A complex tensor of the same shape as `input`. The inner-most 3
   dimensions of `input` are replaced with their 3D Fourier transform.
 
 @compatibility(numpy)
diff --git a/tensorflow/core/api_def/base_api/api_def_IFFT3D.pbtxt b/tensorflow/core/api_def/base_api/api_def_IFFT3D.pbtxt
index 52f1118775b..65857c5661b 100644
--- a/tensorflow/core/api_def/base_api/api_def_IFFT3D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_IFFT3D.pbtxt
@@ -3,13 +3,13 @@ op {
   in_arg {
     name: "input"
     description: <<END
-A complex64 tensor.
+A complex tensor.
 END
   }
   out_arg {
     name: "output"
     description: <<END
-A complex64 tensor of the same shape as `input`. The inner-most 3
+A complex tensor of the same shape as `input`. The inner-most 3
   dimensions of `input` are replaced with their inverse 3D Fourier transform.
 
 @compatibility(numpy)
diff --git a/tensorflow/core/api_def/base_api/api_def_IRFFT.pbtxt b/tensorflow/core/api_def/base_api/api_def_IRFFT.pbtxt
index 1e1caa9eade..f0cd4129a88 100644
--- a/tensorflow/core/api_def/base_api/api_def_IRFFT.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_IRFFT.pbtxt
@@ -3,7 +3,7 @@ op {
   in_arg {
     name: "input"
     description: <<END
-A complex64 tensor.
+A complex tensor.
 END
   }
   in_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_IRFFT2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_IRFFT2D.pbtxt
index 9b7390a3857..15183a87c18 100644
--- a/tensorflow/core/api_def/base_api/api_def_IRFFT2D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_IRFFT2D.pbtxt
@@ -3,7 +3,7 @@ op {
   in_arg {
     name: "input"
     description: <<END
-A complex64 tensor.
+A complex tensor.
 END
   }
   in_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_IRFFT3D.pbtxt b/tensorflow/core/api_def/base_api/api_def_IRFFT3D.pbtxt
index 1cee2ceeff0..068bac2fca3 100644
--- a/tensorflow/core/api_def/base_api/api_def_IRFFT3D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_IRFFT3D.pbtxt
@@ -3,7 +3,7 @@ op {
   in_arg {
     name: "input"
     description: <<END
-A complex64 tensor.
+A complex tensor.
 END
   }
   in_arg {
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index f386d6e9990..05843594839 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -15,7 +15,7 @@ limitations under the License.
 
 #define EIGEN_USE_THREADS
 
-// See docs in ../ops/spectral_ops.cc.
+// See docs in ../ops/fft_ops.cc.
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/op.h"
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/env_var.h"
@@ -94,7 +95,34 @@ class FFTBase : public OpKernel {
     }
 
     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
+
+    if (IsReal()) {
+      if (IsForward()) {
+        OP_REQUIRES(
+            ctx,
+            (in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64) ||
+                (in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128),
+            errors::InvalidArgument("Wrong types for forward real FFT: in=",
+                                    in.dtype(), " out=", out->dtype()));
+      } else {
+        OP_REQUIRES(
+            ctx,
+            (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT) ||
+                (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE),
+            errors::InvalidArgument("Wrong types for backward real FFT: in=",
+                                    in.dtype(), " out=", out->dtype()));
+      }
+    } else {
+      OP_REQUIRES(
+          ctx,
+          (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) ||
+              (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128),
+          errors::InvalidArgument("Wrong types for FFT: in=", in.dtype(),
+                                  " out=", out->dtype()));
+    }
+
     if (input_shape.num_elements() == 0) {
+      DCHECK_EQ(0, output_shape.num_elements());
       return;
     }
 
@@ -129,126 +157,159 @@ class FFTCPU : public FFTBase {
     const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
     auto device = ctx->eigen_device<CPUDevice>();
 
+    const bool is_complex128 =
+        in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
+
     if (!IsReal()) {
       // Compute the FFT using Eigen.
       constexpr auto direction =
           Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
-      if (in.dtype() == DT_COMPLEX64) {
+      if (is_complex128) {
+        DCHECK_EQ(in.dtype(), DT_COMPLEX128);
+        DCHECK_EQ(out->dtype(), DT_COMPLEX128);
+        auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
+        auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
+        output.device(device) =
+            input.template fft<Eigen::BothParts, direction>(axes);
+      } else {
+        DCHECK_EQ(in.dtype(), DT_COMPLEX64);
         DCHECK_EQ(out->dtype(), DT_COMPLEX64);
         auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
         auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
         output.device(device) =
             input.template fft<Eigen::BothParts, direction>(axes);
-      } else {
-        DCHECK_EQ(DT_COMPLEX128, in.dtype());
-        DCHECK_EQ(DT_COMPLEX128, out->dtype());
-        auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
-        auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
-        output.device(device) =
-            input.template fft<Eigen::BothParts, direction>(axes);
       }
     } else {
       if (IsForward()) {
-        auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>();
-        const auto input_dims = input.dimensions();
-
-        // Slice input to fft_shape on its inner-most dimensions.
-        Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
-        input_slice_sizes[0] = input_dims[0];
-        TensorShape temp_shape{input_dims[0]};
-        for (int i = 1; i <= FFTRank; ++i) {
-          input_slice_sizes[i] = fft_shape[i - 1];
-          temp_shape.AddDim(fft_shape[i - 1]);
+        if (is_complex128) {
+          DCHECK_EQ(in.dtype(), DT_DOUBLE);
+          DCHECK_EQ(out->dtype(), DT_COMPLEX128);
+          DoRealForwardFFT<double, complex128>(ctx, fft_shape, in, out);
+        } else {
+          DCHECK_EQ(in.dtype(), DT_FLOAT);
+          DCHECK_EQ(out->dtype(), DT_COMPLEX64);
+          DoRealForwardFFT<float, complex64>(ctx, fft_shape, in, out);
         }
-
-        auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
-        const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
-
-        // Compute the full FFT using a temporary tensor.
-        Tensor temp;
-        OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
-                                               temp_shape, &temp));
-        auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
-        full_fft.device(device) =
-            input.slice(zero_start_indices, input_slice_sizes)
-                .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
-
-        // Slice away the negative frequency components.
-        output.device(device) =
-            full_fft.slice(zero_start_indices, output.dimensions());
       } else {
-        // Reconstruct the full FFT and take the inverse.
-        auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
-        auto output = out->flat_inner_dims<float, FFTRank + 1>();
-        const auto input_dims = input.dimensions();
-
-        // Calculate the shape of the temporary tensor for the full FFT and the
-        // region we will slice from input given fft_shape. We slice input to
-        // fft_shape on its inner-most dimensions, except the last (which we
-        // slice to fft_shape[-1] / 2 + 1).
-        Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
-        input_slice_sizes[0] = input_dims[0];
-        TensorShape full_fft_shape;
-        full_fft_shape.AddDim(input_dims[0]);
-        for (auto i = 1; i <= FFTRank; i++) {
-          input_slice_sizes[i] =
-              i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
-          full_fft_shape.AddDim(fft_shape[i - 1]);
+        if (is_complex128) {
+          DCHECK_EQ(in.dtype(), DT_COMPLEX128);
+          DCHECK_EQ(out->dtype(), DT_DOUBLE);
+          DoRealBackwardFFT<complex128, double>(ctx, fft_shape, in, out);
+        } else {
+          DCHECK_EQ(in.dtype(), DT_COMPLEX64);
+          DCHECK_EQ(out->dtype(), DT_FLOAT);
+          DoRealBackwardFFT<complex64, float>(ctx, fft_shape, in, out);
         }
-
-        Tensor temp;
-        OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
-                                               full_fft_shape, &temp));
-        auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
-
-        // Calculate the starting point and range of the source of
-        // negative frequency part.
-        auto neg_sizes = input_slice_sizes;
-        neg_sizes[FFTRank] =
-            fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
-        Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
-        neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
-
-        const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
-        Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
-        neg_start_indices[FFTRank] = 1;
-
-        full_fft.slice(start_indices, input_slice_sizes).device(device) =
-            input.slice(start_indices, input_slice_sizes);
-
-        // First, conduct IFFTs on outer dimensions. We save computation (and
-        // avoid touching uninitialized memory) by slicing full_fft to the
-        // subregion we wrote input to.
-        if (FFTRank > 1) {
-          const auto outer_axes =
-              Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
-          full_fft.slice(start_indices, input_slice_sizes).device(device) =
-              full_fft.slice(start_indices, input_slice_sizes)
-                  .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(
-                      outer_axes);
-        }
-
-        // Reconstruct the full FFT by appending reversed and conjugated
-        // spectrum as the negative frequency part.
-        Eigen::array<bool, FFTRank + 1> reverse_last_axis;
-        for (auto i = 0; i <= FFTRank; i++) {
-          reverse_last_axis[i] = i == FFTRank;
-        }
-
-        if (neg_sizes[FFTRank] != 0) {
-          full_fft.slice(neg_target_indices, neg_sizes).device(device) =
-              full_fft.slice(neg_start_indices, neg_sizes)
-                  .reverse(reverse_last_axis)
-                  .conjugate();
-        }
-
-        auto inner_axis = Eigen::array<int, 1>{FFTRank};
-        output.device(device) =
-            full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(
-                inner_axis);
       }
     }
   }
+
+  template <typename RealT, typename ComplexT>
+  void DoRealForwardFFT(OpKernelContext* ctx, uint64* fft_shape,
+                        const Tensor& in, Tensor* out) {
+    // Create the axes (which are always trailing).
+    const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
+    auto device = ctx->eigen_device<CPUDevice>();
+    auto input = Tensor(in).flat_inner_dims<RealT, FFTRank + 1>();
+    const auto input_dims = input.dimensions();
+
+    // Slice input to fft_shape on its inner-most dimensions.
+    Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
+    input_slice_sizes[0] = input_dims[0];
+    TensorShape temp_shape{input_dims[0]};
+    for (int i = 1; i <= FFTRank; ++i) {
+      input_slice_sizes[i] = fft_shape[i - 1];
+      temp_shape.AddDim(fft_shape[i - 1]);
+    }
+
+    auto output = out->flat_inner_dims<ComplexT, FFTRank + 1>();
+    const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
+
+    // Compute the full FFT using a temporary tensor.
+    Tensor temp;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
+                                           temp_shape, &temp));
+    auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
+    full_fft.device(device) =
+        input.slice(zero_start_indices, input_slice_sizes)
+            .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
+
+    // Slice away the negative frequency components.
+    output.device(device) =
+        full_fft.slice(zero_start_indices, output.dimensions());
+  }
+
+  template <typename ComplexT, typename RealT>
+  void DoRealBackwardFFT(OpKernelContext* ctx, uint64* fft_shape,
+                         const Tensor& in, Tensor* out) {
+    auto device = ctx->eigen_device<CPUDevice>();
+    // Reconstruct the full FFT and take the inverse.
+    auto input = Tensor(in).flat_inner_dims<ComplexT, FFTRank + 1>();
+    auto output = out->flat_inner_dims<RealT, FFTRank + 1>();
+    const auto input_dims = input.dimensions();
+
+    // Calculate the shape of the temporary tensor for the full FFT and the
+    // region we will slice from input given fft_shape. We slice input to
+    // fft_shape on its inner-most dimensions, except the last (which we
+    // slice to fft_shape[-1] / 2 + 1).
+    Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
+    input_slice_sizes[0] = input_dims[0];
+    TensorShape full_fft_shape;
+    full_fft_shape.AddDim(input_dims[0]);
+    for (auto i = 1; i <= FFTRank; i++) {
+      input_slice_sizes[i] =
+          i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
+      full_fft_shape.AddDim(fft_shape[i - 1]);
+    }
+
+    Tensor temp;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
+                                           full_fft_shape, &temp));
+    auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
+
+    // Calculate the starting point and range of the source of
+    // negative frequency part.
+    auto neg_sizes = input_slice_sizes;
+    neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
+    Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
+    neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
+
+    const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
+    Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
+    neg_start_indices[FFTRank] = 1;
+
+    full_fft.slice(start_indices, input_slice_sizes).device(device) =
+        input.slice(start_indices, input_slice_sizes);
+
+    // First, conduct IFFTs on outer dimensions. We save computation (and
+    // avoid touching uninitialized memory) by slicing full_fft to the
+    // subregion we wrote input to.
+    if (FFTRank > 1) {
+      const auto outer_axes =
+          Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
+      full_fft.slice(start_indices, input_slice_sizes).device(device) =
+          full_fft.slice(start_indices, input_slice_sizes)
+              .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
+    }
+
+    // Reconstruct the full FFT by appending reversed and conjugated
+    // spectrum as the negative frequency part.
+    Eigen::array<bool, FFTRank + 1> reverse_last_axis;
+    for (auto i = 0; i <= FFTRank; i++) {
+      reverse_last_axis[i] = i == FFTRank;
+    }
+
+    if (neg_sizes[FFTRank] != 0) {
+      full_fft.slice(neg_target_indices, neg_sizes).device(device) =
+          full_fft.slice(neg_start_indices, neg_sizes)
+              .reverse(reverse_last_axis)
+              .conjugate();
+    }
+
+    auto inner_axis = Eigen::array<int, 1>{FFTRank};
+    output.device(device) =
+        full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
+  }
 };
 
 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU), FFTCPU<true, false, 1>);
@@ -390,16 +451,19 @@ class FFTGPUBase : public FFTBase {
     }
 
     constexpr bool kInPlaceFft = false;
-    const bool is_complex128 = in.dtype() == DT_COMPLEX128;
-    // complex128 real FFT is not supported yet.
-    DCHECK(!IsReal() || !is_complex128);
+    const bool is_complex128 =
+        in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
 
     const auto kFftType =
-        IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R)
-                 : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
-                                                 : se::fft::Type::kC2CForward)
-                                : (is_complex128 ? se::fft::Type::kZ2ZInverse
-                                                 : se::fft::Type::kC2CInverse));
+        IsReal()
+            ? (IsForward()
+                   ? (is_complex128 ? se::fft::Type::kD2Z : se::fft::Type::kR2C)
+                   : (is_complex128 ? se::fft::Type::kZ2D
+                                    : se::fft::Type::kC2R))
+            : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
+                                            : se::fft::Type::kC2CForward)
+                           : (is_complex128 ? se::fft::Type::kZ2ZInverse
+                                            : se::fft::Type::kC2CInverse));
 
     CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
     auto plan =
@@ -410,67 +474,80 @@ class FFTGPUBase : public FFTBase {
 
     if (IsReal()) {
       if (IsForward()) {
-        auto src = AsDeviceMemory<float>(in.flat<float>().data());
-        auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
-        OP_REQUIRES(
-            ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
-            errors::Internal("fft failed : type=", static_cast<int>(kFftType),
-                             " in.shape=", input_shape.DebugString()));
+        if (is_complex128) {
+          DCHECK_EQ(in.dtype(), DT_DOUBLE);
+          DCHECK_EQ(out->dtype(), DT_COMPLEX128);
+          DoFFTInternal<double, complex128>(ctx, stream, plan.get(), kFftType,
+                                            output_distance, in, out);
+        } else {
+          DCHECK_EQ(in.dtype(), DT_FLOAT);
+          DCHECK_EQ(out->dtype(), DT_COMPLEX64);
+          DoFFTInternal<float, complex64>(ctx, stream, plan.get(), kFftType,
+                                          output_distance, in, out);
+        }
       } else {
-        auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
-        auto dst = AsDeviceMemory<float>(out->flat<float>().data());
-        OP_REQUIRES(
-            ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
-            errors::Internal("fft failed : type=", static_cast<int>(kFftType),
-                             " in.shape=", input_shape.DebugString()));
-        auto alpha = 1.f / output_distance;
-        OP_REQUIRES(
-            ctx,
-            stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
-                .ok(),
-            errors::Internal("BlasScal failed : in.shape=",
-                             input_shape.DebugString()));
+        if (is_complex128) {
+          DCHECK_EQ(in.dtype(), DT_COMPLEX128);
+          DCHECK_EQ(out->dtype(), DT_DOUBLE);
+          DoFFTInternal<complex128, double>(ctx, stream, plan.get(), kFftType,
+                                            output_distance, in, out);
+        } else {
+          DCHECK_EQ(in.dtype(), DT_COMPLEX64);
+          DCHECK_EQ(out->dtype(), DT_FLOAT);
+          DoFFTInternal<complex64, float>(ctx, stream, plan.get(), kFftType,
+                                          output_distance, in, out);
+        }
       }
     } else {
-      if (!is_complex128) {
-        DCHECK_EQ(in.dtype(), DT_COMPLEX64);
-        DCHECK_EQ(out->dtype(), DT_COMPLEX64);
-        auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
-        auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
-        OP_REQUIRES(
-            ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
-            errors::Internal("fft failed : type=", static_cast<int>(kFftType),
-                             " in.shape=", input_shape.DebugString()));
-        if (!IsForward()) {
-          float alpha = 1.f / output_distance;
-          OP_REQUIRES(
-              ctx,
-              stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
-                  .ok(),
-              errors::Internal("BlasScal failed : in.shape=",
-                               input_shape.DebugString()));
-        }
-      } else {
+      if (is_complex128) {
         DCHECK_EQ(in.dtype(), DT_COMPLEX128);
         DCHECK_EQ(out->dtype(), DT_COMPLEX128);
-        auto src = AsDeviceMemory<complex128>(in.flat<complex128>().data());
-        auto dst = AsDeviceMemory<complex128>(out->flat<complex128>().data());
-        OP_REQUIRES(
-            ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
-            errors::Internal("fft failed : type=", static_cast<int>(kFftType),
-                             " in.shape=", input_shape.DebugString()));
-        if (!IsForward()) {
-          double alpha = 1.0 / output_distance;
-          OP_REQUIRES(
-              ctx,
-              stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
-                  .ok(),
-              errors::Internal("BlasScal failed : in.shape=",
-                               input_shape.DebugString()));
-        }
+        DoFFTInternal<complex128, complex128>(ctx, stream, plan.get(), kFftType,
+                                              output_distance, in, out);
+      } else {
+        DCHECK_EQ(in.dtype(), DT_COMPLEX64);
+        DCHECK_EQ(out->dtype(), DT_COMPLEX64);
+        DoFFTInternal<complex64, complex64>(ctx, stream, plan.get(), kFftType,
+                                            output_distance, in, out);
       }
     }
   }
+
+ private:
+  template <typename T>
+  struct RealTypeFromComplexType {
+    typedef T RealT;
+  };
+
+  template <typename T>
+  struct RealTypeFromComplexType<std::complex<T>> {
+    typedef T RealT;
+  };
+
+  template <typename InT, typename OutT>
+  void DoFFTInternal(OpKernelContext* ctx, se::Stream* stream,
+                     se::fft::Plan* plan, const se::fft::Type fft_type,
+                     const uint64 output_distance, const Tensor& in,
+                     Tensor* out) {
+    auto src = AsDeviceMemory<InT>(in.flat<InT>().data());
+    auto dst = AsDeviceMemory<OutT>(out->flat<OutT>().data());
+    const TensorShape& input_shape = in.shape();
+    const TensorShape& output_shape = out->shape();
+    OP_REQUIRES(
+        ctx, stream->ThenFft(plan, src, &dst).ok(),
+        errors::Internal("fft failed : type=", static_cast<int>(fft_type),
+                         " in.shape=", input_shape.DebugString()));
+    if (!IsForward()) {
+      typedef typename RealTypeFromComplexType<OutT>::RealT RealT;
+      RealT alpha = 1.0 / output_distance;
+      OP_REQUIRES(
+          ctx,
+          stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
+              .ok(),
+          errors::Internal("BlasScal failed : in.shape=",
+                           input_shape.DebugString()));
+    }
+  }
 };
 
 int64 FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(
diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc
index b1ae7040f02..3b9b962143b 100644
--- a/tensorflow/core/ops/spectral_ops.cc
+++ b/tensorflow/core/ops/spectral_ops.cc
@@ -109,39 +109,51 @@ Status RFFTShape(InferenceContext* c, const bool forward, const int rank) {
 }
 
 REGISTER_OP("RFFT")
-    .Input("input: float")
+    .Input("input: Treal")
     .Input("fft_length: int32")
-    .Output("output: complex64")
+    .Output("output: Tcomplex")
+    .Attr("Treal: {float32, float64} = DT_FLOAT")
+    .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
     .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 1); });
 
 REGISTER_OP("IRFFT")
-    .Input("input: complex64")
+    .Input("input: Tcomplex")
     .Input("fft_length: int32")
-    .Output("output: float")
+    .Output("output: Treal")
+    .Attr("Treal: {float32, float64} = DT_FLOAT")
+    .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
     .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 1); });
 
 REGISTER_OP("RFFT2D")
-    .Input("input: float")
+    .Input("input: Treal")
     .Input("fft_length: int32")
-    .Output("output: complex64")
+    .Output("output: Tcomplex")
+    .Attr("Treal: {float32, float64} = DT_FLOAT")
+    .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
     .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 2); });
 
 REGISTER_OP("IRFFT2D")
-    .Input("input: complex64")
+    .Input("input: Tcomplex")
     .Input("fft_length: int32")
-    .Output("output: float")
+    .Output("output: Treal")
+    .Attr("Treal: {float32, float64} = DT_FLOAT")
+    .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
     .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 2); });
 
 REGISTER_OP("RFFT3D")
-    .Input("input: float")
+    .Input("input: Treal")
     .Input("fft_length: int32")
-    .Output("output: complex64")
+    .Output("output: Tcomplex")
+    .Attr("Treal: {float32, float64} = DT_FLOAT")
+    .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
     .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 3); });
 
 REGISTER_OP("IRFFT3D")
-    .Input("input: complex64")
+    .Input("input: Tcomplex")
     .Input("fft_length: int32")
-    .Output("output: float")
+    .Output("output: Treal")
+    .Attr("Treal: {float32, float64} = DT_FLOAT")
+    .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
     .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 3); });
 
 // Deprecated ops:
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 54ab48053e9..86b5b9ebd8b 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -48,7 +48,6 @@ from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops.signal import fft_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.training import training
 from tensorflow.python.util import nest
@@ -1579,23 +1578,6 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
     with self.assertRaisesRegexp(ValueError, 'ndarray'):
       g.watch(np.array(1.))
 
-  def testOpWithNoAttrs(self):
-
-    @function.defun(autograph=False)
-    def f():
-      with backprop.GradientTape() as tape:
-        xs = random_ops.random_normal([10, 32])
-        tape.watch(xs)
-        # The `rfft()` op has no defined attrs, which exercises a different
-        # branch in the Python op wrapper code generator for recording
-        # gradients.
-        ys = fft_ops.rfft(xs)
-        self.assertEmpty(ys.op.node_def.attr)
-      gs = tape.gradient(ys, xs)
-      self.assertIsNotNone(gs)
-
-    f.get_concrete_function()
-
 
 class JacobianTest(test.TestCase):
 
diff --git a/tensorflow/python/kernel_tests/signal/BUILD b/tensorflow/python/kernel_tests/signal/BUILD
index 3806783ca11..b260fff573e 100644
--- a/tensorflow/python/kernel_tests/signal/BUILD
+++ b/tensorflow/python/kernel_tests/signal/BUILD
@@ -37,6 +37,7 @@ cuda_py_tests(
 
 cuda_py_tests(
     name = "fft_ops_test",
+    # TODO(rjryan): Parameterize the test to reduce the time it takes.
     size = "medium",
     srcs = ["fft_ops_test.py"],
     additional_deps = [
@@ -46,7 +47,7 @@ cuda_py_tests(
         "//tensorflow/python:math_ops",
         "//tensorflow/python/ops/signal",
     ],
-    shard_count = 4,
+    shard_count = 8,
     tags = [
         "no_rocm",
         "optonly",
diff --git a/tensorflow/python/kernel_tests/signal/fft_ops_test.py b/tensorflow/python/kernel_tests/signal/fft_ops_test.py
index 5745da73045..be7e5c18abc 100644
--- a/tensorflow/python/kernel_tests/signal/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/fft_ops_test.py
@@ -18,10 +18,13 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import contextlib
+
 import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.compat import compat
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -36,6 +39,18 @@ from tensorflow.python.platform import test
 VALID_FFT_RANKS = (1, 2, 3)
 
 
+def _forward_compat_context(np_dtype):
+  @contextlib.contextmanager
+  def null_context():
+    yield
+  if np_dtype in (np.float64, np.complex128):
+    return compat.forward_compatibility_horizon(2019, 10, 13)
+  else:
+    return null_context()
+
+
+# TODO(rjryan): Investigate precision issues. We should be able to achieve
+# better tolerances, at least for the complex128 tests.
 class BaseFFTOpsTest(test.TestCase):
 
   def _compare(self, x, rank, fft_length=None, use_placeholder=False,
@@ -84,8 +99,9 @@ class BaseFFTOpsTest(test.TestCase):
         loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
         return loss
 
-      ((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
-          gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
+      with _forward_compat_context(x.dtype):
+        ((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
+            gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
 
     self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
     self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol)
@@ -99,8 +115,9 @@ class BaseFFTOpsTest(test.TestCase):
       loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
       return loss
 
-    (x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
-        f, [x], delta=1e-2)
+    with _forward_compat_context(x.dtype):
+      (x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
+          f, [x], delta=1e-2)
     self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
 
 
@@ -174,13 +191,12 @@ class FFTOpsTest(BaseFFTOpsTest):
                   (4,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
 
   def test_large_batch(self):
-    if test.is_gpu_available(cuda_only=True):
-      rank = 1
-      for dims in xrange(rank, rank + 3):
-        for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-5)):
-          self._compare(
-              np.mod(np.arange(np.power(128, dims)), 10).reshape(
-                  (128,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
+    rank = 1
+    for dims in xrange(rank, rank + 3):
+      for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 5e-5)):
+        self._compare(
+            np.mod(np.arange(np.power(128, dims)), 10).reshape(
+                (128,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
 
   # TODO(yangzihao): Disable before we can figure out a way to
   # properly test memory fail for large batch fft.
@@ -277,17 +293,15 @@ class FFTOpsTest(BaseFFTOpsTest):
 @test_util.run_all_in_graph_and_eager_modes
 class RFFTOpsTest(BaseFFTOpsTest):
 
-  def _compare_backward(self, x, rank, fft_length=None, use_placeholder=False):
-    super(RFFTOpsTest, self)._compare_backward(x, rank, fft_length,
-                                               use_placeholder)
-
   def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
-    with self.cached_session(use_gpu=True) as sess:
+    with _forward_compat_context(x.dtype), self.cached_session(
+        use_gpu=True) as sess:
       return sess.run(
           self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
 
   def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
-    with self.cached_session(use_gpu=True) as sess:
+    with _forward_compat_context(x.dtype), self.cached_session(
+        use_gpu=True) as sess:
       return sess.run(
           self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
 
@@ -332,61 +346,74 @@ class RFFTOpsTest(BaseFFTOpsTest):
       raise ValueError("invalid rank")
 
   def test_empty(self):
-    for rank in VALID_FFT_RANKS:
-      for dims in xrange(rank, rank + 3):
-        x = np.zeros((0,) * dims).astype(np.float32)
-        self.assertEqual(x.shape, self._tf_fft(x, rank).shape)
-        x = np.zeros((0,) * dims).astype(np.complex64)
-        self.assertEqual(x.shape, self._tf_ifft(x, rank).shape)
+    for np_rtype, np_ctype in ((np.float32, np.complex64),
+                               (np.float64, np.complex128)):
+      for rank in VALID_FFT_RANKS:
+        for dims in xrange(rank, rank + 3):
+          x = np.zeros((0,) * dims).astype(np_rtype)
+          self.assertEqual(x.shape, self._tf_fft(x, rank).shape)
+          x = np.zeros((0,) * dims).astype(np_ctype)
+          self.assertEqual(x.shape, self._tf_ifft(x, rank).shape)
 
   def test_basic(self):
-    for rank in VALID_FFT_RANKS:
-      for dims in xrange(rank, rank + 3):
-        for size in (5, 6):
-          inner_dim = size // 2 + 1
-          r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
-              (size,) * dims)
-          self._compare_forward(r2c.astype(np.float32), rank, (size,) * rank)
-          c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
-                       10).reshape((size,) * (dims - 1) + (inner_dim,))
-          self._compare_backward(
-              c2r.astype(np.complex64), rank, (size,) * rank)
+    for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
+                                    (np.float64, np.complex128, 5e-5)):
+      for rank in VALID_FFT_RANKS:
+        for dims in xrange(rank, rank + 3):
+          for size in (5, 6):
+            inner_dim = size // 2 + 1
+            r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
+                (size,) * dims)
+            self._compare_forward(r2c.astype(np_rtype), rank, (size,) * rank,
+                                  rtol=tol, atol=tol)
+            c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
+                         10).reshape((size,) * (dims - 1) + (inner_dim,))
+            self._compare_backward(
+                c2r.astype(np_ctype), rank, (size,) * rank, rtol=tol, atol=tol)
 
   def test_large_batch(self):
-    if test.is_gpu_available(cuda_only=True):
-      rank = 1
+    rank = 1
+    for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
+                                    (np.float64, np.complex128, 1e-5)):
       for dims in xrange(rank, rank + 3):
         for size in (64, 128):
           inner_dim = size // 2 + 1
           r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
               (size,) * dims)
-          self._compare_forward(r2c.astype(np.float32), rank, (size,) * rank)
+          self._compare_forward(r2c.astype(np_rtype), rank, (size,) * rank,
+                                rtol=tol, atol=tol)
           c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
                        10).reshape((size,) * (dims - 1) + (inner_dim,))
-          self._compare_backward(c2r.astype(np.complex64), rank, (size,) * rank)
+          self._compare_backward(c2r.astype(np_ctype), rank, (size,) * rank,
+                                 rtol=tol, atol=tol)
 
   def test_placeholder(self):
     if context.executing_eagerly():
       return
-    for rank in VALID_FFT_RANKS:
-      for dims in xrange(rank, rank + 3):
-        for size in (5, 6):
-          inner_dim = size // 2 + 1
-          r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
-              (size,) * dims)
-          self._compare_forward(
-              r2c.astype(np.float32),
-              rank, (size,) * rank,
-              use_placeholder=True)
-          c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
-                       10).reshape((size,) * (dims - 1) + (inner_dim,))
-          self._compare_backward(
-              c2r.astype(np.complex64),
-              rank, (size,) * rank,
-              use_placeholder=True)
+    for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
+                                    (np.float64, np.complex128, 1e-8)):
+      for rank in VALID_FFT_RANKS:
+        for dims in xrange(rank, rank + 3):
+          for size in (5, 6):
+            inner_dim = size // 2 + 1
+            r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
+                (size,) * dims)
+            self._compare_forward(
+                r2c.astype(np_rtype),
+                rank, (size,) * rank,
+                use_placeholder=True,
+                rtol=tol, atol=tol)
+            c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
+                         10).reshape((size,) * (dims - 1) + (inner_dim,))
+            self._compare_backward(
+                c2r.astype(np_ctype),
+                rank, (size,) * rank,
+                use_placeholder=True,
+                rtol=tol, atol=tol)
 
-  def test_fft_lenth(self):
-    if test.is_gpu_available(cuda_only=True):
+  def test_fft_length(self):
+    for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
+                                    (np.float64, np.complex128, 8e-5)):
       for rank in VALID_FFT_RANKS:
         for dims in xrange(rank, rank + 3):
           for size in (5, 6):
@@ -397,36 +424,44 @@ class RFFTOpsTest(BaseFFTOpsTest):
                          10).reshape((size,) * (dims - 1) + (inner_dim,))
             # Test truncation (FFT size < dimensions).
             fft_length = (size - 2,) * rank
-            self._compare_forward(r2c.astype(np.float32), rank, fft_length)
-            self._compare_backward(c2r.astype(np.complex64), rank, fft_length)
+            self._compare_forward(r2c.astype(np_rtype), rank, fft_length,
+                                  rtol=tol, atol=tol)
+            self._compare_backward(c2r.astype(np_ctype), rank, fft_length,
+                                   rtol=tol, atol=tol)
             # Confirm it works with unknown shapes as well.
             if not context.executing_eagerly():
               self._compare_forward(
-                  r2c.astype(np.float32),
+                  r2c.astype(np_rtype),
                   rank,
                   fft_length,
-                  use_placeholder=True)
+                  use_placeholder=True,
+                  rtol=tol, atol=tol)
               self._compare_backward(
-                  c2r.astype(np.complex64),
+                  c2r.astype(np_ctype),
                   rank,
                   fft_length,
-                  use_placeholder=True)
+                  use_placeholder=True,
+                  rtol=tol, atol=tol)
             # Test padding (FFT size > dimensions).
             fft_length = (size + 2,) * rank
-            self._compare_forward(r2c.astype(np.float32), rank, fft_length)
-            self._compare_backward(c2r.astype(np.complex64), rank, fft_length)
+            self._compare_forward(r2c.astype(np_rtype), rank, fft_length,
+                                  rtol=tol, atol=tol)
+            self._compare_backward(c2r.astype(np_ctype), rank, fft_length,
+                                   rtol=tol, atol=tol)
             # Confirm it works with unknown shapes as well.
             if not context.executing_eagerly():
               self._compare_forward(
-                  r2c.astype(np.float32),
+                  r2c.astype(np_rtype),
                   rank,
                   fft_length,
-                  use_placeholder=True)
+                  use_placeholder=True,
+                  rtol=tol, atol=tol)
               self._compare_backward(
-                  c2r.astype(np.complex64),
+                  c2r.astype(np_ctype),
                   rank,
                   fft_length,
-                  use_placeholder=True)
+                  use_placeholder=True,
+                  rtol=tol, atol=tol)
 
   def test_random(self):
     def gen_real(shape):
@@ -442,14 +477,20 @@ class RFFTOpsTest(BaseFFTOpsTest):
       ret = (re + im * 1j).reshape(shape)
       return ret
 
-    for rank in VALID_FFT_RANKS:
-      for dims in xrange(rank, rank + 3):
-        for size in (5, 6):
-          inner_dim = size // 2 + 1
-          self._compare_forward(gen_real((size,) * dims), rank, (size,) * rank)
-          complex_dims = (size,) * (dims - 1) + (inner_dim,)
-          self._compare_backward(
-              gen_complex(complex_dims), rank, (size,) * rank)
+    for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4),
+                                    (np.float64, np.complex128, 1e-5)):
+      for rank in VALID_FFT_RANKS:
+        for dims in xrange(rank, rank + 3):
+          for size in (5, 6):
+            inner_dim = size // 2 + 1
+            self._compare_forward(gen_real((size,) * dims).astype(np_rtype),
+                                  rank, (size,) * rank,
+                                  rtol=tol, atol=tol)
+            complex_dims = (size,) * (dims - 1) + (inner_dim,)
+            self._compare_backward(
+                gen_complex(complex_dims).astype(np_ctype),
+                rank, (size,) * rank,
+                rtol=tol, atol=tol)
 
   def test_error(self):
     # TODO(rjryan): Fix this test under Eager.
@@ -510,30 +551,36 @@ class RFFTOpsTest(BaseFFTOpsTest):
           self.evaluate(irfft_fn(x, fft_length))
 
   def test_grad_simple(self):
-    for rank in VALID_FFT_RANKS:
-      # rfft3d/irfft3d do not have gradients yet.
-      if rank == 3:
-        continue
-      for dims in xrange(rank, rank + 2):
-        for size in (5, 6):
-          re = np.ones(shape=(size,) * dims, dtype=np.float32)
-          im = -np.ones(shape=(size,) * dims, dtype=np.float32)
-          self._check_grad_real(self._tf_fft_for_rank(rank), re)
-          self._check_grad_complex(
-              self._tf_ifft_for_rank(rank), re, im, result_is_complex=False)
+    for np_rtype, tol in ((np.float32, 1e-3), (np.float64, 1e-10)):
+      for rank in VALID_FFT_RANKS:
+        # rfft3d/irfft3d do not have gradients yet.
+        if rank == 3:
+          continue
+        for dims in xrange(rank, rank + 2):
+          for size in (5, 6):
+            re = np.ones(shape=(size,) * dims, dtype=np_rtype)
+            im = -np.ones(shape=(size,) * dims, dtype=np_rtype)
+            self._check_grad_real(self._tf_fft_for_rank(rank), re,
+                                  rtol=tol, atol=tol)
+            self._check_grad_complex(
+                self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
+                rtol=tol, atol=tol)
 
   def test_grad_random(self):
-    for rank in VALID_FFT_RANKS:
-      # rfft3d/irfft3d do not have gradients yet.
-      if rank == 3:
-        continue
-      for dims in xrange(rank, rank + 2):
-        for size in (5, 6):
-          re = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
-          im = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
-          self._check_grad_real(self._tf_fft_for_rank(rank), re)
-          self._check_grad_complex(
-              self._tf_ifft_for_rank(rank), re, im, result_is_complex=False)
+    for np_rtype, tol in ((np.float32, 1e-2), (np.float64, 1e-10)):
+      for rank in VALID_FFT_RANKS:
+        # rfft3d/irfft3d do not have gradients yet.
+        if rank == 3:
+          continue
+        for dims in xrange(rank, rank + 2):
+          for size in (5, 6):
+            re = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1
+            im = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1
+            self._check_grad_real(self._tf_fft_for_rank(rank), re,
+                                  rtol=tol, atol=tol)
+            self._check_grad_complex(
+                self._tf_ifft_for_rank(rank), re, im, result_is_complex=False,
+                rtol=tol, atol=tol)
 
 
 @test_util.run_all_in_graph_and_eager_modes
diff --git a/tensorflow/python/ops/signal/fft_ops.py b/tensorflow/python/ops/signal/fft_ops.py
index 0e18c217fc5..cfe7799ac57 100644
--- a/tensorflow/python/ops/signal/fft_ops.py
+++ b/tensorflow/python/ops/signal/fft_ops.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.python.compat import compat
 from tensorflow.python.framework import dtypes as _dtypes
 from tensorflow.python.framework import ops as _ops
 from tensorflow.python.framework import tensor_util as _tensor_util
@@ -115,14 +116,23 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
     """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument."""
     with _ops.name_scope(name, default_name,
                          [input_tensor, fft_length]) as name:
-      input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32)
+      input_tensor = _ops.convert_to_tensor(input_tensor,
+                                            preferred_dtype=_dtypes.float32)
+      real_dtype = input_tensor.dtype
+      if real_dtype == _dtypes.float32:
+        complex_dtype = _dtypes.complex64
+      elif real_dtype == _dtypes.float64:
+        complex_dtype = _dtypes.complex128
       input_tensor.shape.with_rank_at_least(fft_rank)
       if fft_length is None:
         fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank)
       else:
         fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
       input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
-      return fft_fn(input_tensor, fft_length, name)
+
+      if not compat.forward_compatible(2019, 10, 12):
+        return fft_fn(input_tensor, fft_length, name=name)
+      return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
   _rfft.__doc__ = fft_fn.__doc__
   return _rfft
 
@@ -134,15 +144,20 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
     """Wrapper irfft* that infers fft_length argument."""
     with _ops.name_scope(name, default_name,
                          [input_tensor, fft_length]) as name:
-      input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64)
+      input_tensor = _ops.convert_to_tensor(input_tensor,
+                                            preferred_dtype=_dtypes.complex64)
       input_tensor.shape.with_rank_at_least(fft_rank)
+      complex_dtype = input_tensor.dtype
+      real_dtype = complex_dtype.real_dtype
       if fft_length is None:
         fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank)
       else:
         fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
       input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
                                          is_reverse=True)
-      return ifft_fn(input_tensor, fft_length, name)
+      if not compat.forward_compatible(2019, 10, 12):
+        return ifft_fn(input_tensor, fft_length, name=name)
+      return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
   _irfft.__doc__ = ifft_fn.__doc__
   return _irfft
 
@@ -223,8 +238,10 @@ def _rfft_grad_helper(rank, irfft_fn):
   def _grad(op, grad):
     """A gradient function for RFFT with the provided `rank` and `irfft_fn`."""
     fft_length = op.inputs[1]
+    complex_dtype = grad.dtype
+    real_dtype = complex_dtype.real_dtype
     input_shape = _array_ops.shape(op.inputs[0])
-    is_even = _math_ops.cast(1 - (fft_length[-1] % 2), _dtypes.complex64)
+    is_even = _math_ops.cast(1 - (fft_length[-1] % 2), complex_dtype)
 
     def _tile_for_broadcasting(matrix, t):
       expanded = _array_ops.reshape(
@@ -248,13 +265,13 @@ def _rfft_grad_helper(rank, irfft_fn):
           _array_ops.expand_dims(_math_ops.range(length), 0), (length, 1))
       b = _array_ops.transpose(a, [1, 0])
       return _math_ops.exp(
-          -2j * np.pi * _math_ops.cast(a * b, _dtypes.complex64) /
-          _math_ops.cast(length, _dtypes.complex64))
+          -2j * np.pi * _math_ops.cast(a * b, complex_dtype) /
+          _math_ops.cast(length, complex_dtype))
 
     def _ymask(length):
       """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`."""
       return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2),
-                            _dtypes.complex64)
+                            complex_dtype)
 
     y0 = grad[..., 0:1]
     if rank == 1:
@@ -288,7 +305,7 @@ def _rfft_grad_helper(rank, irfft_fn):
     # factor, plus some additional terms to make up for the components dropped
     # due to Hermitian symmetry.
     input_size = _math_ops.cast(
-        _fft_size_for_grad(op.inputs[0], rank), _dtypes.float32)
+        _fft_size_for_grad(op.inputs[0], rank), real_dtype)
     the_irfft = irfft_fn(grad, fft_length)
     return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None
 
@@ -307,21 +324,27 @@ def _irfft_grad_helper(rank, rfft_fn):
     # graph we special-case the situation where the FFT length and last
     # dimension of the input are known at graph construction time.
     fft_length = op.inputs[1]
+    real_dtype = grad.dtype
+    if real_dtype == _dtypes.float32:
+      complex_dtype = _dtypes.complex64
+    elif real_dtype == _dtypes.float64:
+      complex_dtype = _dtypes.complex128
     is_odd = _math_ops.mod(fft_length[-1], 2)
     input_last_dimension = _array_ops.shape(op.inputs[0])[-1]
     mask = _array_ops.concat(
-        [[1.0], 2.0 * _array_ops.ones([input_last_dimension - 2 + is_odd]),
-         _array_ops.ones([1 - is_odd])], 0)
+        [[1.0], 2.0 * _array_ops.ones(
+            [input_last_dimension - 2 + is_odd], real_dtype),
+         _array_ops.ones([1 - is_odd], real_dtype)], 0)
 
     rsize = _math_ops.reciprocal(_math_ops.cast(
-        _fft_size_for_grad(grad, rank), _dtypes.float32))
+        _fft_size_for_grad(grad, rank), real_dtype))
 
     # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
     # factor and a mask. The mask scales the gradient for the Hermitian
     # symmetric components of the RFFT by a factor of two, since these
     # components are de-duplicated in the RFFT.
     the_rfft = rfft_fn(grad, fft_length)
-    return the_rfft * _math_ops.cast(rsize * mask, _dtypes.complex64), None
+    return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None
 
   return _grad
 
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index ba8bb05df48..f6d953d5df7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1626,15 +1626,15 @@ tf_module {
   }
   member_method {
     name: "IRFFT"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
   member_method {
     name: "IRFFT2D"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
   member_method {
     name: "IRFFT3D"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
   member_method {
     name: "Identity"
@@ -2858,15 +2858,15 @@ tf_module {
   }
   member_method {
     name: "RFFT"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
   }
   member_method {
     name: "RFFT2D"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
   }
   member_method {
     name: "RFFT3D"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
   }
   member_method {
     name: "RGBToHSV"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index ba8bb05df48..f6d953d5df7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1626,15 +1626,15 @@ tf_module {
   }
   member_method {
     name: "IRFFT"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
   member_method {
     name: "IRFFT2D"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
   member_method {
     name: "IRFFT3D"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
   }
   member_method {
     name: "Identity"
@@ -2858,15 +2858,15 @@ tf_module {
   }
   member_method {
     name: "RFFT"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
   }
   member_method {
     name: "RFFT2D"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
   }
   member_method {
     name: "RFFT3D"
-    argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\'], "
   }
   member_method {
     name: "RGBToHSV"