Remove forward compatibility guards for float64/complex128 support since the time window has passed.

Also, add type checks for tf.signal.rfft/irfft to ensure correct input types are provided.

PiperOrigin-RevId: 277150624
Change-Id: I8bd5a0fbb6d4ce90351ccbf8ea2b094689fe0da5
This commit is contained in:
RJ Skerry-Ryan 2019-10-28 15:02:26 -07:00 committed by TensorFlower Gardener
parent 8c5f1b7943
commit 7e31db966d
5 changed files with 74 additions and 100 deletions

View File

@ -24,7 +24,6 @@ import itertools
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops.signal import dct_ops from tensorflow.python.ops.signal import dct_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -169,7 +168,6 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
# "ortho" normalization is not implemented for type I. # "ortho" normalization is not implemented for type I.
if dct_type == 1 and norm == "ortho": if dct_type == 1 and norm == "ortho":
return return
with compat.forward_compatibility_horizon(2019, 10, 13):
with self.session(use_gpu=True): with self.session(use_gpu=True):
tol = 5e-4 if dtype == np.float32 else 1e-7 tol = 5e-4 if dtype == np.float32 else 1e-7
signals = np.random.rand(*shape).astype(dtype) signals = np.random.rand(*shape).astype(dtype)

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import contextlib
import itertools import itertools
from absl.testing import parameterized from absl.testing import parameterized
@ -26,7 +25,6 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -41,16 +39,6 @@ from tensorflow.python.platform import test
VALID_FFT_RANKS = (1, 2, 3) VALID_FFT_RANKS = (1, 2, 3)
def _forward_compat_context(np_dtype):
@contextlib.contextmanager
def null_context():
yield
if np_dtype in (np.float64, np.complex128):
return compat.forward_compatibility_horizon(2019, 10, 13)
else:
return null_context()
# TODO(rjryan): Investigate precision issues. We should be able to achieve # TODO(rjryan): Investigate precision issues. We should be able to achieve
# better tolerances, at least for the complex128 tests. # better tolerances, at least for the complex128 tests.
class BaseFFTOpsTest(test.TestCase): class BaseFFTOpsTest(test.TestCase):
@ -101,7 +89,6 @@ class BaseFFTOpsTest(test.TestCase):
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z))) loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
return loss return loss
with _forward_compat_context(x.dtype):
((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = ( ((x_jacob_t, y_jacob_t), (x_jacob_n, y_jacob_n)) = (
gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2)) gradient_checker_v2.compute_gradient(f, [x, y], delta=1e-2))
@ -117,7 +104,6 @@ class BaseFFTOpsTest(test.TestCase):
loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z))) loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
return loss return loss
with _forward_compat_context(x.dtype):
(x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient( (x_jacob_t,), (x_jacob_n,) = gradient_checker_v2.compute_gradient(
f, [x], delta=1e-2) f, [x], delta=1e-2)
self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol) self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol)
@ -301,14 +287,12 @@ class FFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
class RFFTOpsTest(BaseFFTOpsTest, parameterized.TestCase): class RFFTOpsTest(BaseFFTOpsTest, parameterized.TestCase):
def _tf_fft(self, x, rank, fft_length=None, feed_dict=None): def _tf_fft(self, x, rank, fft_length=None, feed_dict=None):
with _forward_compat_context(x.dtype), self.cached_session( with self.cached_session(use_gpu=True) as sess:
use_gpu=True) as sess:
return sess.run( return sess.run(
self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict) self._tf_fft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)
def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None): def _tf_ifft(self, x, rank, fft_length=None, feed_dict=None):
with _forward_compat_context(x.dtype), self.cached_session( with self.cached_session(use_gpu=True) as sess:
use_gpu=True) as sess:
return sess.run( return sess.run(
self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict) self._tf_ifft_for_rank(rank)(x, fft_length), feed_dict=feed_dict)

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.compat import compat
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
@ -46,7 +45,6 @@ class MFCCTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(dtypes.float32, dtypes.float64) @parameterized.parameters(dtypes.float32, dtypes.float64)
def test_basic(self, dtype): def test_basic(self, dtype):
"""A basic test that the op runs on random input.""" """A basic test that the op runs on random input."""
with compat.forward_compatibility_horizon(2019, 10, 13):
signal = random_ops.random_normal((2, 3, 5), dtype=dtype) signal = random_ops.random_normal((2, 3, 5), dtype=dtype)
self.evaluate(mfcc_ops.mfccs_from_log_mel_spectrograms(signal)) self.evaluate(mfcc_ops.mfccs_from_log_mel_spectrograms(signal))

View File

