Merge pull request #25748 from freedomtan:export_SetNumThreads_to_tflite_python
PiperOrigin-RevId: 316137187 Change-Id: I12729a4f760b786a57c36a481261f6232f7d3fee
This commit is contained in:
commit
ad4363bdcf
|
@ -19,11 +19,10 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import tensorflow as tf # TF2
|
import tensorflow as tf # TF2
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,9 +56,12 @@ if __name__ == '__main__':
|
||||||
'--input_std',
|
'--input_std',
|
||||||
default=127.5, type=float,
|
default=127.5, type=float,
|
||||||
help='input standard deviation')
|
help='input standard deviation')
|
||||||
|
parser.add_argument(
|
||||||
|
'--num_threads', default=None, type=int, help='number of threads')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
interpreter = tf.lite.Interpreter(model_path=args.model_file)
|
interpreter = tf.lite.Interpreter(
|
||||||
|
model_path=args.model_file, num_threads=args.num_threads)
|
||||||
interpreter.allocate_tensors()
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
input_details = interpreter.get_input_details()
|
input_details = interpreter.get_input_details()
|
||||||
|
@ -81,7 +83,9 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
interpreter.set_tensor(input_details[0]['index'], input_data)
|
interpreter.set_tensor(input_details[0]['index'], input_data)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
interpreter.invoke()
|
interpreter.invoke()
|
||||||
|
stop_time = time.time()
|
||||||
|
|
||||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||||
results = np.squeeze(output_data)
|
results = np.squeeze(output_data)
|
||||||
|
@ -93,3 +97,5 @@ if __name__ == '__main__':
|
||||||
print('{:08.6f}: {}'.format(float(results[i]), labels[i]))
|
print('{:08.6f}: {}'.format(float(results[i]), labels[i]))
|
||||||
else:
|
else:
|
||||||
print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
|
print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
|
||||||
|
|
||||||
|
print('time: {:.3f}ms'.format((stop_time - start_time) * 1000))
|
||||||
|
|
|
@ -172,7 +172,8 @@ class Interpreter(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_path=None,
|
model_path=None,
|
||||||
model_content=None,
|
model_content=None,
|
||||||
experimental_delegates=None):
|
experimental_delegates=None,
|
||||||
|
num_threads=None):
|
||||||
"""Constructor.
|
"""Constructor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -181,6 +182,10 @@ class Interpreter(object):
|
||||||
experimental_delegates: Experimental. Subject to change. List of
|
experimental_delegates: Experimental. Subject to change. List of
|
||||||
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
|
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
|
||||||
objects returned by lite.load_delegate().
|
objects returned by lite.load_delegate().
|
||||||
|
num_threads: Sets the number of threads used by the interpreter and
|
||||||
|
available to CPU kernels. If not set, the interpreter will use an
|
||||||
|
implementation-dependent default number of threads. Currently,
|
||||||
|
only a subset of kernels, such as conv, support multi-threading.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the interpreter was unable to create.
|
ValueError: If the interpreter was unable to create.
|
||||||
|
@ -206,6 +211,13 @@ class Interpreter(object):
|
||||||
else:
|
else:
|
||||||
raise ValueError('Can\'t both provide `model_path` and `model_content`')
|
raise ValueError('Can\'t both provide `model_path` and `model_content`')
|
||||||
|
|
||||||
|
if num_threads is not None:
|
||||||
|
if not isinstance(num_threads, int):
|
||||||
|
raise ValueError('type of num_threads should be int')
|
||||||
|
if num_threads < 1:
|
||||||
|
raise ValueError('num_threads should >= 1')
|
||||||
|
self._interpreter.SetNumThreads(num_threads)
|
||||||
|
|
||||||
# Each delegate is a wrapper that owns the delegates that have been loaded
|
# Each delegate is a wrapper that owns the delegates that have been loaded
|
||||||
# as plugins. The interpreter wrapper will be using them, but we need to
|
# as plugins. The interpreter wrapper will be using them, but we need to
|
||||||
# hold them in a list so that the lifetime is preserved at least as long as
|
# hold them in a list so that the lifetime is preserved at least as long as
|
||||||
|
|
|
@ -68,6 +68,20 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||||
self.assertAllEqual(zero_points, params['zero_points'])
|
self.assertAllEqual(zero_points, params['zero_points'])
|
||||||
self.assertEqual(quantized_dimension, params['quantized_dimension'])
|
self.assertEqual(quantized_dimension, params['quantized_dimension'])
|
||||||
|
|
||||||
|
def testThreads_NegativeValue(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
'num_threads should >= 1'):
|
||||||
|
interpreter_wrapper.Interpreter(
|
||||||
|
model_path=resource_loader.get_path_to_datafile(
|
||||||
|
'testdata/permute_float.tflite'), num_threads=-1)
|
||||||
|
|
||||||
|
def testThreads_WrongType(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
'type of num_threads should be int'):
|
||||||
|
interpreter_wrapper.Interpreter(
|
||||||
|
model_path=resource_loader.get_path_to_datafile(
|
||||||
|
'testdata/permute_float.tflite'), num_threads=4.2)
|
||||||
|
|
||||||
def testFloat(self):
|
def testFloat(self):
|
||||||
interpreter = interpreter_wrapper.Interpreter(
|
interpreter = interpreter_wrapper.Interpreter(
|
||||||
model_path=resource_loader.get_path_to_datafile(
|
model_path=resource_loader.get_path_to_datafile(
|
||||||
|
@ -100,6 +114,22 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||||
self.assertTrue((expected_output == output_data).all())
|
self.assertTrue((expected_output == output_data).all())
|
||||||
|
|
||||||
|
def testFloatWithTwoThreads(self):
|
||||||
|
interpreter = interpreter_wrapper.Interpreter(
|
||||||
|
model_path=resource_loader.get_path_to_datafile(
|
||||||
|
'testdata/permute_float.tflite'), num_threads=2)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
|
||||||
|
expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
|
||||||
|
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||||
|
interpreter.invoke()
|
||||||
|
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||||
|
self.assertTrue((expected_output == output_data).all())
|
||||||
|
|
||||||
def testUint8(self):
|
def testUint8(self):
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
'testdata/permute_uint8.tflite')
|
'testdata/permute_uint8.tflite')
|
||||||
|
|
|
@ -706,6 +706,12 @@ PyObject* InterpreterWrapper::ResetVariableTensors() {
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
|
||||||
|
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||||
|
interpreter_->SetNumThreads(num_threads);
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
|
PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
|
||||||
TfLiteDelegate* delegate) {
|
TfLiteDelegate* delegate) {
|
||||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||||
|
|
|
@ -87,6 +87,8 @@ class InterpreterWrapper {
|
||||||
// should be the interpreter object providing the memory.
|
// should be the interpreter object providing the memory.
|
||||||
PyObject* tensor(PyObject* base_object, int i);
|
PyObject* tensor(PyObject* base_object, int i);
|
||||||
|
|
||||||
|
PyObject* SetNumThreads(int num_threads);
|
||||||
|
|
||||||
// Adds a delegate to the interpreter.
|
// Adds a delegate to the interpreter.
|
||||||
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);
|
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);
|
||||||
|
|
||||||
|
|
|
@ -145,5 +145,13 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
|
||||||
},
|
},
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Adds a delegate to the interpreter.
|
Adds a delegate to the interpreter.
|
||||||
|
)pbdoc")
|
||||||
|
.def(
|
||||||
|
"SetNumThreads",
|
||||||
|
[](InterpreterWrapper& self, int num_threads) {
|
||||||
|
return tensorflow::PyoOrThrow(self.SetNumThreads(num_threads));
|
||||||
|
},
|
||||||
|
R"pbdoc(
|
||||||
|
ask the interpreter to set the number of threads to use.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ tf_class {
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "allocate_tensors"
|
name: "allocate_tensors"
|
||||||
|
|
|
@ -4,7 +4,7 @@ tf_class {
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "allocate_tensors"
|
name: "allocate_tensors"
|
||||||
|
|
Loading…
Reference in New Issue