Add Interpreter changes for SignatureDef support.

This change includes updates to InterpreterBuilder to use SignatureDef available in the tflite file.
Also, updates Interpreter API to
- List all signatures available
- Fetch Inputs/Outputs in single signature
- Fetch Input/Output tensor using name defined in SignatureDef.

PiperOrigin-RevId: 338711676
Change-Id: I70355ece46295cec57cc2e3732309ad3e62f8708
This commit is contained in:
Karim Nosir 2020-10-23 11:26:55 -07:00 committed by TensorFlower Gardener
parent fba2679b56
commit 92bcc8d011
4 changed files with 249 additions and 0 deletions

View File

@ -22,7 +22,9 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/lite/allocation.h"
@ -275,6 +277,70 @@ class Interpreter {
return nullptr;
}
/// WARNING: Experimental interface, subject to change
/// Returns list of all names of different method signatures defined
/// in the model.
/// Note, pointers returned have lifetime same as the Interpreter object.
std::vector<const std::string*> signature_def_names() const {
std::vector<const std::string*> method_names;
method_names.reserve(signature_defs_.size());
for (const auto& sig_def : signature_defs_) {
method_names.emplace_back(&sig_def.method_name);
}
return method_names;
}
/// WARNING: Experimental interface, subject to change
/// Returns the mapping of inputs to tensor index in the signature
/// specified through 'method_name'.
/// If invalid name passed, an empty list will be returned.
const std::map<std::string, uint32_t>& signature_inputs(
const char* method_name) const {
for (const auto& sig_def : signature_defs_) {
if (sig_def.method_name == method_name) return sig_def.inputs;
}
static const std::map<std::string, uint32_t>* default_empty_list =
new std::map<std::string, uint32_t>();
return *default_empty_list;
}
/// WARNING: Experimental interface, subject to change
/// Returns the mapping of outputs to tensor index in the signature
/// specified through 'method_name'.
/// If invalid name passed, an empty list will be returned.
const std::map<std::string, uint32_t>& signature_outputs(
const char* method_name) const {
for (const auto& sig_def : signature_defs_) {
if (sig_def.method_name == method_name) return sig_def.outputs;
}
static const std::map<std::string, uint32_t>* default_empty_list =
new std::map<std::string, uint32_t>();
return *default_empty_list;
}
/// WARNING: Experimental interface, subject to change
/// Returns the input tensor identified by 'signature_input_name' in the
/// signature identified by 'signature_method_name'.
/// Returns nullptr if not found.
TfLiteTensor* input_tensor_by_signature_name(
const char* signature_input_name, const char* signature_method_name) {
const int tensor_index = GetTensorIndexFromSignatureDefName(
signature_input_name, signature_method_name, /*is_input=*/true);
return tensor_index == -1 ? nullptr : tensor(tensor_index);
}
/// WARNING: Experimental interface, subject to change
/// Returns the output tensor identified by 'signature_output_name' in the
/// signature identified by 'signature_method_name'.
/// Returns nullptr if not found.
const TfLiteTensor* output_tensor_by_signature_name(
const char* signature_output_name,
const char* signature_method_name) const {
const int tensor_index = GetTensorIndexFromSignatureDefName(
signature_output_name, signature_method_name, /*is_input=*/false);
return tensor_index == -1 ? nullptr : tensor(tensor_index);
}
/// Return a mutable pointer to the given input tensor. The given index must
/// be between 0 and inputs().size().
TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); }
@ -592,6 +658,17 @@ class Interpreter {
#endif // DOXYGEN_SKIP
private:
// Structure representing SignatureDef inputs/outputs.
struct SignatureDef {
// Maps name in signature def as key to index of the tensor in the model.
std::map<std::string, uint32_t> inputs;
// Maps name in signature def as key to index of the tensor in the model.
std::map<std::string, uint32_t> outputs;
// The method name for this signature.
std::string method_name;
// The key of this SignatureDef in the SavedModel signature def map.
std::string signature_def_key;
};
friend class InterpreterBuilder;
friend class tflite::InterpreterTest;
friend class tflite::TestDelegate;
@ -602,6 +679,26 @@ class Interpreter {
TfLiteExternalContextType type,
TfLiteExternalContext* ctx);
// Helper method that return the tensot index that corresponds to
// a name in a SignatureDef. Defined by 'signature_method_name', and
// 'signature_tensor_name'.
// If 'is_input' is true then the tensor is checked in input tensors,
// otherwise it will be checked in output tensors.
// Returns -1 if the tensor is not found.
int GetTensorIndexFromSignatureDefName(const char* signature_tensor_name,
const char* signature_method_name,
bool is_input) const {
// Iterate directly and don't use other methods to avoid extra allocation.
for (const auto& signature : signature_defs_) {
if (signature.method_name != signature_method_name) continue;
auto& signature_list = (is_input ? signature.inputs : signature.outputs);
auto tensor_iter = signature_list.find(signature_tensor_name);
if (tensor_iter == signature_list.end()) return -1;
return tensor_iter->second;
}
return -1;
}
// Sets the profiler to all subgraphs.
void SetSubgraphProfiler();
@ -615,6 +712,11 @@ class Interpreter {
// Returns true if cancellation function returns true.
bool IsCancelled();
// Sets the list of signature defs in the model.
void SetSignatureDef(std::vector<SignatureDef> signature_defs) {
signature_defs_ = std::move(signature_defs);
}
// A pure C data structure used to communicate with the pure C plugin
// interface. To avoid copying tensor metadata, this is also the definitive
// structure to store tensors.
@ -661,6 +763,10 @@ class Interpreter {
// An empty one means there's no delegate to be applied by default or
// delegates have been applied and doesn't need to be applied again.
std::vector<TfLiteDelegatePtr> lazy_delegate_providers_;
// List of signature def mapping inputs/output to tensor ids.
// We just keep track of tensor index.
std::vector<SignatureDef> signature_defs_;
};
} // namespace tflite

View File

@ -21,6 +21,9 @@ limitations under the License.
#include <sys/stat.h>
#include <sys/types.h>
#include <map>
#include <string>
#include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
@ -103,6 +106,20 @@ TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src,
return kTfLiteError;
}
// Helper that returns std::map that corresponds to vector of TensorMap.
std::map<std::string, uint32_t> GetMapFromTensorMap(
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>>*
tensor_map) {
if (!tensor_map) return {};
std::map<std::string, uint32_t> result;
for (const auto tensor : *tensor_map) {
if (tensor != nullptr && tensor->name() != nullptr) {
result[tensor->name()->c_str()] = tensor->tensor_index();
}
}
return result;
}
} // namespace
const char* kEmptyTensorName = "";
@ -435,6 +452,48 @@ TfLiteStatus InterpreterBuilder::ParseSparsity(
return kTfLiteOk;
}
TfLiteStatus InterpreterBuilder::ParseSignatureDefs(
const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
signature_def_list,
Interpreter* interpreter) {
if (signature_def_list == nullptr || signature_def_list->size() == 0) {
return kTfLiteOk;
}
std::vector<Interpreter::SignatureDef> signature_defs;
signature_defs.reserve(signature_def_list->size());
for (const auto fb_signature_def : *signature_def_list) {
if (fb_signature_def == nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_, "NULL SignatureDef in the model.");
return kTfLiteError;
}
if (fb_signature_def->method_name() == nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Missing exported method name for SignatureDef");
return kTfLiteError;
}
if (fb_signature_def->inputs() == nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_,
"NULL SignatureDef inputs for exported method %s",
fb_signature_def->method_name()->c_str());
return kTfLiteError;
}
if (fb_signature_def->outputs() == nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_,
"NULL SignatureDef outputs for exported method %s",
fb_signature_def->method_name()->c_str());
return kTfLiteError;
}
signature_defs.resize(signature_defs.size() + 1);
auto& signature_def = signature_defs.back();
signature_def.inputs = GetMapFromTensorMap(fb_signature_def->inputs());
signature_def.outputs = GetMapFromTensorMap(fb_signature_def->outputs());
signature_def.method_name = fb_signature_def->method_name()->c_str();
signature_def.signature_def_key = fb_signature_def->key()->c_str();
}
interpreter->SetSignatureDef(std::move(signature_defs));
return kTfLiteOk;
}
TfLiteStatus InterpreterBuilder::ParseTensors(
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
@ -667,6 +726,11 @@ TfLiteStatus InterpreterBuilder::operator()(
modified_subgraph->SetVariables(std::move(variables));
}
if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=
kTfLiteOk) {
return cleanup_and_error();
}
if (num_fp32_tensors_ > 0) {
(*interpreter)->lazy_delegate_providers_ =
op_resolver_.GetDelegates(num_threads);

View File

@ -80,6 +80,10 @@ class InterpreterBuilder {
const std::vector<int>& dims);
TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
TfLiteSparsity** sparsity);
TfLiteStatus ParseSignatureDefs(
const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
signature_def_list,
Interpreter* interpreter);
const ::tflite::Model* model_;
const OpResolver& op_resolver_;

