tf.signal: Add a Modified Discrete Cosine Transform (MDCT) and its inverse to tf.signal.

Also adds 2 new window types which are commonly used with the MDCT.
- Kaiser-Bessel derived window
- Vorbis window

Also adds a Kaiser window which is used to calculate Kaiser-Bessel derived window and can also be used elsewhere.

TESTED:
- unit tests
PiperOrigin-RevId: 289103282
Change-Id: Id5972a413b7635716cef29b5be51e285a4ac5de5
This commit is contained in:
A. Unique TensorFlower 2020-01-10 09:04:43 -08:00 committed by TensorFlower Gardener
parent f6efdc52b1
commit f146ef1740
6 changed files with 400 additions and 5 deletions

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
import numpy as np
@ -315,6 +317,46 @@ class SpectralOpsTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(b_jacob_t, b_jacob_n,
rtol=backward_tol, atol=backward_tol)
@parameterized.parameters(
itertools.product(
(4000,),
(256,),
(np.float32, np.float64),
("ortho", None),
("vorbis", "kaiser_bessel_derived", None),
(False, True)))
def test_mdct_round_trip(self, signal_length, frame_length, np_rtype,
norm, window_type, pad_end):
if np_rtype == np.float32:
tol = 1e-5
else:
if window_type == "kaiser_bessel_derived":
tol = 1e-6
else:
tol = 1e-8
# Generate a random white Gaussian signal.
signal = np.random.normal(size=signal_length).astype(np_rtype)
if window_type == "vorbis":
window_fn = window_ops.vorbis_window
elif window_type == "kaiser_bessel_derived":
window_fn = window_ops.kaiser_bessel_derived_window
elif window_type is None:
window_fn = None
mdct = spectral_ops.mdct(signal, frame_length, norm=norm,
window_fn=window_fn, pad_end=pad_end)
inverse_mdct = spectral_ops.inverse_mdct(mdct, norm=norm,
window_fn=window_fn)
inverse_mdct = self.evaluate(inverse_mdct)
# Truncate signal and inverse_mdct to their minimum length.
min_length = np.minimum(signal.shape[0], inverse_mdct.shape[0])
# Ignore the half_len samples at either edge.
half_len = frame_length // 2
signal = signal[half_len:min_length-half_len]
inverse_mdct = inverse_mdct[half_len:min_length-half_len]
# Check that the inverse and original signal are close.
self.assertAllClose(inverse_mdct, signal, atol=tol, rtol=tol)
if __name__ == "__main__":
test.main()

View File

