Add dct type IV to tf.signal.dct.
PiperOrigin-RevId: 286485474 Change-Id: I38e87fcc0fcf8ebc38b0dfdc36971a0820242009
This commit is contained in:
parent
a5d49ba936
commit
4d252ebeed
@ -87,7 +87,7 @@ def _np_dct2(signals, n=None, norm=None):
|
||||
phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
|
||||
dct[..., k] = np.sum(signals_mod * phi, axis=-1)
|
||||
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
|
||||
# https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
|
||||
# https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
|
||||
if norm == "ortho":
|
||||
# The orthonormal scaling includes a factor of 0.5 which we combine with
|
||||
# the overall scaling of 2.0 to cancel.
|
||||
@ -101,7 +101,7 @@ def _np_dct2(signals, n=None, norm=None):
|
||||
def _np_dct3(signals, n=None, norm=None):
|
||||
"""Computes the DCT-III manually with NumPy."""
|
||||
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
|
||||
# https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
|
||||
# https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
|
||||
signals_mod = _modify_input_for_dct(signals, n=n)
|
||||
dct_size = signals_mod.shape[-1]
|
||||
signals_mod = np.array(signals_mod) # make a copy so we can modify
|
||||
@ -120,8 +120,30 @@ def _np_dct3(signals, n=None, norm=None):
|
||||
return dct
|
||||
|
||||
|
||||
NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3}
|
||||
NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2}
|
||||
def _np_dct4(signals, n=None, norm=None):
|
||||
"""Computes the DCT-IV manually with NumPy."""
|
||||
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
|
||||
# https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
|
||||
signals_mod = _modify_input_for_dct(signals, n=n)
|
||||
dct_size = signals_mod.shape[-1]
|
||||
signals_mod = np.array(signals_mod) # make a copy so we can modify
|
||||
if norm == "ortho":
|
||||
signals_mod *= np.sqrt(2.0 / dct_size)
|
||||
else:
|
||||
signals_mod *= 2.0
|
||||
dct = np.zeros_like(signals_mod)
|
||||
# X_k = sum_{n=0}^{N-1}
|
||||
# x_n * cos(\frac{pi}{4N} * (2n + 1) * (2k + 1)) k=0,...,N-1
|
||||
for k in range(dct_size):
|
||||
phi = np.cos(np.pi *
|
||||
(2 * np.arange(0, dct_size) + 1) * (2 * k + 1) /
|
||||
(4.0 * dct_size))
|
||||
dct[..., k] = np.sum(signals_mod * phi, axis=-1)
|
||||
return dct
|
||||
|
||||
|
||||
NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3, 4: _np_dct4}
|
||||
NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2, 4: _np_dct4}
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@ -137,7 +159,7 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
|
||||
tf_idct = dct_ops.idct(signals, type=dct_type, norm=norm)
|
||||
self.assertEqual(tf_idct.dtype.as_numpy_dtype, signals.dtype)
|
||||
self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol)
|
||||
if fftpack:
|
||||
if fftpack and dct_type != 4:
|
||||
scipy_dct = fftpack.dct(signals, n=n, type=dct_type, norm=norm)
|
||||
self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
|
||||
scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm)
|
||||
@ -159,7 +181,7 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
|
||||
self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters(itertools.product(
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[None, "ortho"],
|
||||
[[2], [3], [10], [2, 20], [2, 3, 25]],
|
||||
[np.float32, np.float64]))
|
||||
|
@ -34,8 +34,8 @@ def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
|
||||
raise NotImplementedError("axis must be -1. Got: %s" % axis)
|
||||
if n is not None and n < 1:
|
||||
raise ValueError("n should be a positive integer or None")
|
||||
if dct_type not in (1, 2, 3):
|
||||
raise ValueError("Only Types I, II and III (I)DCT are supported.")
|
||||
if dct_type not in (1, 2, 3, 4):
|
||||
raise ValueError("Types I, II, III and IV (I)DCT are supported.")
|
||||
if dct_type == 1:
|
||||
if norm == "ortho":
|
||||
raise ValueError("Normalization is not supported for the Type-I DCT.")
|
||||
@ -53,22 +53,26 @@ def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
|
||||
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
|
||||
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
|
||||
|
||||
Currently only Types I, II and III are supported.
|
||||
Types I, II, III and IV are supported.
|
||||
Type I is implemented using a length `2N` padded `tf.signal.rfft`.
|
||||
Type II is implemented using a length `2N` padded `tf.signal.rfft`, as
|
||||
described here: [Type 2 DCT using 2N FFT padded (Makhoul)](https://dsp.stackexchange.com/a/10606).
|
||||
described here: [Type 2 DCT using 2N FFT padded (Makhoul)]
|
||||
(https://dsp.stackexchange.com/a/10606).
|
||||
Type III is a fairly straightforward inverse of Type II
|
||||
(i.e. using a length `2N` padded `tf.signal.irfft`).
|
||||
(i.e. using a length `2N` padded `tf.signal.irfft`).
|
||||
Type IV is calculated through 2N length DCT2 of padded signal and
|
||||
picking the odd indices.
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to [scipy.fftpack.dct](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html)
|
||||
for Type-I, Type-II and Type-III DCT.
|
||||
Equivalent to [scipy.fftpack.dct]
|
||||
(https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.dct.html)
|
||||
for Type-I, Type-II, Type-III and Type-IV DCT.
|
||||
@end_compatibility
|
||||
|
||||
Args:
|
||||
input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
|
||||
signals to take the DCT of.
|
||||
type: The DCT type to perform. Must be 1, 2 or 3.
|
||||
type: The DCT type to perform. Must be 1, 2, 3 or 4.
|
||||
n: The length of the transform. If length is less than sequence length,
|
||||
only the first n elements of the sequence are considered for the DCT.
|
||||
If n is greater than the sequence length, zeros are padded and then
|
||||
@ -83,7 +87,7 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
|
||||
`input`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `type` is not `1`, `2` or `3`, `axis` is
|
||||
ValueError: If `type` is not `1`, `2`, `3` or `4`, `axis` is
|
||||
not `-1`, `n` is not `None` or greater than 0,
|
||||
or `norm` is not `None` or `'ortho'`.
|
||||
ValueError: If `type` is `1` and `norm` is `ortho`.
|
||||
@ -163,13 +167,24 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
|
||||
|
||||
return dct3
|
||||
|
||||
elif type == 4:
|
||||
# DCT-2 of 2N length zero-padded signal, unnormalized.
|
||||
dct2 = dct(input, type=2, n=2*axis_dim, axis=axis, norm=None)
|
||||
# Get odd indices of DCT-2 of zero padded 2N signal to obtain
|
||||
# DCT-4 of the original N length signal.
|
||||
dct4 = dct2[..., 1::2]
|
||||
if norm == "ortho":
|
||||
dct4 *= _math.sqrt(0.5) * _math_ops.rsqrt(axis_dim_float)
|
||||
|
||||
return dct4
|
||||
|
||||
|
||||
# TODO(rjryan): Implement `n` and `axis` parameters.
|
||||
@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"])
|
||||
def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
|
||||
"""Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
|
||||
|
||||
Currently only Types I, II and III are supported. Type III is the inverse of
|
||||
Currently Types I, II, III, IV are supported. Type III is the inverse of
|
||||
Type II, and vice versa.
|
||||
|
||||
Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
|
||||
@ -179,14 +194,15 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
|
||||
`signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to [scipy.fftpack.idct](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html)
|
||||
for Type-I, Type-II and Type-III DCT.
|
||||
Equivalent to [scipy.fftpack.idct]
|
||||
(https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.idct.html)
|
||||
for Type-I, Type-II, Type-III and Type-IV DCT.
|
||||
@end_compatibility
|
||||
|
||||
Args:
|
||||
input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
|
||||
signals to take the DCT of.
|
||||
type: The IDCT type to perform. Must be 1, 2 or 3.
|
||||
type: The IDCT type to perform. Must be 1, 2, 3 or 4.
|
||||
n: For future expansion. The length of the transform. Must be `None`.
|
||||
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
|
||||
norm: The normalization to apply. `None` for no normalization or `'ortho'`
|
||||
@ -205,5 +221,5 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
|
||||
https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
|
||||
"""
|
||||
_validate_dct_arguments(input, type, n, axis, norm)
|
||||
inverse_type = {1: 1, 2: 3, 3: 2}[type]
|
||||
inverse_type = {1: 1, 2: 3, 3: 2, 4: 4}[type]
|
||||
return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)
|
||||
|
Loading…
Reference in New Issue
Block a user