Add tf.lite tests for tf.signal.frame.

PiperOrigin-RevId: 283624804
Change-Id: I6fd2f5234fb474de3cc5490e926938f76ffe3b5b
This commit is contained in:
Haoyu Zhang 2019-12-03 14:37:49 -08:00 committed by TensorFlower Gardener
parent 0365083580
commit d68fa26586

View File

@ -18,9 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import itertools
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -28,7 +25,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.kernel_tests.signal import test_util from tensorflow.python.kernel_tests.signal import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -38,7 +34,7 @@ from tensorflow.python.platform import test
@tf_test_util.run_all_in_graph_and_eager_modes @tf_test_util.run_all_in_graph_and_eager_modes
class FrameTest(test.TestCase, parameterized.TestCase): class FrameTest(test.TestCase):
def test_mapping_of_indices_without_padding(self): def test_mapping_of_indices_without_padding(self):
tensor = constant_op.constant(np.arange(9152), dtypes.int32) tensor = constant_op.constant(np.arange(9152), dtypes.int32)
@ -356,46 +352,6 @@ class FrameTest(test.TestCase, parameterized.TestCase):
rewritten_graph = test_util.grappler_optimize(g, [frames]) rewritten_graph = test_util.grappler_optimize(g, [frames])
self.assertEqual(1, len(rewritten_graph.node)) self.assertEqual(1, len(rewritten_graph.node))
@parameterized.parameters(
itertools.product(
# length % step == 0
((32, 16),
# gcd(length, step) == 1
(32, 15),
# gcd(length, step) == 5
(25, 15),
# length == step
(32, 32)),
(False, True), # pad_end
(False, True), # use_mlir
(False, True))) # known_batch
def test_tflite_convert(self, length_step, pad_end, use_mlir, known_batch):
"""Check for tf.lite compatibility in a variety of settings."""
def fn(signal):
return shape_ops.frame(
signal, length_step[0], length_step[1], pad_end=pad_end)
# TODO(b/144998258): unknown batch does not currently work with padding.
if not known_batch and pad_end:
return
signal_length, dtype = 8001, dtypes.float32
# If batch size is unknown, tf.lite assumes it's 1. Test batch_size > 1
# only when batch size is known.
batch_size = 2 if known_batch else 1
static_batch_size = batch_size if known_batch else None
tflite_model = test_util.tflite_convert(
fn, [tensor_spec.TensorSpec(
shape=[static_batch_size, signal_length], dtype=dtype)],
use_mlir)
signal = np.random.normal(size=(batch_size, signal_length)).astype(
dtype.as_numpy_dtype)
actual_output, = test_util.evaluate_tflite_model(
tflite_model, [signal])
expected_output = self.evaluate(fn(signal))
self.assertAllClose(actual_output, expected_output)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()