address review commments
This commit is contained in:
parent
969b77defb
commit
30e5e29d48
@ -57,13 +57,13 @@ if __name__ == '__main__':
|
||||
help='input standard deviation')
|
||||
parser.add_argument(
|
||||
'--num_threads',
|
||||
default=1,
|
||||
default=1, type=int,
|
||||
help='number of threads')
|
||||
args = parser.parse_args()
|
||||
|
||||
interpreter = tf.lite.Interpreter(
|
||||
model_path=args.model_file,
|
||||
num_threads=int(args.num_threads))
|
||||
num_threads=args.num_threads)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
@ -100,4 +100,5 @@ if __name__ == '__main__':
|
||||
else:
|
||||
print('{:08.6f}: {}'.format(float(results[i] / 255.0), labels[i]))
|
||||
|
||||
print("time: ", stop_time - start_time)
|
||||
#print("time: ", stop_time - start_time)
|
||||
print('time: {:.3f}ms'.format((stop_time - start_time) * 1000))
|
||||
|
@ -523,6 +523,7 @@ class Interpreter(object):
|
||||
def reset_all_variables(self):
|
||||
return self._interpreter.ResetVariableTensors()
|
||||
|
||||
|
||||
class InterpreterWithCustomOps(Interpreter):
|
||||
"""Interpreter interface for TensorFlow Lite Models that accepts custom ops.
|
||||
|
||||
|
@ -706,9 +706,9 @@ PyObject* InterpreterWrapper::ResetVariableTensors() {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::SetNumThreads(int i) {
|
||||
PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
|
||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||
interpreter_->SetNumThreads(i);
|
||||
interpreter_->SetNumThreads(num_threads);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
|
@ -87,7 +87,7 @@ class InterpreterWrapper {
|
||||
// should be the interpreter object providing the memory.
|
||||
PyObject* tensor(PyObject* base_object, int i);
|
||||
|
||||
PyObject* SetNumThreads(int i);
|
||||
PyObject* SetNumThreads(int num_threads);
|
||||
|
||||
// Adds a delegate to the interpreter.
|
||||
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);
|
||||
|
@ -149,7 +149,7 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
|
||||
.def(
|
||||
"SetNumThreads",
|
||||
[](InterpreterWrapper& self, int i) {
|
||||
return tensorflow::pyo_or_throw(self.SetNumThreads(i));
|
||||
return tensorflow::PyoOrThrow(self.SetNumThreads(i));
|
||||
},
|
||||
R"pbdoc(
|
||||
ask the interpreter to set the number of threads to use.
|
||||
|
Loading…
Reference in New Issue
Block a user