Allow sample_rate
to be a constant-value Tensor.
PiperOrigin-RevId: 257024106
This commit is contained in:
parent
afd99cdb8c
commit
e77886d0ce
@ -20,8 +20,10 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import test_util as tf_test_util
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
from tensorflow.python.kernel_tests.signal import test_util
|
from tensorflow.python.kernel_tests.signal import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -95,6 +97,7 @@ def spectrogram_to_mel_matrix(num_mel_bins=20,
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if frequency edges are incorrectly ordered.
|
ValueError: if frequency edges are incorrectly ordered.
|
||||||
"""
|
"""
|
||||||
|
audio_sample_rate = tensor_util.constant_value(audio_sample_rate)
|
||||||
nyquist_hertz = audio_sample_rate / 2.
|
nyquist_hertz = audio_sample_rate / 2.
|
||||||
if lower_edge_hertz >= upper_edge_hertz:
|
if lower_edge_hertz >= upper_edge_hertz:
|
||||||
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
|
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
|
||||||
@ -135,8 +138,10 @@ class LinearToMelTest(test.TestCase):
|
|||||||
configs = [
|
configs = [
|
||||||
# Defaults.
|
# Defaults.
|
||||||
(20, 129, 8000.0, 125.0, 3800.0, dtypes.float64),
|
(20, 129, 8000.0, 125.0, 3800.0, dtypes.float64),
|
||||||
|
# Same as above, but with a constant Tensor sample rate.
|
||||||
|
(20, 129, constant_op.constant(8000.0), 125.0, 3800.0, dtypes.float64),
|
||||||
# Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
|
# Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
|
||||||
(80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64)
|
(80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64),
|
||||||
]
|
]
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
for config in configs:
|
for config in configs:
|
||||||
@ -159,6 +164,9 @@ class LinearToMelTest(test.TestCase):
|
|||||||
mel_ops.linear_to_mel_weight_matrix(num_spectrogram_bins=0)
|
mel_ops.linear_to_mel_weight_matrix(num_spectrogram_bins=0)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0)
|
mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mel_ops.linear_to_mel_weight_matrix(
|
||||||
|
sample_rate=array_ops.placeholder(dtypes.float32))
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1)
|
mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.signal import shape_ops
|
from tensorflow.python.ops.signal import shape_ops
|
||||||
@ -143,11 +144,21 @@ def linear_to_mel_weight_matrix(num_mel_bins=20,
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not
|
ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not
|
||||||
positive, `lower_edge_hertz` is negative, frequency edges are incorrectly
|
positive, `lower_edge_hertz` is negative, frequency edges are incorrectly
|
||||||
ordered, or `upper_edge_hertz` is larger than the Nyquist frequency.
|
ordered, `upper_edge_hertz` is larger than the Nyquist frequency, or
|
||||||
|
`sample_rate` is neither a Python float nor a constant Tensor.
|
||||||
|
|
||||||
[mel]: https://en.wikipedia.org/wiki/Mel_scale
|
[mel]: https://en.wikipedia.org/wiki/Mel_scale
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name:
|
with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name:
|
||||||
|
# Convert Tensor `sample_rate` to float, if possible.
|
||||||
|
if isinstance(sample_rate, ops.Tensor):
|
||||||
|
maybe_const_val = tensor_util.constant_value(sample_rate)
|
||||||
|
if maybe_const_val is not None:
|
||||||
|
sample_rate = maybe_const_val
|
||||||
|
else:
|
||||||
|
raise ValueError('`sample_rate` was a non-constant Tensor. Must be a '
|
||||||
|
'Python float or a constant Tensor.')
|
||||||
|
|
||||||
# Note: As num_spectrogram_bins is passed to `math_ops.linspace`
|
# Note: As num_spectrogram_bins is passed to `math_ops.linspace`
|
||||||
# and the validation is already done in linspace (both in shape function
|
# and the validation is already done in linspace (both in shape function
|
||||||
# and in kernel), there is no need to validate num_spectrogram_bins here.
|
# and in kernel), there is no need to validate num_spectrogram_bins here.
|
||||||
|
Loading…
Reference in New Issue
Block a user