Add tf.lite tests for tf.signal.frame.
PiperOrigin-RevId: 283624804 Change-Id: I6fd2f5234fb474de3cc5490e926938f76ffe3b5b
This commit is contained in:
parent
0365083580
commit
d68fa26586
@ -18,9 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -28,7 +25,6 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
|
||||||
from tensorflow.python.framework import test_util as tf_test_util
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
from tensorflow.python.kernel_tests.signal import test_util
|
from tensorflow.python.kernel_tests.signal import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -38,7 +34,7 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
|
|
||||||
@tf_test_util.run_all_in_graph_and_eager_modes
|
@tf_test_util.run_all_in_graph_and_eager_modes
|
||||||
class FrameTest(test.TestCase, parameterized.TestCase):
|
class FrameTest(test.TestCase):
|
||||||
|
|
||||||
def test_mapping_of_indices_without_padding(self):
|
def test_mapping_of_indices_without_padding(self):
|
||||||
tensor = constant_op.constant(np.arange(9152), dtypes.int32)
|
tensor = constant_op.constant(np.arange(9152), dtypes.int32)
|
||||||
@ -356,46 +352,6 @@ class FrameTest(test.TestCase, parameterized.TestCase):
|
|||||||
rewritten_graph = test_util.grappler_optimize(g, [frames])
|
rewritten_graph = test_util.grappler_optimize(g, [frames])
|
||||||
self.assertEqual(1, len(rewritten_graph.node))
|
self.assertEqual(1, len(rewritten_graph.node))
|
||||||
|
|
||||||
@parameterized.parameters(
|
|
||||||
itertools.product(
|
|
||||||
# length % step == 0
|
|
||||||
((32, 16),
|
|
||||||
# gcd(length, step) == 1
|
|
||||||
(32, 15),
|
|
||||||
# gcd(length, step) == 5
|
|
||||||
(25, 15),
|
|
||||||
# length == step
|
|
||||||
(32, 32)),
|
|
||||||
(False, True), # pad_end
|
|
||||||
(False, True), # use_mlir
|
|
||||||
(False, True))) # known_batch
|
|
||||||
def test_tflite_convert(self, length_step, pad_end, use_mlir, known_batch):
|
|
||||||
"""Check for tf.lite compatibility in a variety of settings."""
|
|
||||||
def fn(signal):
|
|
||||||
return shape_ops.frame(
|
|
||||||
signal, length_step[0], length_step[1], pad_end=pad_end)
|
|
||||||
|
|
||||||
# TODO(b/144998258): unknown batch does not currently work with padding.
|
|
||||||
if not known_batch and pad_end:
|
|
||||||
return
|
|
||||||
|
|
||||||
signal_length, dtype = 8001, dtypes.float32
|
|
||||||
# If batch size is unknown, tf.lite assumes it's 1. Test batch_size > 1
|
|
||||||
# only when batch size is known.
|
|
||||||
batch_size = 2 if known_batch else 1
|
|
||||||
static_batch_size = batch_size if known_batch else None
|
|
||||||
tflite_model = test_util.tflite_convert(
|
|
||||||
fn, [tensor_spec.TensorSpec(
|
|
||||||
shape=[static_batch_size, signal_length], dtype=dtype)],
|
|
||||||
use_mlir)
|
|
||||||
signal = np.random.normal(size=(batch_size, signal_length)).astype(
|
|
||||||
dtype.as_numpy_dtype)
|
|
||||||
actual_output, = test_util.evaluate_tflite_model(
|
|
||||||
tflite_model, [signal])
|
|
||||||
|
|
||||||
expected_output = self.evaluate(fn(signal))
|
|
||||||
self.assertAllClose(actual_output, expected_output)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user