Add dct type IV to tf.signal.dct.

PiperOrigin-RevId: 286485474
Change-Id: I38e87fcc0fcf8ebc38b0dfdc36971a0820242009
This commit is contained in:
A. Unique TensorFlower 2019-12-19 16:25:28 -08:00 committed by TensorFlower Gardener
parent a5d49ba936
commit 4d252ebeed
2 changed files with 58 additions and 20 deletions

View File

@ -87,7 +87,7 @@ def _np_dct2(signals, n=None, norm=None):
phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
dct[..., k] = np.sum(signals_mod * phi, axis=-1)
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
# https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
# https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
if norm == "ortho":
# The orthonormal scaling includes a factor of 0.5 which we combine with
# the overall scaling of 2.0 to cancel.
@ -101,7 +101,7 @@ def _np_dct2(signals, n=None, norm=None):
def _np_dct3(signals, n=None, norm=None):
"""Computes the DCT-III manually with NumPy."""
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
# https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
# https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
signals_mod = _modify_input_for_dct(signals, n=n)
dct_size = signals_mod.shape[-1]
signals_mod = np.array(signals_mod) # make a copy so we can modify
@ -120,8 +120,30 @@ def _np_dct3(signals, n=None, norm=None):
return dct
NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3}
NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2}
def _np_dct4(signals, n=None, norm=None):
"""Computes the DCT-IV manually with NumPy."""
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
# https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
signals_mod = _modify_input_for_dct(signals, n=n)
dct_size = signals_mod.shape[-1]
signals_mod = np.array(signals_mod) # make a copy so we can modify
if norm == "ortho":
signals_mod *= np.sqrt(2.0 / dct_size)
else:
signals_mod *= 2.0
dct = np.zeros_like(signals_mod)
# X_k = sum_{n=0}^{N-1}
# x_n * cos(\frac{pi}{4N} * (2n + 1) * (2k + 1)) k=0,...,N-1
for k in range(dct_size):
phi = np.cos(np.pi *
(2 * np.arange(0, dct_size) + 1) * (2 * k + 1) /
(4.0 * dct_size))
dct[..., k] = np.sum(signals_mod * phi, axis=-1)
return dct
NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3, 4: _np_dct4}
NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2, 4: _np_dct4}
@test_util.run_all_in_graph_and_eager_modes
@ -137,7 +159,7 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
tf_idct = dct_ops.idct(signals, type=dct_type, norm=norm)
self.assertEqual(tf_idct.dtype.as_numpy_dtype, signals.dtype)
self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol)
if fftpack:
if fftpack and dct_type != 4:
scipy_dct = fftpack.dct(signals, n=n, type=dct_type, norm=norm)
self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm)
@ -159,7 +181,7 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
@parameterized.parameters(itertools.product(
[1, 2, 3],
[1, 2, 3, 4],
[None, "ortho"],
[[2], [3], [10], [2, 20], [2, 3, 25]],
[np.float32, np.float64]))

View File

@ -34,8 +34,8 @@ def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
raise NotImplementedError("axis must be -1. Got: %s" % axis)
if n is not None and n < 1:
raise ValueError("n should be a positive integer or None")
if dct_type not in (1, 2, 3):
raise ValueError("Only Types I, II and III (I)DCT are supported.")
if dct_type not in (1, 2, 3, 4):
raise ValueError("Types I, II, III and IV (I)DCT are supported.")
if dct_type == 1:
if norm == "ortho":
raise ValueError("Normalization is not supported for the Type-I DCT.")
@ -53,22 +53,26 @@ def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
Currently only Types I, II and III are supported.
Types I, II, III and IV are supported.
Type I is implemented using a length `2N` padded `tf.signal.rfft`.
Type II is implemented using a length `2N` padded `tf.signal.rfft`, as
described here: [Type 2 DCT using 2N FFT padded (Makhoul)](https://dsp.stackexchange.com/a/10606).
described here: [Type 2 DCT using 2N FFT padded (Makhoul)]
(https://dsp.stackexchange.com/a/10606).
Type III is a fairly straightforward inverse of Type II
(i.e. using a length `2N` padded `tf.signal.irfft`).
(i.e. using a length `2N` padded `tf.signal.irfft`).
Type IV is calculated through 2N length DCT2 of padded signal and
picking the odd indices.
@compatibility(scipy)
Equivalent to [scipy.fftpack.dct](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html)
for Type-I, Type-II and Type-III DCT.
Equivalent to [scipy.fftpack.dct]
(https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.dct.html)
for Type-I, Type-II, Type-III and Type-IV DCT.
@end_compatibility
Args:
input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
signals to take the DCT of.
type: The DCT type to perform. Must be 1, 2 or 3.
type: The DCT type to perform. Must be 1, 2, 3 or 4.
n: The length of the transform. If length is less than sequence length,
only the first n elements of the sequence are considered for the DCT.
If n is greater than the sequence length, zeros are padded and then
@ -83,7 +87,7 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
`input`.
Raises:
ValueError: If `type` is not `1`, `2` or `3`, `axis` is
ValueError: If `type` is not `1`, `2`, `3` or `4`, `axis` is
not `-1`, `n` is not `None` or greater than 0,
or `norm` is not `None` or `'ortho'`.
ValueError: If `type` is `1` and `norm` is `ortho`.
@ -163,13 +167,24 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
return dct3
elif type == 4:
# DCT-2 of 2N length zero-padded signal, unnormalized.
dct2 = dct(input, type=2, n=2*axis_dim, axis=axis, norm=None)
# Get odd indices of DCT-2 of zero padded 2N signal to obtain
# DCT-4 of the original N length signal.
dct4 = dct2[..., 1::2]
if norm == "ortho":
dct4 *= _math.sqrt(0.5) * _math_ops.rsqrt(axis_dim_float)
return dct4
# TODO(rjryan): Implement `n` and `axis` parameters.
@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"])
def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
Currently only Types I, II and III are supported. Type III is the inverse of
Currently Types I, II, III, IV are supported. Type III is the inverse of
Type II, and vice versa.
Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
@ -179,14 +194,15 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
`signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
@compatibility(scipy)
Equivalent to [scipy.fftpack.idct](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html)
for Type-I, Type-II and Type-III DCT.
Equivalent to [scipy.fftpack.idct]
(https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.idct.html)
for Type-I, Type-II, Type-III and Type-IV DCT.
@end_compatibility
Args:
input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
signals to take the DCT of.
type: The IDCT type to perform. Must be 1, 2 or 3.
type: The IDCT type to perform. Must be 1, 2, 3 or 4.
n: For future expansion. The length of the transform. Must be `None`.
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
norm: The normalization to apply. `None` for no normalization or `'ortho'`
@ -205,5 +221,5 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
"""
_validate_dct_arguments(input, type, n, axis, norm)
inverse_type = {1: 1, 2: 3, 3: 2}[type]
inverse_type = {1: 1, 2: 3, 3: 2, 4: 4}[type]
return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)