Add Python delegate interface to interpreter.
PiperOrigin-RevId: 252130897
This commit is contained in:
parent
6b9defc3e8
commit
860fc4df25
@ -21,7 +21,10 @@ py_library(
|
||||
py_test(
|
||||
name = "interpreter_test",
|
||||
srcs = ["interpreter_test.py"],
|
||||
data = ["//tensorflow/lite/python/testdata:interpreter_test_data"],
|
||||
data = [
|
||||
"//tensorflow/lite/python/testdata:interpreter_test_data",
|
||||
"//tensorflow/lite/python/testdata:test_delegate.so",
|
||||
],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ctypes
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
@ -47,6 +48,76 @@ except ImportError:
|
||||
_tf_export = tf_export_dummy
|
||||
|
||||
|
||||
class Delegate(object):
|
||||
"""Python wrapper class to manage TfLiteDelegate objects.
|
||||
|
||||
Attributes:
|
||||
library: Name of shared library containing the delegate with two functions:
|
||||
TfLiteDelegate* tflite_plugin_create_delegate (char **, char **, int) void
|
||||
tflite_plugin_destroy_delegate (TfLiteDelegate *)
|
||||
options: Dictionary of options that are required to load the delegate. All
|
||||
keys and values in the dictionary should be serializable. Consult the
|
||||
documentation of the specific delegate for required and legal options.
|
||||
(default None)
|
||||
"""
|
||||
|
||||
def __init__(self, library, options=None):
|
||||
self._library = ctypes.pydll.LoadLibrary(library)
|
||||
self._library.tflite_plugin_create_delegate.argtypes = [
|
||||
ctypes.POINTER(ctypes.c_char_p),
|
||||
ctypes.POINTER(ctypes.c_char_p), ctypes.c_int
|
||||
]
|
||||
self._library.tflite_plugin_create_delegate.restype = ctypes.c_void_p
|
||||
|
||||
# Convert the options from a dictionary to lists of char pointers.
|
||||
options = options or {}
|
||||
options_keys = (ctypes.c_char_p * len(options))()
|
||||
options_values = (ctypes.c_char_p * len(options))()
|
||||
for idx, (key, value) in enumerate(options.items()):
|
||||
options_keys[idx] = str(key)
|
||||
options_values[idx] = str(value)
|
||||
|
||||
# Do not make a copy of _delegate_ptr. It is freed by Delegate's finalizer.
|
||||
self._delegate_ptr = self._library.tflite_plugin_create_delegate(
|
||||
options_keys, options_values, len(options))
|
||||
|
||||
def __del__(self):
|
||||
self._library.tflite_plugin_destroy_delegate.argtypes = [ctypes.c_void_p]
|
||||
self._library.tflite_plugin_destroy_delegate(self._delegate_ptr)
|
||||
|
||||
def _get_native_delegate_pointer(self):
|
||||
"""Returns the native TfLiteDelegate pointer.
|
||||
|
||||
It is not safe to copy this pointer because it needs to be freed.
|
||||
|
||||
Returns:
|
||||
TfLiteDelegate *
|
||||
"""
|
||||
return self._delegate_ptr
|
||||
|
||||
|
||||
@_tf_export('lite.experimental.load_delegate')
|
||||
def load_delegate(library, options=None):
|
||||
"""Returns a Delegate object.
|
||||
|
||||
The `library` is expected to have two functions:
|
||||
TfLiteDelegate* tflite_plugin_create_delegate (char **, char **, int)
|
||||
void tflite_plugin_destroy_delegate (TfLiteDelegate *)
|
||||
|
||||
Args:
|
||||
library: Name of shared library containing the
|
||||
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates).
|
||||
options: Dictionary of options that are required to load the delegate. All
|
||||
keys and values in the dictionary should be serializable. Consult the
|
||||
documentation of the specific delegate for required and legal options.
|
||||
(default None)
|
||||
|
||||
Returns:
|
||||
Delegate object.
|
||||
"""
|
||||
return Delegate(library, options)
|
||||
|
||||
|
||||
@_tf_export('lite.Interpreter')
|
||||
class Interpreter(object):
|
||||
"""Interpreter interface for TensorFlow Lite Models.
|
||||
@ -61,12 +132,19 @@ class Interpreter(object):
|
||||
you must use a synchronization primitive between the threads to ensure invoke
|
||||
has returned before calling tensor().
|
||||
"""
|
||||
def __init__(self, model_path=None, model_content=None):
|
||||
|
||||
def __init__(self,
|
||||
model_path=None,
|
||||
model_content=None,
|
||||
experimental_delegates=None):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
model_path: Path to TF-Lite Flatbuffer file.
|
||||
model_content: Content of model.
|
||||
experimental_delegates: Experimental. Subject to change. List of
|
||||
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
|
||||
objects returned by lite.load_delegate().
|
||||
|
||||
Raises:
|
||||
ValueError: If the interpreter was unable to create.
|
||||
@ -90,6 +168,17 @@ class Interpreter(object):
|
||||
else:
|
||||
raise ValueError('Can\'t both provide `model_path` and `model_content`')
|
||||
|
||||
# Each delegate is a wrapper that owns the delegates that have been loaded
|
||||
# as plugins. The interpreter wrapper will be using them, but we need to
|
||||
# hold them in a list so that the lifetime is preserved at least as long as
|
||||
# the interpreter wrapper.
|
||||
self._delegates = []
|
||||
if experimental_delegates:
|
||||
self._delegates = experimental_delegates
|
||||
for delegate in self._delegates:
|
||||
self._interpreter.ModifyGraphWithDelegate(
|
||||
delegate._get_native_delegate_pointer()) # pylint: disable=protected-access
|
||||
|
||||
def allocate_tensors(self):
|
||||
self._ensure_safe()
|
||||
return self._interpreter.AllocateTensors()
|
||||
|
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ctypes
|
||||
import io
|
||||
import numpy as np
|
||||
import six
|
||||
@ -158,7 +159,7 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'))
|
||||
interpreter.allocate_tensors()
|
||||
#Invalid tensor index passed.
|
||||
# Invalid tensor index passed.
|
||||
with self.assertRaisesRegexp(ValueError, 'Tensor with no shape found.'):
|
||||
interpreter._get_tensor_details(4)
|
||||
|
||||
@ -219,5 +220,104 @@ class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase):
|
||||
_ = self.interpreter.allocate_tensors()
|
||||
del in0safe # make sure in0Safe is held but lint doesn't complain
|
||||
|
||||
|
||||
class InterpreterDelegateTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._delegate_file = resource_loader.get_path_to_datafile(
|
||||
'testdata/test_delegate.so')
|
||||
self._model_file = resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite')
|
||||
|
||||
# Load the library to reset the counters.
|
||||
library = ctypes.pydll.LoadLibrary(self._delegate_file)
|
||||
library.initialize_counters()
|
||||
|
||||
def _TestInterpreter(self, model_path, options=None):
|
||||
"""Test wrapper function that creates an interpreter with the delegate."""
|
||||
delegate = interpreter_wrapper.load_delegate(self._delegate_file, options)
|
||||
return interpreter_wrapper.Interpreter(
|
||||
model_path=model_path, experimental_delegates=[delegate])
|
||||
|
||||
def testDelegate(self):
|
||||
"""Tests the delegate creation and destruction."""
|
||||
interpreter = self._TestInterpreter(model_path=self._model_file)
|
||||
lib = interpreter._delegates[0]._library
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 1)
|
||||
|
||||
del interpreter
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 1)
|
||||
|
||||
def testMultipleInterpreters(self):
|
||||
delegate = interpreter_wrapper.load_delegate(self._delegate_file)
|
||||
lib = delegate._library
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 0)
|
||||
|
||||
interpreter_a = interpreter_wrapper.Interpreter(
|
||||
model_path=self._model_file, experimental_delegates=[delegate])
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 1)
|
||||
|
||||
interpreter_b = interpreter_wrapper.Interpreter(
|
||||
model_path=self._model_file, experimental_delegates=[delegate])
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 2)
|
||||
|
||||
del delegate
|
||||
del interpreter_a
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 2)
|
||||
|
||||
del interpreter_b
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 2)
|
||||
|
||||
def testOptions(self):
|
||||
delegate_a = interpreter_wrapper.load_delegate(self._delegate_file)
|
||||
lib = delegate_a._library
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 1)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 0)
|
||||
self.assertEqual(lib.get_options_counter(), 0)
|
||||
|
||||
delegate_b = interpreter_wrapper.load_delegate(
|
||||
self._delegate_file, options={
|
||||
'unused': False,
|
||||
'options_counter': 2
|
||||
})
|
||||
lib = delegate_b._library
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 2)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 0)
|
||||
self.assertEqual(lib.get_options_counter(), 2)
|
||||
|
||||
del delegate_a
|
||||
del delegate_b
|
||||
|
||||
self.assertEqual(lib.get_num_delegates_created(), 2)
|
||||
self.assertEqual(lib.get_num_delegates_destroyed(), 2)
|
||||
self.assertEqual(lib.get_num_delegates_invoked(), 0)
|
||||
self.assertEqual(lib.get_options_counter(), 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -25,6 +25,7 @@ cc_library(
|
||||
":python_utils",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers",
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
@ -446,5 +447,12 @@ PyObject* InterpreterWrapper::ResetVariableTensors() {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
|
||||
TfLiteDelegate* delegate) {
|
||||
TFLITE_PY_ENSURE_VALID_INTERPRETER();
|
||||
TFLITE_PY_CHECK(interpreter_->ModifyGraphWithDelegate(delegate));
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
} // namespace interpreter_wrapper
|
||||
} // namespace tflite
|
||||
|
@ -26,6 +26,9 @@ limitations under the License.
|
||||
// automatically move <Python.h> before <locale>.
|
||||
#include <Python.h>
|
||||
|
||||
struct _TfLiteDelegate;
|
||||
typedef struct _TfLiteDelegate TfLiteDelegate;
|
||||
|
||||
// We forward declare TFLite classes here to avoid exposing them to SWIG.
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -72,6 +75,9 @@ class InterpreterWrapper {
|
||||
// should be the interpreter object providing the memory.
|
||||
PyObject* tensor(PyObject* base_object, int i);
|
||||
|
||||
// Adds a delegate to the interpreter.
|
||||
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);
|
||||
|
||||
private:
|
||||
// Helper function to construct an `InterpreterWrapper` object.
|
||||
// It only returns InterpreterWrapper if it can construct an `Interpreter`.
|
||||
|
@ -25,6 +25,14 @@ limitations under the License.
|
||||
%}
|
||||
|
||||
|
||||
%typemap(in) TfLiteDelegate* {
|
||||
auto pointer_as_int = PyInt_AsLong($input);
|
||||
static_assert(sizeof(pointer_as_int)==sizeof(TfLiteDelegate*),
|
||||
"TFLiteDelegate must be representable as a long.");
|
||||
$1 = reinterpret_cast<TfLiteDelegate*>(pointer_as_int);
|
||||
}
|
||||
|
||||
|
||||
%include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_im
|
||||
from tensorflow.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
|
||||
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
|
||||
from tensorflow.lite.python.optimize import calibrator as _calibrator
|
||||
|
20
tensorflow/lite/python/testdata/BUILD
vendored
20
tensorflow/lite/python/testdata/BUILD
vendored
@ -52,3 +52,23 @@ filegroup(
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_delegate",
|
||||
testonly = 1,
|
||||
srcs = ["test_delegate.cc"],
|
||||
visibility = ["//tensorflow/lite:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "test_delegate.so",
|
||||
testonly = 1,
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":test_delegate",
|
||||
],
|
||||
)
|
||||
|
77
tensorflow/lite/python/testdata/test_delegate.cc
vendored
Normal file
77
tensorflow/lite/python/testdata/test_delegate.cc
vendored
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
|
||||
int num_delegates_created = 0;
|
||||
int num_delegates_destroyed = 0;
|
||||
int num_delegates_invoked = 0;
|
||||
int options_counter = 0;
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
TfLiteDelegate* tflite_plugin_create_delegate(char** options_keys,
|
||||
char** options_values,
|
||||
size_t num_options) {
|
||||
num_delegates_created++;
|
||||
|
||||
for (int idx = 0; idx < num_options; idx++) {
|
||||
if (std::strncmp("options_counter", options_keys[idx], 15) == 0) {
|
||||
int int_value;
|
||||
if (sscanf(options_values[idx], "%d", &int_value) == 1) {
|
||||
options_counter += int_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteDelegate* ptr = new TfLiteDelegate;
|
||||
ptr->Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
num_delegates_invoked++;
|
||||
return kTfLiteOk;
|
||||
};
|
||||
ptr->flags = kTfLiteDelegateFlagsNone;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void tflite_plugin_destroy_delegate(TfLiteDelegate* delegate) {
|
||||
num_delegates_destroyed++;
|
||||
delete delegate;
|
||||
}
|
||||
|
||||
void initialize_counters() {
|
||||
num_delegates_created = 0;
|
||||
num_delegates_destroyed = 0;
|
||||
num_delegates_invoked = 0;
|
||||
options_counter = 0;
|
||||
}
|
||||
|
||||
int get_num_delegates_created() { return num_delegates_created; }
|
||||
|
||||
int get_num_delegates_destroyed() { return num_delegates_destroyed; }
|
||||
|
||||
int get_num_delegates_invoked() { return num_delegates_invoked; }
|
||||
|
||||
int get_options_counter() { return options_counter; }
|
||||
}
|
||||
|
||||
} // namespace tflite
|
@ -30,6 +30,7 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"queue/__init__.py",
|
||||
"linalg/__init__.py",
|
||||
"lite/__init__.py",
|
||||
"lite/experimental/__init__.py",
|
||||
"lookup/__init__.py",
|
||||
"lookup/experimental/__init__.py",
|
||||
"math/__init__.py",
|
||||
|
@ -4,7 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'model_path\', \'model_content\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "allocate_tensors"
|
||||
|
@ -12,4 +12,8 @@ tf_module {
|
||||
name: "get_potentially_supported_ops"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "load_delegate"
|
||||
argspec: "args=[\'library\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'model_path\', \'model_content\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "allocate_tensors"
|
||||
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.lite.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "load_delegate"
|
||||
argspec: "args=[\'library\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
@ -24,4 +24,8 @@ tf_module {
|
||||
name: "TargetSpec"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user