@ -21,7 +21,6 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -165,8 +164,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
def test_stft_and_inverse_stft(self, signal_length, frame_length, def test_stft_and_inverse_stft(self, signal_length, frame_length,
frame_step, fft_length, np_rtype, tol): frame_step, fft_length, np_rtype, tol):
"""Test that spectral_ops.stft/inverse_stft match a NumPy implementation.""" """Test that spectral_ops.stft/inverse_stft match a NumPy implementation."""
# Enable float64 support for RFFTs.
with compat.forward_compatibility_horizon(2019, 10, 13):
signal = np.random.random(signal_length).astype(np_rtype) signal = np.random.random(signal_length).astype(np_rtype)
self._compare(signal, frame_length, frame_step, fft_length, tol) self._compare(signal, frame_length, frame_step, fft_length, tol)
@ -189,8 +186,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
def test_stft_round_trip(self, signal_length, frame_length, frame_step, def test_stft_round_trip(self, signal_length, frame_length, frame_step,
fft_length, np_rtype, threshold, fft_length, np_rtype, threshold,
corrected_threshold): corrected_threshold):
# Enable float64 support for RFFTs.
with compat.forward_compatibility_horizon(2019, 10, 13):
# Generate a random white Gaussian signal. # Generate a random white Gaussian signal.
signal = np.random.normal(size=signal_length).astype(np_rtype) signal = np.random.normal(size=signal_length).astype(np_rtype)
@ -259,8 +254,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
def _compute_stft_gradient(signal, frame_length=32, frame_step=16, def _compute_stft_gradient(signal, frame_length=32, frame_step=16,
fft_length=32): fft_length=32):
"""Computes the gradient of the STFT with respect to `signal`.""" """Computes the gradient of the STFT with respect to `signal`."""
# Enable float64 support for RFFTs.
with compat.forward_compatibility_horizon(2019, 10, 13):
stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length) stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length)
magnitude_stft = math_ops.abs(stft) magnitude_stft = math_ops.abs(stft)
loss = math_ops.reduce_sum(magnitude_stft) loss = math_ops.reduce_sum(magnitude_stft)
@ -301,8 +294,6 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
(29, 5, 1, 10, np.float64, 1e-8, 1e-8)) (29, 5, 1, 10, np.float64, 1e-8, 1e-8))
def test_gradients_numerical(self, signal_length, frame_length, frame_step, def test_gradients_numerical(self, signal_length, frame_length, frame_step,
fft_length, np_rtype, forward_tol, backward_tol): fft_length, np_rtype, forward_tol, backward_tol):
# Enable float64 support for RFFTs.
with compat.forward_compatibility_horizon(2019, 10, 13):
# TODO(rjryan): Investigate why STFT gradient error is so high. # TODO(rjryan): Investigate why STFT gradient error is so high.
signal = np.random.rand(signal_length).astype(np_rtype) * 2 - 1 signal = np.random.rand(signal_length).astype(np_rtype) * 2 - 1

View File

@ -19,7 +19,6 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import tensor_util as _tensor_util from tensorflow.python.framework import tensor_util as _tensor_util
@ -118,10 +117,15 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
[input_tensor, fft_length]) as name: [input_tensor, fft_length]) as name:
input_tensor = _ops.convert_to_tensor(input_tensor, input_tensor = _ops.convert_to_tensor(input_tensor,
preferred_dtype=_dtypes.float32) preferred_dtype=_dtypes.float32)
if input_tensor.dtype not in (_dtypes.float32, _dtypes.float64):
raise ValueError(
"RFFT requires tf.float32 or tf.float64 inputs, got: %s" %
input_tensor)
real_dtype = input_tensor.dtype real_dtype = input_tensor.dtype
if real_dtype == _dtypes.float32: if real_dtype == _dtypes.float32:
complex_dtype = _dtypes.complex64 complex_dtype = _dtypes.complex64
elif real_dtype == _dtypes.float64: else:
assert real_dtype == _dtypes.float64
complex_dtype = _dtypes.complex128 complex_dtype = _dtypes.complex128
input_tensor.shape.with_rank_at_least(fft_rank) input_tensor.shape.with_rank_at_least(fft_rank)
if fft_length is None: if fft_length is None:
@ -129,9 +133,6 @@ def _rfft_wrapper(fft_fn, fft_rank, default_name):
else: else:
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length) input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length)
if not compat.forward_compatible(2019, 10, 12):
return fft_fn(input_tensor, fft_length, name=name)
return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name) return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name)
_rfft.__doc__ = fft_fn.__doc__ _rfft.__doc__ = fft_fn.__doc__
return _rfft return _rfft
@ -147,6 +148,10 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
input_tensor = _ops.convert_to_tensor(input_tensor, input_tensor = _ops.convert_to_tensor(input_tensor,
preferred_dtype=_dtypes.complex64) preferred_dtype=_dtypes.complex64)
input_tensor.shape.with_rank_at_least(fft_rank) input_tensor.shape.with_rank_at_least(fft_rank)
if input_tensor.dtype not in (_dtypes.complex64, _dtypes.complex128):
raise ValueError(
"IRFFT requires tf.complex64 or tf.complex128 inputs, got: %s" %
input_tensor)
complex_dtype = input_tensor.dtype complex_dtype = input_tensor.dtype
real_dtype = complex_dtype.real_dtype real_dtype = complex_dtype.real_dtype
if fft_length is None: if fft_length is None:
@ -155,8 +160,6 @@ def _irfft_wrapper(ifft_fn, fft_rank, default_name):
fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32)
input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length,
is_reverse=True) is_reverse=True)
if not compat.forward_compatible(2019, 10, 12):
return ifft_fn(input_tensor, fft_length, name=name)
return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name) return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name)
_irfft.__doc__ = ifft_fn.__doc__ _irfft.__doc__ = ifft_fn.__doc__
return _irfft return _irfft