Add dct type IV to tf.signal.dct.

PiperOrigin-RevId: 286485474
Change-Id: I38e87fcc0fcf8ebc38b0dfdc36971a0820242009
This commit is contained in:
A. Unique TensorFlower 2019-12-19 16:25:28 -08:00 committed by TensorFlower Gardener
parent a5d49ba936
commit 4d252ebeed
2 changed files with 58 additions and 20 deletions

View File

@ -87,7 +87,7 @@ def _np_dct2(signals, n=None, norm=None):
phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
dct[..., k] = np.sum(signals_mod * phi, axis=-1) dct[..., k] = np.sum(signals_mod * phi, axis=-1)
# SciPy's `dct` has a scaling factor of 2.0 which we follow. # SciPy's `dct` has a scaling factor of 2.0 which we follow.
# https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
if norm == "ortho": if norm == "ortho":
# The orthonormal scaling includes a factor of 0.5 which we combine with # The orthonormal scaling includes a factor of 0.5 which we combine with
# the overall scaling of 2.0 to cancel. # the overall scaling of 2.0 to cancel.
@ -101,7 +101,7 @@ def _np_dct2(signals, n=None, norm=None):
def _np_dct3(signals, n=None, norm=None): def _np_dct3(signals, n=None, norm=None):
"""Computes the DCT-III manually with NumPy.""" """Computes the DCT-III manually with NumPy."""
# SciPy's `dct` has a scaling factor of 2.0 which we follow. # SciPy's `dct` has a scaling factor of 2.0 which we follow.
# https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
signals_mod = _modify_input_for_dct(signals, n=n) signals_mod = _modify_input_for_dct(signals, n=n)
dct_size = signals_mod.shape[-1] dct_size = signals_mod.shape[-1]
signals_mod = np.array(signals_mod) # make a copy so we can modify signals_mod = np.array(signals_mod) # make a copy so we can modify
@ -120,8 +120,30 @@ def _np_dct3(signals, n=None, norm=None):
return dct return dct
NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3} def _np_dct4(signals, n=None, norm=None):
NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2} """Computes the DCT-IV manually with NumPy."""
# SciPy's `dct` has a scaling factor of 2.0 which we follow.
# https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
signals_mod = _modify_input_for_dct(signals, n=n)
dct_size = signals_mod.shape[-1]
signals_mod = np.array(signals_mod) # make a copy so we can modify
if norm == "ortho":
signals_mod *= np.sqrt(2.0 / dct_size)
else:
signals_mod *= 2.0
dct = np.zeros_like(signals_mod)
# X_k = sum_{n=0}^{N-1}
# x_n * cos(\frac{pi}{4N} * (2n + 1) * (2k + 1)) k=0,...,N-1
for k in range(dct_size):
phi = np.cos(np.pi *
(2 * np.arange(0, dct_size) + 1) * (2 * k + 1) /
(4.0 * dct_size))
dct[..., k] = np.sum(signals_mod * phi, axis=-1)
return dct
NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3, 4: _np_dct4}
NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2, 4: _np_dct4}
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
@ -137,7 +159,7 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
tf_idct = dct_ops.idct(signals, type=dct_type, norm=norm) tf_idct = dct_ops.idct(signals, type=dct_type, norm=norm)
self.assertEqual(tf_idct.dtype.as_numpy_dtype, signals.dtype) self.assertEqual(tf_idct.dtype.as_numpy_dtype, signals.dtype)
self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol) self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol)
if fftpack: if fftpack and dct_type != 4:
scipy_dct = fftpack.dct(signals, n=n, type=dct_type, norm=norm) scipy_dct = fftpack.dct(signals, n=n, type=dct_type, norm=norm)
self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol) self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm) scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm)
@ -159,7 +181,7 @@ class DCTOpsTest(parameterized.TestCase, test.TestCase):
self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol) self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
@parameterized.parameters(itertools.product( @parameterized.parameters(itertools.product(
[1, 2, 3], [1, 2, 3, 4],
[None, "ortho"], [None, "ortho"],
[[2], [3], [10], [2, 20], [2, 3, 25]], [[2], [3], [10], [2, 20], [2, 3, 25]],
[np.float32, np.float64])) [np.float32, np.float64]))

View File

