STT-tensorflow/tensorflow/lite/python/interpreter_test.py

497 lines
21 KiB
Python

# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite Python Interface: Sanity check."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ctypes
import io
import sys
import numpy as np
import six
# Force loaded shared object symbols to be globally visible. This is needed so
# that the interpreter_wrapper, in one .so file, can see the test_registerer,
# in a different .so file. Note that this may already be set by default.
# pylint: disable=g-import-not-at-top
if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
from tensorflow.lite.python import interpreter as interpreter_wrapper
from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer
from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
def testRegisterer(self):
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'),
custom_op_registerers=['TF_TestRegisterer'])
self.assertTrue(interpreter._safe_to_run())
self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
def testRegistererFailure(self):
bogus_name = 'CompletelyBogusRegistererName'
with self.assertRaisesRegexp(
ValueError, 'Looking up symbol \'' + bogus_name + '\' failed'):
interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'),
custom_op_registerers=[bogus_name])
class InterpreterTest(test_util.TensorFlowTestCase):
def assertQuantizationParamsEqual(self, scales, zero_points,
quantized_dimension, params):
self.assertAllEqual(scales, params['scales'])
self.assertAllEqual(zero_points, params['zero_points'])
self.assertEqual(quantized_dimension, params['quantized_dimension'])
def testThreads_NegativeValue(self):
with self.assertRaisesRegexp(ValueError,
'num_threads should >= 1'):
interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=-1)
def testThreads_WrongType(self):
with self.assertRaisesRegexp(ValueError,
'type of num_threads should be int'):
interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=4.2)
def testFloat(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
self.assertEqual((0.0, 0), input_details[0]['quantization'])
self.assertQuantizationParamsEqual(
[], [], 0, input_details[0]['quantization_parameters'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
self.assertEqual((0.0, 0), output_details[0]['quantization'])
self.assertQuantizationParamsEqual(
[], [], 0, output_details[0]['quantization_parameters'])
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testFloatWithTwoThreads(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=2)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testUint8(self):
model_path = resource_loader.get_path_to_datafile(
'testdata/permute_uint8.tflite')
with io.open(model_path, 'rb') as model_file:
data = model_file.read()
interpreter = interpreter_wrapper.Interpreter(model_content=data)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 4] == input_details[0]['shape']).all())
self.assertEqual((1.0, 0), input_details[0]['quantization'])
self.assertQuantizationParamsEqual(
[1.0], [0], 0, input_details[0]['quantization_parameters'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.uint8, output_details[0]['dtype'])
self.assertTrue(([1, 4] == output_details[0]['shape']).all())
self.assertEqual((1.0, 0), output_details[0]['quantization'])
self.assertQuantizationParamsEqual(
[1.0], [0], 0, output_details[0]['quantization_parameters'])
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
interpreter.resize_tensor_input(input_details[0]['index'],
test_input.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testString(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/gather_string.tflite'))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(2, len(input_details))
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.string_, input_details[0]['dtype'])
self.assertTrue(([10] == input_details[0]['shape']).all())
self.assertEqual((0.0, 0), input_details[0]['quantization'])
self.assertQuantizationParamsEqual(
[], [], 0, input_details[0]['quantization_parameters'])
self.assertEqual('indices', input_details[1]['name'])
self.assertEqual(np.int64, input_details[1]['dtype'])
self.assertTrue(([3] == input_details[1]['shape']).all())
self.assertEqual((0.0, 0), input_details[1]['quantization'])
self.assertQuantizationParamsEqual(
[], [], 0, input_details[1]['quantization_parameters'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.string_, output_details[0]['dtype'])
self.assertTrue(([3] == output_details[0]['shape']).all())
self.assertEqual((0.0, 0), output_details[0]['quantization'])
self.assertQuantizationParamsEqual(
[], [], 0, output_details[0]['quantization_parameters'])
test_input = np.array([1, 2, 3], dtype=np.int64)
interpreter.set_tensor(input_details[1]['index'], test_input)
test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'])
expected_output = np.array([b'b', b'c', b'd'])
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testPerChannelParams(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))
interpreter.allocate_tensors()
# Tensor index 1 is the weight.
weight_details = interpreter.get_tensor_details()[1]
qparams = weight_details['quantization_parameters']
# Ensure that we retrieve per channel quantization params correctly.
self.assertEqual(len(qparams['scales']), 128)
def testDenseTensorAccess(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))
interpreter.allocate_tensors()
weight_details = interpreter.get_tensor_details()[1]
s_params = weight_details['sparsity_parameters']
self.assertEqual(s_params, {})
def testSparseTensorAccess(self):
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile(
'../testdata/sparse_tensor.bin'),
custom_op_registerers=['TF_TestRegisterer'])
interpreter.allocate_tensors()
# Tensor at index 0 is sparse.
compressed_buffer = interpreter.get_tensor(0)
# Ensure that the buffer is of correct size and value.
self.assertEqual(len(compressed_buffer), 12)
sparse_value = [1, 0, 0, 4, 2, 3, 0, 0, 5, 0, 0, 6]
self.assertAllEqual(compressed_buffer, sparse_value)
tensor_details = interpreter.get_tensor_details()[0]
s_params = tensor_details['sparsity_parameters']
# Ensure sparsity parameter returned is correct
self.assertAllEqual(s_params['traversal_order'], [0, 1, 2, 3])
self.assertAllEqual(s_params['block_map'], [0, 1])
dense_dim_metadata = {'format': 0, 'dense_size': 2}
self.assertAllEqual(s_params['dim_metadata'][0], dense_dim_metadata)
self.assertAllEqual(s_params['dim_metadata'][2], dense_dim_metadata)
self.assertAllEqual(s_params['dim_metadata'][3], dense_dim_metadata)
self.assertEqual(s_params['dim_metadata'][1]['format'], 1)
self.assertAllEqual(s_params['dim_metadata'][1]['array_segments'],
[0, 2, 3])
self.assertAllEqual(s_params['dim_metadata'][1]['array_indices'], [0, 1, 1])
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
def testInvalidModelContent(self):
with self.assertRaisesRegexp(ValueError,
'Model provided has model identifier \''):
interpreter_wrapper.Interpreter(model_content=six.b('garbage'))
def testInvalidModelFile(self):
with self.assertRaisesRegexp(
ValueError, 'Could not open \'totally_invalid_file_name\''):
interpreter_wrapper.Interpreter(
model_path='totally_invalid_file_name')
def testInvokeBeforeReady(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'))
with self.assertRaisesRegexp(RuntimeError,
'Invoke called on model that is not ready'):
interpreter.invoke()
def testInvalidModelFileContent(self):
with self.assertRaisesRegexp(
ValueError, '`model_path` or `model_content` must be specified.'):
interpreter_wrapper.Interpreter(model_path=None, model_content=None)
def testInvalidIndex(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'))
interpreter.allocate_tensors()
# Invalid tensor index passed.
with self.assertRaisesRegexp(ValueError, 'Tensor with no shape found.'):
interpreter._get_tensor_details(4)
with self.assertRaisesRegexp(ValueError, 'Invalid node index'):
interpreter._get_op_details(4)
class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase):
def setUp(self):
self.interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'))
self.interpreter.allocate_tensors()
self.input0 = self.interpreter.get_input_details()[0]['index']
self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32)
def testTensorAccessor(self):
"""Check that tensor returns a reference."""
array_ref = self.interpreter.tensor(self.input0)
np.copyto(array_ref(), self.initial_data)
self.assertAllEqual(array_ref(), self.initial_data)
self.assertAllEqual(
self.interpreter.get_tensor(self.input0), self.initial_data)
def testGetTensorAccessor(self):
"""Check that get_tensor returns a copy."""
self.interpreter.set_tensor(self.input0, self.initial_data)
array_initial_copy = self.interpreter.get_tensor(self.input0)
new_value = np.add(1., array_initial_copy)
self.interpreter.set_tensor(self.input0, new_value)
self.assertAllEqual(array_initial_copy, self.initial_data)
self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value)
def testBase(self):
self.assertTrue(self.interpreter._safe_to_run())
_ = self.interpreter.tensor(self.input0)
self.assertTrue(self.interpreter._safe_to_run())
in0 = self.interpreter.tensor(self.input0)()
self.assertFalse(self.interpreter._safe_to_run())
in0b = self.interpreter.tensor(self.input0)()
self.assertFalse(self.interpreter._safe_to_run())
# Now get rid of the buffers so that we can evaluate.
del in0
del in0b
self.assertTrue(self.interpreter._safe_to_run())
def testBaseProtectsFunctions(self):
in0 = self.interpreter.tensor(self.input0)()
# Make sure we get an exception if we try to run an unsafe operation
with self.assertRaisesRegexp(
RuntimeError, 'There is at least 1 reference'):
_ = self.interpreter.allocate_tensors()
# Make sure we get an exception if we try to run an unsafe operation
with self.assertRaisesRegexp(
RuntimeError, 'There is at least 1 reference'):
_ = self.interpreter.invoke()
# Now test that we can run
del in0 # this is our only buffer reference, so now it is safe to change
in0safe = self.interpreter.tensor(self.input0)
_ = self.interpreter.allocate_tensors()
del in0safe # make sure in0Safe is held but lint doesn't complain
class InterpreterDelegateTest(test_util.TensorFlowTestCase):
def setUp(self):
self._delegate_file = resource_loader.get_path_to_datafile(
'testdata/test_delegate.so')
self._model_file = resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite')
# Load the library to reset the counters.
library = ctypes.pydll.LoadLibrary(self._delegate_file)
library.initialize_counters()
def _TestInterpreter(self, model_path, options=None):
"""Test wrapper function that creates an interpreter with the delegate."""
delegate = interpreter_wrapper.load_delegate(self._delegate_file, options)
return interpreter_wrapper.Interpreter(
model_path=model_path, experimental_delegates=[delegate])
def testDelegate(self):
"""Tests the delegate creation and destruction."""
interpreter = self._TestInterpreter(model_path=self._model_file)
lib = interpreter._delegates[0]._library
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 1)
del interpreter
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 1)
self.assertEqual(lib.get_num_delegates_invoked(), 1)
def testMultipleInterpreters(self):
delegate = interpreter_wrapper.load_delegate(self._delegate_file)
lib = delegate._library
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 0)
interpreter_a = interpreter_wrapper.Interpreter(
model_path=self._model_file, experimental_delegates=[delegate])
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 1)
interpreter_b = interpreter_wrapper.Interpreter(
model_path=self._model_file, experimental_delegates=[delegate])
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 2)
del delegate
del interpreter_a
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 2)
del interpreter_b
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 1)
self.assertEqual(lib.get_num_delegates_invoked(), 2)
def testDestructionOrder(self):
"""Make sure internal _interpreter object is destroyed before delegate."""
self.skipTest('TODO(b/142136355): fix flakiness and re-enable')
# Track which order destructions were doned in
destructions = []
def register_destruction(x):
destructions.append(
x if isinstance(x, str) else six.ensure_text(x, 'utf-8'))
return 0
# Make a wrapper for the callback so we can send this to ctypes
delegate = interpreter_wrapper.load_delegate(self._delegate_file)
# Make an interpreter with the delegate
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), experimental_delegates=[delegate])
class InterpreterDestroyCallback(object):
def __del__(self):
register_destruction('interpreter')
interpreter._interpreter.stuff = InterpreterDestroyCallback()
# Destroy both delegate and interpreter
library = delegate._library
prototype = ctypes.CFUNCTYPE(ctypes.c_int, (ctypes.c_char_p))
library.set_destroy_callback(prototype(register_destruction))
del delegate
del interpreter
library.set_destroy_callback(None)
# check the interpreter was destroyed before the delegate
self.assertEqual(destructions, ['interpreter', 'test_delegate'])
def testOptions(self):
delegate_a = interpreter_wrapper.load_delegate(self._delegate_file)
lib = delegate_a._library
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 0)
self.assertEqual(lib.get_options_counter(), 0)
delegate_b = interpreter_wrapper.load_delegate(
self._delegate_file, options={
'unused': False,
'options_counter': 2
})
lib = delegate_b._library
self.assertEqual(lib.get_num_delegates_created(), 2)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 0)
self.assertEqual(lib.get_options_counter(), 2)
del delegate_a
del delegate_b
self.assertEqual(lib.get_num_delegates_created(), 2)
self.assertEqual(lib.get_num_delegates_destroyed(), 2)
self.assertEqual(lib.get_num_delegates_invoked(), 0)
self.assertEqual(lib.get_options_counter(), 2)
def testFail(self):
with self.assertRaisesRegexp(
# Due to exception chaining in PY3, we can't be more specific here and check that
# the phrase 'Fail argument sent' is present.
ValueError,
r'Failed to load delegate from'):
interpreter_wrapper.load_delegate(
self._delegate_file, options={'fail': 'fail'})
if __name__ == '__main__':
test.main()