Add tf.lite tests for tf.signal.frame.
PiperOrigin-RevId: 283624804 Change-Id: I6fd2f5234fb474de3cc5490e926938f76ffe3b5b
This commit is contained in:
parent
0365083580
commit
d68fa26586
@ -18,9 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
@ -28,7 +25,6 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.kernel_tests.signal import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -38,7 +34,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@tf_test_util.run_all_in_graph_and_eager_modes
|
||||
class FrameTest(test.TestCase, parameterized.TestCase):
|
||||
class FrameTest(test.TestCase):
|
||||
|
||||
def test_mapping_of_indices_without_padding(self):
|
||||
tensor = constant_op.constant(np.arange(9152), dtypes.int32)
|
||||
@ -356,46 +352,6 @@ class FrameTest(test.TestCase, parameterized.TestCase):
|
||||
rewritten_graph = test_util.grappler_optimize(g, [frames])
|
||||
self.assertEqual(1, len(rewritten_graph.node))
|
||||
|
||||
@parameterized.parameters(
|
||||
itertools.product(
|
||||
# length % step == 0
|
||||
((32, 16),
|
||||
# gcd(length, step) == 1
|
||||
(32, 15),
|
||||
# gcd(length, step) == 5
|
||||
(25, 15),
|
||||
# length == step
|
||||
(32, 32)),
|
||||
(False, True), # pad_end
|
||||
(False, True), # use_mlir
|
||||
(False, True))) # known_batch
|
||||
def test_tflite_convert(self, length_step, pad_end, use_mlir, known_batch):
|
||||
"""Check for tf.lite compatibility in a variety of settings."""
|
||||
def fn(signal):
|
||||
return shape_ops.frame(
|
||||
signal, length_step[0], length_step[1], pad_end=pad_end)
|
||||
|
||||
# TODO(b/144998258): unknown batch does not currently work with padding.
|
||||
if not known_batch and pad_end:
|
||||
return
|
||||
|
||||
signal_length, dtype = 8001, dtypes.float32
|
||||
# If batch size is unknown, tf.lite assumes it's 1. Test batch_size > 1
|
||||
# only when batch size is known.
|
||||
batch_size = 2 if known_batch else 1
|
||||
static_batch_size = batch_size if known_batch else None
|
||||
tflite_model = test_util.tflite_convert(
|
||||
fn, [tensor_spec.TensorSpec(
|
||||
shape=[static_batch_size, signal_length], dtype=dtype)],
|
||||
use_mlir)
|
||||
signal = np.random.normal(size=(batch_size, signal_length)).astype(
|
||||
dtype.as_numpy_dtype)
|
||||
actual_output, = test_util.evaluate_tflite_model(
|
||||
tflite_model, [signal])
|
||||
|
||||
expected_output = self.evaluate(fn(signal))
|
||||
self.assertAllClose(actual_output, expected_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user