From 3576dfea7146e4c14979a669fbc3c5c6634c89e6 Mon Sep 17 00:00:00 2001 From: RJ Skerry-Ryan Date: Thu, 12 Sep 2019 13:44:15 -0700 Subject: [PATCH] Add tf.lite MLIR tests for tf.signal.hann_window/hamming_window. PiperOrigin-RevId: 268756379 --- tensorflow/python/kernel_tests/signal/BUILD | 2 + .../python/kernel_tests/signal/test_util.py | 55 +++++++++++++++++++ .../kernel_tests/signal/window_ops_test.py | 27 ++++++++- 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/signal/BUILD b/tensorflow/python/kernel_tests/signal/BUILD index 2e32b7113ad..3806783ca11 100644 --- a/tensorflow/python/kernel_tests/signal/BUILD +++ b/tensorflow/python/kernel_tests/signal/BUILD @@ -14,6 +14,8 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/lite/python:interpreter", + "//tensorflow/lite/python:lite", "//tensorflow/python:tf_optimizer", "//tensorflow/python:training", ], diff --git a/tensorflow/python/kernel_tests/signal/test_util.py b/tensorflow/python/kernel_tests/signal/test_util.py index 4d1807e513c..9f7de258be1 100644 --- a/tensorflow/python/kernel_tests/signal/test_util.py +++ b/tensorflow/python/kernel_tests/signal/test_util.py @@ -19,6 +19,9 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.lite.python import interpreter +from tensorflow.lite.python import lite +from tensorflow.python.eager import def_function from tensorflow.python.grappler import tf_optimizer from tensorflow.python.training import saver @@ -45,3 +48,55 @@ def grappler_optimize(graph, fetches=None, config_proto=None): graph.add_to_collection('train_op', fetch) metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def()) return tf_optimizer.OptimizeGraph(config_proto, metagraph) + + +def tflite_convert(fn, input_templates, use_mlir=False): + """Converts the provided fn to tf.lite model. + + Args: + fn: A callable that expects a list of inputs like input_templates that + returns a tensor or structure of tensors. + input_templates: A list of Tensors, ndarrays or TensorSpecs describing the + inputs that fn expects. The actual values of the Tensors or ndarrays are + unused. + use_mlir: Experimental. Whether to use the tf.lite MLIR converter. + + Returns: + The serialized tf.lite model. + """ + fn = def_function.function(fn) + concrete_func = fn.get_concrete_function(*input_templates) + converter = lite.TFLiteConverterV2([concrete_func]) + converter.experimental_enable_mlir_converter = use_mlir + return converter.convert() + + +def evaluate_tflite_model(tflite_model, input_ndarrays): + """Evaluates the provided tf.lite model with the given input ndarrays. + + Args: + tflite_model: bytes. The serialized tf.lite model. + input_ndarrays: A list of NumPy arrays to feed as input to the model. + + Returns: + A list ndarrays produced by the model. + + Raises: + ValueError: If the number of input arrays does not match the number of + inputs the model expects. + """ + the_interpreter = interpreter.Interpreter(model_content=tflite_model) + the_interpreter.allocate_tensors() + + input_details = the_interpreter.get_input_details() + output_details = the_interpreter.get_output_details() + + if len(input_details) != len(input_ndarrays): + raise ValueError('Wrong number of inputs: provided=%s, ' + 'input_details=%s output_details=%s' % ( + input_ndarrays, input_details, output_details)) + for input_tensor, data in zip(input_details, input_ndarrays): + the_interpreter.set_tensor(input_tensor['index'], data) + the_interpreter.invoke() + return [the_interpreter.get_tensor(details['index']) + for details in output_details] diff --git a/tensorflow/python/kernel_tests/signal/window_ops_test.py b/tensorflow/python/kernel_tests/signal/window_ops_test.py index 4d8ec2d7224..18eb0681df0 100644 --- a/tensorflow/python/kernel_tests/signal/window_ops_test.py +++ b/tensorflow/python/kernel_tests/signal/window_ops_test.py @@ -20,11 +20,13 @@ from __future__ import print_function import functools +from absl.testing import parameterized import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.kernel_tests.signal import test_util from tensorflow.python.ops.signal import window_ops @@ -58,7 +60,7 @@ def _scipy_raised_cosine(length, symmetric=True, a=0.5, b=0.5): @tf_test_util.run_all_in_graph_and_eager_modes -class WindowOpsTest(test.TestCase): +class WindowOpsTest(test.TestCase, parameterized.TestCase): def setUp(self): super(WindowOpsTest, self).setUp() @@ -107,7 +109,28 @@ class WindowOpsTest(test.TestCase): with g.as_default(): window = window_fn(100, periodic=periodic, dtype=dtype) rewritten_graph = test_util.grappler_optimize(g, [window]) - self.assertEqual(1, len(rewritten_graph.node)) + self.assertLen(rewritten_graph.node, 1) + + @parameterized.parameters( + # Due to control flow, only MLIR is supported. + # Only float32 is supported. + (window_ops.hann_window, 10, False, dtypes.float32, True), + (window_ops.hann_window, 10, True, dtypes.float32, True), + (window_ops.hamming_window, 10, False, dtypes.float32, True), + (window_ops.hamming_window, 10, True, dtypes.float32, True)) + def test_tflite_convert(self, window_fn, window_length, periodic, dtype, + use_mlir): + def fn(window_length): + return window_fn(window_length, periodic, dtype=dtype) + + tflite_model = test_util.tflite_convert( + fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)], use_mlir) + window_length = np.array(window_length).astype(np.int32) + actual_output, = test_util.evaluate_tflite_model( + tflite_model, [window_length]) + + expected_output = self.evaluate(fn(window_length)) + self.assertAllClose(actual_output, expected_output, rtol=1e-7, atol=1e-7) if __name__ == '__main__':