diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index ef5831fda50..8e85cd55321 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -22,7 +22,9 @@ limitations under the License. #include #include #include +#include #include +#include #include #include "tensorflow/lite/allocation.h" @@ -275,6 +277,70 @@ class Interpreter { return nullptr; } + /// WARNING: Experimental interface, subject to change + /// Returns list of all names of different method signatures defined + /// in the model. + /// Note, pointers returned have lifetime same as the Interpreter object. + std::vector signature_def_names() const { + std::vector method_names; + method_names.reserve(signature_defs_.size()); + for (const auto& sig_def : signature_defs_) { + method_names.emplace_back(&sig_def.method_name); + } + return method_names; + } + + /// WARNING: Experimental interface, subject to change + /// Returns the mapping of inputs to tensor index in the signature + /// specified through 'method_name'. + /// If invalid name passed, an empty list will be returned. + const std::map& signature_inputs( + const char* method_name) const { + for (const auto& sig_def : signature_defs_) { + if (sig_def.method_name == method_name) return sig_def.inputs; + } + static const std::map* default_empty_list = + new std::map(); + return *default_empty_list; + } + + /// WARNING: Experimental interface, subject to change + /// Returns the mapping of outputs to tensor index in the signature + /// specified through 'method_name'. + /// If invalid name passed, an empty list will be returned. + const std::map& signature_outputs( + const char* method_name) const { + for (const auto& sig_def : signature_defs_) { + if (sig_def.method_name == method_name) return sig_def.outputs; + } + static const std::map* default_empty_list = + new std::map(); + return *default_empty_list; + } + + /// WARNING: Experimental interface, subject to change + /// Returns the input tensor identified by 'signature_input_name' in the + /// signature identified by 'signature_method_name'. + /// Returns nullptr if not found. + TfLiteTensor* input_tensor_by_signature_name( + const char* signature_input_name, const char* signature_method_name) { + const int tensor_index = GetTensorIndexFromSignatureDefName( + signature_input_name, signature_method_name, /*is_input=*/true); + return tensor_index == -1 ? nullptr : tensor(tensor_index); + } + + /// WARNING: Experimental interface, subject to change + /// Returns the output tensor identified by 'signature_output_name' in the + /// signature identified by 'signature_method_name'. + /// Returns nullptr if not found. + const TfLiteTensor* output_tensor_by_signature_name( + const char* signature_output_name, + const char* signature_method_name) const { + const int tensor_index = GetTensorIndexFromSignatureDefName( + signature_output_name, signature_method_name, /*is_input=*/false); + return tensor_index == -1 ? nullptr : tensor(tensor_index); + } + /// Return a mutable pointer to the given input tensor. The given index must /// be between 0 and inputs().size(). TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); } @@ -592,6 +658,17 @@ class Interpreter { #endif // DOXYGEN_SKIP private: + // Structure representing SignatureDef inputs/outputs. + struct SignatureDef { + // Maps name in signature def as key to index of the tensor in the model. + std::map inputs; + // Maps name in signature def as key to index of the tensor in the model. + std::map outputs; + // The method name for this signature. + std::string method_name; + // The key of this SignatureDef in the SavedModel signature def map. + std::string signature_def_key; + }; friend class InterpreterBuilder; friend class tflite::InterpreterTest; friend class tflite::TestDelegate; @@ -602,6 +679,26 @@ class Interpreter { TfLiteExternalContextType type, TfLiteExternalContext* ctx); + // Helper method that return the tensot index that corresponds to + // a name in a SignatureDef. Defined by 'signature_method_name', and + // 'signature_tensor_name'. + // If 'is_input' is true then the tensor is checked in input tensors, + // otherwise it will be checked in output tensors. + // Returns -1 if the tensor is not found. + int GetTensorIndexFromSignatureDefName(const char* signature_tensor_name, + const char* signature_method_name, + bool is_input) const { + // Iterate directly and don't use other methods to avoid extra allocation. + for (const auto& signature : signature_defs_) { + if (signature.method_name != signature_method_name) continue; + auto& signature_list = (is_input ? signature.inputs : signature.outputs); + auto tensor_iter = signature_list.find(signature_tensor_name); + if (tensor_iter == signature_list.end()) return -1; + return tensor_iter->second; + } + return -1; + } + // Sets the profiler to all subgraphs. void SetSubgraphProfiler(); @@ -615,6 +712,11 @@ class Interpreter { // Returns true if cancellation function returns true. bool IsCancelled(); + // Sets the list of signature defs in the model. + void SetSignatureDef(std::vector signature_defs) { + signature_defs_ = std::move(signature_defs); + } + // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. @@ -661,6 +763,10 @@ class Interpreter { // An empty one means there's no delegate to be applied by default or // delegates have been applied and doesn't need to be applied again. std::vector lazy_delegate_providers_; + + // List of signature def mapping inputs/output to tensor ids. + // We just keep track of tensor index. + std::vector signature_defs_; }; } // namespace tflite diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index f5c8d97b962..719bd02d859 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include +#include + #include "tensorflow/lite/allocation.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" @@ -103,6 +106,20 @@ TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src, return kTfLiteError; } +// Helper that returns std::map that corresponds to vector of TensorMap. +std::map GetMapFromTensorMap( + const flatbuffers::Vector>* + tensor_map) { + if (!tensor_map) return {}; + std::map result; + for (const auto tensor : *tensor_map) { + if (tensor != nullptr && tensor->name() != nullptr) { + result[tensor->name()->c_str()] = tensor->tensor_index(); + } + } + return result; +} + } // namespace const char* kEmptyTensorName = ""; @@ -435,6 +452,48 @@ TfLiteStatus InterpreterBuilder::ParseSparsity( return kTfLiteOk; } +TfLiteStatus InterpreterBuilder::ParseSignatureDefs( + const flatbuffers::Vector>* + signature_def_list, + Interpreter* interpreter) { + if (signature_def_list == nullptr || signature_def_list->size() == 0) { + return kTfLiteOk; + } + std::vector signature_defs; + signature_defs.reserve(signature_def_list->size()); + for (const auto fb_signature_def : *signature_def_list) { + if (fb_signature_def == nullptr) { + TF_LITE_REPORT_ERROR(error_reporter_, "NULL SignatureDef in the model."); + return kTfLiteError; + } + if (fb_signature_def->method_name() == nullptr) { + TF_LITE_REPORT_ERROR(error_reporter_, + "Missing exported method name for SignatureDef"); + return kTfLiteError; + } + if (fb_signature_def->inputs() == nullptr) { + TF_LITE_REPORT_ERROR(error_reporter_, + "NULL SignatureDef inputs for exported method %s", + fb_signature_def->method_name()->c_str()); + return kTfLiteError; + } + if (fb_signature_def->outputs() == nullptr) { + TF_LITE_REPORT_ERROR(error_reporter_, + "NULL SignatureDef outputs for exported method %s", + fb_signature_def->method_name()->c_str()); + return kTfLiteError; + } + signature_defs.resize(signature_defs.size() + 1); + auto& signature_def = signature_defs.back(); + signature_def.inputs = GetMapFromTensorMap(fb_signature_def->inputs()); + signature_def.outputs = GetMapFromTensorMap(fb_signature_def->outputs()); + signature_def.method_name = fb_signature_def->method_name()->c_str(); + signature_def.signature_def_key = fb_signature_def->key()->c_str(); + } + interpreter->SetSignatureDef(std::move(signature_defs)); + return kTfLiteOk; +} + TfLiteStatus InterpreterBuilder::ParseTensors( const flatbuffers::Vector>* buffers, const flatbuffers::Vector>* tensors, @@ -667,6 +726,11 @@ TfLiteStatus InterpreterBuilder::operator()( modified_subgraph->SetVariables(std::move(variables)); } + if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) != + kTfLiteOk) { + return cleanup_and_error(); + } + if (num_fp32_tensors_ > 0) { (*interpreter)->lazy_delegate_providers_ = op_resolver_.GetDelegates(num_threads); diff --git a/tensorflow/lite/interpreter_builder.h b/tensorflow/lite/interpreter_builder.h index 4b0052f66ce..34ba9742d9d 100644 --- a/tensorflow/lite/interpreter_builder.h +++ b/tensorflow/lite/interpreter_builder.h @@ -80,6 +80,10 @@ class InterpreterBuilder { const std::vector& dims); TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity, TfLiteSparsity** sparsity); + TfLiteStatus ParseSignatureDefs( + const flatbuffers::Vector>* + signature_def_list, + Interpreter* interpreter); const ::tflite::Model* model_; const OpResolver& op_resolver_; diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc index b70908e7162..9caeefb1073 100644 --- a/tensorflow/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -56,6 +56,17 @@ class InterpreterTest : public ::testing::Test { bool HasDelegates() { return interpreter_.HasDelegates(); } + void BuildSignature(const std::string& method_name, const std::string& key, + const std::map& inputs, + const std::map& outputs) { + Interpreter::SignatureDef signature; + signature.inputs = inputs; + signature.outputs = outputs; + signature.method_name = method_name; + signature.signature_def_key = key; + interpreter_.SetSignatureDef({signature}); + } + Interpreter interpreter_; }; @@ -1846,6 +1857,70 @@ TEST_F(TestLazyDelegateProvider, ApplicationSkipped) { EXPECT_FALSE(HasDelegates()); } +TEST_F(InterpreterTest, SingleSignature_get_signatures) { + const char kMethodName[] = "test_method"; + const char kSignatureDefKey[] = "test_key"; + BuildSignature(kMethodName, kSignatureDefKey, {{"Input1", 0}, {"Input2", 1}}, + {{"Output1", 5}}); + auto results = interpreter_.signature_def_names(); + ASSERT_EQ(1, results.size()); + EXPECT_EQ(kMethodName, *results[0]); +} + +TEST_F(InterpreterTest, SingleSignature_get_inputs) { + const char kMethodName[] = "test_method"; + const char kSignatureDefKey[] = "test_key"; + const std::map inputs = {{"Input1", 0}, {"Input2", 1}}; + const std::map outputs = {{"Output1", 5}}; + BuildSignature(kMethodName, kSignatureDefKey, inputs, outputs); + EXPECT_THAT(interpreter_.signature_inputs(kMethodName), testing::Eq(inputs)); + EXPECT_THAT(interpreter_.signature_outputs(kMethodName), + testing::Eq(outputs)); +} + +TEST_F(InterpreterTest, SingleSignature_validate_get_tensor) { + const char kMethodName[] = "test_method"; + const char kSignatureDefKey[] = "test_key"; + const std::map inputs = {{"Input1", 0}, {"Input2", 1}}; + const std::map outputs = {{"Output1", 5}}; + + BuildSignature(kMethodName, kSignatureDefKey, inputs, outputs); + ASSERT_EQ(interpreter_.AddTensors(6), kTfLiteOk); + ASSERT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk); + ASSERT_EQ(interpreter_.SetOutputs({5}), kTfLiteOk); + ASSERT_EQ(interpreter_.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + ASSERT_EQ(interpreter_.SetTensorParametersReadWrite( + 1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + ASSERT_EQ(interpreter_.ResizeInputTensor(interpreter_.inputs()[0], {1, 2, 3}), + kTfLiteOk); + ASSERT_EQ(interpreter_.ResizeInputTensor(interpreter_.inputs()[1], {1, 2, 3}), + kTfLiteOk); + ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk); + + EXPECT_TRUE(interpreter_.input_tensor_by_signature_name( + "Input1", kMethodName) != nullptr); + EXPECT_TRUE(interpreter_.input_tensor_by_signature_name( + "Input2", kMethodName) != nullptr); + EXPECT_TRUE(interpreter_.output_tensor_by_signature_name( + "Output1", kMethodName) != nullptr); + + // Invalid tensor + EXPECT_EQ(interpreter_.input_tensor_by_signature_name("Input3", kMethodName), + nullptr); + EXPECT_EQ(interpreter_.output_tensor_by_signature_name("Input3", kMethodName), + nullptr); + // Invalid method + EXPECT_EQ( + interpreter_.input_tensor_by_signature_name("Input1", "InvalidMethod"), + nullptr); + EXPECT_EQ( + interpreter_.output_tensor_by_signature_name("Output1", "InvalidMethod"), + nullptr); +} + } // namespace } // namespace tflite