Add DCT-I to tf.signal.dct.
PiperOrigin-RevId: 221356811
This commit is contained in:
parent
a17b4fc17b
commit
c4a6eb5857
@ -2630,6 +2630,7 @@ cuda_py_test(
|
||||
name = "dct_ops_test",
|
||||
srcs = ["dct_ops_test.py"],
|
||||
additional_deps = [
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.ops import spectral_ops
|
||||
@ -40,6 +41,20 @@ def try_import(name): # pylint: disable=invalid-name
|
||||
fftpack = try_import("scipy.fftpack")
|
||||
|
||||
|
||||
def _np_dct1(signals, norm=None):
|
||||
"""Computes the DCT-I manually with NumPy."""
|
||||
# X_k = (x_0 + (-1)**k * x_{N-1} +
|
||||
# 2 * sum_{n=0}^{N-2} x_n * cos(\frac{pi}{N-1} * n * k) k=0,...,N-1
|
||||
del norm
|
||||
dct_size = signals.shape[-1]
|
||||
dct = np.zeros_like(signals)
|
||||
for k in range(dct_size):
|
||||
phi = np.cos(np.pi * np.arange(1, dct_size - 1) * k / (dct_size - 1))
|
||||
dct[..., k] = 2 * np.sum(signals[..., 1:-1] * phi, axis=-1) + (
|
||||
signals[..., 0] + (-1) ** k * signals[..., -1])
|
||||
return dct
|
||||
|
||||
|
||||
def _np_dct2(signals, norm=None):
|
||||
"""Computes the DCT-II manually with NumPy."""
|
||||
# X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
|
||||
@ -81,11 +96,11 @@ def _np_dct3(signals, norm=None):
|
||||
return dct
|
||||
|
||||
|
||||
NP_DCT = {2: _np_dct2, 3: _np_dct3}
|
||||
NP_IDCT = {2: _np_dct3, 3: _np_dct2}
|
||||
NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3}
|
||||
NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2}
|
||||
|
||||
|
||||
class DCTOpsTest(test.TestCase):
|
||||
class DCTOpsTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
def _compare(self, signals, norm, dct_type, atol=5e-4, rtol=5e-4):
|
||||
"""Compares (I)DCT to SciPy (if available) and a NumPy implementation."""
|
||||
@ -106,26 +121,39 @@ class DCTOpsTest(test.TestCase):
|
||||
tf_dct_idct = spectral_ops.dct(
|
||||
tf_idct, type=dct_type, norm=norm).eval()
|
||||
if norm is None:
|
||||
tf_idct_dct *= 0.5 / signals.shape[-1]
|
||||
tf_dct_idct *= 0.5 / signals.shape[-1]
|
||||
if dct_type == 1:
|
||||
tf_idct_dct *= 0.5 / (signals.shape[-1] - 1)
|
||||
tf_dct_idct *= 0.5 / (signals.shape[-1] - 1)
|
||||
else:
|
||||
tf_idct_dct *= 0.5 / signals.shape[-1]
|
||||
tf_dct_idct *= 0.5 / signals.shape[-1]
|
||||
self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol)
|
||||
self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
|
||||
|
||||
def test_random(self):
|
||||
@parameterized.parameters([
|
||||
[[2]], [[3]], [[10]], [[2, 20]], [[2, 3, 25]]])
|
||||
def test_random(self, shape):
|
||||
"""Test randomly generated batches of data."""
|
||||
with spectral_ops_test_util.fft_kernel_label_map():
|
||||
with self.session(use_gpu=True):
|
||||
for shape in ([1], [2], [3], [10], [2, 20], [2, 3, 25]):
|
||||
signals = np.random.rand(*shape).astype(np.float32)
|
||||
for norm in (None, "ortho"):
|
||||
self._compare(signals, norm, 2)
|
||||
self._compare(signals, norm, 3)
|
||||
signals = np.random.rand(*shape).astype(np.float32)
|
||||
# Normalization not implemented for orthonormal.
|
||||
self._compare(signals, norm=None, dct_type=1)
|
||||
for norm in (None, "ortho"):
|
||||
self._compare(signals, norm, 2)
|
||||
self._compare(signals, norm, 3)
|
||||
|
||||
def test_error(self):
|
||||
signals = np.random.rand(10)
|
||||
# Unsupported type.
|
||||
with self.assertRaises(ValueError):
|
||||
spectral_ops.dct(signals, type=1)
|
||||
spectral_ops.dct(signals, type=5)
|
||||
# DCT-I normalization not implemented.
|
||||
with self.assertRaises(ValueError):
|
||||
spectral_ops.dct(signals, type=1, norm="ortho")
|
||||
# DCT-I requires at least two inputs.
|
||||
with self.assertRaises(ValueError):
|
||||
spectral_ops.dct(np.random.rand(1), type=1)
|
||||
# Unknown normalization.
|
||||
with self.assertRaises(ValueError):
|
||||
spectral_ops.dct(signals, norm="bad")
|
||||
|
@ -165,37 +165,48 @@ irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d")
|
||||
tf_export("spectral.irfft3d")(irfft3d)
|
||||
|
||||
|
||||
def _validate_dct_arguments(dct_type, n, axis, norm):
|
||||
def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
|
||||
"""Checks that DCT/IDCT arguments are compatible and well formed."""
|
||||
if n is not None:
|
||||
raise NotImplementedError("The DCT length argument is not implemented.")
|
||||
if axis != -1:
|
||||
raise NotImplementedError("axis must be -1. Got: %s" % axis)
|
||||
if dct_type not in (2, 3):
|
||||
raise ValueError("Only Types II and III (I)DCT are supported.")
|
||||
if dct_type not in (1, 2, 3):
|
||||
raise ValueError("Only Types I, II and III (I)DCT are supported.")
|
||||
if dct_type == 1:
|
||||
if norm == "ortho":
|
||||
raise ValueError("Normalization is not supported for the Type-I DCT.")
|
||||
if input_tensor.shape[-1] is not None and input_tensor.shape[-1] < 2:
|
||||
raise ValueError(
|
||||
"Type-I DCT requires the dimension to be greater than one.")
|
||||
|
||||
if norm not in (None, "ortho"):
|
||||
raise ValueError(
|
||||
"Unknown normalization. Expected None or 'ortho', got: %s" % norm)
|
||||
|
||||
|
||||
# TODO(rjryan): Implement `type`, `n` and `axis` parameters.
|
||||
# TODO(rjryan): Implement `n` and `axis` parameters.
|
||||
@tf_export("spectral.dct")
|
||||
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
|
||||
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
|
||||
|
||||
Currently only Types II and III are supported. Type II is implemented using a
|
||||
length `2N` padded `tf.spectral.rfft`, as described here:
|
||||
https://dsp.stackexchange.com/a/10606. Type III is a fairly straightforward
|
||||
inverse of Type II (i.e. using a length `2N` padded `tf.spectral.irfft`).
|
||||
Currently only Types I, II and III are supported.
|
||||
Type I is implemented using a length `2N` padded `tf.spectral.rfft`.
|
||||
Type II is implemented using a length `2N` padded `tf.spectral.rfft`, as
|
||||
described here:
|
||||
https://dsp.stackexchange.com/a/10606.
|
||||
Type III is a fairly straightforward inverse of Type II
|
||||
(i.e. using a length `2N` padded `tf.spectral.irfft`).
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to scipy.fftpack.dct for Type-II and Type-III DCT.
|
||||
Equivalent to scipy.fftpack.dct for Type-I, Type-II and Type-III DCT.
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
|
||||
@end_compatibility
|
||||
|
||||
Args:
|
||||
input: A `[..., samples]` `float32` `Tensor` containing the signals to
|
||||
take the DCT of.
|
||||
type: The DCT type to perform. Must be 2 or 3.
|
||||
type: The DCT type to perform. Must be 1, 2 or 3.
|
||||
n: For future expansion. The length of the transform. Must be `None`.
|
||||
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
|
||||
norm: The normalization to apply. `None` for no normalization or `'ortho'`
|
||||
@ -206,12 +217,13 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
|
||||
A `[..., samples]` `float32` `Tensor` containing the DCT of `input`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
|
||||
`-1`, or `norm` is not `None` or `'ortho'`.
|
||||
ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is
|
||||
not `-1`, or `norm` is not `None` or `'ortho'`.
|
||||
ValueError: If `type` is `1` and `norm` is `ortho`.
|
||||
|
||||
[dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform
|
||||
"""
|
||||
_validate_dct_arguments(type, n, axis, norm)
|
||||
_validate_dct_arguments(input, type, n, axis, norm)
|
||||
with _ops.name_scope(name, "dct", [input]):
|
||||
# We use the RFFT to compute the DCT and TensorFlow only supports float32
|
||||
# for FFTs at the moment.
|
||||
@ -220,6 +232,12 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
|
||||
axis_dim = (tensor_shape.dimension_value(input.shape[-1])
|
||||
or _array_ops.shape(input)[-1])
|
||||
axis_dim_float = _math_ops.to_float(axis_dim)
|
||||
|
||||
if type == 1:
|
||||
dct1_input = _array_ops.concat([input, input[..., -2:0:-1]], axis=-1)
|
||||
dct1 = _math_ops.real(rfft(dct1_input))
|
||||
return dct1
|
||||
|
||||
if type == 2:
|
||||
scale = 2.0 * _math_ops.exp(
|
||||
_math_ops.complex(
|
||||
@ -266,12 +284,12 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
|
||||
return dct3
|
||||
|
||||
|
||||
# TODO(rjryan): Implement `type`, `n` and `axis` parameters.
|
||||
# TODO(rjryan): Implement `n` and `axis` parameters.
|
||||
@tf_export("spectral.idct")
|
||||
def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
|
||||
"""Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
|
||||
|
||||
Currently only Types II and III are supported. Type III is the inverse of
|
||||
Currently only Types I, II and III are supported. Type III is the inverse of
|
||||
Type II, and vice versa.
|
||||
|
||||
Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
|
||||
@ -281,14 +299,14 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
|
||||
`signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to scipy.fftpack.idct for Type-II and Type-III DCT.
|
||||
Equivalent to scipy.fftpack.idct for Type-I, Type-II and Type-III DCT.
|
||||
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html
|
||||
@end_compatibility
|
||||
|
||||
Args:
|
||||
input: A `[..., samples]` `float32` `Tensor` containing the signals to take
|
||||
the DCT of.
|
||||
type: The IDCT type to perform. Must be 2 or 3.
|
||||
type: The IDCT type to perform. Must be 1, 2 or 3.
|
||||
n: For future expansion. The length of the transform. Must be `None`.
|
||||
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
|
||||
norm: The normalization to apply. `None` for no normalization or `'ortho'`
|
||||
@ -299,12 +317,12 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
|
||||
A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
|
||||
`-1`, or `norm` is not `None` or `'ortho'`.
|
||||
ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is
|
||||
not `-1`, or `norm` is not `None` or `'ortho'`.
|
||||
|
||||
[idct]:
|
||||
https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
|
||||
"""
|
||||
_validate_dct_arguments(type, n, axis, norm)
|
||||
inverse_type = {2: 3, 3: 2}[type]
|
||||
_validate_dct_arguments(input, type, n, axis, norm)
|
||||
inverse_type = {1: 1, 2: 3, 3: 2}[type]
|
||||
return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)
|
||||
|
Loading…
x
Reference in New Issue
Block a user