diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 5ac288d8a34..e5e4e797cc5 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -138,10 +138,20 @@ class RFFTOp : public GenericFftOp { explicit RFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("RFFT").CompileTimeConstantInput("fft_length"), RFFTOp<1>); -REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("RFFT") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), + RFFTOp<1>); +REGISTER_XLA_OP(Name("RFFT2D") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), RFFTOp<2>); -REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("RFFT3D") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), RFFTOp<3>); template @@ -150,11 +160,20 @@ class IRFFTOp : public GenericFftOp { explicit IRFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("IRFFT") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), IRFFTOp<1>); -REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("IRFFT2D") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), IRFFTOp<2>); -REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstantInput("fft_length"), +REGISTER_XLA_OP(Name("IRFFT3D") + .TypeConstraint("Treal", DT_FLOAT) + .TypeConstraint("Tcomplex", DT_COMPLEX64) + .CompileTimeConstantInput("fft_length"), IRFFTOp<3>); } // namespace diff --git a/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt b/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt index abd2e67bceb..33de5f424c9 100644 --- a/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_FFT3D.pbtxt @@ -3,13 +3,13 @@ op { in_arg { name: "input" description: <allocate_output(0, output_shape, &out)); + + if (IsReal()) { + if (IsForward()) { + OP_REQUIRES( + ctx, + (in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64) || + (in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128), + errors::InvalidArgument("Wrong types for forward real FFT: in=", + in.dtype(), " out=", out->dtype())); + } else { + OP_REQUIRES( + ctx, + (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT) || + (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE), + errors::InvalidArgument("Wrong types for backward real FFT: in=", + in.dtype(), " out=", out->dtype())); + } + } else { + OP_REQUIRES( + ctx, + (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) || + (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128), + errors::InvalidArgument("Wrong types for FFT: in=", in.dtype(), + " out=", out->dtype())); + } + if (input_shape.num_elements() == 0) { + DCHECK_EQ(0, output_shape.num_elements()); return; } @@ -129,126 +157,159 @@ class FFTCPU : public FFTBase { const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); auto device = ctx->eigen_device(); + const bool is_complex128 = + in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128; + if (!IsReal()) { // Compute the FFT using Eigen. constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE; - if (in.dtype() == DT_COMPLEX64) { + if (is_complex128) { + DCHECK_EQ(in.dtype(), DT_COMPLEX128); + DCHECK_EQ(out->dtype(), DT_COMPLEX128); + auto input = Tensor(in).flat_inner_dims(); + auto output = out->flat_inner_dims(); + output.device(device) = + input.template fft(axes); + } else { + DCHECK_EQ(in.dtype(), DT_COMPLEX64); DCHECK_EQ(out->dtype(), DT_COMPLEX64); auto input = Tensor(in).flat_inner_dims(); auto output = out->flat_inner_dims(); output.device(device) = input.template fft(axes); - } else { - DCHECK_EQ(DT_COMPLEX128, in.dtype()); - DCHECK_EQ(DT_COMPLEX128, out->dtype()); - auto input = Tensor(in).flat_inner_dims(); - auto output = out->flat_inner_dims(); - output.device(device) = - input.template fft(axes); } } else { if (IsForward()) { - auto input = Tensor(in).flat_inner_dims(); - const auto input_dims = input.dimensions(); - - // Slice input to fft_shape on its inner-most dimensions. - Eigen::DSizes input_slice_sizes; - input_slice_sizes[0] = input_dims[0]; - TensorShape temp_shape{input_dims[0]}; - for (int i = 1; i <= FFTRank; ++i) { - input_slice_sizes[i] = fft_shape[i - 1]; - temp_shape.AddDim(fft_shape[i - 1]); + if (is_complex128) { + DCHECK_EQ(in.dtype(), DT_DOUBLE); + DCHECK_EQ(out->dtype(), DT_COMPLEX128); + DoRealForwardFFT(ctx, fft_shape, in, out); + } else { + DCHECK_EQ(in.dtype(), DT_FLOAT); + DCHECK_EQ(out->dtype(), DT_COMPLEX64); + DoRealForwardFFT(ctx, fft_shape, in, out); } - - auto output = out->flat_inner_dims(); - const Eigen::DSizes zero_start_indices; - - // Compute the full FFT using a temporary tensor. - Tensor temp; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::v(), - temp_shape, &temp)); - auto full_fft = temp.flat_inner_dims(); - full_fft.device(device) = - input.slice(zero_start_indices, input_slice_sizes) - .template fft(axes); - - // Slice away the negative frequency components. - output.device(device) = - full_fft.slice(zero_start_indices, output.dimensions()); } else { - // Reconstruct the full FFT and take the inverse. - auto input = Tensor(in).flat_inner_dims(); - auto output = out->flat_inner_dims(); - const auto input_dims = input.dimensions(); - - // Calculate the shape of the temporary tensor for the full FFT and the - // region we will slice from input given fft_shape. We slice input to - // fft_shape on its inner-most dimensions, except the last (which we - // slice to fft_shape[-1] / 2 + 1). - Eigen::DSizes input_slice_sizes; - input_slice_sizes[0] = input_dims[0]; - TensorShape full_fft_shape; - full_fft_shape.AddDim(input_dims[0]); - for (auto i = 1; i <= FFTRank; i++) { - input_slice_sizes[i] = - i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1]; - full_fft_shape.AddDim(fft_shape[i - 1]); + if (is_complex128) { + DCHECK_EQ(in.dtype(), DT_COMPLEX128); + DCHECK_EQ(out->dtype(), DT_DOUBLE); + DoRealBackwardFFT(ctx, fft_shape, in, out); + } else { + DCHECK_EQ(in.dtype(), DT_COMPLEX64); + DCHECK_EQ(out->dtype(), DT_FLOAT); + DoRealBackwardFFT(ctx, fft_shape, in, out); } - - Tensor temp; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::v(), - full_fft_shape, &temp)); - auto full_fft = temp.flat_inner_dims(); - - // Calculate the starting point and range of the source of - // negative frequency part. - auto neg_sizes = input_slice_sizes; - neg_sizes[FFTRank] = - fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank]; - Eigen::DSizes neg_target_indices; - neg_target_indices[FFTRank] = input_slice_sizes[FFTRank]; - - const Eigen::DSizes start_indices; - Eigen::DSizes neg_start_indices; - neg_start_indices[FFTRank] = 1; - - full_fft.slice(start_indices, input_slice_sizes).device(device) = - input.slice(start_indices, input_slice_sizes); - - // First, conduct IFFTs on outer dimensions. We save computation (and - // avoid touching uninitialized memory) by slicing full_fft to the - // subregion we wrote input to. - if (FFTRank > 1) { - const auto outer_axes = - Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1); - full_fft.slice(start_indices, input_slice_sizes).device(device) = - full_fft.slice(start_indices, input_slice_sizes) - .template fft( - outer_axes); - } - - // Reconstruct the full FFT by appending reversed and conjugated - // spectrum as the negative frequency part. - Eigen::array reverse_last_axis; - for (auto i = 0; i <= FFTRank; i++) { - reverse_last_axis[i] = i == FFTRank; - } - - if (neg_sizes[FFTRank] != 0) { - full_fft.slice(neg_target_indices, neg_sizes).device(device) = - full_fft.slice(neg_start_indices, neg_sizes) - .reverse(reverse_last_axis) - .conjugate(); - } - - auto inner_axis = Eigen::array{FFTRank}; - output.device(device) = - full_fft.template fft( - inner_axis); } } } + + template + void DoRealForwardFFT(OpKernelContext* ctx, uint64* fft_shape, + const Tensor& in, Tensor* out) { + // Create the axes (which are always trailing). + const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); + auto device = ctx->eigen_device(); + auto input = Tensor(in).flat_inner_dims(); + const auto input_dims = input.dimensions(); + + // Slice input to fft_shape on its inner-most dimensions. + Eigen::DSizes input_slice_sizes; + input_slice_sizes[0] = input_dims[0]; + TensorShape temp_shape{input_dims[0]}; + for (int i = 1; i <= FFTRank; ++i) { + input_slice_sizes[i] = fft_shape[i - 1]; + temp_shape.AddDim(fft_shape[i - 1]); + } + + auto output = out->flat_inner_dims(); + const Eigen::DSizes zero_start_indices; + + // Compute the full FFT using a temporary tensor. + Tensor temp; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::v(), + temp_shape, &temp)); + auto full_fft = temp.flat_inner_dims(); + full_fft.device(device) = + input.slice(zero_start_indices, input_slice_sizes) + .template fft(axes); + + // Slice away the negative frequency components. + output.device(device) = + full_fft.slice(zero_start_indices, output.dimensions()); + } + + template + void DoRealBackwardFFT(OpKernelContext* ctx, uint64* fft_shape, + const Tensor& in, Tensor* out) { + auto device = ctx->eigen_device(); + // Reconstruct the full FFT and take the inverse. + auto input = Tensor(in).flat_inner_dims(); + auto output = out->flat_inner_dims(); + const auto input_dims = input.dimensions(); + + // Calculate the shape of the temporary tensor for the full FFT and the + // region we will slice from input given fft_shape. We slice input to + // fft_shape on its inner-most dimensions, except the last (which we + // slice to fft_shape[-1] / 2 + 1). + Eigen::DSizes input_slice_sizes; + input_slice_sizes[0] = input_dims[0]; + TensorShape full_fft_shape; + full_fft_shape.AddDim(input_dims[0]); + for (auto i = 1; i <= FFTRank; i++) { + input_slice_sizes[i] = + i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1]; + full_fft_shape.AddDim(fft_shape[i - 1]); + } + + Tensor temp; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::v(), + full_fft_shape, &temp)); + auto full_fft = temp.flat_inner_dims(); + + // Calculate the starting point and range of the source of + // negative frequency part. + auto neg_sizes = input_slice_sizes; + neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank]; + Eigen::DSizes neg_target_indices; + neg_target_indices[FFTRank] = input_slice_sizes[FFTRank]; + + const Eigen::DSizes start_indices; + Eigen::DSizes neg_start_indices; + neg_start_indices[FFTRank] = 1; + + full_fft.slice(start_indices, input_slice_sizes).device(device) = + input.slice(start_indices, input_slice_sizes); + + // First, conduct IFFTs on outer dimensions. We save computation (and + // avoid touching uninitialized memory) by slicing full_fft to the + // subregion we wrote input to. + if (FFTRank > 1) { + const auto outer_axes = + Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1); + full_fft.slice(start_indices, input_slice_sizes).device(device) = + full_fft.slice(start_indices, input_slice_sizes) + .template fft(outer_axes); + } + + // Reconstruct the full FFT by appending reversed and conjugated + // spectrum as the negative frequency part. + Eigen::array reverse_last_axis; + for (auto i = 0; i <= FFTRank; i++) { + reverse_last_axis[i] = i == FFTRank; + } + + if (neg_sizes[FFTRank] != 0) { + full_fft.slice(neg_target_indices, neg_sizes).device(device) = + full_fft.slice(neg_start_indices, neg_sizes) + .reverse(reverse_last_axis) + .conjugate(); + } + + auto inner_axis = Eigen::array{FFTRank}; + output.device(device) = + full_fft.template fft(inner_axis); + } }; REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU), FFTCPU); @@ -390,16 +451,19 @@ class FFTGPUBase : public FFTBase { } constexpr bool kInPlaceFft = false; - const bool is_complex128 = in.dtype() == DT_COMPLEX128; - // complex128 real FFT is not supported yet. - DCHECK(!IsReal() || !is_complex128); + const bool is_complex128 = + in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128; const auto kFftType = - IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R) - : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward - : se::fft::Type::kC2CForward) - : (is_complex128 ? se::fft::Type::kZ2ZInverse - : se::fft::Type::kC2CInverse)); + IsReal() + ? (IsForward() + ? (is_complex128 ? se::fft::Type::kD2Z : se::fft::Type::kR2C) + : (is_complex128 ? se::fft::Type::kZ2D + : se::fft::Type::kC2R)) + : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward + : se::fft::Type::kC2CForward) + : (is_complex128 ? se::fft::Type::kZ2ZInverse + : se::fft::Type::kC2CInverse)); CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx); auto plan = @@ -410,67 +474,80 @@ class FFTGPUBase : public FFTBase { if (IsReal()) { if (IsForward()) { - auto src = AsDeviceMemory(in.flat().data()); - auto dst = AsDeviceMemory(out->flat().data()); - OP_REQUIRES( - ctx, stream->ThenFft(plan.get(), src, &dst).ok(), - errors::Internal("fft failed : type=", static_cast(kFftType), - " in.shape=", input_shape.DebugString())); + if (is_complex128) { + DCHECK_EQ(in.dtype(), DT_DOUBLE); + DCHECK_EQ(out->dtype(), DT_COMPLEX128); + DoFFTInternal(ctx, stream, plan.get(), kFftType, + output_distance, in, out); + } else { + DCHECK_EQ(in.dtype(), DT_FLOAT); + DCHECK_EQ(out->dtype(), DT_COMPLEX64); + DoFFTInternal(ctx, stream, plan.get(), kFftType, + output_distance, in, out); + } } else { - auto src = AsDeviceMemory(in.flat().data()); - auto dst = AsDeviceMemory(out->flat().data()); - OP_REQUIRES( - ctx, stream->ThenFft(plan.get(), src, &dst).ok(), - errors::Internal("fft failed : type=", static_cast(kFftType), - " in.shape=", input_shape.DebugString())); - auto alpha = 1.f / output_distance; - OP_REQUIRES( - ctx, - stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) - .ok(), - errors::Internal("BlasScal failed : in.shape=", - input_shape.DebugString())); + if (is_complex128) { + DCHECK_EQ(in.dtype(), DT_COMPLEX128); + DCHECK_EQ(out->dtype(), DT_DOUBLE); + DoFFTInternal(ctx, stream, plan.get(), kFftType, + output_distance, in, out); + } else { + DCHECK_EQ(in.dtype(), DT_COMPLEX64); + DCHECK_EQ(out->dtype(), DT_FLOAT); + DoFFTInternal(ctx, stream, plan.get(), kFftType, + output_distance, in, out); + } } } else { - if (!is_complex128) { - DCHECK_EQ(in.dtype(), DT_COMPLEX64); - DCHECK_EQ(out->dtype(), DT_COMPLEX64); - auto src = AsDeviceMemory(in.flat().data()); - auto dst = AsDeviceMemory(out->flat().data()); - OP_REQUIRES( - ctx, stream->ThenFft(plan.get(), src, &dst).ok(), - errors::Internal("fft failed : type=", static_cast(kFftType), - " in.shape=", input_shape.DebugString())); - if (!IsForward()) { - float alpha = 1.f / output_distance; - OP_REQUIRES( - ctx, - stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) - .ok(), - errors::Internal("BlasScal failed : in.shape=", - input_shape.DebugString())); - } - } else { + if (is_complex128) { DCHECK_EQ(in.dtype(), DT_COMPLEX128); DCHECK_EQ(out->dtype(), DT_COMPLEX128); - auto src = AsDeviceMemory(in.flat().data()); - auto dst = AsDeviceMemory(out->flat().data()); - OP_REQUIRES( - ctx, stream->ThenFft(plan.get(), src, &dst).ok(), - errors::Internal("fft failed : type=", static_cast(kFftType), - " in.shape=", input_shape.DebugString())); - if (!IsForward()) { - double alpha = 1.0 / output_distance; - OP_REQUIRES( - ctx, - stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) - .ok(), - errors::Internal("BlasScal failed : in.shape=", - input_shape.DebugString())); - } + DoFFTInternal(ctx, stream, plan.get(), kFftType, + output_distance, in, out); + } else { + DCHECK_EQ(in.dtype(), DT_COMPLEX64); + DCHECK_EQ(out->dtype(), DT_COMPLEX64); + DoFFTInternal(ctx, stream, plan.get(), kFftType, + output_distance, in, out); } } } + + private: + template + struct RealTypeFromComplexType { + typedef T RealT; + }; + + template + struct RealTypeFromComplexType> { + typedef T RealT; + }; + + template + void DoFFTInternal(OpKernelContext* ctx, se::Stream* stream, + se::fft::Plan* plan, const se::fft::Type fft_type, + const uint64 output_distance, const Tensor& in, + Tensor* out) { + auto src = AsDeviceMemory(in.flat().data()); + auto dst = AsDeviceMemory(out->flat().data()); + const TensorShape& input_shape = in.shape(); + const TensorShape& output_shape = out->shape(); + OP_REQUIRES( + ctx, stream->ThenFft(plan, src, &dst).ok(), + errors::Internal("fft failed : type=", static_cast(fft_type), + " in.shape=", input_shape.DebugString())); + if (!IsForward()) { + typedef typename RealTypeFromComplexType::RealT RealT; + RealT alpha = 1.0 / output_distance; + OP_REQUIRES( + ctx, + stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) + .ok(), + errors::Internal("BlasScal failed : in.shape=", + input_shape.DebugString())); + } + } }; int64 FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit( diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc index b1ae7040f02..3b9b962143b 100644 --- a/tensorflow/core/ops/spectral_ops.cc +++ b/tensorflow/core/ops/spectral_ops.cc @@ -109,39 +109,51 @@ Status RFFTShape(InferenceContext* c, const bool forward, const int rank) { } REGISTER_OP("RFFT") - .Input("input: float") + .Input("input: Treal") .Input("fft_length: int32") - .Output("output: complex64") + .Output("output: Tcomplex") + .Attr("Treal: {float32, float64} = DT_FLOAT") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 1); }); REGISTER_OP("IRFFT") - .Input("input: complex64") + .Input("input: Tcomplex") .Input("fft_length: int32") - .Output("output: float") + .Output("output: Treal") + .Attr("Treal: {float32, float64} = DT_FLOAT") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 1); }); REGISTER_OP("RFFT2D") - .Input("input: float") + .Input("input: Treal") .Input("fft_length: int32") - .Output("output: complex64") + .Output("output: Tcomplex") + .Attr("Treal: {float32, float64} = DT_FLOAT") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 2); }); REGISTER_OP("IRFFT2D") - .Input("input: complex64") + .Input("input: Tcomplex") .Input("fft_length: int32") - .Output("output: float") + .Output("output: Treal") + .Attr("Treal: {float32, float64} = DT_FLOAT") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 2); }); REGISTER_OP("RFFT3D") - .Input("input: float") + .Input("input: Treal") .Input("fft_length: int32") - .Output("output: complex64") + .Output("output: Tcomplex") + .Attr("Treal: {float32, float64} = DT_FLOAT") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 3); }); REGISTER_OP("IRFFT3D") - .Input("input: complex64") + .Input("input: Tcomplex") .Input("fft_length: int32") - .Output("output: float") + .Output("output: Treal") + .Attr("Treal: {float32, float64} = DT_FLOAT") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 3); }); // Deprecated ops: diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 54ab48053e9..86b5b9ebd8b 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -48,7 +48,6 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.signal import fft_ops from tensorflow.python.ops import variables from tensorflow.python.training import training from tensorflow.python.util import nest @@ -1579,23 +1578,6 @@ class BackpropTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp(ValueError, 'ndarray'): g.watch(np.array(1.)) - def testOpWithNoAttrs(self): - - @function.defun(autograph=False) - def f(): - with backprop.GradientTape() as tape: - xs = random_ops.random_normal([10, 32]) - tape.watch(xs) - # The `rfft()` op has no defined attrs, which exercises a different - # branch in the Python op wrapper code generator for recording - # gradients. - ys = fft_ops.rfft(xs) - self.assertEmpty(ys.op.node_def.attr) - gs = tape.gradient(ys, xs) - self.assertIsNotNone(gs) - - f.get_concrete_function() - class JacobianTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/signal/BUILD b/tensorflow/python/kernel_tests/signal/BUILD index 3806783ca11..b260fff573e 100644 --- a/tensorflow/python/kernel_tests/signal/BUILD +++ b/tensorflow/python/kernel_tests/signal/BUILD @@ -37,6 +37,7 @@ cuda_py_tests( cuda_py_tests( name = "fft_ops_test", + # TODO(rjryan): Parameterize the test to reduce the time it takes. size = "medium", srcs = ["fft_ops_test.py"], additional_deps = [ @@ -46,7 +47,7 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python/ops/signal", ], - shard_count = 4, + shard_count = 8, tags = [ "no_rocm", "optonly", diff --git a/tensorflow/python/kernel_tests/signal/fft_ops_test.py b/tensorflow/python/kernel_tests/signal/fft_ops_test.py index 5745da73045..be7e5c18abc 100644 --- a/tensorflow/python/kernel_tests/signal/fft_ops_test.py +++ b/tensorflow/python/kernel_tests/signal/fft_ops_test.py @@ -18,10 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib + import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -36,6 +39,18 @@ from tensorflow.python.platform import test VALID_FFT_RANKS = (1, 2, 3) +def _forward_compat_context(np_dtype): + @contextlib.contextmanager + def null_context(): + yield + if np_dtype in (np.float64, np.complex128): + return compat.forward_compatibility_horizon(2019, 10, 13) + else: + return null_context() + + +# TODO(rjryan): Investigate precision issues. We should be able to achieve +# better tolerances, at least for the complex128 tests. class BaseFFTOpsTest(test.TestCase): def _compare(self, x, rank, fft_length=None, use_placeholder=False, @@ -84,8 +99,9 @@ class BaseFFTOpsTest(test.TestCase): loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z))) return loss - ((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = ( - gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2)) + with _forward_compat_context(x.dtype): + ((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = ( + gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2)) self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol) self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol) @@ -99,8 +115,9 @@ class BaseFFTOpsTest(test.TestCase): loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z))) return loss - (x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient( - f, [x], delta=1e-2) + with _forward_compat_context(x.dtype): + (x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient( + f, [x], delta=1e-2) self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol) @@ -174,13 +191,12 @@ class FFTOpsTest(BaseFFTOpsTest): (4,) * dims).astype(np_type), rank, rtol=tol, atol=tol) def test_large_batch(self): - if test.is_gpu_available(cuda_only=True): - rank = 1 - for dims in xrange(rank, rank + 3): - for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-5)): - self._compare( - np.mod(np.arange(np.power(128, dims)), 10).reshape( - (128,) * dims).astype(np_type), rank, rtol=tol, atol=tol) + rank = 1 + for dims in xrange(rank, rank + 3): + for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 5e-5)): + self._compare( + np.mod(np.arange(np.power(128, dims)), 10).reshape( + (128,) * dims).astype(np_type), rank, rtol=tol, atol=tol) # TODO(yangzihao): Disable before we can figure out a way to # properly test memory fail for large batch fft. @@ -277,17 +293,15 @@ class FFTOpsTest(BaseFFTOpsTest): @test_util.run_all_in_graph_and_eager_modes class RFFTOpsTest(BaseFFTOpsTest): - def _compare_backward(self, x, rank, fft_length=None, use_placeholder=False): - super(RFFTOpsTest, self)._compare_backward(x, rank, fft_length, - use_placeholder) - def _tf_fft(self, x, rank, fft_length=None, feed_dict=None): - with self.cached_session(use_gpu=True) as sess: + with _forward_compat_context(x.dtype), self.cached_session( + use_gpu=True) as sess: return sess.run( self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict) def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None): - with self.cached_session(use_gpu=True) as sess: + with _forward_compat_context(x.dtype), self.cached_session( + use_gpu=True) as sess: return sess.run( self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict) @@ -332,61 +346,74 @@ class RFFTOpsTest(BaseFFTOpsTest): raise ValueError("invalid rank") def test_empty(self): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - x = np.zeros((0,) * dims).astype(np.float32) - self.assertEqual(x.shape, self._tf_fft(x, rank).shape) - x = np.zeros((0,) * dims).astype(np.complex64) - self.assertEqual(x.shape, self._tf_ifft(x, rank).shape) + for np_rtype, np_ctype in ((np.float32, np.complex64), + (np.float64, np.complex128)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + x = np.zeros((0,) * dims).astype(np_rtype) + self.assertEqual(x.shape, self._tf_fft(x, rank).shape) + x = np.zeros((0,) * dims).astype(np_ctype) + self.assertEqual(x.shape, self._tf_ifft(x, rank).shape) def test_basic(self): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - for size in (5, 6): - inner_dim = size // 2 + 1 - r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( - (size,) * dims) - self._compare_forward(r2c.astype(np.float32), rank, (size,) * rank) - c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), - 10).reshape((size,) * (dims - 1) + (inner_dim,)) - self._compare_backward( - c2r.astype(np.complex64), rank, (size,) * rank) + for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4), + (np.float64, np.complex128, 5e-5)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + for size in (5, 6): + inner_dim = size // 2 + 1 + r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( + (size,) * dims) + self._compare_forward(r2c.astype(np_rtype), rank, (size,) * rank, + rtol=tol, atol=tol) + c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), + 10).reshape((size,) * (dims - 1) + (inner_dim,)) + self._compare_backward( + c2r.astype(np_ctype), rank, (size,) * rank, rtol=tol, atol=tol) def test_large_batch(self): - if test.is_gpu_available(cuda_only=True): - rank = 1 + rank = 1 + for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4), + (np.float64, np.complex128, 1e-5)): for dims in xrange(rank, rank + 3): for size in (64, 128): inner_dim = size // 2 + 1 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( (size,) * dims) - self._compare_forward(r2c.astype(np.float32), rank, (size,) * rank) + self._compare_forward(r2c.astype(np_rtype), rank, (size,) * rank, + rtol=tol, atol=tol) c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 10).reshape((size,) * (dims - 1) + (inner_dim,)) - self._compare_backward(c2r.astype(np.complex64), rank, (size,) * rank) + self._compare_backward(c2r.astype(np_ctype), rank, (size,) * rank, + rtol=tol, atol=tol) def test_placeholder(self): if context.executing_eagerly(): return - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - for size in (5, 6): - inner_dim = size // 2 + 1 - r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( - (size,) * dims) - self._compare_forward( - r2c.astype(np.float32), - rank, (size,) * rank, - use_placeholder=True) - c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), - 10).reshape((size,) * (dims - 1) + (inner_dim,)) - self._compare_backward( - c2r.astype(np.complex64), - rank, (size,) * rank, - use_placeholder=True) + for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4), + (np.float64, np.complex128, 1e-8)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + for size in (5, 6): + inner_dim = size // 2 + 1 + r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( + (size,) * dims) + self._compare_forward( + r2c.astype(np_rtype), + rank, (size,) * rank, + use_placeholder=True, + rtol=tol, atol=tol) + c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), + 10).reshape((size,) * (dims - 1) + (inner_dim,)) + self._compare_backward( + c2r.astype(np_ctype), + rank, (size,) * rank, + use_placeholder=True, + rtol=tol, atol=tol) - def test_fft_lenth(self): - if test.is_gpu_available(cuda_only=True): + def test_fft_length(self): + for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4), + (np.float64, np.complex128, 8e-5)): for rank in VALID_FFT_RANKS: for dims in xrange(rank, rank + 3): for size in (5, 6): @@ -397,36 +424,44 @@ class RFFTOpsTest(BaseFFTOpsTest): 10).reshape((size,) * (dims - 1) + (inner_dim,)) # Test truncation (FFT size < dimensions). fft_length = (size - 2,) * rank - self._compare_forward(r2c.astype(np.float32), rank, fft_length) - self._compare_backward(c2r.astype(np.complex64), rank, fft_length) + self._compare_forward(r2c.astype(np_rtype), rank, fft_length, + rtol=tol, atol=tol) + self._compare_backward(c2r.astype(np_ctype), rank, fft_length, + rtol=tol, atol=tol) # Confirm it works with unknown shapes as well. if not context.executing_eagerly(): self._compare_forward( - r2c.astype(np.float32), + r2c.astype(np_rtype), rank, fft_length, - use_placeholder=True) + use_placeholder=True, + rtol=tol, atol=tol) self._compare_backward( - c2r.astype(np.complex64), + c2r.astype(np_ctype), rank, fft_length, - use_placeholder=True) + use_placeholder=True, + rtol=tol, atol=tol) # Test padding (FFT size > dimensions). fft_length = (size + 2,) * rank - self._compare_forward(r2c.astype(np.float32), rank, fft_length) - self._compare_backward(c2r.astype(np.complex64), rank, fft_length) + self._compare_forward(r2c.astype(np_rtype), rank, fft_length, + rtol=tol, atol=tol) + self._compare_backward(c2r.astype(np_ctype), rank, fft_length, + rtol=tol, atol=tol) # Confirm it works with unknown shapes as well. if not context.executing_eagerly(): self._compare_forward( - r2c.astype(np.float32), + r2c.astype(np_rtype), rank, fft_length, - use_placeholder=True) + use_placeholder=True, + rtol=tol, atol=tol) self._compare_backward( - c2r.astype(np.complex64), + c2r.astype(np_ctype), rank, fft_length, - use_placeholder=True) + use_placeholder=True, + rtol=tol, atol=tol) def test_random(self): def gen_real(shape): @@ -442,14 +477,20 @@ class RFFTOpsTest(BaseFFTOpsTest): ret = (re + im * 1j).reshape(shape) return ret - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - for size in (5, 6): - inner_dim = size // 2 + 1 - self._compare_forward(gen_real((size,) * dims), rank, (size,) * rank) - complex_dims = (size,) * (dims - 1) + (inner_dim,) - self._compare_backward( - gen_complex(complex_dims), rank, (size,) * rank) + for np_rtype, np_ctype, tol in ((np.float32, np.complex64, 1e-4), + (np.float64, np.complex128, 1e-5)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + for size in (5, 6): + inner_dim = size // 2 + 1 + self._compare_forward(gen_real((size,) * dims).astype(np_rtype), + rank, (size,) * rank, + rtol=tol, atol=tol) + complex_dims = (size,) * (dims - 1) + (inner_dim,) + self._compare_backward( + gen_complex(complex_dims).astype(np_ctype), + rank, (size,) * rank, + rtol=tol, atol=tol) def test_error(self): # TODO(rjryan): Fix this test under Eager. @@ -510,30 +551,36 @@ class RFFTOpsTest(BaseFFTOpsTest): self.evaluate(irfft_fn(x, fft_length)) def test_grad_simple(self): - for rank in VALID_FFT_RANKS: - # rfft3d/irfft3d do not have gradients yet. - if rank == 3: - continue - for dims in xrange(rank, rank + 2): - for size in (5, 6): - re = np.ones(shape=(size,) * dims, dtype=np.float32) - im = -np.ones(shape=(size,) * dims, dtype=np.float32) - self._check_grad_real(self._tf_fft_for_rank(rank), re) - self._check_grad_complex( - self._tf_ifft_for_rank(rank), re, im, result_is_complex=False) + for np_rtype, tol in ((np.float32, 1e-3), (np.float64, 1e-10)): + for rank in VALID_FFT_RANKS: + # rfft3d/irfft3d do not have gradients yet. + if rank == 3: + continue + for dims in xrange(rank, rank + 2): + for size in (5, 6): + re = np.ones(shape=(size,) * dims, dtype=np_rtype) + im = -np.ones(shape=(size,) * dims, dtype=np_rtype) + self._check_grad_real(self._tf_fft_for_rank(rank), re, + rtol=tol, atol=tol) + self._check_grad_complex( + self._tf_ifft_for_rank(rank), re, im, result_is_complex=False, + rtol=tol, atol=tol) def test_grad_random(self): - for rank in VALID_FFT_RANKS: - # rfft3d/irfft3d do not have gradients yet. - if rank == 3: - continue - for dims in xrange(rank, rank + 2): - for size in (5, 6): - re = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1 - im = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1 - self._check_grad_real(self._tf_fft_for_rank(rank), re) - self._check_grad_complex( - self._tf_ifft_for_rank(rank), re, im, result_is_complex=False) + for np_rtype, tol in ((np.float32, 1e-2), (np.float64, 1e-10)): + for rank in VALID_FFT_RANKS: + # rfft3d/irfft3d do not have gradients yet. + if rank == 3: + continue + for dims in xrange(rank, rank + 2): + for size in (5, 6): + re = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1 + im = np.random.rand(*((size,) * dims)).astype(np_rtype) * 2 - 1 + self._check_grad_real(self._tf_fft_for_rank(rank), re, + rtol=tol, atol=tol) + self._check_grad_complex( + self._tf_ifft_for_rank(rank), re, im, result_is_complex=False, + rtol=tol, atol=tol) @test_util.run_all_in_graph_and_eager_modes diff --git a/tensorflow/python/ops/signal/fft_ops.py b/tensorflow/python/ops/signal/fft_ops.py index 0e18c217fc5..cfe7799ac57 100644 --- a/tensorflow/python/ops/signal/fft_ops.py +++ b/tensorflow/python/ops/signal/fft_ops.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import tensor_util as _tensor_util @@ -115,14 +116,23 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name): """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" with _ops.name_scope(name, default_name, [input_tensor, fft_length]) as name: - input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.float32) + input_tensor = _ops.convert_to_tensor(input_tensor, + preferred_dtype=_dtypes.float32) + real_dtype = input_tensor.dtype + if real_dtype == _dtypes.float32: + complex_dtype = _dtypes.complex64 + elif real_dtype == _dtypes.float64: + complex_dtype = _dtypes.complex128 input_tensor.shape.with_rank_at_least(fft_rank) if fft_length is None: fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank) else: fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length) - return fft_fn(input_tensor, fft_length, name) + + if not compat.forward_compatible(2019, 10, 12): + return fft_fn(input_tensor, fft_length, name=name) + return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name) _rfft.__doc__ = fft_fn.__doc__ return _rfft @@ -134,15 +144,20 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name): """Wrapper irfft* that infers fft_length argument.""" with _ops.name_scope(name, default_name, [input_tensor, fft_length]) as name: - input_tensor = _ops.convert_to_tensor(input_tensor, _dtypes.complex64) + input_tensor = _ops.convert_to_tensor(input_tensor, + preferred_dtype=_dtypes.complex64) input_tensor.shape.with_rank_at_least(fft_rank) + complex_dtype = input_tensor.dtype + real_dtype = complex_dtype.real_dtype if fft_length is None: fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank) else: fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=True) - return ifft_fn(input_tensor, fft_length, name) + if not compat.forward_compatible(2019, 10, 12): + return ifft_fn(input_tensor, fft_length, name=name) + return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name) _irfft.__doc__ = ifft_fn.__doc__ return _irfft @@ -223,8 +238,10 @@ def _rfft_grad_helper(rank, irfft_fn): def _grad(op, grad): """A gradient function for RFFT with the provided `rank` and `irfft_fn`.""" fft_length = op.inputs[1] + complex_dtype = grad.dtype + real_dtype = complex_dtype.real_dtype input_shape = _array_ops.shape(op.inputs[0]) - is_even = _math_ops.cast(1 - (fft_length[-1] % 2), _dtypes.complex64) + is_even = _math_ops.cast(1 - (fft_length[-1] % 2), complex_dtype) def _tile_for_broadcasting(matrix, t): expanded = _array_ops.reshape( @@ -248,13 +265,13 @@ def _rfft_grad_helper(rank, irfft_fn): _array_ops.expand_dims(_math_ops.range(length), 0), (length, 1)) b = _array_ops.transpose(a, [1, 0]) return _math_ops.exp( - -2j * np.pi * _math_ops.cast(a * b, _dtypes.complex64) / - _math_ops.cast(length, _dtypes.complex64)) + -2j * np.pi * _math_ops.cast(a * b, complex_dtype) / + _math_ops.cast(length, complex_dtype)) def _ymask(length): """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`.""" return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2), - _dtypes.complex64) + complex_dtype) y0 = grad[..., 0:1] if rank == 1: @@ -288,7 +305,7 @@ def _rfft_grad_helper(rank, irfft_fn): # factor, plus some additional terms to make up for the components dropped # due to Hermitian symmetry. input_size = _math_ops.cast( - _fft_size_for_grad(op.inputs[0], rank), _dtypes.float32) + _fft_size_for_grad(op.inputs[0], rank), real_dtype) the_irfft = irfft_fn(grad, fft_length) return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None @@ -307,21 +324,27 @@ def _irfft_grad_helper(rank, rfft_fn): # graph we special-case the situation where the FFT length and last # dimension of the input are known at graph construction time. fft_length = op.inputs[1] + real_dtype = grad.dtype + if real_dtype == _dtypes.float32: + complex_dtype = _dtypes.complex64 + elif real_dtype == _dtypes.float64: + complex_dtype = _dtypes.complex128 is_odd = _math_ops.mod(fft_length[-1], 2) input_last_dimension = _array_ops.shape(op.inputs[0])[-1] mask = _array_ops.concat( - [[1.0], 2.0 * _array_ops.ones([input_last_dimension - 2 + is_odd]), - _array_ops.ones([1 - is_odd])], 0) + [[1.0], 2.0 * _array_ops.ones( + [input_last_dimension - 2 + is_odd], real_dtype), + _array_ops.ones([1 - is_odd], real_dtype)], 0) rsize = _math_ops.reciprocal(_math_ops.cast( - _fft_size_for_grad(grad, rank), _dtypes.float32)) + _fft_size_for_grad(grad, rank), real_dtype)) # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling # factor and a mask. The mask scales the gradient for the Hermitian # symmetric components of the RFFT by a factor of two, since these # components are de-duplicated in the RFFT. the_rfft = rfft_fn(grad, fft_length) - return the_rfft * _math_ops.cast(rsize * mask, _dtypes.complex64), None + return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None return _grad diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index ba8bb05df48..f6d953d5df7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1626,15 +1626,15 @@ tf_module { } member_method { name: "IRFFT" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "IRFFT2D" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "IRFFT3D" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "Identity" @@ -2858,15 +2858,15 @@ tf_module { } member_method { name: "RFFT" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "RFFT2D" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "RFFT3D" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "RGBToHSV" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index ba8bb05df48..f6d953d5df7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1626,15 +1626,15 @@ tf_module { } member_method { name: "IRFFT" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "IRFFT2D" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "IRFFT3D" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Treal\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "Identity" @@ -2858,15 +2858,15 @@ tf_module { } member_method { name: "RFFT" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "RFFT2D" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "RFFT3D" - argspec: "args=[\'input\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'fft_length\', \'Tcomplex\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "RGBToHSV"