View File

@ -56,6 +56,17 @@ class InterpreterTest : public ::testing::Test {
bool HasDelegates() { return interpreter_.HasDelegates(); }
void BuildSignature(const std::string& method_name, const std::string& key,
const std::map<std::string, uint32_t>& inputs,
const std::map<std::string, uint32_t>& outputs) {
Interpreter::SignatureDef signature;
signature.inputs = inputs;
signature.outputs = outputs;
signature.method_name = method_name;
signature.signature_def_key = key;
interpreter_.SetSignatureDef({signature});
}
Interpreter interpreter_;
};
@ -1846,6 +1857,70 @@ TEST_F(TestLazyDelegateProvider, ApplicationSkipped) {
EXPECT_FALSE(HasDelegates());
}
TEST_F(InterpreterTest, SingleSignature_get_signatures) {
const char kMethodName[] = "test_method";
const char kSignatureDefKey[] = "test_key";
BuildSignature(kMethodName, kSignatureDefKey, {{"Input1", 0}, {"Input2", 1}},
{{"Output1", 5}});
auto results = interpreter_.signature_def_names();
ASSERT_EQ(1, results.size());
EXPECT_EQ(kMethodName, *results[0]);
}
TEST_F(InterpreterTest, SingleSignature_get_inputs) {
const char kMethodName[] = "test_method";
const char kSignatureDefKey[] = "test_key";
const std::map<std::string, uint32_t> inputs = {{"Input1", 0}, {"Input2", 1}};
const std::map<std::string, uint32_t> outputs = {{"Output1", 5}};
BuildSignature(kMethodName, kSignatureDefKey, inputs, outputs);
EXPECT_THAT(interpreter_.signature_inputs(kMethodName), testing::Eq(inputs));
EXPECT_THAT(interpreter_.signature_outputs(kMethodName),
testing::Eq(outputs));
}
TEST_F(InterpreterTest, SingleSignature_validate_get_tensor) {
const char kMethodName[] = "test_method";
const char kSignatureDefKey[] = "test_key";
const std::map<std::string, uint32_t> inputs = {{"Input1", 0}, {"Input2", 1}};
const std::map<std::string, uint32_t> outputs = {{"Output1", 5}};
BuildSignature(kMethodName, kSignatureDefKey, inputs, outputs);
ASSERT_EQ(interpreter_.AddTensors(6), kTfLiteOk);
ASSERT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
ASSERT_EQ(interpreter_.SetOutputs({5}), kTfLiteOk);
ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
kTfLiteOk);
ASSERT_EQ(interpreter_.SetTensorParametersReadWrite(
1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
kTfLiteOk);
ASSERT_EQ(interpreter_.ResizeInputTensor(interpreter_.inputs()[0], {1, 2, 3}),
kTfLiteOk);
ASSERT_EQ(interpreter_.ResizeInputTensor(interpreter_.inputs()[1], {1, 2, 3}),
kTfLiteOk);
ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk);
EXPECT_TRUE(interpreter_.input_tensor_by_signature_name(
"Input1", kMethodName) != nullptr);
EXPECT_TRUE(interpreter_.input_tensor_by_signature_name(
"Input2", kMethodName) != nullptr);
EXPECT_TRUE(interpreter_.output_tensor_by_signature_name(
"Output1", kMethodName) != nullptr);
// Invalid tensor
EXPECT_EQ(interpreter_.input_tensor_by_signature_name("Input3", kMethodName),
nullptr);
EXPECT_EQ(interpreter_.output_tensor_by_signature_name("Input3", kMethodName),
nullptr);
// Invalid method
EXPECT_EQ(
interpreter_.input_tensor_by_signature_name("Input1", "InvalidMethod"),
nullptr);
EXPECT_EQ(
interpreter_.output_tensor_by_signature_name("Output1", "InvalidMethod"),
nullptr);
}
} // namespace
} // namespace tflite