@ -38,6 +38,7 @@ _TF_DTYPE_TOLERANCE = [(dtypes.float16, 1e-2),
(dtypes.float32, 1e-6),
(dtypes.float64, 1e-9)]
_WINDOW_LENGTHS = [1, 2, 3, 4, 5, 31, 64, 128]
_MDCT_WINDOW_LENGTHS = [4, 16, 256]
def _scipy_raised_cosine(length, symmetric=True, a=0.5, b=0.5):
@ -69,6 +70,21 @@ def _scipy_raised_cosine(length, symmetric=True, a=0.5, b=0.5):
@tf_test_util.run_all_in_graph_and_eager_modes
class WindowOpsTest(test.TestCase, parameterized.TestCase):
def _check_mdct_window(self, window, tol=1e-6):
"""Check that an MDCT window satisfies necessary conditions."""
# We check that the length of the window is a multiple of 4 and
# for symmetry of the window and also Princen-Bradley condition which
# requires that w[n]^2 + w[n + N//2]^2 = 1 for an N length window.
wlen = int(np.shape(window)[0])
assert wlen % 4 == 0
half_len = wlen // 2
squared_sums = window[:half_len]**2 + window[half_len:]**2
self.assertAllClose(squared_sums, np.ones((half_len,)),
tol, tol)
sym_diff = window[:half_len] - window[-1:half_len-1:-1]
self.assertAllClose(sym_diff, np.zeros((half_len,)),
tol, tol)
def _compare_window_fns(self, np_window_fn, tf_window_fn, window_length,
periodic, tf_dtype_tol):
tf_dtype, tol = tf_dtype_tol
@ -79,6 +95,18 @@ class WindowOpsTest(test.TestCase, parameterized.TestCase):
dtype=tf_dtype)
self.assertAllClose(expected, actual, tol, tol)
@parameterized.parameters(
itertools.product(
_WINDOW_LENGTHS,
(4., 8., 10., 12.),
_TF_DTYPE_TOLERANCE))
def test_kaiser_window(self, window_length, beta, tf_dtype_tol):
"""Check that kaiser_window matches np.kaiser behavior."""
self.assertAllClose(
np.kaiser(window_length, beta),
window_ops.kaiser_window(window_length, beta, tf_dtype_tol[0]),
tf_dtype_tol[1], tf_dtype_tol[1])
@parameterized.parameters(
itertools.product(
_WINDOW_LENGTHS,
@ -109,7 +137,9 @@ class WindowOpsTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(
itertools.product(
(window_ops.hann_window, window_ops.hamming_window),
(window_ops.hann_window, window_ops.hamming_window,
window_ops.kaiser_window, window_ops.kaiser_bessel_derived_window,
window_ops.vorbis_window),
(False, True),
_TF_DTYPE_TOLERANCE))
def test_constant_folding(self, window_fn, periodic, tf_dtype_tol):
@ -118,7 +148,10 @@ class WindowOpsTest(test.TestCase, parameterized.TestCase):
return
g = ops.Graph()
with g.as_default():
window = window_fn(100, periodic=periodic, dtype=tf_dtype_tol[0])
try:
window = window_fn(100, periodic=periodic, dtype=tf_dtype_tol[0])
except TypeError:
window = window_fn(100, dtype=tf_dtype_tol[0])
rewritten_graph = test_util.grappler_optimize(g, [window])
self.assertLen(rewritten_graph.node, 1)
@ -128,11 +161,15 @@ class WindowOpsTest(test.TestCase, parameterized.TestCase):
(window_ops.hann_window, 10, False, dtypes.float32, True),
(window_ops.hann_window, 10, True, dtypes.float32, True),
(window_ops.hamming_window, 10, False, dtypes.float32, True),
(window_ops.hamming_window, 10, True, dtypes.float32, True))
(window_ops.hamming_window, 10, True, dtypes.float32, True),
(window_ops.vorbis_window, 12, None, dtypes.float32, True))
def test_tflite_convert(self, window_fn, window_length, periodic, dtype,
use_mlir):
def fn(window_length):
return window_fn(window_length, periodic, dtype=dtype)
try:
return window_fn(window_length, periodic=periodic, dtype=dtype)
except TypeError:
return window_fn(window_length, dtype=dtype)
tflite_model = test_util.tflite_convert(
fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)], use_mlir)
@ -143,6 +180,26 @@ class WindowOpsTest(test.TestCase, parameterized.TestCase):
expected_output = self.evaluate(fn(window_length))
self.assertAllClose(actual_output, expected_output, rtol=1e-6, atol=1e-6)
@parameterized.parameters(
itertools.product(
_MDCT_WINDOW_LENGTHS,
_TF_DTYPE_TOLERANCE))
def test_vorbis_window(self, window_length, tf_dtype_tol):
"""Check if vorbis windows satisfy MDCT window conditions."""
self._check_mdct_window(window_ops.vorbis_window(window_length,
dtype=tf_dtype_tol[0]),
tol=tf_dtype_tol[1])
@parameterized.parameters(
itertools.product(
_MDCT_WINDOW_LENGTHS,
(4., 8., 10., 12.),
_TF_DTYPE_TOLERANCE))
def test_kaiser_bessel_derived_window(self, window_length, beta,
tf_dtype_tol):
"""Check if Kaiser-Bessel derived windows satisfy MDCT window conditions."""
self._check_mdct_window(window_ops.kaiser_bessel_derived_window(
window_length, beta=beta, dtype=tf_dtype_tol[0]), tol=tf_dtype_tol[1])
if __name__ == '__main__':
test.main()

View File

