Add Interpreter changes for SignatureDef support.
This change includes updates to InterpreterBuilder to use SignatureDef available in the tflite file. Also, updates Interpreter API to - List all signatures available - Fetch Inputs/Outputs in single signature - Fetch Input/Output tensor using name defined in SignatureDef. PiperOrigin-RevId: 338711676 Change-Id: I70355ece46295cec57cc2e3732309ad3e62f8708
This commit is contained in:
parent
fba2679b56
commit
92bcc8d011
@ -22,7 +22,9 @@ limitations under the License.
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/allocation.h"
|
#include "tensorflow/lite/allocation.h"
|
||||||
@ -275,6 +277,70 @@ class Interpreter {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// WARNING: Experimental interface, subject to change
|
||||||
|
/// Returns list of all names of different method signatures defined
|
||||||
|
/// in the model.
|
||||||
|
/// Note, pointers returned have lifetime same as the Interpreter object.
|
||||||
|
std::vector<const std::string*> signature_def_names() const {
|
||||||
|
std::vector<const std::string*> method_names;
|
||||||
|
method_names.reserve(signature_defs_.size());
|
||||||
|
for (const auto& sig_def : signature_defs_) {
|
||||||
|
method_names.emplace_back(&sig_def.method_name);
|
||||||
|
}
|
||||||
|
return method_names;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WARNING: Experimental interface, subject to change
|
||||||
|
/// Returns the mapping of inputs to tensor index in the signature
|
||||||
|
/// specified through 'method_name'.
|
||||||
|
/// If invalid name passed, an empty list will be returned.
|
||||||
|
const std::map<std::string, uint32_t>& signature_inputs(
|
||||||
|
const char* method_name) const {
|
||||||
|
for (const auto& sig_def : signature_defs_) {
|
||||||
|
if (sig_def.method_name == method_name) return sig_def.inputs;
|
||||||
|
}
|
||||||
|
static const std::map<std::string, uint32_t>* default_empty_list =
|
||||||
|
new std::map<std::string, uint32_t>();
|
||||||
|
return *default_empty_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WARNING: Experimental interface, subject to change
|
||||||
|
/// Returns the mapping of outputs to tensor index in the signature
|
||||||
|
/// specified through 'method_name'.
|
||||||
|
/// If invalid name passed, an empty list will be returned.
|
||||||
|
const std::map<std::string, uint32_t>& signature_outputs(
|
||||||
|
const char* method_name) const {
|
||||||
|
for (const auto& sig_def : signature_defs_) {
|
||||||
|
if (sig_def.method_name == method_name) return sig_def.outputs;
|
||||||
|
}
|
||||||
|
static const std::map<std::string, uint32_t>* default_empty_list =
|
||||||
|
new std::map<std::string, uint32_t>();
|
||||||
|
return *default_empty_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WARNING: Experimental interface, subject to change
|
||||||
|
/// Returns the input tensor identified by 'signature_input_name' in the
|
||||||
|
/// signature identified by 'signature_method_name'.
|
||||||
|
/// Returns nullptr if not found.
|
||||||
|
TfLiteTensor* input_tensor_by_signature_name(
|
||||||
|
const char* signature_input_name, const char* signature_method_name) {
|
||||||
|
const int tensor_index = GetTensorIndexFromSignatureDefName(
|
||||||
|
signature_input_name, signature_method_name, /*is_input=*/true);
|
||||||
|
return tensor_index == -1 ? nullptr : tensor(tensor_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WARNING: Experimental interface, subject to change
|
||||||
|
/// Returns the output tensor identified by 'signature_output_name' in the
|
||||||
|
/// signature identified by 'signature_method_name'.
|
||||||
|
/// Returns nullptr if not found.
|
||||||
|
const TfLiteTensor* output_tensor_by_signature_name(
|
||||||
|
const char* signature_output_name,
|
||||||
|
const char* signature_method_name) const {
|
||||||
|
const int tensor_index = GetTensorIndexFromSignatureDefName(
|
||||||
|
signature_output_name, signature_method_name, /*is_input=*/false);
|
||||||
|
return tensor_index == -1 ? nullptr : tensor(tensor_index);
|
||||||
|
}
|
||||||
|
|
||||||
/// Return a mutable pointer to the given input tensor. The given index must
|
/// Return a mutable pointer to the given input tensor. The given index must
|
||||||
/// be between 0 and inputs().size().
|
/// be between 0 and inputs().size().
|
||||||
TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); }
|
TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); }
|
||||||
@ -592,6 +658,17 @@ class Interpreter {
|
|||||||
#endif // DOXYGEN_SKIP
|
#endif // DOXYGEN_SKIP
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Structure representing SignatureDef inputs/outputs.
|
||||||
|
struct SignatureDef {
|
||||||
|
// Maps name in signature def as key to index of the tensor in the model.
|
||||||
|
std::map<std::string, uint32_t> inputs;
|
||||||
|
// Maps name in signature def as key to index of the tensor in the model.
|
||||||
|
std::map<std::string, uint32_t> outputs;
|
||||||
|
// The method name for this signature.
|
||||||
|
std::string method_name;
|
||||||
|
// The key of this SignatureDef in the SavedModel signature def map.
|
||||||
|
std::string signature_def_key;
|
||||||
|
};
|
||||||
friend class InterpreterBuilder;
|
friend class InterpreterBuilder;
|
||||||
friend class tflite::InterpreterTest;
|
friend class tflite::InterpreterTest;
|
||||||
friend class tflite::TestDelegate;
|
friend class tflite::TestDelegate;
|
||||||
@ -602,6 +679,26 @@ class Interpreter {
|
|||||||
TfLiteExternalContextType type,
|
TfLiteExternalContextType type,
|
||||||
TfLiteExternalContext* ctx);
|
TfLiteExternalContext* ctx);
|
||||||
|
|
||||||
|
// Helper method that return the tensot index that corresponds to
|
||||||
|
// a name in a SignatureDef. Defined by 'signature_method_name', and
|
||||||
|
// 'signature_tensor_name'.
|
||||||
|
// If 'is_input' is true then the tensor is checked in input tensors,
|
||||||
|
// otherwise it will be checked in output tensors.
|
||||||
|
// Returns -1 if the tensor is not found.
|
||||||
|
int GetTensorIndexFromSignatureDefName(const char* signature_tensor_name,
|
||||||
|
const char* signature_method_name,
|
||||||
|
bool is_input) const {
|
||||||
|
// Iterate directly and don't use other methods to avoid extra allocation.
|
||||||
|
for (const auto& signature : signature_defs_) {
|
||||||
|
if (signature.method_name != signature_method_name) continue;
|
||||||
|
auto& signature_list = (is_input ? signature.inputs : signature.outputs);
|
||||||
|
auto tensor_iter = signature_list.find(signature_tensor_name);
|
||||||
|
if (tensor_iter == signature_list.end()) return -1;
|
||||||
|
return tensor_iter->second;
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
// Sets the profiler to all subgraphs.
|
// Sets the profiler to all subgraphs.
|
||||||
void SetSubgraphProfiler();
|
void SetSubgraphProfiler();
|
||||||
|
|
||||||
@ -615,6 +712,11 @@ class Interpreter {
|
|||||||
// Returns true if cancellation function returns true.
|
// Returns true if cancellation function returns true.
|
||||||
bool IsCancelled();
|
bool IsCancelled();
|
||||||
|
|
||||||
|
// Sets the list of signature defs in the model.
|
||||||
|
void SetSignatureDef(std::vector<SignatureDef> signature_defs) {
|
||||||
|
signature_defs_ = std::move(signature_defs);
|
||||||
|
}
|
||||||
|
|
||||||
// A pure C data structure used to communicate with the pure C plugin
|
// A pure C data structure used to communicate with the pure C plugin
|
||||||
// interface. To avoid copying tensor metadata, this is also the definitive
|
// interface. To avoid copying tensor metadata, this is also the definitive
|
||||||
// structure to store tensors.
|
// structure to store tensors.
|
||||||
@ -661,6 +763,10 @@ class Interpreter {
|
|||||||
// An empty one means there's no delegate to be applied by default or
|
// An empty one means there's no delegate to be applied by default or
|
||||||
// delegates have been applied and doesn't need to be applied again.
|
// delegates have been applied and doesn't need to be applied again.
|
||||||
std::vector<TfLiteDelegatePtr> lazy_delegate_providers_;
|
std::vector<TfLiteDelegatePtr> lazy_delegate_providers_;
|
||||||
|
|
||||||
|
// List of signature def mapping inputs/output to tensor ids.
|
||||||
|
// We just keep track of tensor index.
|
||||||
|
std::vector<SignatureDef> signature_defs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@ -21,6 +21,9 @@ limitations under the License.
|
|||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "tensorflow/lite/allocation.h"
|
#include "tensorflow/lite/allocation.h"
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
@ -103,6 +106,20 @@ TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src,
|
|||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper that returns std::map that corresponds to vector of TensorMap.
|
||||||
|
std::map<std::string, uint32_t> GetMapFromTensorMap(
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>>*
|
||||||
|
tensor_map) {
|
||||||
|
if (!tensor_map) return {};
|
||||||
|
std::map<std::string, uint32_t> result;
|
||||||
|
for (const auto tensor : *tensor_map) {
|
||||||
|
if (tensor != nullptr && tensor->name() != nullptr) {
|
||||||
|
result[tensor->name()->c_str()] = tensor->tensor_index();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const char* kEmptyTensorName = "";
|
const char* kEmptyTensorName = "";
|
||||||
@ -435,6 +452,48 @@ TfLiteStatus InterpreterBuilder::ParseSparsity(
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteStatus InterpreterBuilder::ParseSignatureDefs(
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
|
||||||
|
signature_def_list,
|
||||||
|
Interpreter* interpreter) {
|
||||||
|
if (signature_def_list == nullptr || signature_def_list->size() == 0) {
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
std::vector<Interpreter::SignatureDef> signature_defs;
|
||||||
|
signature_defs.reserve(signature_def_list->size());
|
||||||
|
for (const auto fb_signature_def : *signature_def_list) {
|
||||||
|
if (fb_signature_def == nullptr) {
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter_, "NULL SignatureDef in the model.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (fb_signature_def->method_name() == nullptr) {
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||||
|
"Missing exported method name for SignatureDef");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (fb_signature_def->inputs() == nullptr) {
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||||
|
"NULL SignatureDef inputs for exported method %s",
|
||||||
|
fb_signature_def->method_name()->c_str());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (fb_signature_def->outputs() == nullptr) {
|
||||||
|
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||||
|
"NULL SignatureDef outputs for exported method %s",
|
||||||
|
fb_signature_def->method_name()->c_str());
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
signature_defs.resize(signature_defs.size() + 1);
|
||||||
|
auto& signature_def = signature_defs.back();
|
||||||
|
signature_def.inputs = GetMapFromTensorMap(fb_signature_def->inputs());
|
||||||
|
signature_def.outputs = GetMapFromTensorMap(fb_signature_def->outputs());
|
||||||
|
signature_def.method_name = fb_signature_def->method_name()->c_str();
|
||||||
|
signature_def.signature_def_key = fb_signature_def->key()->c_str();
|
||||||
|
}
|
||||||
|
interpreter->SetSignatureDef(std::move(signature_defs));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus InterpreterBuilder::ParseTensors(
|
TfLiteStatus InterpreterBuilder::ParseTensors(
|
||||||
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
|
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
|
||||||
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
|
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
|
||||||
@ -667,6 +726,11 @@ TfLiteStatus InterpreterBuilder::operator()(
|
|||||||
modified_subgraph->SetVariables(std::move(variables));
|
modified_subgraph->SetVariables(std::move(variables));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=
|
||||||
|
kTfLiteOk) {
|
||||||
|
return cleanup_and_error();
|
||||||
|
}
|
||||||
|
|
||||||
if (num_fp32_tensors_ > 0) {
|
if (num_fp32_tensors_ > 0) {
|
||||||
(*interpreter)->lazy_delegate_providers_ =
|
(*interpreter)->lazy_delegate_providers_ =
|
||||||
op_resolver_.GetDelegates(num_threads);
|
op_resolver_.GetDelegates(num_threads);
|
||||||
|
|||||||
@ -80,6 +80,10 @@ class InterpreterBuilder {
|
|||||||
const std::vector<int>& dims);
|
const std::vector<int>& dims);
|
||||||
TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
|
TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
|
||||||
TfLiteSparsity** sparsity);
|
TfLiteSparsity** sparsity);
|
||||||
|
TfLiteStatus ParseSignatureDefs(
|
||||||
|
const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
|
||||||
|
signature_def_list,
|
||||||
|
Interpreter* interpreter);
|
||||||
|
|
||||||
const ::tflite::Model* model_;
|
const ::tflite::Model* model_;
|
||||||
const OpResolver& op_resolver_;
|
const OpResolver& op_resolver_;
|
||||||
|
|||||||
@ -56,6 +56,17 @@ class InterpreterTest : public ::testing::Test {
|
|||||||
|
|
||||||
bool HasDelegates() { return interpreter_.HasDelegates(); }
|
bool HasDelegates() { return interpreter_.HasDelegates(); }
|
||||||
|
|
||||||
|
void BuildSignature(const std::string& method_name, const std::string& key,
|
||||||
|
const std::map<std::string, uint32_t>& inputs,
|
||||||
|
const std::map<std::string, uint32_t>& outputs) {
|
||||||
|
Interpreter::SignatureDef signature;
|
||||||
|
signature.inputs = inputs;
|
||||||
|
signature.outputs = outputs;
|
||||||
|
signature.method_name = method_name;
|
||||||
|
signature.signature_def_key = key;
|
||||||
|
interpreter_.SetSignatureDef({signature});
|
||||||
|
}
|
||||||
|
|
||||||
Interpreter interpreter_;
|
Interpreter interpreter_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1846,6 +1857,70 @@ TEST_F(TestLazyDelegateProvider, ApplicationSkipped) {
|
|||||||
EXPECT_FALSE(HasDelegates());
|
EXPECT_FALSE(HasDelegates());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(InterpreterTest, SingleSignature_get_signatures) {
|
||||||
|
const char kMethodName[] = "test_method";
|
||||||
|
const char kSignatureDefKey[] = "test_key";
|
||||||
|
BuildSignature(kMethodName, kSignatureDefKey, {{"Input1", 0}, {"Input2", 1}},
|
||||||
|
{{"Output1", 5}});
|
||||||
|
auto results = interpreter_.signature_def_names();
|
||||||
|
ASSERT_EQ(1, results.size());
|
||||||
|
EXPECT_EQ(kMethodName, *results[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(InterpreterTest, SingleSignature_get_inputs) {
|
||||||
|
const char kMethodName[] = "test_method";
|
||||||
|
const char kSignatureDefKey[] = "test_key";
|
||||||
|
const std::map<std::string, uint32_t> inputs = {{"Input1", 0}, {"Input2", 1}};
|
||||||
|
const std::map<std::string, uint32_t> outputs = {{"Output1", 5}};
|
||||||
|
BuildSignature(kMethodName, kSignatureDefKey, inputs, outputs);
|
||||||
|
EXPECT_THAT(interpreter_.signature_inputs(kMethodName), testing::Eq(inputs));
|
||||||
|
EXPECT_THAT(interpreter_.signature_outputs(kMethodName),
|
||||||
|
testing::Eq(outputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(InterpreterTest, SingleSignature_validate_get_tensor) {
|
||||||
|
const char kMethodName[] = "test_method";
|
||||||
|
const char kSignatureDefKey[] = "test_key";
|
||||||
|
const std::map<std::string, uint32_t> inputs = {{"Input1", 0}, {"Input2", 1}};
|
||||||
|
const std::map<std::string, uint32_t> outputs = {{"Output1", 5}};
|
||||||
|
|
||||||
|
BuildSignature(kMethodName, kSignatureDefKey, inputs, outputs);
|
||||||
|
ASSERT_EQ(interpreter_.AddTensors(6), kTfLiteOk);
|
||||||
|
ASSERT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
|
||||||
|
ASSERT_EQ(interpreter_.SetOutputs({5}), kTfLiteOk);
|
||||||
|
ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
|
||||||
|
0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
|
||||||
|
kTfLiteOk);
|
||||||
|
ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
|
||||||
|
1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
|
||||||
|
kTfLiteOk);
|
||||||
|
ASSERT_EQ(interpreter_.ResizeInputTensor(interpreter_.inputs()[0], {1, 2, 3}),
|
||||||
|
kTfLiteOk);
|
||||||
|
ASSERT_EQ(interpreter_.ResizeInputTensor(interpreter_.inputs()[1], {1, 2, 3}),
|
||||||
|
kTfLiteOk);
|
||||||
|
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
|
||||||
|
|
||||||
|
EXPECT_TRUE(interpreter_.input_tensor_by_signature_name(
|
||||||
|
"Input1", kMethodName) != nullptr);
|
||||||
|
EXPECT_TRUE(interpreter_.input_tensor_by_signature_name(
|
||||||
|
"Input2", kMethodName) != nullptr);
|
||||||
|
EXPECT_TRUE(interpreter_.output_tensor_by_signature_name(
|
||||||
|
"Output1", kMethodName) != nullptr);
|
||||||
|
|
||||||
|
// Invalid tensor
|
||||||
|
EXPECT_EQ(interpreter_.input_tensor_by_signature_name("Input3", kMethodName),
|
||||||
|
nullptr);
|
||||||
|
EXPECT_EQ(interpreter_.output_tensor_by_signature_name("Input3", kMethodName),
|
||||||
|
nullptr);
|
||||||
|
// Invalid method
|
||||||
|
EXPECT_EQ(
|
||||||
|
interpreter_.input_tensor_by_signature_name("Input1", "InvalidMethod"),
|
||||||
|
nullptr);
|
||||||
|
EXPECT_EQ(
|
||||||
|
interpreter_.output_tensor_by_signature_name("Output1", "InvalidMethod"),
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user