Add Python delegate interface to interpreter.

PiperOrigin-RevId: 252130897
This commit is contained in:
Nupur Garg 2019-06-07 15:25:52 -07:00 committed by TensorFlower Gardener
parent 6b9defc3e8
commit 860fc4df25
16 changed files with 334 additions and 5 deletions

View File

@ -21,7 +21,10 @@ py_library(
py_test(
name = "interpreter_test",
srcs = ["interpreter_test.py"],
data = ["//tensorflow/lite/python/testdata:interpreter_test_data"],
data = [
"//tensorflow/lite/python/testdata:interpreter_test_data",
"//tensorflow/lite/python/testdata:test_delegate.so",
],
python_version = "PY2",
srcs_version = "PY2AND3",
tags = [

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ctypes
import sys
import numpy as np
@ -47,6 +48,76 @@ except ImportError:
_tf_export = tf_export_dummy
class Delegate(object):
"""Python wrapper class to manage TfLiteDelegate objects.
Attributes:
library: Name of shared library containing the delegate with two functions:
TfLiteDelegate* tflite_plugin_create_delegate (char **, char **, int) void
tflite_plugin_destroy_delegate (TfLiteDelegate *)
options: Dictionary of options that are required to load the delegate. All
keys and values in the dictionary should be serializable. Consult the
documentation of the specific delegate for required and legal options.
(default None)
"""
def __init__(self, library, options=None):
self._library = ctypes.pydll.LoadLibrary(library)
self._library.tflite_plugin_create_delegate.argtypes = [
ctypes.POINTER(ctypes.c_char_p),
ctypes.POINTER(ctypes.c_char_p), ctypes.c_int
]
self._library.tflite_plugin_create_delegate.restype = ctypes.c_void_p
# Convert the options from a dictionary to lists of char pointers.
options = options or {}
options_keys = (ctypes.c_char_p * len(options))()
options_values = (ctypes.c_char_p * len(options))()
for idx, (key, value) in enumerate(options.items()):
options_keys[idx] = str(key)
options_values[idx] = str(value)
# Do not make a copy of _delegate_ptr. It is freed by Delegate's finalizer.
self._delegate_ptr = self._library.tflite_plugin_create_delegate(
options_keys, options_values, len(options))
def __del__(self):
self._library.tflite_plugin_destroy_delegate.argtypes = [ctypes.c_void_p]
self._library.tflite_plugin_destroy_delegate(self._delegate_ptr)
def _get_native_delegate_pointer(self):
"""Returns the native TfLiteDelegate pointer.
It is not safe to copy this pointer because it needs to be freed.
Returns:
TfLiteDelegate *
"""
return self._delegate_ptr
@_tf_export('lite.experimental.load_delegate')
def load_delegate(library, options=None):
"""Returns a Delegate object.
The `library` is expected to have two functions:
TfLiteDelegate* tflite_plugin_create_delegate (char **, char **, int)
void tflite_plugin_destroy_delegate (TfLiteDelegate *)
Args:
library: Name of shared library containing the
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates).
options: Dictionary of options that are required to load the delegate. All
keys and values in the dictionary should be serializable. Consult the
documentation of the specific delegate for required and legal options.
(default None)
Returns:
Delegate object.
"""
return Delegate(library, options)
@_tf_export('lite.Interpreter')
class Interpreter(object):
"""Interpreter interface for TensorFlow Lite Models.
@ -61,12 +132,19 @@ class Interpreter(object):
you must use a synchronization primitive between the threads to ensure invoke
has returned before calling tensor().
"""
def __init__(self, model_path=None, model_content=None):
def __init__(self,
model_path=None,
model_content=None,
experimental_delegates=None):
"""Constructor.
Args:
model_path: Path to TF-Lite Flatbuffer file.
model_content: Content of model.
experimental_delegates: Experimental. Subject to change. List of
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
objects returned by lite.load_delegate().
Raises:
ValueError: If the interpreter was unable to create.
@ -90,6 +168,17 @@ class Interpreter(object):
else:
raise ValueError('Can\'t both provide `model_path` and `model_content`')
# Each delegate is a wrapper that owns the delegates that have been loaded
# as plugins. The interpreter wrapper will be using them, but we need to
# hold them in a list so that the lifetime is preserved at least as long as
# the interpreter wrapper.
self._delegates = []
if experimental_delegates:
self._delegates = experimental_delegates
for delegate in self._delegates:
self._interpreter.ModifyGraphWithDelegate(
delegate._get_native_delegate_pointer()) # pylint: disable=protected-access
def allocate_tensors(self):
self._ensure_safe()
return self._interpreter.AllocateTensors()

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ctypes
import io
import numpy as np
import six
@ -158,7 +159,7 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'))
interpreter.allocate_tensors()
#Invalid tensor index passed.
# Invalid tensor index passed.
with self.assertRaisesRegexp(ValueError, 'Tensor with no shape found.'):
interpreter._get_tensor_details(4)
@ -219,5 +220,104 @@ class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase):
_ = self.interpreter.allocate_tensors()
del in0safe # make sure in0Safe is held but lint doesn't complain
class InterpreterDelegateTest(test_util.TensorFlowTestCase):
def setUp(self):
self._delegate_file = resource_loader.get_path_to_datafile(
'testdata/test_delegate.so')
self._model_file = resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite')
# Load the library to reset the counters.
library = ctypes.pydll.LoadLibrary(self._delegate_file)
library.initialize_counters()
def _TestInterpreter(self, model_path, options=None):
"""Test wrapper function that creates an interpreter with the delegate."""
delegate = interpreter_wrapper.load_delegate(self._delegate_file, options)
return interpreter_wrapper.Interpreter(
model_path=model_path, experimental_delegates=[delegate])
def testDelegate(self):
"""Tests the delegate creation and destruction."""
interpreter = self._TestInterpreter(model_path=self._model_file)
lib = interpreter._delegates[0]._library
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 1)
del interpreter
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 1)
self.assertEqual(lib.get_num_delegates_invoked(), 1)
def testMultipleInterpreters(self):
delegate = interpreter_wrapper.load_delegate(self._delegate_file)
lib = delegate._library
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 0)
interpreter_a = interpreter_wrapper.Interpreter(
model_path=self._model_file, experimental_delegates=[delegate])
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 1)
interpreter_b = interpreter_wrapper.Interpreter(
model_path=self._model_file, experimental_delegates=[delegate])
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 2)
del delegate
del interpreter_a
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 2)
del interpreter_b
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 1)
self.assertEqual(lib.get_num_delegates_invoked(), 2)
def testOptions(self):
delegate_a = interpreter_wrapper.load_delegate(self._delegate_file)
lib = delegate_a._library
self.assertEqual(lib.get_num_delegates_created(), 1)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 0)
self.assertEqual(lib.get_options_counter(), 0)
delegate_b = interpreter_wrapper.load_delegate(
self._delegate_file, options={
'unused': False,
'options_counter': 2
})
lib = delegate_b._library
self.assertEqual(lib.get_num_delegates_created(), 2)
self.assertEqual(lib.get_num_delegates_destroyed(), 0)
self.assertEqual(lib.get_num_delegates_invoked(), 0)
self.assertEqual(lib.get_options_counter(), 2)
del delegate_a
del delegate_b
self.assertEqual(lib.get_num_delegates_created(), 2)
self.assertEqual(lib.get_num_delegates_destroyed(), 2)
self.assertEqual(lib.get_num_delegates_invoked(), 0)
self.assertEqual(lib.get_options_counter(), 2)
if __name__ == '__main__':
test.main()

View File

@ -25,6 +25,7 @@ cc_library(
":python_utils",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:builtin_ops",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
@ -446,5 +447,12 @@ PyObject* InterpreterWrapper::ResetVariableTensors() {
Py_RETURN_NONE;
}
PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
TfLiteDelegate* delegate) {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
TFLITE_PY_CHECK(interpreter_->ModifyGraphWithDelegate(delegate));
Py_RETURN_NONE;
}
} // namespace interpreter_wrapper
} // namespace tflite

View File

@ -26,6 +26,9 @@ limitations under the License.
// automatically move <Python.h> before <locale>.
#include <Python.h>
struct _TfLiteDelegate;
typedef struct _TfLiteDelegate TfLiteDelegate;
// We forward declare TFLite classes here to avoid exposing them to SWIG.
namespace tflite {
namespace ops {
@ -72,6 +75,9 @@ class InterpreterWrapper {
// should be the interpreter object providing the memory.
PyObject* tensor(PyObject* base_object, int i);
// Adds a delegate to the interpreter.
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);
private:
// Helper function to construct an `InterpreterWrapper` object.
// It only returns InterpreterWrapper if it can construct an `Interpreter`.

View File

@ -25,6 +25,14 @@ limitations under the License.
%}
%typemap(in) TfLiteDelegate* {
auto pointer_as_int = PyInt_AsLong($input);
static_assert(sizeof(pointer_as_int)==sizeof(TfLiteDelegate*),
"TFLiteDelegate must be representable as a long.");
$1 = reinterpret_cast<TfLiteDelegate*>(pointer_as_int);
}
%include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
namespace tflite {

View File

@ -39,6 +39,7 @@ from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_im
from tensorflow.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
from tensorflow.lite.python.optimize import calibrator as _calibrator

View File

@ -52,3 +52,23 @@ filegroup(
],
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "test_delegate",
testonly = 1,
srcs = ["test_delegate.cc"],
visibility = ["//tensorflow/lite:__subpackages__"],
deps = [
"//tensorflow/lite/c:c_api_internal",
],
)
cc_binary(
name = "test_delegate.so",
testonly = 1,
linkshared = 1,
linkstatic = 1,
deps = [
":test_delegate",
],
)

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "tensorflow/lite/c/c_api_internal.h"
namespace tflite {
namespace {
int num_delegates_created = 0;
int num_delegates_destroyed = 0;
int num_delegates_invoked = 0;
int options_counter = 0;
} // namespace
extern "C" {
TfLiteDelegate* tflite_plugin_create_delegate(char** options_keys,
char** options_values,
size_t num_options) {
num_delegates_created++;
for (int idx = 0; idx < num_options; idx++) {
if (std::strncmp("options_counter", options_keys[idx], 15) == 0) {
int int_value;
if (sscanf(options_values[idx], "%d", &int_value) == 1) {
options_counter += int_value;
}
}
}
TfLiteDelegate* ptr = new TfLiteDelegate;
ptr->Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) {
num_delegates_invoked++;
return kTfLiteOk;
};
ptr->flags = kTfLiteDelegateFlagsNone;
return ptr;
}
void tflite_plugin_destroy_delegate(TfLiteDelegate* delegate) {
num_delegates_destroyed++;
delete delegate;
}
void initialize_counters() {
num_delegates_created = 0;
num_delegates_destroyed = 0;
num_delegates_invoked = 0;
options_counter = 0;
}
int get_num_delegates_created() { return num_delegates_created; }
int get_num_delegates_destroyed() { return num_delegates_destroyed; }
int get_num_delegates_invoked() { return num_delegates_invoked; }
int get_options_counter() { return options_counter; }
}
} // namespace tflite

View File

@ -30,6 +30,7 @@ TENSORFLOW_API_INIT_FILES = [
"queue/__init__.py",
"linalg/__init__.py",
"lite/__init__.py",
"lite/experimental/__init__.py",
"lookup/__init__.py",
"lookup/experimental/__init__.py",
"math/__init__.py",

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'model_path\', \'model_content\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "allocate_tensors"

View File

@ -12,4 +12,8 @@ tf_module {
name: "get_potentially_supported_ops"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "load_delegate"
argspec: "args=[\'library\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'model_path\', \'model_content\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'self\', \'model_path\', \'model_content\', \'experimental_delegates\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "allocate_tensors"

View File

@ -0,0 +1,7 @@
path: "tensorflow.lite.experimental"
tf_module {
member_method {
name: "load_delegate"
argspec: "args=[\'library\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -24,4 +24,8 @@ tf_module {
name: "TargetSpec"
mtype: "<type \'type\'>"
}
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}