diff --git a/tensorflow/lite/examples/python/label_image.py b/tensorflow/lite/examples/python/label_image.py index 2ef1aa14fb2..d51bed91cfa 100644 --- a/tensorflow/lite/examples/python/label_image.py +++ b/tensorflow/lite/examples/python/label_image.py @@ -19,11 +19,10 @@ from __future__ import division from __future__ import print_function import argparse +import time import numpy as np - from PIL import Image - import tensorflow as tf # TF2 @@ -57,9 +56,12 @@ if __name__ == '__main__': '--input_std', default=127.5, type=float, help='input standard deviation') + parser.add_argument( + '--num_threads', default=None, type=int, help='number of threads') args = parser.parse_args() - interpreter = tf.lite.Interpreter(model_path=args.model_file) + interpreter = tf.lite.Interpreter( + model_path=args.model_file, num_threads=args.num_threads) interpreter.allocate_tensors() input_details = interpreter.get_input_details() @@ -81,7 +83,9 @@ if __name__ == '__main__': interpreter.set_tensor(input_details[0]['index'], input_data) + start_time = time.time() interpreter.invoke() + stop_time = time.time() output_data = interpreter.get_tensor(output_details[0]['index']) results = np.squeeze(output_data) @@ -93,3 +97,5 @@ if __name__ == '__main__': print('{:08.6f}: {}'.format(float(results[i]), labels[i])) else: print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i])) + + print('time: {:.3f}ms'.format((stop_time - start_time) * 1000)) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index 04863b12853..f4a9d96da3f 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -172,7 +172,8 @@ class Interpreter(object): def __init__(self, model_path=None, model_content=None, - experimental_delegates=None): + experimental_delegates=None, + num_threads=None): """Constructor. Args: @@ -181,6 +182,10 @@ class Interpreter(object): experimental_delegates: Experimental. Subject to change. List of [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) objects returned by lite.load_delegate(). + num_threads: Sets the number of threads used by the interpreter and + available to CPU kernels. If not set, the interpreter will use an + implementation-dependent default number of threads. Currently, + only a subset of kernels, such as conv, support multi-threading. Raises: ValueError: If the interpreter was unable to create. @@ -206,6 +211,13 @@ class Interpreter(object): else: raise ValueError('Can\'t both provide `model_path` and `model_content`') + if num_threads is not None: + if not isinstance(num_threads, int): + raise ValueError('type of num_threads should be int') + if num_threads < 1: + raise ValueError('num_threads should >= 1') + self._interpreter.SetNumThreads(num_threads) + # Each delegate is a wrapper that owns the delegates that have been loaded # as plugins. The interpreter wrapper will be using them, but we need to # hold them in a list so that the lifetime is preserved at least as long as diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 2a10eb0cc69..770c9dc3090 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -68,6 +68,20 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertAllEqual(zero_points, params['zero_points']) self.assertEqual(quantized_dimension, params['quantized_dimension']) + def testThreads_NegativeValue(self): + with self.assertRaisesRegexp(ValueError, + 'num_threads should >= 1'): + interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite'), num_threads=-1) + + def testThreads_WrongType(self): + with self.assertRaisesRegexp(ValueError, + 'type of num_threads should be int'): + interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite'), num_threads=4.2) + def testFloat(self): interpreter = interpreter_wrapper.Interpreter( model_path=resource_loader.get_path_to_datafile( @@ -100,6 +114,22 @@ class InterpreterTest(test_util.TensorFlowTestCase): output_data = interpreter.get_tensor(output_details[0]['index']) self.assertTrue((expected_output == output_data).all()) + def testFloatWithTwoThreads(self): + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite'), num_threads=2) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) + expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_details = interpreter.get_output_details() + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) + def testUint8(self): model_path = resource_loader.get_path_to_datafile( 'testdata/permute_uint8.tflite') diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 92e7c22a702..2a8c1ffdcd6 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -706,6 +706,12 @@ PyObject* InterpreterWrapper::ResetVariableTensors() { Py_RETURN_NONE; } +PyObject* InterpreterWrapper::SetNumThreads(int num_threads) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + interpreter_->SetNumThreads(num_threads); + Py_RETURN_NONE; +} + PyObject* InterpreterWrapper::ModifyGraphWithDelegate( TfLiteDelegate* delegate) { TFLITE_PY_ENSURE_VALID_INTERPRETER(); diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index 2de38d07ed6..b799a3067f6 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -87,6 +87,8 @@ class InterpreterWrapper { // should be the interpreter object providing the memory. PyObject* tensor(PyObject* base_object, int i); + PyObject* SetNumThreads(int num_threads); + // Adds a delegate to the interpreter. PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate); diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc index 1a61c2aa33b..a85bdc8baf4 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc @@ -145,5 +145,13 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) { }, R"pbdoc( Adds a delegate to the interpreter. + )pbdoc") + .def( + "SetNumThreads", + [](InterpreterWrapper& self, int num_threads) { + return tensorflow::PyoOrThrow(self.SetNumThreads(num_threads)); + }, + R"pbdoc( + ask the interpreter to set the number of threads to use. )pbdoc"); } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt index e1c235b5150..fdc7a9e4014 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-interpreter.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "allocate_tensors" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt index e1c235b5150..fdc7a9e4014 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-interpreter.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "allocate_tensors"