Add Interpreter changes for SignatureDef support.
This change includes updates to InterpreterBuilder to use SignatureDef available in the tflite file. Also, updates Interpreter API to - List all signatures available - Fetch Inputs/Outputs in single signature - Fetch Input/Output tensor using name defined in SignatureDef. PiperOrigin-RevId: 338711676 Change-Id: I70355ece46295cec57cc2e3732309ad3e62f8708
This commit is contained in:
parent
fba2679b56
commit
92bcc8d011
@ -22,7 +22,9 @@ limitations under the License.
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/allocation.h"
|
||||
@ -275,6 +277,70 @@ class Interpreter {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// WARNING: Experimental interface, subject to change
|
||||
/// Returns list of all names of different method signatures defined
|
||||
/// in the model.
|
||||
/// Note, pointers returned have lifetime same as the Interpreter object.
|
||||
std::vector<const std::string*> signature_def_names() const {
|
||||
std::vector<const std::string*> method_names;
|
||||
method_names.reserve(signature_defs_.size());
|
||||
for (const auto& sig_def : signature_defs_) {
|
||||
method_names.emplace_back(&sig_def.method_name);
|
||||
}
|
||||
return method_names;
|
||||
}
|
||||
|
||||
/// WARNING: Experimental interface, subject to change
|
||||
/// Returns the mapping of inputs to tensor index in the signature
|
||||
/// specified through 'method_name'.
|
||||
/// If invalid name passed, an empty list will be returned.
|
||||
const std::map<std::string, uint32_t>& signature_inputs(
|
||||
const char* method_name) const {
|
||||
for (const auto& sig_def : signature_defs_) {
|
||||
if (sig_def.method_name == method_name) return sig_def.inputs;
|
||||
}
|
||||
static const std::map<std::string, uint32_t>* default_empty_list =
|
||||
new std::map<std::string, uint32_t>();
|
||||
return *default_empty_list;
|
||||
}
|
||||
|
||||
/// WARNING: Experimental interface, subject to change
|
||||
/// Returns the mapping of outputs to tensor index in the signature
|
||||
/// specified through 'method_name'.
|
||||
/// If invalid name passed, an empty list will be returned.
|
||||
const std::map<std::string, uint32_t>& signature_outputs(
|
||||
const char* method_name) const {
|
||||
for (const auto& sig_def : signature_defs_) {
|
||||
if (sig_def.method_name == method_name) return sig_def.outputs;
|
||||
}
|
||||
static const std::map<std::string, uint32_t>* default_empty_list =
|
||||
new std::map<std::string, uint32_t>();
|
||||
return *default_empty_list;
|
||||
}
|
||||
|
||||
/// WARNING: Experimental interface, subject to change
|
||||
/// Returns the input tensor identified by 'signature_input_name' in the
|
||||
/// signature identified by 'signature_method_name'.
|
||||
/// Returns nullptr if not found.
|
||||
TfLiteTensor* input_tensor_by_signature_name(
|
||||
const char* signature_input_name, const char* signature_method_name) {
|
||||
const int tensor_index = GetTensorIndexFromSignatureDefName(
|
||||
signature_input_name, signature_method_name, /*is_input=*/true);
|
||||
return tensor_index == -1 ? nullptr : tensor(tensor_index);
|
||||
}
|
||||
|
||||
/// WARNING: Experimental interface, subject to change
|
||||
/// Returns the output tensor identified by 'signature_output_name' in the
|
||||
/// signature identified by 'signature_method_name'.
|
||||
/// Returns nullptr if not found.
|
||||
const TfLiteTensor* output_tensor_by_signature_name(
|
||||
const char* signature_output_name,
|
||||
const char* signature_method_name) const {
|
||||
const int tensor_index = GetTensorIndexFromSignatureDefName(
|
||||
signature_output_name, signature_method_name, /*is_input=*/false);
|
||||
return tensor_index == -1 ? nullptr : tensor(tensor_index);
|
||||
}
|
||||
|
||||
/// Return a mutable pointer to the given input tensor. The given index must
|
||||
/// be between 0 and inputs().size().
|
||||
TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); }
|
||||
@ -592,6 +658,17 @@ class Interpreter {
|
||||
#endif // DOXYGEN_SKIP
|
||||
|
||||
private:
|
||||
// Structure representing SignatureDef inputs/outputs.
|
||||
struct SignatureDef {
|
||||
// Maps name in signature def as key to index of the tensor in the model.
|
||||
std::map<std::string, uint32_t> inputs;
|
||||
// Maps name in signature def as key to index of the tensor in the model.
|
||||
std::map<std::string, uint32_t> outputs;
|
||||
// The method name for this signature.
|
||||
std::string method_name;
|
||||
// The key of this SignatureDef in the SavedModel signature def map.
|
||||
std::string signature_def_key;
|
||||
};
|
||||
friend class InterpreterBuilder;
|
||||
friend class tflite::InterpreterTest;
|
||||
friend class tflite::TestDelegate;
|
||||
@ -602,6 +679,26 @@ class Interpreter {
|
||||
TfLiteExternalContextType type,
|
||||
TfLiteExternalContext* ctx);
|
||||
|
||||
// Helper method that return the tensot index that corresponds to
|
||||
// a name in a SignatureDef. Defined by 'signature_method_name', and
|
||||
// 'signature_tensor_name'.
|
||||
// If 'is_input' is true then the tensor is checked in input tensors,
|
||||
// otherwise it will be checked in output tensors.
|
||||
// Returns -1 if the tensor is not found.
|
||||
int GetTensorIndexFromSignatureDefName(const char* signature_tensor_name,
|
||||
const char* signature_method_name,
|
||||
bool is_input) const {
|
||||
// Iterate directly and don't use other methods to avoid extra allocation.
|
||||
for (const auto& signature : signature_defs_) {
|
||||
if (signature.method_name != signature_method_name) continue;
|
||||
auto& signature_list = (is_input ? signature.inputs : signature.outputs);
|
||||
auto tensor_iter = signature_list.find(signature_tensor_name);
|
||||
if (tensor_iter == signature_list.end()) return -1;
|
||||
return tensor_iter->second;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Sets the profiler to all subgraphs.
|
||||
void SetSubgraphProfiler();
|
||||
|
||||
@ -615,6 +712,11 @@ class Interpreter {
|
||||
// Returns true if cancellation function returns true.
|
||||
bool IsCancelled();
|
||||
|
||||
// Sets the list of signature defs in the model.
|
||||
void SetSignatureDef(std::vector<SignatureDef> signature_defs) {
|
||||
signature_defs_ = std::move(signature_defs);
|
||||
}
|
||||
|
||||
// A pure C data structure used to communicate with the pure C plugin
|
||||
// interface. To avoid copying tensor metadata, this is also the definitive
|
||||
// structure to store tensors.
|
||||
@ -661,6 +763,10 @@ class Interpreter {
|
||||
// An empty one means there's no delegate to be applied by default or
|
||||
// delegates have been applied and doesn't need to be applied again.
|
||||
std::vector<TfLiteDelegatePtr> lazy_delegate_providers_;
|
||||
|
||||
// List of signature def mapping inputs/output to tensor ids.
|
||||
// We just keep track of tensor index.
|
||||
std::vector<SignatureDef> signature_defs_;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -21,6 +21,9 @@ limitations under the License.
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/allocation.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
@ -103,6 +106,20 @@ TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src,
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Helper that returns std::map that corresponds to vector of TensorMap.
|
||||
std::map<std::string, uint32_t> GetMapFromTensorMap(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>>*
|
||||
tensor_map) {
|
||||
if (!tensor_map) return {};
|
||||
std::map<std::string, uint32_t> result;
|
||||
for (const auto tensor : *tensor_map) {
|
||||
if (tensor != nullptr && tensor->name() != nullptr) {
|
||||
result[tensor->name()->c_str()] = tensor->tensor_index();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const char* kEmptyTensorName = "";
|
||||
@ -435,6 +452,48 @@ TfLiteStatus InterpreterBuilder::ParseSparsity(
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus InterpreterBuilder::ParseSignatureDefs(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
|
||||
signature_def_list,
|
||||
Interpreter* interpreter) {
|
||||
if (signature_def_list == nullptr || signature_def_list->size() == 0) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
std::vector<Interpreter::SignatureDef> signature_defs;
|
||||
signature_defs.reserve(signature_def_list->size());
|
||||
for (const auto fb_signature_def : *signature_def_list) {
|
||||
if (fb_signature_def == nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_, "NULL SignatureDef in the model.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (fb_signature_def->method_name() == nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Missing exported method name for SignatureDef");
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (fb_signature_def->inputs() == nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"NULL SignatureDef inputs for exported method %s",
|
||||
fb_signature_def->method_name()->c_str());
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (fb_signature_def->outputs() == nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"NULL SignatureDef outputs for exported method %s",
|
||||
fb_signature_def->method_name()->c_str());
|
||||
return kTfLiteError;
|
||||
}
|
||||
signature_defs.resize(signature_defs.size() + 1);
|
||||
auto& signature_def = signature_defs.back();
|
||||
signature_def.inputs = GetMapFromTensorMap(fb_signature_def->inputs());
|
||||
signature_def.outputs = GetMapFromTensorMap(fb_signature_def->outputs());
|
||||
signature_def.method_name = fb_signature_def->method_name()->c_str();
|
||||
signature_def.signature_def_key = fb_signature_def->key()->c_str();
|
||||
}
|
||||
interpreter->SetSignatureDef(std::move(signature_defs));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus InterpreterBuilder::ParseTensors(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
|
||||
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
|
||||
@ -667,6 +726,11 @@ TfLiteStatus InterpreterBuilder::operator()(
|
||||
modified_subgraph->SetVariables(std::move(variables));
|
||||
}
|
||||
|
||||
if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=
|
||||
kTfLiteOk) {
|
||||
return cleanup_and_error();
|
||||
}
|
||||
|
||||
if (num_fp32_tensors_ > 0) {
|
||||
(*interpreter)->lazy_delegate_providers_ =
|
||||
op_resolver_.GetDelegates(num_threads);
|
||||
|
@ -80,6 +80,10 @@ class InterpreterBuilder {
|
||||
const std::vector<int>& dims);
|
||||
TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
|
||||
TfLiteSparsity** sparsity);
|
||||
TfLiteStatus ParseSignatureDefs(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
|
||||
signature_def_list,
|
||||
Interpreter* interpreter);
|
||||
|
||||
const ::tflite::Model* model_;
|
||||
const OpResolver& op_resolver_;
|
||||
|
@ -56,6 +56,17 @@ class InterpreterTest : public ::testing::Test {
|
||||
|
||||
bool HasDelegates() { return interpreter_.HasDelegates(); }
|
||||
|
||||
void BuildSignature(const std::string& method_name, const std::string& key,
|
||||
const std::map<std::string, uint32_t>& inputs,
|
||||
const std::map<std::string, uint32_t>& outputs) {
|
||||
Interpreter::SignatureDef signature;
|
||||
signature.inputs = inputs;
|
||||
signature.outputs = outputs;
|
||||
signature.method_name = method_name;
|
||||
signature.signature_def_key = key;
|
||||
interpreter_.SetSignatureDef({signature});
|
||||
}
|
||||
|
||||
Interpreter interpreter_;
|
||||
};
|
||||
|
||||
@ -1846,6 +1857,70 @@ TEST_F(TestLazyDelegateProvider, ApplicationSkipped) {
|
||||
EXPECT_FALSE(HasDelegates());
|
||||
}
|
||||
|
||||
TEST_F(InterpreterTest, SingleSignature_get_signatures) {
|
||||
const char kMethodName[] = "test_method";
|
||||
const char kSignatureDefKey[] = "test_key";
|
||||
BuildSignature(kMethodName, kSignatureDefKey, {{"Input1", 0}, {"Input2", 1}},
|
||||
{{"Output1", 5}});
|
||||
auto results = interpreter_.signature_def_names();
|
||||
ASSERT_EQ(1, results.size());
|
||||
EXPECT_EQ(kMethodName, *results[0]);
|
||||
}
|
||||
|
||||
TEST_F(InterpreterTest, SingleSignature_get_inputs) {
|
||||
const char kMethodName[] = "test_method";
|
||||
const char kSignatureDefKey[] = "test_key";
|
||||
const std::map<std::string, uint32_t> inputs = {{"Input1", 0}, {"Input2", 1}};
|
||||
const std::map<std::string, uint32_t> outputs = {{"Output1", 5}};
|
||||
BuildSignature(kMethodName, kSignatureDefKey, inputs, outputs);
|
||||
EXPECT_THAT(interpreter_.signature_inputs(kMethodName), testing::Eq(inputs));
|
||||
EXPECT_THAT(interpreter_.signature_outputs(kMethodName),
|
||||
testing::Eq(outputs));
|
||||
}
|
||||
|
||||
TEST_F(InterpreterTest, SingleSignature_validate_get_tensor) {
|
||||
const char kMethodName[] = "test_method";
|
||||
const char kSignatureDefKey[] = "test_key";
|
||||
const std::map<std::string, uint32_t> inputs = {{"Input1", 0}, {"Input2", 1}};
|
||||
const std::map<std::string, uint32_t> outputs = {{"Output1", 5}};
|
||||
|
||||
BuildSignature(kMethodName, kSignatureDefKey, inputs, outputs);
|
||||
ASSERT_EQ(interpreter_.AddTensors(6), kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.SetOutputs({5}), kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
|
||||
0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
|
||||
1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.ResizeInputTensor(interpreter_.inputs()[0], {1, 2, 3}),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.ResizeInputTensor(interpreter_.inputs()[1], {1, 2, 3}),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
|
||||
|
||||
EXPECT_TRUE(interpreter_.input_tensor_by_signature_name(
|
||||
"Input1", kMethodName) != nullptr);
|
||||
EXPECT_TRUE(interpreter_.input_tensor_by_signature_name(
|
||||
"Input2", kMethodName) != nullptr);
|
||||
EXPECT_TRUE(interpreter_.output_tensor_by_signature_name(
|
||||
"Output1", kMethodName) != nullptr);
|
||||
|
||||
// Invalid tensor
|
||||
EXPECT_EQ(interpreter_.input_tensor_by_signature_name("Input3", kMethodName),
|
||||
nullptr);
|
||||
EXPECT_EQ(interpreter_.output_tensor_by_signature_name("Input3", kMethodName),
|
||||
nullptr);
|
||||
// Invalid method
|
||||
EXPECT_EQ(
|
||||
interpreter_.input_tensor_by_signature_name("Input1", "InvalidMethod"),
|
||||
nullptr);
|
||||
EXPECT_EQ(
|
||||
interpreter_.output_tensor_by_signature_name("Output1", "InvalidMethod"),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user