Add tf.lite MLIR tests for tf.signal.hann_window/hamming_window.
PiperOrigin-RevId: 268756379
This commit is contained in:
parent
bf50319afe
commit
3576dfea71
@ -14,6 +14,8 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/lite/python:interpreter",
|
||||||
|
"//tensorflow/lite/python:lite",
|
||||||
"//tensorflow/python:tf_optimizer",
|
"//tensorflow/python:tf_optimizer",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
],
|
],
|
||||||
|
@ -19,6 +19,9 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.lite.python import interpreter
|
||||||
|
from tensorflow.lite.python import lite
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.grappler import tf_optimizer
|
from tensorflow.python.grappler import tf_optimizer
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
|
|
||||||
@ -45,3 +48,55 @@ def grappler_optimize(graph, fetches=None, config_proto=None):
|
|||||||
graph.add_to_collection('train_op', fetch)
|
graph.add_to_collection('train_op', fetch)
|
||||||
metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
|
metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
|
||||||
return tf_optimizer.OptimizeGraph(config_proto, metagraph)
|
return tf_optimizer.OptimizeGraph(config_proto, metagraph)
|
||||||
|
|
||||||
|
|
||||||
|
def tflite_convert(fn, input_templates, use_mlir=False):
|
||||||
|
"""Converts the provided fn to tf.lite model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: A callable that expects a list of inputs like input_templates that
|
||||||
|
returns a tensor or structure of tensors.
|
||||||
|
input_templates: A list of Tensors, ndarrays or TensorSpecs describing the
|
||||||
|
inputs that fn expects. The actual values of the Tensors or ndarrays are
|
||||||
|
unused.
|
||||||
|
use_mlir: Experimental. Whether to use the tf.lite MLIR converter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The serialized tf.lite model.
|
||||||
|
"""
|
||||||
|
fn = def_function.function(fn)
|
||||||
|
concrete_func = fn.get_concrete_function(*input_templates)
|
||||||
|
converter = lite.TFLiteConverterV2([concrete_func])
|
||||||
|
converter.experimental_enable_mlir_converter = use_mlir
|
||||||
|
return converter.convert()
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_tflite_model(tflite_model, input_ndarrays):
|
||||||
|
"""Evaluates the provided tf.lite model with the given input ndarrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tflite_model: bytes. The serialized tf.lite model.
|
||||||
|
input_ndarrays: A list of NumPy arrays to feed as input to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list ndarrays produced by the model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the number of input arrays does not match the number of
|
||||||
|
inputs the model expects.
|
||||||
|
"""
|
||||||
|
the_interpreter = interpreter.Interpreter(model_content=tflite_model)
|
||||||
|
the_interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
input_details = the_interpreter.get_input_details()
|
||||||
|
output_details = the_interpreter.get_output_details()
|
||||||
|
|
||||||
|
if len(input_details) != len(input_ndarrays):
|
||||||
|
raise ValueError('Wrong number of inputs: provided=%s, '
|
||||||
|
'input_details=%s output_details=%s' % (
|
||||||
|
input_ndarrays, input_details, output_details))
|
||||||
|
for input_tensor, data in zip(input_details, input_ndarrays):
|
||||||
|
the_interpreter.set_tensor(input_tensor['index'], data)
|
||||||
|
the_interpreter.invoke()
|
||||||
|
return [the_interpreter.get_tensor(details['index'])
|
||||||
|
for details in output_details]
|
||||||
|
@ -20,11 +20,13 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util as tf_test_util
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
from tensorflow.python.kernel_tests.signal import test_util
|
from tensorflow.python.kernel_tests.signal import test_util
|
||||||
from tensorflow.python.ops.signal import window_ops
|
from tensorflow.python.ops.signal import window_ops
|
||||||
@ -58,7 +60,7 @@ def _scipy_raised_cosine(length, symmetric=True, a=0.5, b=0.5):
|
|||||||
|
|
||||||
|
|
||||||
@tf_test_util.run_all_in_graph_and_eager_modes
|
@tf_test_util.run_all_in_graph_and_eager_modes
|
||||||
class WindowOpsTest(test.TestCase):
|
class WindowOpsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(WindowOpsTest, self).setUp()
|
super(WindowOpsTest, self).setUp()
|
||||||
@ -107,7 +109,28 @@ class WindowOpsTest(test.TestCase):
|
|||||||
with g.as_default():
|
with g.as_default():
|
||||||
window = window_fn(100, periodic=periodic, dtype=dtype)
|
window = window_fn(100, periodic=periodic, dtype=dtype)
|
||||||
rewritten_graph = test_util.grappler_optimize(g, [window])
|
rewritten_graph = test_util.grappler_optimize(g, [window])
|
||||||
self.assertEqual(1, len(rewritten_graph.node))
|
self.assertLen(rewritten_graph.node, 1)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
# Due to control flow, only MLIR is supported.
|
||||||
|
# Only float32 is supported.
|
||||||
|
(window_ops.hann_window, 10, False, dtypes.float32, True),
|
||||||
|
(window_ops.hann_window, 10, True, dtypes.float32, True),
|
||||||
|
(window_ops.hamming_window, 10, False, dtypes.float32, True),
|
||||||
|
(window_ops.hamming_window, 10, True, dtypes.float32, True))
|
||||||
|
def test_tflite_convert(self, window_fn, window_length, periodic, dtype,
|
||||||
|
use_mlir):
|
||||||
|
def fn(window_length):
|
||||||
|
return window_fn(window_length, periodic, dtype=dtype)
|
||||||
|
|
||||||
|
tflite_model = test_util.tflite_convert(
|
||||||
|
fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)], use_mlir)
|
||||||
|
window_length = np.array(window_length).astype(np.int32)
|
||||||
|
actual_output, = test_util.evaluate_tflite_model(
|
||||||
|
tflite_model, [window_length])
|
||||||
|
|
||||||
|
expected_output = self.evaluate(fn(window_length))
|
||||||
|
self.assertAllClose(actual_output, expected_output, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user