Remove forward compatibility guards for float64/complex128 support since the time window has passed.
Also, add type checks for tf.signal.rfft/irfft to ensure correct input types are provided. PiperOrigin-RevId: 277150624 Change-Id: I8bd5a0fbb6d4ce90351ccbf8ea2b094689fe0da5
This commit is contained in:
parent
8c5f1b7943
commit
7e31db966d
@ -24,7 +24,6 @@ import itertools
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops.signal import dct_ops
|
from tensorflow.python.ops.signal import dct_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -169,7 +168,6 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
|
|||||||
# "ortho" normalization is not implemented for type I.
|
# "ortho" normalization is not implemented for type I.
|
||||||
if dct_type == 1 and norm == "ortho":
|
if dct_type == 1 and norm == "ortho":
|
||||||
return
|
return
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
tol = 5e-4 if dtype == np.float32 else 1e-7
|
tol = 5e-4 if dtype == np.float32 else 1e-7
|
||||||
signals = np.random.rand(*shape).astype(dtype)
|
signals = np.random.rand(*shape).astype(dtype)
|
||||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
@ -26,7 +25,6 @@ import numpy as np
|
|||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -41,16 +39,6 @@ from tensorflow.python.platform import test
|
|||||||
VALID_FFT_RANKS = (1, 2, 3)
|
VALID_FFT_RANKS = (1, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
def _forward_compat_context(np_dtype):
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def null_context():
|
|
||||||
yield
|
|
||||||
if np_dtype in (np.float64, np.complex128):
|
|
||||||
return compat.forward_compatibility_horizon(2019, 10, 13)
|
|
||||||
else:
|
|
||||||
return null_context()
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(rjryan): Investigate precision issues. We should be able to achieve
|
# TODO(rjryan): Investigate precision issues. We should be able to achieve
|
||||||
# better tolerances, at least for the complex128 tests.
|
# better tolerances, at least for the complex128 tests.
|
||||||
class BaseFFTOpsTest(test.TestCase):
|
class BaseFFTOpsTest(test.TestCase):
|
||||||
@ -101,7 +89,6 @@ class BaseFFTOpsTest(test.TestCase):
|
|||||||
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
|
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
with _forward_compat_context(x.dtype):
|
|
||||||
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
|
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
|
||||||
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
|
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
|
||||||
|
|
||||||
@ -117,7 +104,6 @@ class BaseFFTOpsTest(test.TestCase):
|
|||||||
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
|
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
with _forward_compat_context(x.dtype):
|
|
||||||
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
|
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
|
||||||
f, [x], delta=1e-2)
|
f, [x], delta=1e-2)
|
||||||
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
|
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
|
||||||
@ -301,14 +287,12 @@ class FFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
|
|||||||
class RFFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
|
class RFFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
|
||||||
|
|
||||||
def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
|
def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
|
||||||
with _forward_compat_context(x.dtype), self.cached_session(
|
with self.cached_session(use_gpu=True) as sess:
|
||||||
use_gpu=True) as sess:
|
|
||||||
return sess.run(
|
return sess.run(
|
||||||
self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
|
self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
|
||||||
|
|
||||||
def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
|
def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
|
||||||
with _forward_compat_context(x.dtype), self.cached_session(
|
with self.cached_session(use_gpu=True) as sess:
|
||||||
use_gpu=True) as sess:
|
|
||||||
return sess.run(
|
return sess.run(
|
||||||
self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
|
self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -46,7 +45,6 @@ class MFCCTest(test.TestCase, parameterized.TestCase):
|
|||||||
@parameterized.parameters(dtypes.float32, dtypes.float64)
|
@parameterized.parameters(dtypes.float32, dtypes.float64)
|
||||||
def test_basic(self, dtype):
|
def test_basic(self, dtype):
|
||||||
"""A basic test that the op runs on random input."""
|
"""A basic test that the op runs on random input."""
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
|
||||||
signal = random_ops.random_normal((2, 3, 5), dtype=dtype)
|
signal = random_ops.random_normal((2, 3, 5), dtype=dtype)
|
||||||
self.evaluate(mfcc_ops.mfccs_from_log_mel_spectrograms(signal))
|
self.evaluate(mfcc_ops.mfccs_from_log_mel_spectrograms(signal))
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -165,8 +164,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
def test_stft_and_inverse_stft(self, signal_length, frame_length,
|
def test_stft_and_inverse_stft(self, signal_length, frame_length,
|
||||||
frame_step, fft_length, np_rtype, tol):
|
frame_step, fft_length, np_rtype, tol):
|
||||||
"""Test that spectral_ops.stft/inverse_stft match a NumPy implementation."""
|
"""Test that spectral_ops.stft/inverse_stft match a NumPy implementation."""
|
||||||
# Enable float64 support for RFFTs.
|
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
|
||||||
signal = np.random.random(signal_length).astype(np_rtype)
|
signal = np.random.random(signal_length).astype(np_rtype)
|
||||||
self._compare(signal, frame_length, frame_step, fft_length, tol)
|
self._compare(signal, frame_length, frame_step, fft_length, tol)
|
||||||
|
|
||||||
@ -189,8 +186,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
def test_stft_round_trip(self, signal_length, frame_length, frame_step,
|
def test_stft_round_trip(self, signal_length, frame_length, frame_step,
|
||||||
fft_length, np_rtype, threshold,
|
fft_length, np_rtype, threshold,
|
||||||
corrected_threshold):
|
corrected_threshold):
|
||||||
# Enable float64 support for RFFTs.
|
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
|
||||||
# Generate a random white Gaussian signal.
|
# Generate a random white Gaussian signal.
|
||||||
signal = np.random.normal(size=signal_length).astype(np_rtype)
|
signal = np.random.normal(size=signal_length).astype(np_rtype)
|
||||||
|
|
||||||
@ -259,8 +254,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
def _compute_stft_gradient(signal, frame_length=32, frame_step=16,
|
def _compute_stft_gradient(signal, frame_length=32, frame_step=16,
|
||||||
fft_length=32):
|
fft_length=32):
|
||||||
"""Computes the gradient of the STFT with respect to `signal`."""
|
"""Computes the gradient of the STFT with respect to `signal`."""
|
||||||
# Enable float64 support for RFFTs.
|
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
|
||||||
stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length)
|
stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length)
|
||||||
magnitude_stft = math_ops.abs(stft)
|
magnitude_stft = math_ops.abs(stft)
|
||||||
loss = math_ops.reduce_sum(magnitude_stft)
|
loss = math_ops.reduce_sum(magnitude_stft)
|
||||||
@ -301,8 +294,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
(29, 5, 1, 10, np.float64, 1e-8, 1e-8))
|
(29, 5, 1, 10, np.float64, 1e-8, 1e-8))
|
||||||
def test_gradients_numerical(self, signal_length, frame_length, frame_step,
|
def test_gradients_numerical(self, signal_length, frame_length, frame_step,
|
||||||
fft_length, np_rtype, forward_tol, backward_tol):
|
fft_length, np_rtype, forward_tol, backward_tol):
|
||||||
# Enable float64 support for RFFTs.
|
|
||||||
with compat.forward_compatibility_horizon(2019, 10, 13):
|
|
||||||
# TODO(rjryan): Investigate why STFT gradient error is so high.
|
# TODO(rjryan): Investigate why STFT gradient error is so high.
|
||||||
signal = np.random.rand(signal_length).astype(np_rtype) * 2 - 1
|
signal = np.random.rand(signal_length).astype(np_rtype) * 2 - 1
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.framework import dtypes as _dtypes
|
from tensorflow.python.framework import dtypes as _dtypes
|
||||||
from tensorflow.python.framework import ops as _ops
|
from tensorflow.python.framework import ops as _ops
|
||||||
from tensorflow.python.framework import tensor_util as _tensor_util
|
from tensorflow.python.framework import tensor_util as _tensor_util
|
||||||
@ -118,10 +117,15 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
|
|||||||
[input_tensor, fft_length]) as name:
|
[input_tensor, fft_length]) as name:
|
||||||
input_tensor = _ops.convert_to_tensor(input_tensor,
|
input_tensor = _ops.convert_to_tensor(input_tensor,
|
||||||
preferred_dtype=_dtypes.float32)
|
preferred_dtype=_dtypes.float32)
|
||||||
|
if input_tensor.dtype not in (_dtypes.float32, _dtypes.float64):
|
||||||
|
raise ValueError(
|
||||||
|
"RFFT requires tf.float32 or tf.float64 inputs, got: %s" %
|
||||||
|
input_tensor)
|
||||||
real_dtype = input_tensor.dtype
|
real_dtype = input_tensor.dtype
|
||||||
if real_dtype == _dtypes.float32:
|
if real_dtype == _dtypes.float32:
|
||||||
complex_dtype = _dtypes.complex64
|
complex_dtype = _dtypes.complex64
|
||||||
elif real_dtype == _dtypes.float64:
|
else:
|
||||||
|
assert real_dtype == _dtypes.float64
|
||||||
complex_dtype = _dtypes.complex128
|
complex_dtype = _dtypes.complex128
|
||||||
input_tensor.shape.with_rank_at_least(fft_rank)
|
input_tensor.shape.with_rank_at_least(fft_rank)
|
||||||
if fft_length is None:
|
if fft_length is None:
|
||||||
@ -129,9 +133,6 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
|
|||||||
else:
|
else:
|
||||||
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
|
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
|
||||||
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
|
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
|
||||||
|
|
||||||
if not compat.forward_compatible(2019, 10, 12):
|
|
||||||
return fft_fn(input_tensor, fft_length, name=name)
|
|
||||||
return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
|
return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
|
||||||
_rfft.__doc__ = fft_fn.__doc__
|
_rfft.__doc__ = fft_fn.__doc__
|
||||||
return _rfft
|
return _rfft
|
||||||
@ -147,6 +148,10 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
|
|||||||
input_tensor = _ops.convert_to_tensor(input_tensor,
|
input_tensor = _ops.convert_to_tensor(input_tensor,
|
||||||
preferred_dtype=_dtypes.complex64)
|
preferred_dtype=_dtypes.complex64)
|
||||||
input_tensor.shape.with_rank_at_least(fft_rank)
|
input_tensor.shape.with_rank_at_least(fft_rank)
|
||||||
|
if input_tensor.dtype not in (_dtypes.complex64, _dtypes.complex128):
|
||||||
|
raise ValueError(
|
||||||
|
"IRFFT requires tf.complex64 or tf.complex128 inputs, got: %s" %
|
||||||
|
input_tensor)
|
||||||
complex_dtype = input_tensor.dtype
|
complex_dtype = input_tensor.dtype
|
||||||
real_dtype = complex_dtype.real_dtype
|
real_dtype = complex_dtype.real_dtype
|
||||||
if fft_length is None:
|
if fft_length is None:
|
||||||
@ -155,8 +160,6 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
|
|||||||
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
|
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
|
||||||
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
|
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
|
||||||
is_reverse=True)
|
is_reverse=True)
|
||||||
if not compat.forward_compatible(2019, 10, 12):
|
|
||||||
return ifft_fn(input_tensor, fft_length, name=name)
|
|
||||||
return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
|
return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
|
||||||
_irfft.__doc__ = ifft_fn.__doc__
|
_irfft.__doc__ = ifft_fn.__doc__
|
||||||
return _irfft
|
return _irfft
|
||||||
|
Loading…
Reference in New Issue
Block a user