Merge pull request #25748 from freedomtan:export_SetNumThreads_to_tflite_python

PiperOrigin-RevId: 316137187
Change-Id: I12729a4f760b786a57c36a481261f6232f7d3fee
This commit is contained in:
TensorFlower Gardener 2020-06-12 11:13:35 -07:00
commit ad4363bdcf
8 changed files with 70 additions and 6 deletions

View File

@ -19,11 +19,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import time
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import tensorflow as tf # TF2 import tensorflow as tf # TF2
@ -57,9 +56,12 @@ if __name__ == '__main__':
'--input_std', '--input_std',
default=127.5, type=float, default=127.5, type=float,
help='input standard deviation') help='input standard deviation')
parser.add_argument(
'--num_threads', default=None, type=int, help='number of threads')
args = parser.parse_args() args = parser.parse_args()
interpreter = tf.lite.Interpreter(model_path=args.model_file) interpreter = tf.lite.Interpreter(
model_path=args.model_file, num_threads=args.num_threads)
interpreter.allocate_tensors() interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
@ -81,7 +83,9 @@ if __name__ == '__main__':
interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.set_tensor(input_details[0]['index'], input_data)
start_time = time.time()
interpreter.invoke() interpreter.invoke()
stop_time = time.time()
output_data = interpreter.get_tensor(output_details[0]['index']) output_data = interpreter.get_tensor(output_details[0]['index'])
results = np.squeeze(output_data) results = np.squeeze(output_data)
@ -93,3 +97,5 @@ if __name__ == '__main__':
print('{:08.6f}: {}'.format(float(results[i]), labels[i])) print('{:08.6f}: {}'.format(float(results[i]), labels[i]))
else: else:
print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i])) print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
print('time: {:.3f}ms'.format((stop_time - start_time) * 1000))

View File

@ -172,7 +172,8 @@ class Interpreter(object):
def __init__(self, def __init__(self,
model_path=None, model_path=None,
model_content=None, model_content=None,
experimental_delegates=None): experimental_delegates=None,
num_threads=None):
"""Constructor. """Constructor.
Args: Args:
@ -181,6 +182,10 @@ class Interpreter(object):
experimental_delegates: Experimental. Subject to change. List of experimental_delegates: Experimental. Subject to change. List of
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
objects returned by lite.load_delegate(). objects returned by lite.load_delegate().
num_threads: Sets the number of threads used by the interpreter and
available to CPU kernels. If not set, the interpreter will use an
implementation-dependent default number of threads. Currently,
only a subset of kernels, such as conv, support multi-threading.
Raises: Raises:
ValueError: If the interpreter was unable to create. ValueError: If the interpreter was unable to create.
@ -206,6 +211,13 @@ class Interpreter(object):
else: else:
raise ValueError('Can\'t both provide `model_path` and `model_content`') raise ValueError('Can\'t both provide `model_path` and `model_content`')
if num_threads is not None:
if not isinstance(num_threads, int):
raise ValueError('type of num_threads should be int')
if num_threads < 1:
raise ValueError('num_threads should >= 1')
self._interpreter.SetNumThreads(num_threads)
# Each delegate is a wrapper that owns the delegates that have been loaded # Each delegate is a wrapper that owns the delegates that have been loaded
# as plugins. The interpreter wrapper will be using them, but we need to # as plugins. The interpreter wrapper will be using them, but we need to
# hold them in a list so that the lifetime is preserved at least as long as # hold them in a list so that the lifetime is preserved at least as long as

View File

@ -68,6 +68,20 @@ class InterpreterTest(test_util.TensorFlowTestCase):
self.assertAllEqual(zero_points, params['zero_points']) self.assertAllEqual(zero_points, params['zero_points'])
self.assertEqual(quantized_dimension, params['quantized_dimension']) self.assertEqual(quantized_dimension, params['quantized_dimension'])
def testThreads_NegativeValue(self):
with self.assertRaisesRegexp(ValueError,
'num_threads should >= 1'):
interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=-1)
def testThreads_WrongType(self):
with self.assertRaisesRegexp(ValueError,
'type of num_threads should be int'):
interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=4.2)
def testFloat(self): def testFloat(self):
interpreter = interpreter_wrapper.Interpreter( interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile( model_path=resource_loader.get_path_to_datafile(
@ -100,6 +114,22 @@ class InterpreterTest(test_util.TensorFlowTestCase):
output_data = interpreter.get_tensor(output_details[0]['index']) output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all()) self.assertTrue((expected_output == output_data).all())
def testFloatWithTwoThreads(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=2)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_details = interpreter.get_output_details()
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testUint8(self): def testUint8(self):
model_path = resource_loader.get_path_to_datafile( model_path = resource_loader.get_path_to_datafile(
'testdata/permute_uint8.tflite') 'testdata/permute_uint8.tflite')

View File

@ -706,6 +706,12 @@ PyObject* InterpreterWrapper::ResetVariableTensors() {
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
interpreter_->SetNumThreads(num_threads);
Py_RETURN_NONE;
}
PyObject* InterpreterWrapper::ModifyGraphWithDelegate( PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
TfLiteDelegate* delegate) { TfLiteDelegate* delegate) {
TFLITE_PY_ENSURE_VALID_INTERPRETER(); TFLITE_PY_ENSURE_VALID_INTERPRETER();

View File

@ -87,6 +87,8 @@ class InterpreterWrapper {
// should be the interpreter object providing the memory. // should be the interpreter object providing the memory.
PyObject* tensor(PyObject* base_object, int i); PyObject* tensor(PyObject* base_object, int i);
PyObject* SetNumThreads(int num_threads);
// Adds a delegate to the interpreter. // Adds a delegate to the interpreter.
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate); PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);

View File

@ -145,5 +145,13 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
}, },
R"pbdoc( R"pbdoc(
Adds a delegate to the interpreter. Adds a delegate to the interpreter.
)pbdoc")
.def(
"SetNumThreads",
[](InterpreterWrapper& self, int num_threads) {
return tensorflow::PyoOrThrow(self.SetNumThreads(num_threads));
},
R"pbdoc(
ask the interpreter to set the number of threads to use.
)pbdoc"); )pbdoc");
} }

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "allocate_tensors" name: "allocate_tensors"

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "allocate_tensors" name: "allocate_tensors"