Add DCT-I to tf.signal.dct.

PiperOrigin-RevId: 221356811
This commit is contained in:
A. Unique TensorFlower 2018-11-13 16:36:30 -08:00 committed by TensorFlower Gardener
parent a17b4fc17b
commit c4a6eb5857
3 changed files with 80 additions and 33 deletions

View File

@ -2630,6 +2630,7 @@ cuda_py_test(
name = "dct_ops_test",
srcs = ["dct_ops_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import importlib
from absl.testing import parameterized
import numpy as np
from tensorflow.python.ops import spectral_ops
@ -40,6 +41,20 @@ def try_import(name): # pylint: disable=invalid-name
fftpack = try_import("scipy.fftpack")
def _np_dct1(signals, norm=None):
"""Computes the DCT-I manually with NumPy."""
# X_k = (x_0 + (-1)**k * x_{N-1} +
# 2 * sum_{n=0}^{N-2} x_n * cos(\frac{pi}{N-1} * n * k) k=0,...,N-1
del norm
dct_size = signals.shape[-1]
dct = np.zeros_like(signals)
for k in range(dct_size):
phi = np.cos(np.pi * np.arange(1, dct_size - 1) * k / (dct_size - 1))
dct[..., k] = 2 * np.sum(signals[..., 1:-1] * phi, axis=-1) + (
signals[..., 0] + (-1) ** k * signals[..., -1])
return dct
def _np_dct2(signals, norm=None):
"""Computes the DCT-II manually with NumPy."""
# X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
@ -81,11 +96,11 @@ def _np_dct3(signals, norm=None):
return dct
NP_DCT = {2: _np_dct2, 3: _np_dct3}
NP_IDCT = {2: _np_dct3, 3: _np_dct2}
NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3}
NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2}
class DCTOpsTest(test.TestCase):
class DCTOpsTest(parameterized.TestCase, test.TestCase):
def _compare(self, signals, norm, dct_type, atol=5e-4, rtol=5e-4):
"""Compares (I)DCT to SciPy (if available) and a NumPy implementation."""
@ -106,26 +121,39 @@ class DCTOpsTest(test.TestCase):
tf_dct_idct = spectral_ops.dct(
tf_idct, type=dct_type, norm=norm).eval()
if norm is None:
tf_idct_dct *= 0.5 / signals.shape[-1]
tf_dct_idct *= 0.5 / signals.shape[-1]
if dct_type == 1:
tf_idct_dct *= 0.5 / (signals.shape[-1] - 1)
tf_dct_idct *= 0.5 / (signals.shape[-1] - 1)
else:
tf_idct_dct *= 0.5 / signals.shape[-1]
tf_dct_idct *= 0.5 / signals.shape[-1]
self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol)
self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
def test_random(self):
@parameterized.parameters([
[[2]], [[3]], [[10]], [[2, 20]], [[2, 3, 25]]])
def test_random(self, shape):
"""Test randomly generated batches of data."""
with spectral_ops_test_util.fft_kernel_label_map():
with self.session(use_gpu=True):
for shape in ([1], [2], [3], [10], [2, 20], [2, 3, 25]):
signals = np.random.rand(*shape).astype(np.float32)
for norm in (None, "ortho"):
self._compare(signals, norm, 2)
self._compare(signals, norm, 3)
signals = np.random.rand(*shape).astype(np.float32)
# Normalization not implemented for orthonormal.
self._compare(signals, norm=None, dct_type=1)
for norm in (None, "ortho"):
self._compare(signals, norm, 2)
self._compare(signals, norm, 3)
def test_error(self):
signals = np.random.rand(10)
# Unsupported type.
with self.assertRaises(ValueError):
spectral_ops.dct(signals, type=1)
spectral_ops.dct(signals, type=5)
# DCT-I normalization not implemented.
with self.assertRaises(ValueError):
spectral_ops.dct(signals, type=1, norm="ortho")
# DCT-I requires at least two inputs.
with self.assertRaises(ValueError):
spectral_ops.dct(np.random.rand(1), type=1)
# Unknown normalization.
with self.assertRaises(ValueError):
spectral_ops.dct(signals, norm="bad")

View File

