Add tf.lite MLIR tests for tf.signal.hann_window/hamming_window.

PiperOrigin-RevId: 268756379
This commit is contained in:
RJ Skerry-Ryan 2019-09-12 13:44:15 -07:00 committed by TensorFlower Gardener
parent bf50319afe
commit 3576dfea71
3 changed files with 82 additions and 2 deletions

View File

@ -14,6 +14,8 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/lite/python:interpreter",
"//tensorflow/lite/python:lite",
"//tensorflow/python:tf_optimizer",
"//tensorflow/python:training",
],

View File

@ -19,6 +19,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.lite.python import interpreter
from tensorflow.lite.python import lite
from tensorflow.python.eager import def_function
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.training import saver
@ -45,3 +48,55 @@ def grappler_optimize(graph, fetches=None, config_proto=None):
graph.add_to_collection('train_op', fetch)
metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
return tf_optimizer.OptimizeGraph(config_proto, metagraph)
def tflite_convert(fn, input_templates, use_mlir=False):
"""Converts the provided fn to tf.lite model.
Args:
fn: A callable that expects a list of inputs like input_templates that
returns a tensor or structure of tensors.
input_templates: A list of Tensors, ndarrays or TensorSpecs describing the
inputs that fn expects. The actual values of the Tensors or ndarrays are
unused.
use_mlir: Experimental. Whether to use the tf.lite MLIR converter.
Returns:
The serialized tf.lite model.
"""
fn = def_function.function(fn)
concrete_func = fn.get_concrete_function(*input_templates)
converter = lite.TFLiteConverterV2([concrete_func])
converter.experimental_enable_mlir_converter = use_mlir
return converter.convert()
def evaluate_tflite_model(tflite_model, input_ndarrays):
"""Evaluates the provided tf.lite model with the given input ndarrays.
Args:
tflite_model: bytes. The serialized tf.lite model.
input_ndarrays: A list of NumPy arrays to feed as input to the model.
Returns:
A list ndarrays produced by the model.
Raises:
ValueError: If the number of input arrays does not match the number of
inputs the model expects.
"""
the_interpreter = interpreter.Interpreter(model_content=tflite_model)
the_interpreter.allocate_tensors()
input_details = the_interpreter.get_input_details()
output_details = the_interpreter.get_output_details()
if len(input_details) != len(input_ndarrays):
raise ValueError('Wrong number of inputs: provided=%s, '
'input_details=%s output_details=%s' % (
input_ndarrays, input_details, output_details))
for input_tensor, data in zip(input_details, input_ndarrays):
the_interpreter.set_tensor(input_tensor['index'], data)
the_interpreter.invoke()
return [the_interpreter.get_tensor(details['index'])
for details in output_details]

View File

@ -20,11 +20,13 @@ from __future__ import print_function
import functools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.kernel_tests.signal import test_util
from tensorflow.python.ops.signal import window_ops
@ -58,7 +60,7 @@ def _scipy_raised_cosine(length, symmetric=True, a=0.5, b=0.5):
@tf_test_util.run_all_in_graph_and_eager_modes
class WindowOpsTest(test.TestCase):
class WindowOpsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(WindowOpsTest, self).setUp()
@ -107,7 +109,28 @@ class WindowOpsTest(test.TestCase):
with g.as_default():
window = window_fn(100, periodic=periodic, dtype=dtype)
rewritten_graph = test_util.grappler_optimize(g, [window])
self.assertEqual(1, len(rewritten_graph.node))
self.assertLen(rewritten_graph.node, 1)
@parameterized.parameters(
# Due to control flow, only MLIR is supported.
# Only float32 is supported.
(window_ops.hann_window, 10, False, dtypes.float32, True),
(window_ops.hann_window, 10, True, dtypes.float32, True),
(window_ops.hamming_window, 10, False, dtypes.float32, True),
(window_ops.hamming_window, 10, True, dtypes.float32, True))
def test_tflite_convert(self, window_fn, window_length, periodic, dtype,
use_mlir):
def fn(window_length):
return window_fn(window_length, periodic, dtype=dtype)
tflite_model = test_util.tflite_convert(
fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)], use_mlir)
window_length = np.array(window_length).astype(np.int32)
actual_output, = test_util.evaluate_tflite_model(
tflite_model, [window_length])
expected_output = self.evaluate(fn(window_length))
self.assertAllClose(actual_output, expected_output, rtol=1e-7, atol=1e-7)
if __name__ == '__main__':