Remove forward compatibility guards for float64/complex128 support since the time window has passed.
Also, add type checks for tf.signal.rfft/irfft to ensure correct input types are provided. PiperOrigin-RevId: 277150624 Change-Id: I8bd5a0fbb6d4ce90351ccbf8ea2b094689fe0da5
This commit is contained in:
parent
8c5f1b7943
commit
7e31db966d
@ -24,7 +24,6 @@ import itertools
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.signal import dct_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -169,14 +168,13 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
|
||||
# "ortho" normalization is not implemented for type I.
|
||||
if dct_type == 1 and norm == "ortho":
|
||||
return
|
||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
||||
with self.session(use_gpu=True):
|
||||
tol = 5e-4 if dtype == np.float32 else 1e-7
|
||||
signals = np.random.rand(*shape).astype(dtype)
|
||||
n = np.random.randint(1, 2 * signals.shape[-1])
|
||||
n = np.random.choice([None, n])
|
||||
self._compare(signals, n, norm=norm, dct_type=dct_type,
|
||||
rtol=tol, atol=tol)
|
||||
with self.session(use_gpu=True):
|
||||
tol = 5e-4 if dtype == np.float32 else 1e-7
|
||||
signals = np.random.rand(*shape).astype(dtype)
|
||||
n = np.random.randint(1, 2 * signals.shape[-1])
|
||||
n = np.random.choice([None, n])
|
||||
self._compare(signals, n, norm=norm, dct_type=dct_type,
|
||||
rtol=tol, atol=tol)
|
||||
|
||||
def test_error(self):
|
||||
signals = np.random.rand(10)
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import itertools
|
||||
|
||||
from absl.testing import parameterized
|
||||
@ -26,7 +25,6 @@ import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -41,16 +39,6 @@ from tensorflow.python.platform import test
|
||||
VALID_FFT_RANKS = (1, 2, 3)
|
||||
|
||||
|
||||
def _forward_compat_context(np_dtype):
|
||||
@contextlib.contextmanager
|
||||
def null_context():
|
||||
yield
|
||||
if np_dtype in (np.float64, np.complex128):
|
||||
return compat.forward_compatibility_horizon(2019, 10, 13)
|
||||
else:
|
||||
return null_context()
|
||||
|
||||
|
||||
# TODO(rjryan): Investigate precision issues. We should be able to achieve
|
||||
# better tolerances, at least for the complex128 tests.
|
||||
class BaseFFTOpsTest(test.TestCase):
|
||||
@ -101,9 +89,8 @@ class BaseFFTOpsTest(test.TestCase):
|
||||
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
|
||||
return loss
|
||||
|
||||
with _forward_compat_context(x.dtype):
|
||||
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
|
||||
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
|
||||
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
|
||||
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
|
||||
|
||||
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol)
|
||||
@ -117,9 +104,8 @@ class BaseFFTOpsTest(test.TestCase):
|
||||
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
|
||||
return loss
|
||||
|
||||
with _forward_compat_context(x.dtype):
|
||||
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
|
||||
f, [x], delta=1e-2)
|
||||
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
|
||||
f, [x], delta=1e-2)
|
||||
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@ -301,14 +287,12 @@ class FFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
|
||||
class RFFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
|
||||
|
||||
def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
|
||||
with _forward_compat_context(x.dtype), self.cached_session(
|
||||
use_gpu=True) as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
return sess.run(
|
||||
self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
|
||||
|
||||
def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
|
||||
with _forward_compat_context(x.dtype), self.cached_session(
|
||||
use_gpu=True) as sess:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
return sess.run(
|
||||
self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
|
||||
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -46,9 +45,8 @@ class MFCCTest(test.TestCase, parameterized.TestCase):
|
||||
@parameterized.parameters(dtypes.float32, dtypes.float64)
|
||||
def test_basic(self, dtype):
|
||||
"""A basic test that the op runs on random input."""
|
||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
||||
signal = random_ops.random_normal((2, 3, 5), dtype=dtype)
|
||||
self.evaluate(mfcc_ops.mfccs_from_log_mel_spectrograms(signal))
|
||||
signal = random_ops.random_normal((2, 3, 5), dtype=dtype)
|
||||
self.evaluate(mfcc_ops.mfccs_from_log_mel_spectrograms(signal))
|
||||
|
||||
def test_unknown_shape(self):
|
||||
"""A test that the op runs when shape and rank are unknown."""
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -165,10 +164,8 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
|
||||
def test_stft_and_inverse_stft(self, signal_length, frame_length,
|
||||
frame_step, fft_length, np_rtype, tol):
|
||||
"""Test that spectral_ops.stft/inverse_stft match a NumPy implementation."""
|
||||
# Enable float64 support for RFFTs.
|
||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
||||
signal = np.random.random(signal_length).astype(np_rtype)
|
||||
self._compare(signal, frame_length, frame_step, fft_length, tol)
|
||||
signal = np.random.random(signal_length).astype(np_rtype)
|
||||
self._compare(signal, frame_length, frame_step, fft_length, tol)
|
||||
|
||||
@parameterized.parameters(
|
||||
# 87.5% overlap.
|
||||
@ -189,39 +186,37 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
|
||||
def test_stft_round_trip(self, signal_length, frame_length, frame_step,
|
||||
fft_length, np_rtype, threshold,
|
||||
corrected_threshold):
|
||||
# Enable float64 support for RFFTs.
|
||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
||||
# Generate a random white Gaussian signal.
|
||||
signal = np.random.normal(size=signal_length).astype(np_rtype)
|
||||
# Generate a random white Gaussian signal.
|
||||
signal = np.random.normal(size=signal_length).astype(np_rtype)
|
||||
|
||||
stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length,
|
||||
pad_end=False)
|
||||
inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step,
|
||||
fft_length)
|
||||
inverse_stft_corrected = spectral_ops.inverse_stft(
|
||||
stft, frame_length, frame_step, fft_length,
|
||||
window_fn=spectral_ops.inverse_stft_window_fn(frame_step))
|
||||
inverse_stft, inverse_stft_corrected = self.evaluate(
|
||||
[inverse_stft, inverse_stft_corrected])
|
||||
stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length,
|
||||
pad_end=False)
|
||||
inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step,
|
||||
fft_length)
|
||||
inverse_stft_corrected = spectral_ops.inverse_stft(
|
||||
stft, frame_length, frame_step, fft_length,
|
||||
window_fn=spectral_ops.inverse_stft_window_fn(frame_step))
|
||||
inverse_stft, inverse_stft_corrected = self.evaluate(
|
||||
[inverse_stft, inverse_stft_corrected])
|
||||
|
||||
# Truncate signal to the size of inverse stft.
|
||||
signal = signal[:inverse_stft.shape[0]]
|
||||
# Truncate signal to the size of inverse stft.
|
||||
signal = signal[:inverse_stft.shape[0]]
|
||||
|
||||
# Ignore the frame_length samples at either edge.
|
||||
signal = signal[frame_length:-frame_length]
|
||||
inverse_stft = inverse_stft[frame_length:-frame_length]
|
||||
inverse_stft_corrected = inverse_stft_corrected[
|
||||
frame_length:-frame_length]
|
||||
# Ignore the frame_length samples at either edge.
|
||||
signal = signal[frame_length:-frame_length]
|
||||
inverse_stft = inverse_stft[frame_length:-frame_length]
|
||||
inverse_stft_corrected = inverse_stft_corrected[
|
||||
frame_length:-frame_length]
|
||||
|
||||
# Check that the inverse and original signal are close up to a scale
|
||||
# factor.
|
||||
inverse_stft_scaled = inverse_stft / np.mean(np.abs(inverse_stft))
|
||||
signal_scaled = signal / np.mean(np.abs(signal))
|
||||
self.assertLess(np.std(inverse_stft_scaled - signal_scaled), threshold)
|
||||
# Check that the inverse and original signal are close up to a scale
|
||||
# factor.
|
||||
inverse_stft_scaled = inverse_stft / np.mean(np.abs(inverse_stft))
|
||||
signal_scaled = signal / np.mean(np.abs(signal))
|
||||
self.assertLess(np.std(inverse_stft_scaled - signal_scaled), threshold)
|
||||
|
||||
# Check that the inverse with correction and original signal are close.
|
||||
self.assertLess(np.std(inverse_stft_corrected - signal),
|
||||
corrected_threshold)
|
||||
# Check that the inverse with correction and original signal are close.
|
||||
self.assertLess(np.std(inverse_stft_corrected - signal),
|
||||
corrected_threshold)
|
||||
|
||||
@parameterized.parameters(
|
||||
(256, 32),
|
||||
@ -259,12 +254,10 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
|
||||
def _compute_stft_gradient(signal, frame_length=32, frame_step=16,
|
||||
fft_length=32):
|
||||
"""Computes the gradient of the STFT with respect to `signal`."""
|
||||
# Enable float64 support for RFFTs.
|
||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
||||
stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length)
|
||||
magnitude_stft = math_ops.abs(stft)
|
||||
loss = math_ops.reduce_sum(magnitude_stft)
|
||||
return gradients_impl.gradients([loss], [signal])[0]
|
||||
stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length)
|
||||
magnitude_stft = math_ops.abs(stft)
|
||||
loss = math_ops.reduce_sum(magnitude_stft)
|
||||
return gradients_impl.gradients([loss], [signal])[0]
|
||||
|
||||
def test_gradients(self):
|
||||
"""Test that spectral_ops.stft has a working gradient."""
|
||||
@ -301,28 +294,26 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
|
||||
(29, 5, 1, 10, np.float64, 1e-8, 1e-8))
|
||||
def test_gradients_numerical(self, signal_length, frame_length, frame_step,
|
||||
fft_length, np_rtype, forward_tol, backward_tol):
|
||||
# Enable float64 support for RFFTs.
|
||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
||||
# TODO(rjryan): Investigate why STFT gradient error is so high.
|
||||
signal = np.random.rand(signal_length).astype(np_rtype) * 2 - 1
|
||||
# TODO(rjryan): Investigate why STFT gradient error is so high.
|
||||
signal = np.random.rand(signal_length).astype(np_rtype) * 2 - 1
|
||||
|
||||
def forward(signal):
|
||||
return spectral_ops.stft(
|
||||
signal, frame_length, frame_step, fft_length, pad_end=False)
|
||||
((f_jacob_t,), (f_jacob_n,)) = gradient_checker_v2.compute_gradient(
|
||||
forward, [signal])
|
||||
self.assertAllClose(f_jacob_t, f_jacob_n,
|
||||
rtol=forward_tol, atol=forward_tol)
|
||||
def forward(signal):
|
||||
return spectral_ops.stft(
|
||||
signal, frame_length, frame_step, fft_length, pad_end=False)
|
||||
((f_jacob_t,), (f_jacob_n,)) = gradient_checker_v2.compute_gradient(
|
||||
forward, [signal])
|
||||
self.assertAllClose(f_jacob_t, f_jacob_n,
|
||||
rtol=forward_tol, atol=forward_tol)
|
||||
|
||||
def backward(stft):
|
||||
return spectral_ops.inverse_stft(
|
||||
stft, frame_length, frame_step, fft_length)
|
||||
def backward(stft):
|
||||
return spectral_ops.inverse_stft(
|
||||
stft, frame_length, frame_step, fft_length)
|
||||
|
||||
stft = forward(signal)
|
||||
((b_jacob_t,), (b_jacob_n,)) = gradient_checker_v2.compute_gradient(
|
||||
backward, [stft])
|
||||
self.assertAllClose(b_jacob_t, b_jacob_n,
|
||||
rtol=backward_tol, atol=backward_tol)
|
||||
stft = forward(signal)
|
||||
((b_jacob_t,), (b_jacob_n,)) = gradient_checker_v2.compute_gradient(
|
||||
backward, [stft])
|
||||
self.assertAllClose(b_jacob_t, b_jacob_n,
|
||||
rtol=backward_tol, atol=backward_tol)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import dtypes as _dtypes
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework import tensor_util as _tensor_util
|
||||
@ -118,10 +117,15 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
|
||||
[input_tensor, fft_length]) as name:
|
||||
input_tensor = _ops.convert_to_tensor(input_tensor,
|
||||
preferred_dtype=_dtypes.float32)
|
||||
if input_tensor.dtype not in (_dtypes.float32, _dtypes.float64):
|
||||
raise ValueError(
|
||||
"RFFT requires tf.float32 or tf.float64 inputs, got: %s" %
|
||||
input_tensor)
|
||||
real_dtype = input_tensor.dtype
|
||||
if real_dtype == _dtypes.float32:
|
||||
complex_dtype = _dtypes.complex64
|
||||
elif real_dtype == _dtypes.float64:
|
||||
else:
|
||||
assert real_dtype == _dtypes.float64
|
||||
complex_dtype = _dtypes.complex128
|
||||
input_tensor.shape.with_rank_at_least(fft_rank)
|
||||
if fft_length is None:
|
||||
@ -129,9 +133,6 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
|
||||
else:
|
||||
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
|
||||
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
|
||||
|
||||
if not compat.forward_compatible(2019, 10, 12):
|
||||
return fft_fn(input_tensor, fft_length, name=name)
|
||||
return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
|
||||
_rfft.__doc__ = fft_fn.__doc__
|
||||
return _rfft
|
||||
@ -147,6 +148,10 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
|
||||
input_tensor = _ops.convert_to_tensor(input_tensor,
|
||||
preferred_dtype=_dtypes.complex64)
|
||||
input_tensor.shape.with_rank_at_least(fft_rank)
|
||||
if input_tensor.dtype not in (_dtypes.complex64, _dtypes.complex128):
|
||||
raise ValueError(
|
||||
"IRFFT requires tf.complex64 or tf.complex128 inputs, got: %s" %
|
||||
input_tensor)
|
||||
complex_dtype = input_tensor.dtype
|
||||
real_dtype = complex_dtype.real_dtype
|
||||
if fft_length is None:
|
||||
@ -155,8 +160,6 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
|
||||
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
|
||||
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
|
||||
is_reverse=True)
|
||||
if not compat.forward_compatible(2019, 10, 12):
|
||||
return ifft_fn(input_tensor, fft_length, name=name)
|
||||
return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
|
||||
_irfft.__doc__ = ifft_fn.__doc__
|
||||
return _irfft
|
||||
|
Loading…
Reference in New Issue
Block a user