@ -165,37 +165,48 @@ irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d")
tf_export("spectral.irfft3d")(irfft3d)
def _validate_dct_arguments(dct_type, n, axis, norm):
def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm):
"""Checks that DCT/IDCT arguments are compatible and well formed."""
if n is not None:
raise NotImplementedError("The DCT length argument is not implemented.")
if axis != -1:
raise NotImplementedError("axis must be -1. Got: %s" % axis)
if dct_type not in (2, 3):
raise ValueError("Only Types II and III (I)DCT are supported.")
if dct_type not in (1, 2, 3):
raise ValueError("Only Types I, II and III (I)DCT are supported.")
if dct_type == 1:
if norm == "ortho":
raise ValueError("Normalization is not supported for the Type-I DCT.")
if input_tensor.shape[-1] is not None and input_tensor.shape[-1] < 2:
raise ValueError(
"Type-I DCT requires the dimension to be greater than one.")
if norm not in (None, "ortho"):
raise ValueError(
"Unknown normalization. Expected None or 'ortho', got: %s" % norm)
# TODO(rjryan): Implement `type`, `n` and `axis` parameters.
# TODO(rjryan): Implement `n` and `axis` parameters.
@tf_export("spectral.dct")
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
Currently only Types II and III are supported. Type II is implemented using a
length `2N` padded `tf.spectral.rfft`, as described here:
https://dsp.stackexchange.com/a/10606. Type III is a fairly straightforward
inverse of Type II (i.e. using a length `2N` padded `tf.spectral.irfft`).
Currently only Types I, II and III are supported.
Type I is implemented using a length `2N` padded `tf.spectral.rfft`.
Type II is implemented using a length `2N` padded `tf.spectral.rfft`, as
described here:
https://dsp.stackexchange.com/a/10606.
Type III is a fairly straightforward inverse of Type II
(i.e. using a length `2N` padded `tf.spectral.irfft`).
@compatibility(scipy)
Equivalent to scipy.fftpack.dct for Type-II and Type-III DCT.
Equivalent to scipy.fftpack.dct for Type-I, Type-II and Type-III DCT.
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
@end_compatibility
Args:
input: A `[..., samples]` `float32` `Tensor` containing the signals to
take the DCT of.
type: The DCT type to perform. Must be 2 or 3.
type: The DCT type to perform. Must be 1, 2 or 3.
n: For future expansion. The length of the transform. Must be `None`.
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
norm: The normalization to apply. `None` for no normalization or `'ortho'`
@ -206,12 +217,13 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
A `[..., samples]` `float32` `Tensor` containing the DCT of `input`.
Raises:
ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
`-1`, or `norm` is not `None` or `'ortho'`.
ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is
not `-1`, or `norm` is not `None` or `'ortho'`.
ValueError: If `type` is `1` and `norm` is `ortho`.
[dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform
"""
_validate_dct_arguments(type, n, axis, norm)
_validate_dct_arguments(input, type, n, axis, norm)
with _ops.name_scope(name, "dct", [input]):
# We use the RFFT to compute the DCT and TensorFlow only supports float32
# for FFTs at the moment.
@ -220,6 +232,12 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
axis_dim = (tensor_shape.dimension_value(input.shape[-1])
or _array_ops.shape(input)[-1])
axis_dim_float = _math_ops.to_float(axis_dim)
if type == 1:
dct1_input = _array_ops.concat([input, input[..., -2:0:-1]], axis=-1)
dct1 = _math_ops.real(rfft(dct1_input))
return dct1
if type == 2:
scale = 2.0 * _math_ops.exp(
_math_ops.complex(
@ -266,12 +284,12 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
return dct3
# TODO(rjryan): Implement `type`, `n` and `axis` parameters.
# TODO(rjryan): Implement `n` and `axis` parameters.
@tf_export("spectral.idct")
def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
Currently only Types II and III are supported. Type III is the inverse of
Currently only Types I, II and III are supported. Type III is the inverse of
Type II, and vice versa.
Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
@ -281,14 +299,14 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
`signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
@compatibility(scipy)
Equivalent to scipy.fftpack.idct for Type-II and Type-III DCT.
Equivalent to scipy.fftpack.idct for Type-I, Type-II and Type-III DCT.
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html
@end_compatibility
Args:
input: A `[..., samples]` `float32` `Tensor` containing the signals to take
the DCT of.
type: The IDCT type to perform. Must be 2 or 3.
type: The IDCT type to perform. Must be 1, 2 or 3.
n: For future expansion. The length of the transform. Must be `None`.
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
norm: The normalization to apply. `None` for no normalization or `'ortho'`
@ -299,12 +317,12 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disab
A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`.
Raises:
ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
`-1`, or `norm` is not `None` or `'ortho'`.
ValueError: If `type` is not `1`, `2` or `3`, `n` is not `None, `axis` is
not `-1`, or `norm` is not `None` or `'ortho'`.
[idct]:
https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
"""
_validate_dct_arguments(type, n, axis, norm)
inverse_type = {2: 3, 3: 2}[type]
_validate_dct_arguments(input, type, n, axis, norm)
inverse_type = {1: 1, 2: 3, 3: 2}[type]
return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)