Allow sample_rate to be a constant-value Tensor.

PiperOrigin-RevId: 257024106
This commit is contained in:
A. Unique TensorFlower 2019-07-08 11:44:16 -07:00 committed by TensorFlower Gardener
parent afd99cdb8c
commit e77886d0ce
2 changed files with 21 additions and 2 deletions

View File

@ -20,8 +20,10 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.kernel_tests.signal import test_util from tensorflow.python.kernel_tests.signal import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -95,6 +97,7 @@ def spectrogram_to_mel_matrix(num_mel_bins=20,
Raises: Raises:
ValueError: if frequency edges are incorrectly ordered. ValueError: if frequency edges are incorrectly ordered.
""" """
audio_sample_rate = tensor_util.constant_value(audio_sample_rate)
nyquist_hertz = audio_sample_rate / 2. nyquist_hertz = audio_sample_rate / 2.
if lower_edge_hertz >= upper_edge_hertz: if lower_edge_hertz >= upper_edge_hertz:
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
@ -135,8 +138,10 @@ class LinearToMelTest(test.TestCase):
configs = [ configs = [
# Defaults. # Defaults.
(20, 129, 8000.0, 125.0, 3800.0, dtypes.float64), (20, 129, 8000.0, 125.0, 3800.0, dtypes.float64),
# Same as above, but with a constant Tensor sample rate.
(20, 129, constant_op.constant(8000.0), 125.0, 3800.0, dtypes.float64),
# Settings used by Tacotron (https://arxiv.org/abs/1703.10135). # Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
(80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64) (80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64),
] ]
with self.session(use_gpu=True): with self.session(use_gpu=True):
for config in configs: for config in configs:
@ -159,6 +164,9 @@ class LinearToMelTest(test.TestCase):
mel_ops.linear_to_mel_weight_matrix(num_spectrogram_bins=0) mel_ops.linear_to_mel_weight_matrix(num_spectrogram_bins=0)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0) mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0)
with self.assertRaises(ValueError):
mel_ops.linear_to_mel_weight_matrix(
sample_rate=array_ops.placeholder(dtypes.float32))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1) mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops.signal import shape_ops from tensorflow.python.ops.signal import shape_ops
@ -143,11 +144,21 @@ def linear_to_mel_weight_matrix(num_mel_bins=20,
Raises: Raises:
ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not
positive, `lower_edge_hertz` is negative, frequency edges are incorrectly positive, `lower_edge_hertz` is negative, frequency edges are incorrectly
ordered, or `upper_edge_hertz` is larger than the Nyquist frequency. ordered, `upper_edge_hertz` is larger than the Nyquist frequency, or
`sample_rate` is neither a Python float nor a constant Tensor.
[mel]: https://en.wikipedia.org/wiki/Mel_scale [mel]: https://en.wikipedia.org/wiki/Mel_scale
""" """
with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name: with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name:
# Convert Tensor `sample_rate` to float, if possible.
if isinstance(sample_rate, ops.Tensor):
maybe_const_val = tensor_util.constant_value(sample_rate)
if maybe_const_val is not None:
sample_rate = maybe_const_val
else:
raise ValueError('`sample_rate` was a non-constant Tensor. Must be a '
'Python float or a constant Tensor.')
# Note: As num_spectrogram_bins is passed to `math_ops.linspace` # Note: As num_spectrogram_bins is passed to `math_ops.linspace`
# and the validation is already done in linspace (both in shape function # and the validation is already done in linspace (both in shape function
# and in kernel), there is no need to validate num_spectrogram_bins here. # and in kernel), there is no need to validate num_spectrogram_bins here.