@ -26,6 +26,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.signal import dct_ops
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.ops.signal import reconstruction_ops
from tensorflow.python.ops.signal import shape_ops
@ -287,3 +288,146 @@ def _enclosing_power_of_two(value):
math_ops.ceil(
math_ops.log(math_ops.cast(value, dtypes.float32)) /
math_ops.log(2.0))), value.dtype)
@tf_export('signal.mdct')
def mdct(signals, frame_length, window_fn=window_ops.vorbis_window,
pad_end=False, norm=None, name=None):
"""Computes the [Modified Discrete Cosine Transform][mdct] of `signals`.
Implemented with TPU/GPU-compatible ops and supports gradients.
Args:
signals: A `[..., samples]` `float32`/`float64` `Tensor` of real-valued
signals.
frame_length: An integer scalar `Tensor`. The window length in samples
which must be divisible by 4.
window_fn: A callable that takes a window length and a `dtype` keyword
argument and returns a `[window_length]` `Tensor` of samples in the
provided datatype. If set to `None`, no windowing is used.
pad_end: Whether to pad the end of `signals` with zeros when the provided
frame length and step produces a frame that lies partially past its end.
norm: If it is None, unnormalized dct4 is used, if it is "ortho"
orthonormal dct4 is used.
name: An optional name for the operation.
Returns:
A `[..., frames, frame_length // 2]` `Tensor` of `float32`/`float64`
MDCT values where `frames` is roughly `samples // (frame_length // 2)`
when `pad_end=False`.
Raises:
ValueError: If `signals` is not at least rank 1, `frame_length` is
not scalar, or `frame_length` is not a multiple of `4`.
[mdct]: https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform
"""
with ops.name_scope(name, 'mdct', [signals, frame_length]):
signals = ops.convert_to_tensor(signals, name='signals')
signals.shape.with_rank_at_least(1)
frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
frame_length.shape.assert_has_rank(0)
# Assert that frame_length is divisible by 4.
frame_length_static = tensor_util.constant_value(frame_length)
if frame_length_static is not None and frame_length_static % 4 != 0:
raise ValueError('The frame length must be a multiple of 4.')
frame_step = frame_length // 2
framed_signals = shape_ops.frame(
signals, frame_length, frame_step, pad_end=pad_end)
# Optionally window the framed signals.
if window_fn is not None:
window = window_fn(frame_length, dtype=framed_signals.dtype)
framed_signals *= window
else:
framed_signals *= 1.0 / np.sqrt(2)
split_frames = array_ops.split(framed_signals, 4, axis=-1)
frame_firsthalf = -array_ops.reverse(split_frames[2],
[-1]) - split_frames[3]
frame_secondhalf = split_frames[0] - array_ops.reverse(split_frames[1],
[-1])
frames_rearranged = array_ops.concat((frame_firsthalf, frame_secondhalf),
axis=-1)
# Below call produces the (frame_length // 2) unique components of the
# type 4 orthonormal DCT of the real windowed signals in frames_rearranged.
return dct_ops.dct(frames_rearranged, type=4, norm=norm)
@tf_export('signal.inverse_mdct')
def inverse_mdct(mdcts,
window_fn=window_ops.vorbis_window,
norm=None,
name=None):
"""Computes the inverse modified DCT of `mdcts`.
To reconstruct an original waveform, the same window function should
be used with `mdct` and `inverse_mdct`.
Example usage:
>>> @tf.function
... def compare_round_trip():
... samples = 1000
... frame_length = 400
... halflen = frame_length // 2
... waveform = tf.random.normal(dtype=tf.float32, shape=[samples])
... waveform_pad = tf.pad(waveform, [[halflen, 0],])
... mdct = tf.signal.mdct(waveform_pad, frame_length, pad_end=True,
... window_fn=tf.signal.vorbis_window)
... inverse_mdct = tf.signal.inverse_mdct(mdct,
... window_fn=tf.signal.vorbis_window)
... inverse_mdct = inverse_mdct[halflen: halflen + samples]
... return waveform, inverse_mdct
>>> waveform, inverse_mdct = compare_round_trip()
>>> np.allclose(waveform.numpy(), inverse_mdct.numpy(), rtol=1e-3, atol=1e-4)
True
Implemented with TPU/GPU-compatible ops and supports gradients.
Args:
mdcts: A `float32`/`float64` `[..., frames, frame_length // 2]`
`Tensor` of MDCT bins representing a batch of `frame_length // 2`-point
MDCTs.
window_fn: A callable that takes a window length and a `dtype` keyword
argument and returns a `[window_length]` `Tensor` of samples in the
provided datatype. If set to `None`, no windowing is used.
norm: If "ortho", orthonormal inverse DCT4 is performed, if it is None,
a regular dct4 followed by scaling of `1/frame_length` is performed.
name: An optional name for the operation.
Returns:
A `[..., samples]` `Tensor` of `float32`/`float64` signals representing
the inverse MDCT for each input MDCT in `mdcts` where `samples` is
`(frames - 1) * (frame_length // 2) + frame_length`.
Raises:
ValueError: If `mdcts` is not at least rank 2.
[mdct]: https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform
"""
with ops.name_scope(name, 'inverse_mdct', [mdcts]):
mdcts = ops.convert_to_tensor(mdcts, name='mdcts')
mdcts.shape.with_rank_at_least(2)
half_len = math_ops.cast(mdcts.shape[-1], dtype=dtypes.int32)
if norm is None:
half_len_float = math_ops.cast(half_len, dtype=mdcts.dtype)
result_idct4 = (0.5 / half_len_float) * dct_ops.dct(mdcts, type=4)
elif norm == 'ortho':
result_idct4 = dct_ops.dct(mdcts, type=4, norm='ortho')
split_result = array_ops.split(result_idct4, 2, axis=-1)
real_frames = array_ops.concat((split_result[1],
-array_ops.reverse(split_result[1], [-1]),
-array_ops.reverse(split_result[0], [-1]),
-split_result[0]), axis=-1)
# Optionally window and overlap-add the inner 2 dimensions of real_frames
# into a single [samples] dimension.
if window_fn is not None:
window = window_fn(2 * half_len, dtype=mdcts.dtype)
real_frames *= window
else:
real_frames *= 1.0 / np.sqrt(2)
return reconstruction_ops.overlap_and_add(real_frames, half_len)

