Update tf.signal.overlap_and_add tests to be parameterized and enable V2 numerical gradient tests.
PiperOrigin-RevId: 269343900
This commit is contained in:
parent
9d8beea892
commit
9a97b1ee2e
@ -86,6 +86,7 @@ cuda_py_tests(
|
|||||||
name = "reconstruction_ops_test",
|
name = "reconstruction_ops_test",
|
||||||
srcs = ["reconstruction_ops_test.py"],
|
srcs = ["reconstruction_ops_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -25,6 +26,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradient_checker_v2
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.signal import reconstruction_ops
|
from tensorflow.python.ops.signal import reconstruction_ops
|
||||||
@ -32,7 +34,7 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class ReconstructionOpsTest(test.TestCase):
|
class ReconstructionOpsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(ReconstructionOpsTest, self).__init__(*args, **kwargs)
|
super(ReconstructionOpsTest, self).__init__(*args, **kwargs)
|
||||||
@ -87,8 +89,6 @@ class ReconstructionOpsTest(test.TestCase):
|
|||||||
self.assertAllClose(reconstruction, expected_output)
|
self.assertAllClose(reconstruction, expected_output)
|
||||||
|
|
||||||
def test_fast_path(self):
|
def test_fast_path(self):
|
||||||
if context.executing_eagerly():
|
|
||||||
return
|
|
||||||
# This test uses tensor names and does not work in eager mode.
|
# This test uses tensor names and does not work in eager mode.
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return
|
return
|
||||||
@ -99,31 +99,27 @@ class ReconstructionOpsTest(test.TestCase):
|
|||||||
expected_output = np.ones([15])
|
expected_output = np.ones([15])
|
||||||
self.assertAllClose(reconstruction, expected_output)
|
self.assertAllClose(reconstruction, expected_output)
|
||||||
|
|
||||||
def test_simple(self):
|
@parameterized.parameters(
|
||||||
|
# All hop lengths on a frame length of 2.
|
||||||
|
(2, [1, 5, 9, 6], 1),
|
||||||
|
(2, [1, 2, 3, 4, 5, 6], 2),
|
||||||
|
|
||||||
|
# All hop lengths on a frame length of 3.
|
||||||
|
(3, [1, 6, 15, 14, 9], 1),
|
||||||
|
(3, [1, 2, 7, 5, 13, 8, 9], 2),
|
||||||
|
(3, [1, 2, 3, 4, 5, 6, 7, 8, 9], 3),
|
||||||
|
|
||||||
|
# All hop lengths on a frame length of 4.
|
||||||
|
(4, [1, 7, 18, 21, 19, 12], 1),
|
||||||
|
(4, [1, 2, 8, 10, 16, 18, 11, 12], 2),
|
||||||
|
(4, [1, 2, 3, 9, 6, 7, 17, 10, 11, 12], 3),
|
||||||
|
(4, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4))
|
||||||
|
def test_simple(self, frame_length, expected, frame_hop):
|
||||||
def make_input(frame_length, num_frames=3):
|
def make_input(frame_length, num_frames=3):
|
||||||
"""Generate a tensor of num_frames frames of frame_length."""
|
"""Generate a tensor of num_frames frames of frame_length."""
|
||||||
return np.reshape(np.arange(1, num_frames * frame_length + 1),
|
return np.reshape(np.arange(1, num_frames * frame_length + 1),
|
||||||
(-1, frame_length))
|
(-1, frame_length))
|
||||||
|
signal = make_input(frame_length)
|
||||||
# List of (signal, expected_result, frame_hop).
|
|
||||||
configurations = [
|
|
||||||
# All hop lengths on a frame length of 2.
|
|
||||||
(make_input(2), [1, 5, 9, 6], 1),
|
|
||||||
(make_input(2), [1, 2, 3, 4, 5, 6], 2),
|
|
||||||
|
|
||||||
# All hop lengths on a frame length of 3.
|
|
||||||
(make_input(3), [1, 6, 15, 14, 9], 1),
|
|
||||||
(make_input(3), [1, 2, 7, 5, 13, 8, 9], 2),
|
|
||||||
(make_input(3), [1, 2, 3, 4, 5, 6, 7, 8, 9], 3),
|
|
||||||
|
|
||||||
# All hop lengths on a frame length of 4.
|
|
||||||
(make_input(4), [1, 7, 18, 21, 19, 12], 1),
|
|
||||||
(make_input(4), [1, 2, 8, 10, 16, 18, 11, 12], 2),
|
|
||||||
(make_input(4), [1, 2, 3, 9, 6, 7, 17, 10, 11, 12], 3),
|
|
||||||
(make_input(4), [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4),
|
|
||||||
]
|
|
||||||
|
|
||||||
for signal, expected, frame_hop in configurations:
|
|
||||||
reconstruction = reconstruction_ops.overlap_and_add(
|
reconstruction = reconstruction_ops.overlap_and_add(
|
||||||
np.array(signal), frame_hop)
|
np.array(signal), frame_hop)
|
||||||
expected_output = np.array(expected)
|
expected_output = np.array(expected)
|
||||||
@ -165,20 +161,17 @@ class ReconstructionOpsTest(test.TestCase):
|
|||||||
self.assertEqual(output.shape, (1, 9))
|
self.assertEqual(output.shape, (1, 9))
|
||||||
self.assertEqual(string_output, self.expected_string)
|
self.assertEqual(string_output, self.expected_string)
|
||||||
|
|
||||||
def test_gradient(self):
|
@parameterized.parameters(
|
||||||
# TODO(rjryan): Eager gradient tests.
|
|
||||||
if context.executing_eagerly():
|
|
||||||
return
|
|
||||||
configurations = [
|
|
||||||
((1, 128), 1),
|
((1, 128), 1),
|
||||||
((5, 35), 17),
|
((5, 35), 17),
|
||||||
((10, 128), 128),
|
((10, 128), 128),
|
||||||
((2, 10, 128), 127),
|
((2, 10, 128), 127),
|
||||||
((2, 2, 10, 128), 126),
|
((2, 2, 10, 128), 126),
|
||||||
((2, 2, 2, 10, 128), 125),
|
((2, 2, 2, 10, 128), 125))
|
||||||
]
|
def test_gradient(self, shape, frame_hop):
|
||||||
|
# TODO(rjryan): Eager gradient tests.
|
||||||
for shape, frame_hop in configurations:
|
if context.executing_eagerly():
|
||||||
|
return
|
||||||
signal = array_ops.zeros(shape)
|
signal = array_ops.zeros(shape)
|
||||||
reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop)
|
reconstruction = reconstruction_ops.overlap_and_add(signal, frame_hop)
|
||||||
loss = math_ops.reduce_sum(reconstruction)
|
loss = math_ops.reduce_sum(reconstruction)
|
||||||
@ -214,18 +207,14 @@ class ReconstructionOpsTest(test.TestCase):
|
|||||||
self.assertAllEqual(expected_gradient, gradient)
|
self.assertAllEqual(expected_gradient, gradient)
|
||||||
|
|
||||||
def test_gradient_numerical(self):
|
def test_gradient_numerical(self):
|
||||||
# TODO(rjryan): Eager gradient tests.
|
|
||||||
if context.executing_eagerly():
|
|
||||||
return
|
|
||||||
with self.session(use_gpu=True):
|
|
||||||
shape = (2, 10, 10)
|
shape = (2, 10, 10)
|
||||||
framed_signal = array_ops.zeros(shape)
|
framed_signal = array_ops.zeros(shape)
|
||||||
frame_hop = 10
|
frame_hop = 10
|
||||||
reconstruction = reconstruction_ops.overlap_and_add(
|
def f(signal):
|
||||||
framed_signal, frame_hop)
|
return reconstruction_ops.overlap_and_add(signal, frame_hop)
|
||||||
error = test.compute_gradient_error(
|
((jacob_t,), (jacob_n,)) = gradient_checker_v2.compute_gradient(
|
||||||
framed_signal, shape, reconstruction, [2, 100])
|
f, [framed_signal])
|
||||||
self.assertLess(error, 2e-5)
|
self.assertAllClose(jacob_t, jacob_n)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user