address review commments

This commit is contained in:
Koan-Sin Tan 2020-05-20 16:26:15 +08:00
parent 969b77defb
commit 30e5e29d48
5 changed files with 9 additions and 7 deletions

View File

@ -57,13 +57,13 @@ if __name__ == '__main__':
help='input standard deviation')
parser.add_argument(
'--num_threads',
default=1,
default=1, type=int,
help='number of threads')
args = parser.parse_args()
interpreter = tf.lite.Interpreter(
model_path=args.model_file,
num_threads=int(args.num_threads))
num_threads=args.num_threads)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
@ -100,4 +100,5 @@ if __name__ == '__main__':
else:
print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
print("time: ", stop_time - start_time)
#print("time: ", stop_time - start_time)
print('time: {:.3f}ms'.format((stop_time - start_time) * 1000))

View File

@ -523,6 +523,7 @@ class Interpreter(object):
def reset_all_variables(self):
return self._interpreter.ResetVariableTensors()
class InterpreterWithCustomOps(Interpreter):
"""Interpreter interface for TensorFlow Lite Models that accepts custom ops.

View File

@ -706,9 +706,9 @@ PyObject* InterpreterWrapper::ResetVariableTensors() {
Py_RETURN_NONE;
}
PyObject* InterpreterWrapper::SetNumThreads(int i) {
PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
interpreter_->SetNumThreads(i);
interpreter_->SetNumThreads(num_threads);
Py_RETURN_NONE;
}

View File

@ -87,7 +87,7 @@ class InterpreterWrapper {
// should be the interpreter object providing the memory.
PyObject* tensor(PyObject* base_object, int i);
PyObject* SetNumThreads(int i);
PyObject* SetNumThreads(int num_threads);
// Adds a delegate to the interpreter.
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);

View File

@ -149,7 +149,7 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
.def(
"SetNumThreads",
[](InterpreterWrapper& self, int i) {
return tensorflow::pyo_or_throw(self.SetNumThreads(i));
return tensorflow::PyoOrThrow(self.SetNumThreads(i));
},
R"pbdoc(
ask the interpreter to set the number of threads to use.