To run a specific SignatureDef use get_signature_runner(..) to get a SignatureRunner for running inference. The SignatureRunner returned is a callable object and can be called to invoke inference. Example, my_signature = interpreter.get_signature_runner("my_method") results = my_signature(input_1=input_tensor_1, input_2=input_tensor_2) print(results["my_output"]) To get the details about the available Signatures use interpreter.get_signature_list() Example, signatures = interpreter.get_signature_list() print(signatures) PiperOrigin-RevId: 345583496 Change-Id: Ie2a0a7e4e5676f06e98c82247cf4327534ce308e
155 lines
5.7 KiB
C++
155 lines
5.7 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
|
|
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
|
|
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
// Place `<locale>` before <Python.h> to avoid build failures in macOS.
|
|
#include <locale>
|
|
|
|
// The empty line above is on purpose as otherwise clang-format will
|
|
// automatically move <Python.h> before <locale>.
|
|
#include <Python.h>
|
|
|
|
#include "tensorflow/lite/interpreter.h"
|
|
|
|
struct TfLiteDelegate;
|
|
|
|
// We forward declare TFLite classes here to avoid exposing them to SWIG.
|
|
namespace tflite {
|
|
namespace ops {
|
|
namespace builtin {
|
|
class BuiltinOpResolver;
|
|
} // namespace builtin
|
|
} // namespace ops
|
|
|
|
class FlatBufferModel;
|
|
|
|
namespace interpreter_wrapper {
|
|
|
|
class PythonErrorReporter;
|
|
|
|
class InterpreterWrapper {
|
|
public:
|
|
using Model = FlatBufferModel;
|
|
|
|
// SWIG caller takes ownership of pointer.
|
|
static InterpreterWrapper* CreateWrapperCPPFromFile(
|
|
const char* model_path, const std::vector<std::string>& registerers,
|
|
std::string* error_msg);
|
|
static InterpreterWrapper* CreateWrapperCPPFromFile(
|
|
const char* model_path,
|
|
const std::vector<std::string>& registerers_by_name,
|
|
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
|
|
std::string* error_msg);
|
|
|
|
// SWIG caller takes ownership of pointer.
|
|
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
|
|
PyObject* data, const std::vector<std::string>& registerers,
|
|
std::string* error_msg);
|
|
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
|
|
PyObject* data, const std::vector<std::string>& registerers_by_name,
|
|
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
|
|
std::string* error_msg);
|
|
|
|
~InterpreterWrapper();
|
|
PyObject* AllocateTensors();
|
|
PyObject* Invoke();
|
|
|
|
PyObject* InputIndices() const;
|
|
PyObject* OutputIndices() const;
|
|
PyObject* ResizeInputTensor(int i, PyObject* value, bool strict);
|
|
|
|
int NumTensors() const;
|
|
std::string TensorName(int i) const;
|
|
PyObject* TensorType(int i) const;
|
|
PyObject* TensorSize(int i) const;
|
|
PyObject* TensorSizeSignature(int i) const;
|
|
PyObject* TensorSparsityParameters(int i) const;
|
|
// Deprecated in favor of TensorQuantizationScales, below.
|
|
PyObject* TensorQuantization(int i) const;
|
|
PyObject* TensorQuantizationParameters(int i) const;
|
|
PyObject* SetTensor(int i, PyObject* value);
|
|
PyObject* GetTensor(int i) const;
|
|
PyObject* SetInputTensorFromSignatureDefName(const char* input_name,
|
|
const char* method_name,
|
|
PyObject* value);
|
|
PyObject* GetOutputTensorFromSignatureDefName(const char* output_name,
|
|
const char* method_name) const;
|
|
PyObject* GetSignatureDefs() const;
|
|
PyObject* ResetVariableTensors();
|
|
|
|
int NumNodes() const;
|
|
std::string NodeName(int i) const;
|
|
PyObject* NodeInputs(int i) const;
|
|
PyObject* NodeOutputs(int i) const;
|
|
|
|
// Returns a reference to tensor index i as a numpy array. The base_object
|
|
// should be the interpreter object providing the memory.
|
|
PyObject* tensor(PyObject* base_object, int i);
|
|
|
|
PyObject* SetNumThreads(int num_threads);
|
|
|
|
// Adds a delegate to the interpreter.
|
|
PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate);
|
|
|
|
// Experimental and subject to change.
|
|
//
|
|
// Returns a pointer to the underlying interpreter.
|
|
Interpreter* interpreter() { return interpreter_.get(); }
|
|
|
|
private:
|
|
// Helper function to construct an `InterpreterWrapper` object.
|
|
// It only returns InterpreterWrapper if it can construct an `Interpreter`.
|
|
// Otherwise it returns `nullptr`.
|
|
static InterpreterWrapper* CreateInterpreterWrapper(
|
|
std::unique_ptr<Model> model,
|
|
std::unique_ptr<PythonErrorReporter> error_reporter,
|
|
const std::vector<std::string>& registerers_by_name,
|
|
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
|
|
std::string* error_msg);
|
|
|
|
InterpreterWrapper(
|
|
std::unique_ptr<Model> model,
|
|
std::unique_ptr<PythonErrorReporter> error_reporter,
|
|
std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
|
|
std::unique_ptr<Interpreter> interpreter);
|
|
|
|
// InterpreterWrapper is not copyable or assignable. We avoid the use of
|
|
// InterpreterWrapper() = delete here for SWIG compatibility.
|
|
InterpreterWrapper();
|
|
InterpreterWrapper(const InterpreterWrapper& rhs);
|
|
|
|
// Helper function to resize an input tensor.
|
|
PyObject* ResizeInputTensorImpl(int i, PyObject* value);
|
|
|
|
// The public functions which creates `InterpreterWrapper` should ensure all
|
|
// these member variables are initialized successfully. Otherwise it should
|
|
// report the error and return `nullptr`.
|
|
const std::unique_ptr<Model> model_;
|
|
const std::unique_ptr<PythonErrorReporter> error_reporter_;
|
|
const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_;
|
|
const std::unique_ptr<Interpreter> interpreter_;
|
|
};
|
|
|
|
} // namespace interpreter_wrapper
|
|
} // namespace tflite
|
|
|
|
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
|