From 30e5e29d484ac8cfa196ae037546d01748a8b56f Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Wed, 20 May 2020 16:26:15 +0800 Subject: [PATCH] address review commments --- tensorflow/lite/examples/python/label_image.py | 7 ++++--- tensorflow/lite/python/interpreter.py | 1 + .../lite/python/interpreter_wrapper/interpreter_wrapper.cc | 4 ++-- .../lite/python/interpreter_wrapper/interpreter_wrapper.h | 2 +- .../interpreter_wrapper/interpreter_wrapper_pybind11.cc | 2 +- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/examples/python/label_image.py b/tensorflow/lite/examples/python/label_image.py index c6e0fbdb2bd..9d3c47fc4c2 100644 --- a/tensorflow/lite/examples/python/label_image.py +++ b/tensorflow/lite/examples/python/label_image.py @@ -57,13 +57,13 @@ if __name__ == '__main__': help='input standard deviation') parser.add_argument( '--num_threads', - default=1, + default=1, type=int, help='number of threads') args = parser.parse_args() interpreter = tf.lite.Interpreter( model_path=args.model_file, - num_threads=int(args.num_threads)) + num_threads=args.num_threads) interpreter.allocate_tensors() input_details = interpreter.get_input_details() @@ -100,4 +100,5 @@ if __name__ == '__main__': else: print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i])) - print("time: ", stop_time - start_time) + #print("time: ", stop_time - start_time) + print('time: {:.3f}ms'.format((stop_time - start_time) * 1000)) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index 4c2528000da..36ea2e26943 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -523,6 +523,7 @@ class Interpreter(object): def reset_all_variables(self): return self._interpreter.ResetVariableTensors() + class InterpreterWithCustomOps(Interpreter): """Interpreter interface for TensorFlow Lite Models that accepts custom ops. diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index c457e68c91b..2a8c1ffdcd6 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -706,9 +706,9 @@ PyObject* InterpreterWrapper::ResetVariableTensors() { Py_RETURN_NONE; } -PyObject* InterpreterWrapper::SetNumThreads(int i) { +PyObject* InterpreterWrapper::SetNumThreads(int num_threads) { TFLITE_PY_ENSURE_VALID_INTERPRETER(); - interpreter_->SetNumThreads(i); + interpreter_->SetNumThreads(num_threads); Py_RETURN_NONE; } diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index d7141189319..b799a3067f6 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -87,7 +87,7 @@ class InterpreterWrapper { // should be the interpreter object providing the memory. PyObject* tensor(PyObject* base_object, int i); - PyObject* SetNumThreads(int i); + PyObject* SetNumThreads(int num_threads); // Adds a delegate to the interpreter. PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate); diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc index 55c377c2bf1..74bbf6fdedd 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc @@ -149,7 +149,7 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) { .def( "SetNumThreads", [](InterpreterWrapper& self, int i) { - return tensorflow::pyo_or_throw(self.SetNumThreads(i)); + return tensorflow::PyoOrThrow(self.SetNumThreads(i)); }, R"pbdoc( ask the interpreter to set the number of threads to use.