Merge pull request #25748 from freedomtan:export_SetNumThreads_to_tflite_python
PiperOrigin-RevId: 316137187 Change-Id: I12729a4f760b786a57c36a481261f6232f7d3fee
This commit is contained in:
commit
ad4363bdcf
|
@ -19,11 +19,10 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import tensorflow as tf # TF2
|
||||
|
||||
|
||||
|
@ -57,9 +56,12 @@ if __name__ == '__main__':
|
|||
'--input_std',
|
||||
default=127.5, type=float,
|
||||
help='input standard deviation')
|
||||
parser.add_argument(
|
||||
'--num_threads', default=None, type=int, help='number of threads')
|
||||
args = parser.parse_args()
|
||||
|
||||
interpreter = tf.lite.Interpreter(model_path=args.model_file)
|
||||
interpreter = tf.lite.Interpreter(
|
||||
model_path=args.model_file, num_threads=args.num_threads)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
|
@ -81,7 +83,9 @@ if __name__ == '__main__':
|
|||
|
||||
interpreter.set_tensor(input_details[0]['index'], input_data)
|
||||
|
||||
start_time = time.time()
|
||||
interpreter.invoke()
|
||||
stop_time = time.time()
|
||||
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
results = np.squeeze(output_data)
|
||||
|
@ -93,3 +97,5 @@ if __name__ == '__main__':
|
|||
print('{:08.6f}: {}'.format(float(results[i]), labels[i]))
|
||||
else:
|
||||
print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
|
||||
|
||||
print('time: {:.3f}ms'.format((stop_time - start_time) * 1000))
|
||||
|
|
|
@ -172,7 +172,8 @@ class Interpreter(object):
|
|||
def __init__(self,
|
||||
model_path=None,
|
||||
model_content=None,
|
||||
experimental_delegates=None):
|
||||
experimental_delegates=None,
|
||||
num_threads=None):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
|
@ -181,6 +182,10 @@ class Interpreter(object):
|
|||
experimental_delegates: Experimental. Subject to change. List of
|
||||
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
|
||||
objects returned by lite.load_delegate().
|
||||
num_threads: Sets the number of threads used by the interpreter and
|
||||
available to CPU kernels. If not set, the interpreter will use an
|
||||
implementation-dependent default number of threads. Currently,
|
||||
only a subset of kernels, such as conv, support multi-threading.
|
||||
|
||||
Raises:
|
||||
ValueError: If the interpreter was unable to create.
|
||||
|
@ -206,6 +211,13 @@ class Interpreter(object):
|
|||
else:
|
||||
raise ValueError('Can\'t both provide `model_path` and `model_content`')
|
||||
|
||||
if num_threads is not None:
|
||||
if not isinstance(num_threads, int):
|
||||
raise ValueError('type of num_threads should be int')
|
||||
if num_threads < 1:
|
||||
raise ValueError('num_threads should >= 1')
|
||||
self._interpreter.SetNumThreads(num_threads)
|
||||
|
||||
# Each delegate is a wrapper that owns the delegates that have been loaded
|
||||
# as plugins. The interpreter wrapper will be using them, but we need to
|
||||
# hold them in a list so that the lifetime is preserved at least as long as
|
||||
|
|
|
@ -68,6 +68,20 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||
self.assertAllEqual(zero_points, params['zero_points'])
|
||||
self.assertEqual(quantized_dimension, params['quantized_dimension'])
|
||||
|
||||
def testThreads_NegativeValue(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'num_threads should >= 1'):
|
||||
interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'), num_threads=-1)
|
||||
|
||||
def testThreads_WrongType(self):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'type of num_threads should be int'):
|
||||
interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'), num_threads=4.2)
|
||||
|
||||
def testFloat(self):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
|
@ -100,6 +114,22 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
|||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
def testFloatWithTwoThreads(self):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'), num_threads=2)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
|
||||
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
|
||||
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||
interpreter.invoke()
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
def testUint8(self):
|
||||
model_path = resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_uint8.tflite')
|
||||
|
|
|
@ -706,6 +706,12 @@ PyObject* InterpreterWrapper::ResetVariableTensors() {
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
|
||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||
interpreter_->SetNumThreads(num_threads);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
|
||||
TfLiteDelegate* delegate) {
|
||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||
|
|
|
@ -87,6 +87,8 @@ class InterpreterWrapper {
|
|||
// should be the interpreter object providing the memory.
|
||||
PyObject* tensor(PyObject* base_object, int i);
|
||||
|
||||
PyObject* SetNumThreads(int num_threads);
|
||||
|
||||
// Adds a delegate to the interpreter.
|
||||
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);
|
||||
|
||||
|
|
|
@ -145,5 +145,13 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
|
|||
},
|
||||
R"pbdoc(
|
||||
Adds a delegate to the interpreter.
|
||||
)pbdoc")
|
||||
.def(
|
||||
"SetNumThreads",
|
||||
[](InterpreterWrapper& self, int num_threads) {
|
||||
return tensorflow::PyoOrThrow(self.SetNumThreads(num_threads));
|
||||
},
|
||||
R"pbdoc(
|
||||
ask the interpreter to set the number of threads to use.
|
||||
)pbdoc");
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ tf_class {
|
|||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "allocate_tensors"
|
||||
|
|
|
@ -4,7 +4,7 @@ tf_class {
|
|||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "allocate_tensors"
|
||||
|
|
Loading…
Reference in New Issue