@ -34,8 +34,8 @@ def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
raise NotImplementedError("axis must be -1. Got: %s" % axis) raise NotImplementedError("axis must be -1. Got: %s" % axis)
if n is not None and n < 1: if n is not None and n < 1:
raise ValueError("n should be a positive integer or None") raise ValueError("n should be a positive integer or None")
if dct_type not in (1, 2, 3): if dct_type not in (1, 2, 3, 4):
raise ValueError("Only Types I, II and III (I)DCT are supported.") raise ValueError("Types I, II, III and IV (I)DCT are supported.")
if dct_type == 1: if dct_type == 1:
if norm == "ortho": if norm == "ortho":
raise ValueError("Normalization is not supported for the Type-I DCT.") raise ValueError("Normalization is not supported for the Type-I DCT.")
@ -53,22 +53,26 @@ def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
Currently only Types I, II and III are supported. Types I, II, III and IV are supported.
Type I is implemented using a length `2N` padded `tf.signal.rfft`. Type I is implemented using a length `2N` padded `tf.signal.rfft`.
Type II is implemented using a length `2N` padded `tf.signal.rfft`, as Type II is implemented using a length `2N` padded `tf.signal.rfft`, as
described here: [Type 2 DCT using 2N FFT padded (Makhoul)](https://dsp.stackexchange.com/a/10606). described here: [Type 2 DCT using 2N FFT padded (Makhoul)]
(https://dsp.stackexchange.com/a/10606).
Type III is a fairly straightforward inverse of Type II Type III is a fairly straightforward inverse of Type II
(i.e. using a length `2N` padded `tf.signal.irfft`). (i.e. using a length `2N` padded `tf.signal.irfft`).
Type IV is calculated through 2N length DCT2 of padded signal and
picking the odd indices.
@compatibility(scipy) @compatibility(scipy)
Equivalent to [scipy.fftpack.dct](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html) Equivalent to [scipy.fftpack.dct]
for Type-I, Type-II and Type-III DCT. (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.dct.html)
for Type-I, Type-II, Type-III and Type-IV DCT.
@end_compatibility @end_compatibility
Args: Args:
input: A `[..., samples]` `float32`/`float64` `Tensor` containing the input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
signals to take the DCT of. signals to take the DCT of.
type: The DCT type to perform. Must be 1, 2 or 3. type: The DCT type to perform. Must be 1, 2, 3 or 4.
n: The length of the transform. If length is less than sequence length, n: The length of the transform. If length is less than sequence length,
only the first n elements of the sequence are considered for the DCT. only the first n elements of the sequence are considered for the DCT.
If n is greater than the sequence length, zeros are padded and then If n is greater than the sequence length, zeros are padded and then
@ -83,7 +87,7 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
`input`. `input`.
Raises: Raises:
ValueError: If `type` is not `1`, `2` or `3`, `axis` is ValueError: If `type` is not `1`, `2`, `3` or `4`, `axis` is
not `-1`, `n` is not `None` or greater than 0, not `-1`, `n` is not `None` or greater than 0,
or `norm` is not `None` or `'ortho'`. or `norm` is not `None` or `'ortho'`.
ValueError: If `type` is `1` and `norm` is `ortho`. ValueError: If `type` is `1` and `norm` is `ortho`.
@ -163,13 +167,24 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
return dct3 return dct3
elif type == 4:
# DCT-2 of 2N length zero-padded signal, unnormalized.
dct2 = dct(input, type=2, n=2*axis_dim, axis=axis, norm=None)
# Get odd indices of DCT-2 of zero padded 2N signal to obtain
# DCT-4 of the original N length signal.
dct4 = dct2[..., 1::2]
if norm == "ortho":
dct4 *= _math.sqrt(0.5) * _math_ops.rsqrt(axis_dim_float)
return dct4
# TODO(rjryan): Implement `n` and `axis` parameters. # TODO(rjryan): Implement `n` and `axis` parameters.
@tf_export("signal.idct", v1=["signal.idct", "spectral.idct"]) @tf_export("signal.idct", v1=["signal.idct", "spectral.idct"])
def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
Currently only Types I, II and III are supported. Type III is the inverse of Currently Types I, II, III, IV are supported. Type III is the inverse of
Type II, and vice versa. Type II, and vice versa.
Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
@ -179,14 +194,15 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
`signal == idct(dct(signal, norm='ortho'), norm='ortho')`. `signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
@compatibility(scipy) @compatibility(scipy)
Equivalent to [scipy.fftpack.idct](https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html) Equivalent to [scipy.fftpack.idct]
for Type-I, Type-II and Type-III DCT. (https://docs.scipy.org/doc/scipy-1.4.0/reference/generated/scipy.fftpack.idct.html)
for Type-I, Type-II, Type-III and Type-IV DCT.
@end_compatibility @end_compatibility
Args: Args:
input: A `[..., samples]` `float32`/`float64` `Tensor` containing the input: A `[..., samples]` `float32`/`float64` `Tensor` containing the
signals to take the DCT of. signals to take the DCT of.
type: The IDCT type to perform. Must be 1, 2 or 3. type: The IDCT type to perform. Must be 1, 2, 3 or 4.
n: For future expansion. The length of the transform. Must be `None`. n: For future expansion. The length of the transform. Must be `None`.
axis: For future expansion. The axis to compute the DCT along. Must be `-1`. axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
norm: The normalization to apply. `None` for no normalization or `'ortho'` norm: The normalization to apply. `None` for no normalization or `'ortho'`
@ -205,5 +221,5 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
""" """
_validate_dct_arguments(input, type, n, axis, norm) _validate_dct_arguments(input, type, n, axis, norm)
inverse_type = {1: 1, 2: 3, 3: 2}[type] inverse_type = {1: 1, 2: 3, 3: 2, 4: 4}[type]
return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name) return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)