View File

@ -30,6 +30,117 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
def _check_params(window_length, dtype):
"""Check window_length and dtype params.
Args:
window_length: A scalar value or `Tensor`.
dtype: The data type to produce. Must be a floating point type.
Returns:
window_length converted to a tensor of type int32.
Raises:
ValueError: If `dtype` is not a floating point type or window_length is not
a scalar.
"""
if not dtype.is_floating:
raise ValueError('dtype must be a floating point type. Found %s' % dtype)
window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32)
window_length.shape.assert_has_rank(0)
return window_length
@tf_export('signal.kaiser_window')
def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None):
"""Generate a [Kaiser window][kaiser].
Args:
window_length: A scalar `Tensor` indicating the window length to generate.
beta: Beta parameter for Kaiser window, see reference below.
dtype: The data type to produce. Must be a floating point type.
name: An optional name for the operation.
Returns:
A `Tensor` of shape `[window_length]` of type `dtype`.
[kaiser]:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.kaiser.html
"""
with ops.name_scope(name, 'kaiser_window'):
window_length = _check_params(window_length, dtype)
window_length_const = tensor_util.constant_value(window_length)
if window_length_const == 1:
return array_ops.ones([1], dtype=dtype)
# tf.range does not support float16 so we work with float32 initially.
halflen_float = (
math_ops.cast(window_length, dtype=dtypes.float32) - 1.0) / 2.0
arg = math_ops.range(-halflen_float, halflen_float + 0.1,
dtype=dtypes.float32)
# Convert everything into given dtype which can be float16.
arg = math_ops.cast(arg, dtype=dtype)
beta = math_ops.cast(beta, dtype=dtype)
one = math_ops.cast(1.0, dtype=dtype)
two = math_ops.cast(2.0, dtype=dtype)
halflen_float = math_ops.cast(halflen_float, dtype=dtype)
num = beta * math_ops.sqrt(
one - math_ops.pow(arg, two) / math_ops.pow(halflen_float, two))
window = math_ops.exp(num - beta) * (math_ops.bessel_i0e(num) /
math_ops.bessel_i0e(beta))
return window
@tf_export('signal.kaiser_bessel_derived_window')
def kaiser_bessel_derived_window(window_length, beta=12.,
dtype=dtypes.float32, name=None):
"""Generate a [Kaiser Bessel derived window][kbd].
Args:
window_length: A scalar `Tensor` indicating the window length to generate.
beta: Beta parameter for Kaiser window.
dtype: The data type to produce. Must be a floating point type.
name: An optional name for the operation.
Returns:
A `Tensor` of shape `[window_length]` of type `dtype`.
[kbd]:
https://en.wikipedia.org/wiki/Kaiser_window#Kaiser%E2%80%93Bessel-derived_(KBD)_window
"""
with ops.name_scope(name, 'kaiser_bessel_derived_window'):
window_length = _check_params(window_length, dtype)
halflen = window_length // 2
kaiserw = kaiser_window(halflen + 1, beta, dtype=dtype)
kaiserw_csum = math_ops.cumsum(kaiserw)
halfw = math_ops.sqrt(kaiserw_csum[:-1] / kaiserw_csum[-1])
window = array_ops.concat((halfw, halfw[::-1]), axis=0)
return window
@tf_export('signal.vorbis_window')
def vorbis_window(window_length, dtype=dtypes.float32, name=None):
"""Generate a [Vorbis power complementary window][vorbis].
Args:
window_length: A scalar `Tensor` indicating the window length to generate.
dtype: The data type to produce. Must be a floating point type.
name: An optional name for the operation.
Returns:
A `Tensor` of shape `[window_length]` of type `dtype`.
[vorbis]:
https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform#Window_functions
"""
with ops.name_scope(name, 'vorbis_window'):
window_length = _check_params(window_length, dtype)
arg = math_ops.cast(math_ops.range(window_length), dtype=dtype)
window = math_ops.sin(np.pi / 2.0 * math_ops.pow(math_ops.sin(
np.pi / math_ops.cast(window_length, dtype=dtype) *
(arg + 0.5)), 2.0))
return window
@tf_export('signal.hann_window')
def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
"""Generate a [Hann window][hann].
@ -75,7 +186,8 @@ def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
Raises:
ValueError: If `dtype` is not a floating point type.
[hamming]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
[hamming]:
https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
"""
return _raised_cosine_window(name, 'hamming_window', window_length, periodic,
dtype, 0.54, 0.46)

