diff --git a/tensorflow/python/kernel_tests/signal/mel_ops_test.py b/tensorflow/python/kernel_tests/signal/mel_ops_test.py index 3134503daec..a36a8c3f758 100644 --- a/tensorflow/python/kernel_tests/signal/mel_ops_test.py +++ b/tensorflow/python/kernel_tests/signal/mel_ops_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.kernel_tests.signal import test_util from tensorflow.python.ops import array_ops @@ -95,6 +97,7 @@ def spectrogram_to_mel_matrix(num_mel_bins=20, Raises: ValueError: if frequency edges are incorrectly ordered. """ + audio_sample_rate = tensor_util.constant_value(audio_sample_rate) nyquist_hertz = audio_sample_rate / 2. if lower_edge_hertz >= upper_edge_hertz: raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % @@ -135,8 +138,10 @@ class LinearToMelTest(test.TestCase): configs = [ # Defaults. (20, 129, 8000.0, 125.0, 3800.0, dtypes.float64), + # Same as above, but with a constant Tensor sample rate. + (20, 129, constant_op.constant(8000.0), 125.0, 3800.0, dtypes.float64), # Settings used by Tacotron (https://arxiv.org/abs/1703.10135). - (80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64) + (80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64), ] with self.session(use_gpu=True): for config in configs: @@ -159,6 +164,9 @@ class LinearToMelTest(test.TestCase): mel_ops.linear_to_mel_weight_matrix(num_spectrogram_bins=0) with self.assertRaises(ValueError): mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0) + with self.assertRaises(ValueError): + mel_ops.linear_to_mel_weight_matrix( + sample_rate=array_ops.placeholder(dtypes.float32)) with self.assertRaises(ValueError): mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1) with self.assertRaises(ValueError): diff --git a/tensorflow/python/ops/signal/mel_ops.py b/tensorflow/python/ops/signal/mel_ops.py index 9702d66506a..adc3519e193 100644 --- a/tensorflow/python/ops/signal/mel_ops.py +++ b/tensorflow/python/ops/signal/mel_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.signal import shape_ops @@ -143,11 +144,21 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, Raises: ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not positive, `lower_edge_hertz` is negative, frequency edges are incorrectly - ordered, or `upper_edge_hertz` is larger than the Nyquist frequency. + ordered, `upper_edge_hertz` is larger than the Nyquist frequency, or + `sample_rate` is neither a Python float nor a constant Tensor. [mel]: https://en.wikipedia.org/wiki/Mel_scale """ with ops.name_scope(name, 'linear_to_mel_weight_matrix') as name: + # Convert Tensor `sample_rate` to float, if possible. + if isinstance(sample_rate, ops.Tensor): + maybe_const_val = tensor_util.constant_value(sample_rate) + if maybe_const_val is not None: + sample_rate = maybe_const_val + else: + raise ValueError('`sample_rate` was a non-constant Tensor. Must be a ' + 'Python float or a constant Tensor.') + # Note: As num_spectrogram_bins is passed to `math_ops.linspace` # and the validation is already done in linspace (both in shape function # and in kernel), there is no need to validate num_spectrogram_bins here.