Add mel-scale conversion matrix support to tf.contrib.signal.
PiperOrigin-RevId: 168560255
This commit is contained in:
parent
b00b6d23c8
commit
a4f6e7c1af
@ -24,6 +24,16 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "mel_ops_test",
|
||||
srcs = ["python/kernel_tests/mel_ops_test.py"],
|
||||
additional_deps = [
|
||||
":signal_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "reconstruction_ops_test",
|
||||
srcs = ["python/kernel_tests/reconstruction_ops_test.py"],
|
||||
|
@ -20,11 +20,13 @@ See the @{$python/contrib.signal} guide.
|
||||
@@hamming_window
|
||||
@@hann_window
|
||||
@@inverse_stft
|
||||
@@linear_to_mel_weight_matrix
|
||||
@@overlap_and_add
|
||||
@@stft
|
||||
|
||||
[hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window
|
||||
[hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window
|
||||
[mel]: https://en.wikipedia.org/wiki/Mel_scale
|
||||
[stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
|
||||
"""
|
||||
|
||||
@ -32,6 +34,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.signal.python.ops.mel_ops import linear_to_mel_weight_matrix
|
||||
from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add
|
||||
from tensorflow.contrib.signal.python.ops.shape_ops import frame
|
||||
# `frame` used to be named `frames`, which is a noun and not a verb.
|
||||
|
164
tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
Normal file
164
tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py
Normal file
@ -0,0 +1,164 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for mel_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.signal.python.ops import mel_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
# mel spectrum constants and functions.
|
||||
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
|
||||
_MEL_HIGH_FREQUENCY_Q = 1127.0
|
||||
|
||||
|
||||
def hertz_to_mel(frequencies_hertz):
|
||||
"""Convert frequencies to mel scale using HTK formula.
|
||||
|
||||
Copied from
|
||||
https://github.com/tensorflow/models/blob/master/audioset/mel_features.py.
|
||||
|
||||
Args:
|
||||
frequencies_hertz: Scalar or np.array of frequencies in hertz.
|
||||
|
||||
Returns:
|
||||
Object of same size as frequencies_hertz containing corresponding values
|
||||
on the mel scale.
|
||||
"""
|
||||
return _MEL_HIGH_FREQUENCY_Q * np.log(
|
||||
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
|
||||
|
||||
|
||||
def spectrogram_to_mel_matrix(num_mel_bins=20,
|
||||
num_spectrogram_bins=129,
|
||||
audio_sample_rate=8000,
|
||||
lower_edge_hertz=125.0,
|
||||
upper_edge_hertz=3800.0):
|
||||
"""Return a matrix that can post-multiply spectrogram rows to make mel.
|
||||
|
||||
Copied from
|
||||
https://github.com/tensorflow/models/blob/master/audioset/mel_features.py.
|
||||
|
||||
Returns a np.array matrix A that can be used to post-multiply a matrix S of
|
||||
spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
|
||||
"mel spectrogram" M of frames x num_mel_bins. M = S A.
|
||||
|
||||
The classic HTK algorithm exploits the complementarity of adjacent mel bands
|
||||
to multiply each FFT bin by only one mel weight, then add it, with positive
|
||||
and negative signs, to the two adjacent mel bands to which that bin
|
||||
contributes. Here, by expressing this operation as a matrix multiply, we go
|
||||
from num_fft multiplies per frame (plus around 2*num_fft adds) to around
|
||||
num_fft^2 multiplies and adds. However, because these are all presumably
|
||||
accomplished in a single call to np.dot(), it's not clear which approach is
|
||||
faster in Python. The matrix multiplication has the attraction of being more
|
||||
general and flexible, and much easier to read.
|
||||
|
||||
Args:
|
||||
num_mel_bins: How many bands in the resulting mel spectrum. This is
|
||||
the number of columns in the output matrix.
|
||||
num_spectrogram_bins: How many bins there are in the source spectrogram
|
||||
data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
|
||||
only contains the nonredundant FFT bins.
|
||||
audio_sample_rate: Samples per second of the audio at the input to the
|
||||
spectrogram. We need this to figure out the actual frequencies for
|
||||
each spectrogram bin, which dictates how they are mapped into mel.
|
||||
lower_edge_hertz: Lower bound on the frequencies to be included in the mel
|
||||
spectrum. This corresponds to the lower edge of the lowest triangular
|
||||
band.
|
||||
upper_edge_hertz: The desired top edge of the highest frequency band.
|
||||
|
||||
Returns:
|
||||
An np.array with shape (num_spectrogram_bins, num_mel_bins).
|
||||
|
||||
Raises:
|
||||
ValueError: if frequency edges are incorrectly ordered.
|
||||
"""
|
||||
nyquist_hertz = audio_sample_rate / 2.
|
||||
if lower_edge_hertz >= upper_edge_hertz:
|
||||
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
|
||||
(lower_edge_hertz, upper_edge_hertz))
|
||||
spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
|
||||
spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
|
||||
# The i'th mel band (starting from i=1) has center frequency
|
||||
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
|
||||
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
|
||||
# the band_edges_mel arrays.
|
||||
band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
|
||||
hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
|
||||
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
|
||||
# of spectrogram values.
|
||||
mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
|
||||
for i in range(num_mel_bins):
|
||||
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
|
||||
# Calculate lower and upper slopes for every spectrogram bin.
|
||||
# Line segments are linear in the *mel* domain, not hertz.
|
||||
lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
|
||||
(center_mel - lower_edge_mel))
|
||||
upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
|
||||
(upper_edge_mel - center_mel))
|
||||
# .. then intersect them with each other and zero.
|
||||
mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
|
||||
upper_slope))
|
||||
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
|
||||
# coefficient.
|
||||
mel_weights_matrix[0, :] = 0.0
|
||||
return mel_weights_matrix
|
||||
|
||||
|
||||
class LinearToMelTest(test.TestCase):
|
||||
|
||||
def test_matches_reference_implementation(self):
|
||||
# Tuples of (num_mel_bins, num_spectrogram_bins, sample_rate,
|
||||
# lower_edge_hertz, upper_edge_hertz) to test.
|
||||
configs = [
|
||||
# Defaults.
|
||||
(20, 129, 8000.0, 125.0, 3800.0),
|
||||
# Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
|
||||
(80, 1025, 24000.0, 80.0, 12000.0)
|
||||
]
|
||||
with self.test_session(use_gpu=True):
|
||||
for config in configs:
|
||||
mel_matrix_np = spectrogram_to_mel_matrix(*config)
|
||||
mel_matrix = mel_ops.linear_to_mel_weight_matrix(*config)
|
||||
self.assertAllClose(mel_matrix_np, mel_matrix.eval(), atol=3e-6)
|
||||
|
||||
def test_dtypes(self):
|
||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||
self.assertEqual(dtype,
|
||||
mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype)
|
||||
|
||||
def test_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mel_ops.linear_to_mel_weight_matrix(num_mel_bins=0)
|
||||
with self.assertRaises(ValueError):
|
||||
mel_ops.linear_to_mel_weight_matrix(num_spectrogram_bins=0)
|
||||
with self.assertRaises(ValueError):
|
||||
mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0)
|
||||
with self.assertRaises(ValueError):
|
||||
mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1)
|
||||
with self.assertRaises(ValueError):
|
||||
mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=100,
|
||||
upper_edge_hertz=10)
|
||||
with self.assertRaises(ValueError):
|
||||
mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
199
tensorflow/contrib/signal/python/ops/mel_ops.py
Normal file
199
tensorflow/contrib/signal/python/ops/mel_ops.py
Normal file
@ -0,0 +1,199 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""mel conversion ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.signal.python.ops import shape_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
# mel spectrum constants.
|
||||
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
|
||||
_MEL_HIGH_FREQUENCY_Q = 1127.0
|
||||
|
||||
|
||||
def _mel_to_hertz(mel_values, name=None):
|
||||
"""Converts frequencies in `mel_values` from the mel scale to linear scale.
|
||||
|
||||
Args:
|
||||
mel_values: A `Tensor` of frequencies in the mel scale.
|
||||
name: An optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A `Tensor` of the same shape and type as `mel_values` containing linear
|
||||
scale frequencies in Hertz.
|
||||
"""
|
||||
with ops.name_scope(name, 'mel_to_hertz', [mel_values]):
|
||||
mel_values = ops.convert_to_tensor(mel_values)
|
||||
return _MEL_BREAK_FREQUENCY_HERTZ * (
|
||||
math_ops.exp(mel_values / _MEL_HIGH_FREQUENCY_Q) - 1.0
|
||||
)
|
||||
|
||||
|
||||
def _hertz_to_mel(frequencies_hertz, name=None):
|
||||
"""Converts frequencies in `frequencies_hertz` in Hertz to the mel scale.
|
||||
|
||||
Args:
|
||||
frequencies_hertz: A `Tensor` of frequencies in Hertz.
|
||||
name: An optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A `Tensor` of the same shape and type of `frequencies_hertz` containing
|
||||
frequencies in the mel scale.
|
||||
"""
|
||||
with ops.name_scope(name, 'hertz_to_mel', [frequencies_hertz]):
|
||||
frequencies_hertz = ops.convert_to_tensor(frequencies_hertz)
|
||||
return _MEL_HIGH_FREQUENCY_Q * math_ops.log(
|
||||
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
|
||||
|
||||
|
||||
def _validate_arguments(num_mel_bins, num_spectrogram_bins, sample_rate,
|
||||
lower_edge_hertz, upper_edge_hertz, dtype):
|
||||
"""Checks the inputs to linear_to_mel_weight_matrix."""
|
||||
if num_mel_bins <= 0:
|
||||
raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
|
||||
if num_spectrogram_bins <= 0:
|
||||
raise ValueError('num_spectrogram_bins must be positive. Got: %s' %
|
||||
num_spectrogram_bins)
|
||||
if sample_rate <= 0.0:
|
||||
raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
|
||||
if lower_edge_hertz < 0.0:
|
||||
raise ValueError('lower_edge_hertz must be non-negative. Got: %s' %
|
||||
lower_edge_hertz)
|
||||
if lower_edge_hertz >= upper_edge_hertz:
|
||||
raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' %
|
||||
(lower_edge_hertz, upper_edge_hertz))
|
||||
if not dtype.is_floating:
|
||||
raise ValueError('dtype must be a floating point type. Got: %s' % dtype)
|
||||
|
||||
|
||||
def linear_to_mel_weight_matrix(num_mel_bins=20,
|
||||
num_spectrogram_bins=129,
|
||||
sample_rate=8000,
|
||||
lower_edge_hertz=125.0,
|
||||
upper_edge_hertz=3800.0,
|
||||
dtype=dtypes.float32,
|
||||
name=None):
|
||||
"""Returns a matrix to warp linear scale spectrograms to the [mel scale][mel].
|
||||
|
||||
Returns a weight matrix that can be used to re-weight a `Tensor` containing
|
||||
`num_spectrogram_bins` linearly sampled frequency information from
|
||||
`[0, sample_rate / 2]` into `num_mel_bins` frequency information from
|
||||
`[lower_edge_hertz, upper_edge_hertz]` on the [mel scale][mel].
|
||||
|
||||
For example, the returned matrix `A` can be used to right-multiply a
|
||||
spectrogram `S` of shape `[frames, num_spectrogram_bins]` of linear
|
||||
scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogram"
|
||||
`M` of shape `[frames, num_mel_bins]`.
|
||||
|
||||
# `S` has shape [frames, num_spectrogram_bins]
|
||||
# `M` has shape [frames, num_mel_bins]
|
||||
M = tf.matmul(S, A)
|
||||
|
||||
The matrix can be used with @{tf.tensordot} to convert an arbitrary rank
|
||||
`Tensor` of linear-scale spectral bins into the mel scale.
|
||||
|
||||
# S has shape [..., num_spectrogram_bins].
|
||||
# M has shape [..., num_mel_bins].
|
||||
M = tf.tensordot(S, A, 1)
|
||||
# tf.tensordot does not support shape inference for this case yet.
|
||||
M.set_shape(S.shape[:-1].concatenate(A.shape[-1:]))
|
||||
|
||||
Args:
|
||||
num_mel_bins: Python int. How many bands in the resulting mel spectrum.
|
||||
num_spectrogram_bins: Python int. How many bins there are in the source
|
||||
spectrogram data, which is understood to be `fft_size // 2 + 1`, i.e. the
|
||||
spectrogram only contains the nonredundant FFT bins.
|
||||
sample_rate: Python float. Samples per second of the input signal used to
|
||||
create the spectrogram. We need this to figure out the actual frequencies
|
||||
for each spectrogram bin, which dictates how they are mapped into the mel
|
||||
scale.
|
||||
lower_edge_hertz: Python float. Lower bound on the frequencies to be
|
||||
included in the mel spectrum. This corresponds to the lower edge of the
|
||||
lowest triangular band.
|
||||
upper_edge_hertz: Python float. The desired top edge of the highest
|
||||
frequency band.
|
||||
dtype: The `DType` of the result matrix. Must be a floating point type.
|
||||
name: An optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A `Tensor` of shape `[num_spectrogram_bins, num_mel_bins]`.
|
||||
|
||||
Raises:
|
||||
ValueError: If num_mel_bins/num_spectrogram_bins/sample_rate are not
|
||||
positive, lower_edge_hertz is negative, or frequency edges are incorrectly
|
||||
ordered.
|
||||
|
||||
[mel]: https://en.wikipedia.org/wiki/Mel_scale
|
||||
"""
|
||||
with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name:
|
||||
_validate_arguments(num_mel_bins, num_spectrogram_bins, sample_rate,
|
||||
lower_edge_hertz, upper_edge_hertz, dtype)
|
||||
|
||||
# To preserve accuracy, we compute the matrix at float64 precision and then
|
||||
# cast to `dtype` at the end. This function can be constant folded by graph
|
||||
# optimization since there are no Tensor inputs.
|
||||
sample_rate = ops.convert_to_tensor(
|
||||
sample_rate, dtypes.float64, name='sample_rate')
|
||||
lower_edge_hertz = ops.convert_to_tensor(
|
||||
lower_edge_hertz, dtypes.float64, name='lower_edge_hertz')
|
||||
upper_edge_hertz = ops.convert_to_tensor(
|
||||
upper_edge_hertz, dtypes.float64, name='upper_edge_hertz')
|
||||
zero_float64 = ops.convert_to_tensor(0.0, dtypes.float64)
|
||||
|
||||
# HTK excludes the spectrogram DC bin.
|
||||
bands_to_zero = 1
|
||||
nyquist_hertz = sample_rate / 2.0
|
||||
linear_frequencies = math_ops.linspace(
|
||||
zero_float64, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:]
|
||||
spectrogram_bins_mel = array_ops.expand_dims(
|
||||
_hertz_to_mel(linear_frequencies), 1)
|
||||
|
||||
# Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
|
||||
# center of each band is the lower and upper edge of the adjacent bands.
|
||||
# Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
|
||||
# num_mel_bins + 2 pieces.
|
||||
band_edges_mel = shape_ops.frame(
|
||||
math_ops.linspace(_hertz_to_mel(lower_edge_hertz),
|
||||
_hertz_to_mel(upper_edge_hertz),
|
||||
num_mel_bins + 2), frame_length=3, frame_step=1)
|
||||
|
||||
# Split the triples up and reshape them into [1, num_mel_bins] tensors.
|
||||
lower_edge_mel, center_mel, upper_edge_mel = tuple(array_ops.reshape(
|
||||
t, [1, num_mel_bins]) for t in array_ops.split(
|
||||
band_edges_mel, 3, axis=1))
|
||||
|
||||
# Calculate lower and upper slopes for every spectrogram bin.
|
||||
# Line segments are linear in the mel domain, not Hertz.
|
||||
lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
|
||||
center_mel - lower_edge_mel)
|
||||
upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
|
||||
upper_edge_mel - center_mel)
|
||||
|
||||
# Intersect the line segments with each other and zero.
|
||||
mel_weights_matrix = math_ops.maximum(
|
||||
zero_float64, math_ops.minimum(lower_slopes, upper_slopes))
|
||||
|
||||
# Re-add the zeroed lower bins we sliced out above.
|
||||
mel_weights_matrix = array_ops.pad(
|
||||
mel_weights_matrix, [[bands_to_zero, 0], [0, 0]])
|
||||
|
||||
# Cast to the desired type.
|
||||
return math_ops.cast(mel_weights_matrix, dtype, name=name)
|
Loading…
Reference in New Issue
Block a user