View File

@ -52,6 +52,10 @@ tf_module {
name: "ifftshift"
argspec: "args=[\'x\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "inverse_mdct"
argspec: "args=[\'mdcts\', \'window_fn\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'<function vorbis_window instance>\', \'None\', \'None\'], "
}
member_method {
name: "inverse_stft"
argspec: "args=[\'stfts\', \'frame_length\', \'frame_step\', \'fft_length\', \'window_fn\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'<function hann_window instance>\', \'None\'], "
@ -72,10 +76,22 @@ tf_module {
name: "irfft3d"
argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "kaiser_bessel_derived_window"
argspec: "args=[\'window_length\', \'beta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'12.0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "kaiser_window"
argspec: "args=[\'window_length\', \'beta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'12.0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "linear_to_mel_weight_matrix"
argspec: "args=[\'num_mel_bins\', \'num_spectrogram_bins\', \'sample_rate\', \'lower_edge_hertz\', \'upper_edge_hertz\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'20\', \'129\', \'8000\', \'125.0\', \'3800.0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "mdct"
argspec: "args=[\'signals\', \'frame_length\', \'window_fn\', \'pad_end\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'<function vorbis_window instance>\', \'False\', \'None\', \'None\'], "
}
member_method {
name: "mfccs_from_log_mel_spectrograms"
argspec: "args=[\'log_mel_spectrograms\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -100,4 +116,8 @@ tf_module {
name: "stft"
argspec: "args=[\'signals\', \'frame_length\', \'frame_step\', \'fft_length\', \'window_fn\', \'pad_end\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'<function hann_window instance>\', \'False\', \'None\'], "
}
member_method {
name: "vorbis_window"
argspec: "args=[\'window_length\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
}

View File

@ -52,6 +52,10 @@ tf_module {
name: "ifftshift"
argspec: "args=[\'x\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "inverse_mdct"
argspec: "args=[\'mdcts\', \'window_fn\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'<function vorbis_window instance>\', \'None\', \'None\'], "
}
member_method {
name: "inverse_stft"
argspec: "args=[\'stfts\', \'frame_length\', \'frame_step\', \'fft_length\', \'window_fn\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'<function hann_window instance>\', \'None\'], "
@ -72,10 +76,22 @@ tf_module {
name: "irfft3d"
argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "kaiser_bessel_derived_window"
argspec: "args=[\'window_length\', \'beta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'12.0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "kaiser_window"
argspec: "args=[\'window_length\', \'beta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'12.0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "linear_to_mel_weight_matrix"
argspec: "args=[\'num_mel_bins\', \'num_spectrogram_bins\', \'sample_rate\', \'lower_edge_hertz\', \'upper_edge_hertz\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'20\', \'129\', \'8000\', \'125.0\', \'3800.0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "mdct"
argspec: "args=[\'signals\', \'frame_length\', \'window_fn\', \'pad_end\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'<function vorbis_window instance>\', \'False\', \'None\', \'None\'], "
}
member_method {
name: "mfccs_from_log_mel_spectrograms"
argspec: "args=[\'log_mel_spectrograms\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -100,4 +116,8 @@ tf_module {
name: "stft"
argspec: "args=[\'signals\', \'frame_length\', \'frame_step\', \'fft_length\', \'window_fn\', \'pad_end\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'<function hann_window instance>\', \'False\', \'None\'], "
}
member_method {
name: "vorbis_window"
argspec: "args=[\'window_length\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
}