Remove forward compatibility guards for float64/complex128 support since the time window has passed.

Also, add type checks for tf.signal.rfft/irfft to ensure correct input types are provided.

PiperOrigin-RevId: 277150624
Change-Id: I8bd5a0fbb6d4ce90351ccbf8ea2b094689fe0da5
This commit is contained in:
RJ Skerry-Ryan 2019-10-28 15:02:26 -07:00 committed by TensorFlower Gardener
parent 8c5f1b7943
commit 7e31db966d
5 changed files with 74 additions and 100 deletions

View File

@ -24,7 +24,6 @@ import itertools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import test_util
from tensorflow.python.ops.signal import dct_ops
from tensorflow.python.platform import test
@ -169,7 +168,6 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
# "ortho" normalization is not implemented for type I.
if dct_type == 1 and norm == "ortho":
return
with compat.forward_compatibility_horizon(2019, 10, 13):
with self.session(use_gpu=True):
tol = 5e-4 if dtype == np.float32 else 1e-7
signals = np.random.rand(*shape).astype(dtype)

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import itertools
from absl.testing import parameterized
@ -26,7 +25,6 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -41,16 +39,6 @@ from tensorflow.python.platform import test
VALID_FFT_RANKS = (1, 2, 3)
def _forward_compat_context(np_dtype):
@contextlib.contextmanager
def null_context():
yield
if np_dtype in (np.float64, np.complex128):
return compat.forward_compatibility_horizon(2019, 10, 13)
else:
return null_context()
# TODO(rjryan): Investigate precision issues. We should be able to achieve
# better tolerances, at least for the complex128 tests.
class BaseFFTOpsTest(test.TestCase):
@ -101,7 +89,6 @@ class BaseFFTOpsTest(test.TestCase):
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
return loss
with _forward_compat_context(x.dtype):
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
@ -117,7 +104,6 @@ class BaseFFTOpsTest(test.TestCase):
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
return loss
with _forward_compat_context(x.dtype):
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
f, [x], delta=1e-2)
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
@ -301,14 +287,12 @@ class FFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
class RFFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
with _forward_compat_context(x.dtype), self.cached_session(
use_gpu=True) as sess:
with self.cached_session(use_gpu=True) as sess:
return sess.run(
self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
with _forward_compat_context(x.dtype), self.cached_session(
use_gpu=True) as sess:
with self.cached_session(use_gpu=True) as sess:
return sess.run(
self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
@ -46,7 +45,6 @@ class MFCCTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(dtypes.float32, dtypes.float64)
def test_basic(self, dtype):
"""A basic test that the op runs on random input."""
with compat.forward_compatibility_horizon(2019, 10, 13):
signal = random_ops.random_normal((2, 3, 5), dtype=dtype)
self.evaluate(mfcc_ops.mfccs_from_log_mel_spectrograms(signal))

View File

@ -21,7 +21,6 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@ -165,8 +164,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
def test_stft_and_inverse_stft(self, signal_length, frame_length,
frame_step, fft_length, np_rtype, tol):
"""Test that spectral_ops.stft/inverse_stft match a NumPy implementation."""
# Enable float64 support for RFFTs.
with compat.forward_compatibility_horizon(2019, 10, 13):
signal = np.random.random(signal_length).astype(np_rtype)
self._compare(signal, frame_length, frame_step, fft_length, tol)
@ -189,8 +186,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
def test_stft_round_trip(self, signal_length, frame_length, frame_step,
fft_length, np_rtype, threshold,
corrected_threshold):
# Enable float64 support for RFFTs.
with compat.forward_compatibility_horizon(2019, 10, 13):
# Generate a random white Gaussian signal.
signal = np.random.normal(size=signal_length).astype(np_rtype)
@ -259,8 +254,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
def _compute_stft_gradient(signal, frame_length=32, frame_step=16,
fft_length=32):
"""Computes the gradient of the STFT with respect to `signal`."""
# Enable float64 support for RFFTs.
with compat.forward_compatibility_horizon(2019, 10, 13):
stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length)
magnitude_stft = math_ops.abs(stft)
loss = math_ops.reduce_sum(magnitude_stft)
@ -301,8 +294,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
(29, 5, 1, 10, np.float64, 1e-8, 1e-8))
def test_gradients_numerical(self, signal_length, frame_length, frame_step,
fft_length, np_rtype, forward_tol, backward_tol):
# Enable float64 support for RFFTs.
with compat.forward_compatibility_horizon(2019, 10, 13):
# TODO(rjryan): Investigate why STFT gradient error is so high.
signal = np.random.rand(signal_length).astype(np_rtype) * 2 - 1

View File

@ -19,7 +19,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import tensor_util as _tensor_util
@ -118,10 +117,15 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
[input_tensor, fft_length]) as name:
input_tensor = _ops.convert_to_tensor(input_tensor,
preferred_dtype=_dtypes.float32)
if input_tensor.dtype not in (_dtypes.float32, _dtypes.float64):
raise ValueError(
"RFFT requires tf.float32 or tf.float64 inputs, got: %s" %
input_tensor)
real_dtype = input_tensor.dtype
if real_dtype == _dtypes.float32:
complex_dtype = _dtypes.complex64
elif real_dtype == _dtypes.float64:
else:
assert real_dtype == _dtypes.float64
complex_dtype = _dtypes.complex128
input_tensor.shape.with_rank_at_least(fft_rank)
if fft_length is None:
@ -129,9 +133,6 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
else:
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
if not compat.forward_compatible(2019, 10, 12):
return fft_fn(input_tensor, fft_length, name=name)
return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
_rfft.__doc__ = fft_fn.__doc__
return _rfft
@ -147,6 +148,10 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
input_tensor = _ops.convert_to_tensor(input_tensor,
preferred_dtype=_dtypes.complex64)
input_tensor.shape.with_rank_at_least(fft_rank)
if input_tensor.dtype not in (_dtypes.complex64, _dtypes.complex128):
raise ValueError(
"IRFFT requires tf.complex64 or tf.complex128 inputs, got: %s" %
input_tensor)
complex_dtype = input_tensor.dtype
real_dtype = complex_dtype.real_dtype
if fft_length is None:
@ -155,8 +160,6 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
is_reverse=True)
if not compat.forward_compatible(2019, 10, 12):
return ifft_fn(input_tensor, fft_length, name=name)
return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
_irfft.__doc__ = ifft_fn.__doc__
return _irfft