STT-tensorflow/tensorflow/lite/python/lite_v2_test_util.py
Nupur Garg 10f4282cff Support unknown dimensions in quantized models.
PiperOrigin-RevId: 310269161
Change-Id: I83b58ec5d8128485bbed9839d60478e543a7c09e
2020-05-06 18:30:12 -07:00

108 lines
3.7 KiB
Python

# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for lite.py functionality related to TensorFlow 2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
from six.moves import zip
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.python.eager import def_function
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import tracking
class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
"""Base test class for TensorFlow Lite 2.x model tests."""
def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None):
"""Evaluates the model on the `input_data`.
Args:
tflite_model: TensorFlow Lite model.
input_data: List of EagerTensor const ops containing the input data for
each input tensor.
input_shapes: List of tuples representing the `shape_signature` and the
new shape of each input tensor that has unknown dimensions.
Returns:
[np.ndarray]
"""
interpreter = Interpreter(model_content=tflite_model)
input_details = interpreter.get_input_details()
if input_shapes:
for idx, (shape_signature, final_shape) in enumerate(input_shapes):
self.assertTrue(
(input_details[idx]['shape_signature'] == shape_signature).all())
index = input_details[idx]['index']
interpreter.resize_tensor_input(index, final_shape, strict=True)
interpreter.allocate_tensors()
output_details = interpreter.get_output_details()
input_details = interpreter.get_input_details()
for input_tensor, tensor_data in zip(input_details, input_data):
interpreter.set_tensor(input_tensor['index'], tensor_data.numpy())
interpreter.invoke()
return [
interpreter.get_tensor(details['index']) for details in output_details
]
def _getSimpleVariableModel(self):
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
return root
def _getMultiFunctionModel(self):
class BasicModel(tracking.AutoTrackable):
def __init__(self):
self.y = None
self.z = None
@def_function.function
def add(self, x):
if self.y is None:
self.y = variables.Variable(2.)
return x + self.y
@def_function.function
def sub(self, x):
if self.z is None:
self.z = variables.Variable(3.)
return x - self.z
return BasicModel()
def _assertValidDebugInfo(self, debug_info):
"""Verify the DebugInfo is valid."""
file_names = set()
for file_path in debug_info.files:
file_names.add(os.path.basename(file_path))
# To make the test independent on how the nodes are created, we only assert
# the name of this test file.
self.assertIn('lite_v2_test.py', file_names)
self.assertNotIn('lite_test.py', file_names)