address review commments

This commit is contained in:
Koan-Sin Tan 2020-05-20 16:26:15 +08:00
parent 969b77defb
commit 30e5e29d48
5 changed files with 9 additions and 7 deletions

View File

@ -57,13 +57,13 @@ if __name__ == '__main__':
help='input standard deviation') help='input standard deviation')
parser.add_argument( parser.add_argument(
'--num_threads', '--num_threads',
default=1, default=1, type=int,
help='number of threads') help='number of threads')
args = parser.parse_args() args = parser.parse_args()
interpreter = tf.lite.Interpreter( interpreter = tf.lite.Interpreter(
model_path=args.model_file, model_path=args.model_file,
num_threads=int(args.num_threads)) num_threads=args.num_threads)
interpreter.allocate_tensors() interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
@ -100,4 +100,5 @@ if __name__ == '__main__':
else: else:
print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i])) print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
print("time: ", stop_time - start_time) #print("time: ", stop_time - start_time)
print('time: {:.3f}ms'.format((stop_time - start_time) * 1000))

View File

@ -523,6 +523,7 @@ class Interpreter(object):
def reset_all_variables(self): def reset_all_variables(self):
return self._interpreter.ResetVariableTensors() return self._interpreter.ResetVariableTensors()
class InterpreterWithCustomOps(Interpreter): class InterpreterWithCustomOps(Interpreter):
"""Interpreter interface for TensorFlow Lite Models that accepts custom ops. """Interpreter interface for TensorFlow Lite Models that accepts custom ops.

View File

@ -706,9 +706,9 @@ PyObject* InterpreterWrapper::ResetVariableTensors() {
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyObject* InterpreterWrapper::SetNumThreads(int i) { PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
TFLITE_PY_ENSURE_VALID_INTERPRETER(); TFLITE_PY_ENSURE_VALID_INTERPRETER();
interpreter_->SetNumThreads(i); interpreter_->SetNumThreads(num_threads);
Py_RETURN_NONE; Py_RETURN_NONE;
} }

View File

@ -87,7 +87,7 @@ class InterpreterWrapper {
// should be the interpreter object providing the memory. // should be the interpreter object providing the memory.
PyObject* tensor(PyObject* base_object, int i); PyObject* tensor(PyObject* base_object, int i);
PyObject* SetNumThreads(int i); PyObject* SetNumThreads(int num_threads);
// Adds a delegate to the interpreter. // Adds a delegate to the interpreter.
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate); PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);

View File

@ -149,7 +149,7 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
.def( .def(
"SetNumThreads", "SetNumThreads",
[](InterpreterWrapper& self, int i) { [](InterpreterWrapper& self, int i) {
return tensorflow::pyo_or_throw(self.SetNumThreads(i)); return tensorflow::PyoOrThrow(self.SetNumThreads(i));
}, },
R"pbdoc( R"pbdoc(
ask the interpreter to set the number of threads to use. ask the interpreter to set the number of threads to use.