Add tf.lite MLIR tests for tf.signal.hann_window/hamming_window.
PiperOrigin-RevId: 268756379
This commit is contained in:
parent
bf50319afe
commit
3576dfea71
@ -14,6 +14,8 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/lite/python:interpreter",
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/python:tf_optimizer",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
|
@ -19,6 +19,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.lite.python import interpreter
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.training import saver
|
||||
|
||||
@ -45,3 +48,55 @@ def grappler_optimize(graph, fetches=None, config_proto=None):
|
||||
graph.add_to_collection('train_op', fetch)
|
||||
metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
|
||||
return tf_optimizer.OptimizeGraph(config_proto, metagraph)
|
||||
|
||||
|
||||
def tflite_convert(fn, input_templates, use_mlir=False):
|
||||
"""Converts the provided fn to tf.lite model.
|
||||
|
||||
Args:
|
||||
fn: A callable that expects a list of inputs like input_templates that
|
||||
returns a tensor or structure of tensors.
|
||||
input_templates: A list of Tensors, ndarrays or TensorSpecs describing the
|
||||
inputs that fn expects. The actual values of the Tensors or ndarrays are
|
||||
unused.
|
||||
use_mlir: Experimental. Whether to use the tf.lite MLIR converter.
|
||||
|
||||
Returns:
|
||||
The serialized tf.lite model.
|
||||
"""
|
||||
fn = def_function.function(fn)
|
||||
concrete_func = fn.get_concrete_function(*input_templates)
|
||||
converter = lite.TFLiteConverterV2([concrete_func])
|
||||
converter.experimental_enable_mlir_converter = use_mlir
|
||||
return converter.convert()
|
||||
|
||||
|
||||
def evaluate_tflite_model(tflite_model, input_ndarrays):
|
||||
"""Evaluates the provided tf.lite model with the given input ndarrays.
|
||||
|
||||
Args:
|
||||
tflite_model: bytes. The serialized tf.lite model.
|
||||
input_ndarrays: A list of NumPy arrays to feed as input to the model.
|
||||
|
||||
Returns:
|
||||
A list ndarrays produced by the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of input arrays does not match the number of
|
||||
inputs the model expects.
|
||||
"""
|
||||
the_interpreter = interpreter.Interpreter(model_content=tflite_model)
|
||||
the_interpreter.allocate_tensors()
|
||||
|
||||
input_details = the_interpreter.get_input_details()
|
||||
output_details = the_interpreter.get_output_details()
|
||||
|
||||
if len(input_details) != len(input_ndarrays):
|
||||
raise ValueError('Wrong number of inputs: provided=%s, '
|
||||
'input_details=%s output_details=%s' % (
|
||||
input_ndarrays, input_details, output_details))
|
||||
for input_tensor, data in zip(input_details, input_ndarrays):
|
||||
the_interpreter.set_tensor(input_tensor['index'], data)
|
||||
the_interpreter.invoke()
|
||||
return [the_interpreter.get_tensor(details['index'])
|
||||
for details in output_details]
|
||||
|
@ -20,11 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.kernel_tests.signal import test_util
|
||||
from tensorflow.python.ops.signal import window_ops
|
||||
@ -58,7 +60,7 @@ def _scipy_raised_cosine(length, symmetric=True, a=0.5, b=0.5):
|
||||
|
||||
|
||||
@tf_test_util.run_all_in_graph_and_eager_modes
|
||||
class WindowOpsTest(test.TestCase):
|
||||
class WindowOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(WindowOpsTest, self).setUp()
|
||||
@ -107,7 +109,28 @@ class WindowOpsTest(test.TestCase):
|
||||
with g.as_default():
|
||||
window = window_fn(100, periodic=periodic, dtype=dtype)
|
||||
rewritten_graph = test_util.grappler_optimize(g, [window])
|
||||
self.assertEqual(1, len(rewritten_graph.node))
|
||||
self.assertLen(rewritten_graph.node, 1)
|
||||
|
||||
@parameterized.parameters(
|
||||
# Due to control flow, only MLIR is supported.
|
||||
# Only float32 is supported.
|
||||
(window_ops.hann_window, 10, False, dtypes.float32, True),
|
||||
(window_ops.hann_window, 10, True, dtypes.float32, True),
|
||||
(window_ops.hamming_window, 10, False, dtypes.float32, True),
|
||||
(window_ops.hamming_window, 10, True, dtypes.float32, True))
|
||||
def test_tflite_convert(self, window_fn, window_length, periodic, dtype,
|
||||
use_mlir):
|
||||
def fn(window_length):
|
||||
return window_fn(window_length, periodic, dtype=dtype)
|
||||
|
||||
tflite_model = test_util.tflite_convert(
|
||||
fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)], use_mlir)
|
||||
window_length = np.array(window_length).astype(np.int32)
|
||||
actual_output, = test_util.evaluate_tflite_model(
|
||||
tflite_model, [window_length])
|
||||
|
||||
expected_output = self.evaluate(fn(window_length))
|
||||
self.assertAllClose(actual_output, expected_output, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user