Add tf.lite MLIR tests for tf.signal.hann_window/hamming_window.

PiperOrigin-RevId: 268756379
This commit is contained in:
RJ Skerry-Ryan 2019-09-12 13:44:15 -07:00 committed by TensorFlower Gardener
parent bf50319afe
commit 3576dfea71
3 changed files with 82 additions and 2 deletions

View File

@ -14,6 +14,8 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/lite/python:interpreter",
"//tensorflow/lite/python:lite",
"//tensorflow/python:tf_optimizer", "//tensorflow/python:tf_optimizer",
"//tensorflow/python:training", "//tensorflow/python:training",
], ],

View File

@ -19,6 +19,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.lite.python import interpreter
from tensorflow.lite.python import lite
from tensorflow.python.eager import def_function
from tensorflow.python.grappler import tf_optimizer from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.training import saver from tensorflow.python.training import saver
@ -45,3 +48,55 @@ def grappler_optimize(graph, fetches=None, config_proto=None):
graph.add_to_collection('train_op', fetch) graph.add_to_collection('train_op', fetch)
metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def()) metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
return tf_optimizer.OptimizeGraph(config_proto, metagraph) return tf_optimizer.OptimizeGraph(config_proto, metagraph)
def tflite_convert(fn, input_templates, use_mlir=False):
"""Converts the provided fn to tf.lite model.
Args:
fn: A callable that expects a list of inputs like input_templates that
returns a tensor or structure of tensors.
input_templates: A list of Tensors, ndarrays or TensorSpecs describing the
inputs that fn expects. The actual values of the Tensors or ndarrays are
unused.
use_mlir: Experimental. Whether to use the tf.lite MLIR converter.
Returns:
The serialized tf.lite model.
"""
fn = def_function.function(fn)
concrete_func = fn.get_concrete_function(*input_templates)
converter = lite.TFLiteConverterV2([concrete_func])
converter.experimental_enable_mlir_converter = use_mlir
return converter.convert()
def evaluate_tflite_model(tflite_model, input_ndarrays):
"""Evaluates the provided tf.lite model with the given input ndarrays.
Args:
tflite_model: bytes. The serialized tf.lite model.
input_ndarrays: A list of NumPy arrays to feed as input to the model.
Returns:
A list ndarrays produced by the model.
Raises:
ValueError: If the number of input arrays does not match the number of
inputs the model expects.
"""
the_interpreter = interpreter.Interpreter(model_content=tflite_model)
the_interpreter.allocate_tensors()
input_details = the_interpreter.get_input_details()
output_details = the_interpreter.get_output_details()
if len(input_details) != len(input_ndarrays):
raise ValueError('Wrong number of inputs: provided=%s, '
'input_details=%s output_details=%s' % (
input_ndarrays, input_details, output_details))
for input_tensor, data in zip(input_details, input_ndarrays):
the_interpreter.set_tensor(input_tensor['index'], data)
the_interpreter.invoke()
return [the_interpreter.get_tensor(details['index'])
for details in output_details]

View File

@ -20,11 +20,13 @@ from __future__ import print_function
import functools import functools
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.kernel_tests.signal import test_util from tensorflow.python.kernel_tests.signal import test_util
from tensorflow.python.ops.signal import window_ops from tensorflow.python.ops.signal import window_ops
@ -58,7 +60,7 @@ def _scipy_raised_cosine(length, symmetric=True, a=0.5, b=0.5):
@tf_test_util.run_all_in_graph_and_eager_modes @tf_test_util.run_all_in_graph_and_eager_modes
class WindowOpsTest(test.TestCase): class WindowOpsTest(test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(WindowOpsTest, self).setUp() super(WindowOpsTest, self).setUp()
@ -107,7 +109,28 @@ class WindowOpsTest(test.TestCase):
with g.as_default(): with g.as_default():
window = window_fn(100, periodic=periodic, dtype=dtype) window = window_fn(100, periodic=periodic, dtype=dtype)
rewritten_graph = test_util.grappler_optimize(g, [window]) rewritten_graph = test_util.grappler_optimize(g, [window])
self.assertEqual(1, len(rewritten_graph.node)) self.assertLen(rewritten_graph.node, 1)
@parameterized.parameters(
# Due to control flow, only MLIR is supported.
# Only float32 is supported.
(window_ops.hann_window, 10, False, dtypes.float32, True),
(window_ops.hann_window, 10, True, dtypes.float32, True),
(window_ops.hamming_window, 10, False, dtypes.float32, True),
(window_ops.hamming_window, 10, True, dtypes.float32, True))
def test_tflite_convert(self, window_fn, window_length, periodic, dtype,
use_mlir):
def fn(window_length):
return window_fn(window_length, periodic, dtype=dtype)
tflite_model = test_util.tflite_convert(
fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)], use_mlir)
window_length = np.array(window_length).astype(np.int32)
actual_output, = test_util.evaluate_tflite_model(
tflite_model, [window_length])
expected_output = self.evaluate(fn(window_length))
self.assertAllClose(actual_output, expected_output, rtol=1e-7, atol=1e-7)
if __name__ == '__main__': if __name__ == '__main__':