Adds utility methods for storing SignatureDefs in the metadata table in the flatbuffer
PiperOrigin-RevId: 311652937 Change-Id: I397c7ce6fad843cff789dedb583d6df44545db3f
This commit is contained in:
parent
efa3fb28d9
commit
37df93331e
106
tensorflow/lite/tools/signature/BUILD
Normal file
106
tensorflow/lite/tools/signature/BUILD
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# Utilities for signature_defs in TFLite
|
||||||
|
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||||
|
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||||
|
load("//tensorflow/lite/micro:build_def.bzl", "cc_library")
|
||||||
|
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
TFLITE_DEFAULT_COPTS = if_not_windows([
|
||||||
|
"-Wall",
|
||||||
|
"-Wno-comment",
|
||||||
|
"-Wno-extern-c-compat",
|
||||||
|
])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "signature_def_util",
|
||||||
|
srcs = ["signature_def_util.cc"],
|
||||||
|
hdrs = ["signature_def_util.h"],
|
||||||
|
copts = TFLITE_DEFAULT_COPTS + tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:protos_all_cc_impl",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
"//tensorflow/lite/c:common",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_protobuf//:protobuf",
|
||||||
|
"@flatbuffers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "signature_def_util_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["signature_def_util_test.cc"],
|
||||||
|
data = [
|
||||||
|
"//tensorflow/lite:testdata/add.bin",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"tflite_not_portable",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":signature_def_util",
|
||||||
|
"//tensorflow/cc/saved_model:signature_constants",
|
||||||
|
"//tensorflow/core:tflite_portable_logging",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
|
"//tensorflow/lite:framework_lib",
|
||||||
|
"//tensorflow/lite/c:c_api",
|
||||||
|
"//tensorflow/lite/c:common",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/lite/testing:util",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
pybind_extension(
|
||||||
|
name = "_pywrap_signature_def_util_wrapper",
|
||||||
|
srcs = [
|
||||||
|
"signature_def_util_wrapper_pybind11.cc",
|
||||||
|
],
|
||||||
|
module_name = "_pywrap_signature_def_util_wrapper",
|
||||||
|
deps = [
|
||||||
|
":signature_def_util",
|
||||||
|
"//tensorflow/lite:framework_lib",
|
||||||
|
"//tensorflow/python:pybind11_lib",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "signature_def_utils",
|
||||||
|
srcs = ["signature_def_utils.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":_pywrap_signature_def_util_wrapper",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "signature_def_utils_test",
|
||||||
|
srcs = ["signature_def_utils_test.py"],
|
||||||
|
data = ["//tensorflow/lite:testdata/add.bin"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = [
|
||||||
|
"no_mac",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":signature_def_utils",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tflite_portable_test_suite()
|
175
tensorflow/lite/tools/signature/signature_def_util.cc
Normal file
175
tensorflow/lite/tools/signature/signature_def_util.cc
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/tools/signature/signature_def_util.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||||
|
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||||
|
#include "tensorflow/lite/model_builder.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::Status;
|
||||||
|
using SerializedSignatureDefMap = std::map<std::string, std::string>;
|
||||||
|
using SignatureDefMap = std::map<std::string, tensorflow::SignatureDef>;
|
||||||
|
|
||||||
|
const Metadata* GetSignatureDefMetadata(const Model* model) {
|
||||||
|
if (!model || !model->metadata()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < model->metadata()->size(); ++i) {
|
||||||
|
const Metadata* metadata = model->metadata()->Get(i);
|
||||||
|
if (metadata->name()->str() == kSignatureDefsMetadataName) {
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ReadSignatureDefMap(const Model* model, const Metadata* metadata,
|
||||||
|
SerializedSignatureDefMap* map) {
|
||||||
|
if (!model || !metadata || !map) {
|
||||||
|
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
|
||||||
|
}
|
||||||
|
const flatbuffers::Vector<uint8_t>* flatbuffer_data =
|
||||||
|
model->buffers()->Get(metadata->buffer())->data();
|
||||||
|
const auto signature_defs =
|
||||||
|
flexbuffers::GetRoot(flatbuffer_data->data(), flatbuffer_data->size())
|
||||||
|
.AsMap();
|
||||||
|
for (int i = 0; i < signature_defs.Keys().size(); ++i) {
|
||||||
|
const std::string key = signature_defs.Keys()[i].AsString().c_str();
|
||||||
|
(*map)[key] = signature_defs[key].AsString().c_str();
|
||||||
|
}
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status SetSignatureDefMap(const Model* model,
|
||||||
|
const SignatureDefMap& signature_def_map,
|
||||||
|
std::string* model_data_with_signature_def) {
|
||||||
|
if (!model || !model_data_with_signature_def) {
|
||||||
|
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
|
||||||
|
}
|
||||||
|
if (signature_def_map.empty()) {
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"signature_def_map should not be empty");
|
||||||
|
}
|
||||||
|
flexbuffers::Builder fbb;
|
||||||
|
const size_t start_map = fbb.StartMap();
|
||||||
|
auto mutable_model = absl::make_unique<ModelT>();
|
||||||
|
model->UnPackTo(mutable_model.get(), nullptr);
|
||||||
|
int buffer_id = mutable_model->buffers.size();
|
||||||
|
const Metadata* metadata = GetSignatureDefMetadata(model);
|
||||||
|
if (metadata) {
|
||||||
|
buffer_id = metadata->buffer();
|
||||||
|
} else {
|
||||||
|
auto buffer = absl::make_unique<BufferT>();
|
||||||
|
mutable_model->buffers.emplace_back(std::move(buffer));
|
||||||
|
auto sigdef_metadata = absl::make_unique<MetadataT>();
|
||||||
|
sigdef_metadata->buffer = buffer_id;
|
||||||
|
sigdef_metadata->name = kSignatureDefsMetadataName;
|
||||||
|
mutable_model->metadata.emplace_back(std::move(sigdef_metadata));
|
||||||
|
}
|
||||||
|
for (const auto& entry : signature_def_map) {
|
||||||
|
fbb.String(entry.first.c_str(), entry.second.SerializeAsString());
|
||||||
|
}
|
||||||
|
fbb.EndMap(start_map);
|
||||||
|
fbb.Finish();
|
||||||
|
mutable_model->buffers[buffer_id]->data = fbb.GetBuffer();
|
||||||
|
flatbuffers::FlatBufferBuilder builder;
|
||||||
|
auto packed_model = Model::Pack(builder, mutable_model.get());
|
||||||
|
FinishModelBuffer(builder, packed_model);
|
||||||
|
*model_data_with_signature_def =
|
||||||
|
std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
||||||
|
builder.GetSize());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasSignatureDef(const Model* model, const std::string& signature_key) {
|
||||||
|
if (!model) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const Metadata* metadata = GetSignatureDefMetadata(model);
|
||||||
|
if (!metadata) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SerializedSignatureDefMap signature_defs;
|
||||||
|
if (ReadSignatureDefMap(model, metadata, &signature_defs) !=
|
||||||
|
tensorflow::Status::OK()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return (signature_defs.find(signature_key) != signature_defs.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetSignatureDefMap(const Model* model,
|
||||||
|
SignatureDefMap* signature_def_map) {
|
||||||
|
if (!model || !signature_def_map) {
|
||||||
|
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
|
||||||
|
}
|
||||||
|
SignatureDefMap retrieved_signature_def_map;
|
||||||
|
const Metadata* metadata = GetSignatureDefMetadata(model);
|
||||||
|
if (metadata) {
|
||||||
|
SerializedSignatureDefMap signature_defs;
|
||||||
|
auto status = ReadSignatureDefMap(model, metadata, &signature_defs);
|
||||||
|
if (status != tensorflow::Status::OK()) {
|
||||||
|
return tensorflow::errors::Internal("Error reading signature def map: %s",
|
||||||
|
status.error_message());
|
||||||
|
}
|
||||||
|
for (const auto& entry : signature_defs) {
|
||||||
|
tensorflow::SignatureDef signature_def;
|
||||||
|
if (!signature_def.ParseFromString(entry.second)) {
|
||||||
|
return tensorflow::errors::Internal(
|
||||||
|
"Cannot parse signature def found in flatbuffer.");
|
||||||
|
}
|
||||||
|
retrieved_signature_def_map[entry.first] = signature_def;
|
||||||
|
}
|
||||||
|
*signature_def_map = retrieved_signature_def_map;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ClearSignatureDefMap(const Model* model, std::string* model_data) {
|
||||||
|
if (!model || !model_data) {
|
||||||
|
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
|
||||||
|
}
|
||||||
|
auto mutable_model = absl::make_unique<ModelT>();
|
||||||
|
model->UnPackTo(mutable_model.get(), nullptr);
|
||||||
|
for (int id = 0; id < model->metadata()->size(); ++id) {
|
||||||
|
const Metadata* metadata = model->metadata()->Get(id);
|
||||||
|
if (metadata->name()->str() == kSignatureDefsMetadataName) {
|
||||||
|
auto* buffers = &(mutable_model->buffers);
|
||||||
|
buffers->erase(buffers->begin() + metadata->buffer());
|
||||||
|
mutable_model->metadata.erase(mutable_model->metadata.begin() + id);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flatbuffers::FlatBufferBuilder builder;
|
||||||
|
auto packed_model = Model::Pack(builder, mutable_model.get());
|
||||||
|
FinishModelBuffer(builder, packed_model);
|
||||||
|
*model_data =
|
||||||
|
std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
||||||
|
builder.GetSize());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
71
tensorflow/lite/tools/signature/signature_def_util.h
Normal file
71
tensorflow/lite/tools/signature/signature_def_util.h
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_TOOLS_SIGNATURE_DEF_UTIL_H_
|
||||||
|
#define TENSORFLOW_LITE_TOOLS_SIGNATURE_DEF_UTIL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
// Constant for name of the Metadata entry associated with SignatureDefs.
|
||||||
|
constexpr char kSignatureDefsMetadataName[] = "signature_defs_metadata";
|
||||||
|
|
||||||
|
// The function `SetSignatureDefMap()` results in
|
||||||
|
// `model_data_with_signature_defs` containing a serialized TFLite model
|
||||||
|
// identical to `model` with a metadata and associated buffer containing
|
||||||
|
// a FlexBuffer::Map with `signature_def_map` keys and values serialized to
|
||||||
|
// String.
|
||||||
|
//
|
||||||
|
// If a Metadata entry containing a SignatureDef map exists, it will be
|
||||||
|
// overwritten.
|
||||||
|
//
|
||||||
|
// Returns error if `model_data_with_signature_defs` is null or
|
||||||
|
// `signature_def_map` is empty.
|
||||||
|
//
|
||||||
|
// On success, returns tensorflow::Status::OK() or error otherwise.
|
||||||
|
// On error, `model_data_with_signature_defs` is unchanged.
|
||||||
|
tensorflow::Status SetSignatureDefMap(
|
||||||
|
const Model* model,
|
||||||
|
const std::map<std::string, tensorflow::SignatureDef>& signature_def_map,
|
||||||
|
std::string* model_data_with_signature_defs);
|
||||||
|
|
||||||
|
// The function `HasSignatureDef()` returns true if `model` contains a Metadata
|
||||||
|
// table pointing to a buffer containing a FlexBuffer::Map and the map has
|
||||||
|
// `signature_key` as a key, or false otherwise.
|
||||||
|
bool HasSignatureDef(const Model* model, const std::string& signature_key);
|
||||||
|
|
||||||
|
// The function `GetSignatureDefMap()` results in `signature_def_map`
|
||||||
|
// pointing to a map<std::string, tensorflow::SignatureDef>
|
||||||
|
// parsed from `model`'s metadata buffer.
|
||||||
|
//
|
||||||
|
// If the Metadata entry does not exist, `signature_def_map` is unchanged.
|
||||||
|
// If the Metadata entry exists but cannot be parsed, returns an error.
|
||||||
|
tensorflow::Status GetSignatureDefMap(
|
||||||
|
const Model* model,
|
||||||
|
std::map<std::string, tensorflow::SignatureDef>* signature_def_map);
|
||||||
|
|
||||||
|
// The function `ClearSignatureDefs` results in `model_data`
|
||||||
|
// containing a serialized Model identical to `model` omitting any
|
||||||
|
// SignatureDef-related metadata or buffers.
|
||||||
|
tensorflow::Status ClearSignatureDefMap(const Model* model,
|
||||||
|
std::string* model_data);
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_TOOLS_SIGNATURE_DEF_UTIL_H_
|
167
tensorflow/lite/tools/signature/signature_def_util_test.cc
Normal file
167
tensorflow/lite/tools/signature/signature_def_util_test.cc
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/tools/signature/signature_def_util.h"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/cc/saved_model/signature_constants.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/lite/c/c_api.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/model_builder.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/testing/util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using tensorflow::kClassifyMethodName;
|
||||||
|
using tensorflow::kDefaultServingSignatureDefKey;
|
||||||
|
using tensorflow::kPredictMethodName;
|
||||||
|
using tensorflow::SignatureDef;
|
||||||
|
using tensorflow::Status;
|
||||||
|
|
||||||
|
constexpr char kSignatureInput[] = "input";
|
||||||
|
constexpr char kSignatureOutput[] = "output";
|
||||||
|
constexpr char kTestFilePath[] = "tensorflow/lite/testdata/add.bin";
|
||||||
|
|
||||||
|
class SimpleSignatureDefUtilTest : public testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
flatbuffer_model_ = FlatBufferModel::BuildFromFile(kTestFilePath);
|
||||||
|
ASSERT_NE(flatbuffer_model_, nullptr);
|
||||||
|
model_ = flatbuffer_model_->GetModel();
|
||||||
|
ASSERT_NE(model_, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
SignatureDef GetTestSignatureDef() {
|
||||||
|
auto signature_def = SignatureDef();
|
||||||
|
tensorflow::TensorInfo input_tensor;
|
||||||
|
tensorflow::TensorInfo output_tensor;
|
||||||
|
*input_tensor.mutable_name() = kSignatureInput;
|
||||||
|
*output_tensor.mutable_name() = kSignatureOutput;
|
||||||
|
*signature_def.mutable_method_name() = kClassifyMethodName;
|
||||||
|
(*signature_def.mutable_inputs())[kSignatureInput] = input_tensor;
|
||||||
|
(*signature_def.mutable_outputs())[kSignatureOutput] = output_tensor;
|
||||||
|
return signature_def;
|
||||||
|
}
|
||||||
|
std::unique_ptr<FlatBufferModel> flatbuffer_model_;
|
||||||
|
const Model* model_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefTest) {
|
||||||
|
SignatureDef expected_signature_def = GetTestSignatureDef();
|
||||||
|
std::string model_output;
|
||||||
|
const std::map<string, SignatureDef> expected_signature_def_map = {
|
||||||
|
{kDefaultServingSignatureDefKey, expected_signature_def}};
|
||||||
|
EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
|
||||||
|
&model_output));
|
||||||
|
const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||||
|
EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
|
||||||
|
std::map<string, SignatureDef> test_signature_def_map;
|
||||||
|
EXPECT_EQ(Status::OK(),
|
||||||
|
GetSignatureDefMap(add_model, &test_signature_def_map));
|
||||||
|
SignatureDef test_signature_def =
|
||||||
|
test_signature_def_map[kDefaultServingSignatureDefKey];
|
||||||
|
EXPECT_EQ(expected_signature_def.SerializeAsString(),
|
||||||
|
test_signature_def.SerializeAsString());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SimpleSignatureDefUtilTest, OverwriteSignatureDefTest) {
|
||||||
|
auto expected_signature_def = GetTestSignatureDef();
|
||||||
|
std::string model_output;
|
||||||
|
std::map<string, SignatureDef> expected_signature_def_map = {
|
||||||
|
{kDefaultServingSignatureDefKey, expected_signature_def}};
|
||||||
|
EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
|
||||||
|
&model_output));
|
||||||
|
const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||||
|
EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
|
||||||
|
std::map<string, SignatureDef> test_signature_def_map;
|
||||||
|
EXPECT_EQ(Status::OK(),
|
||||||
|
GetSignatureDefMap(add_model, &test_signature_def_map));
|
||||||
|
SignatureDef test_signature_def =
|
||||||
|
test_signature_def_map[kDefaultServingSignatureDefKey];
|
||||||
|
EXPECT_EQ(expected_signature_def.SerializeAsString(),
|
||||||
|
test_signature_def.SerializeAsString());
|
||||||
|
*expected_signature_def.mutable_method_name() = kPredictMethodName;
|
||||||
|
expected_signature_def_map.erase(
|
||||||
|
expected_signature_def_map.find(kDefaultServingSignatureDefKey));
|
||||||
|
constexpr char kTestSignatureDefKey[] = "ServingTest";
|
||||||
|
expected_signature_def_map[kTestSignatureDefKey] = expected_signature_def;
|
||||||
|
EXPECT_EQ(
|
||||||
|
Status::OK(),
|
||||||
|
SetSignatureDefMap(add_model, expected_signature_def_map, &model_output));
|
||||||
|
const Model* final_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||||
|
EXPECT_FALSE(HasSignatureDef(final_model, kDefaultServingSignatureDefKey));
|
||||||
|
EXPECT_EQ(Status::OK(),
|
||||||
|
GetSignatureDefMap(final_model, &test_signature_def_map));
|
||||||
|
EXPECT_NE(expected_signature_def.SerializeAsString(),
|
||||||
|
test_signature_def.SerializeAsString());
|
||||||
|
EXPECT_TRUE(HasSignatureDef(final_model, kTestSignatureDefKey));
|
||||||
|
EXPECT_EQ(Status::OK(),
|
||||||
|
GetSignatureDefMap(final_model, &test_signature_def_map));
|
||||||
|
test_signature_def = test_signature_def_map[kTestSignatureDefKey];
|
||||||
|
EXPECT_EQ(expected_signature_def.SerializeAsString(),
|
||||||
|
test_signature_def.SerializeAsString());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SimpleSignatureDefUtilTest, GetSignatureDefTest) {
|
||||||
|
std::map<string, SignatureDef> test_signature_def_map;
|
||||||
|
EXPECT_EQ(Status::OK(), GetSignatureDefMap(model_, &test_signature_def_map));
|
||||||
|
EXPECT_FALSE(HasSignatureDef(model_, kDefaultServingSignatureDefKey));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SimpleSignatureDefUtilTest, ClearSignatureDefTest) {
|
||||||
|
const int expected_num_buffers = model_->buffers()->size();
|
||||||
|
auto expected_signature_def = GetTestSignatureDef();
|
||||||
|
std::string model_output;
|
||||||
|
std::map<string, SignatureDef> expected_signature_def_map = {
|
||||||
|
{kDefaultServingSignatureDefKey, expected_signature_def}};
|
||||||
|
EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
|
||||||
|
&model_output));
|
||||||
|
const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||||
|
EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
|
||||||
|
SignatureDef test_signature_def;
|
||||||
|
std::map<string, SignatureDef> test_signature_def_map;
|
||||||
|
EXPECT_EQ(Status::OK(),
|
||||||
|
GetSignatureDefMap(add_model, &test_signature_def_map));
|
||||||
|
test_signature_def = test_signature_def_map[kDefaultServingSignatureDefKey];
|
||||||
|
EXPECT_EQ(expected_signature_def.SerializeAsString(),
|
||||||
|
test_signature_def.SerializeAsString());
|
||||||
|
EXPECT_EQ(Status::OK(), ClearSignatureDefMap(add_model, &model_output));
|
||||||
|
const Model* clear_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||||
|
EXPECT_FALSE(HasSignatureDef(clear_model, kDefaultServingSignatureDefKey));
|
||||||
|
EXPECT_EQ(expected_num_buffers, clear_model->buffers()->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefErrorsTest) {
|
||||||
|
std::map<string, SignatureDef> test_signature_def_map;
|
||||||
|
std::string model_output;
|
||||||
|
EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
|
||||||
|
SetSignatureDefMap(model_, test_signature_def_map, &model_output)));
|
||||||
|
SignatureDef test_signature_def;
|
||||||
|
test_signature_def_map[kDefaultServingSignatureDefKey] = test_signature_def;
|
||||||
|
EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
|
||||||
|
SetSignatureDefMap(model_, test_signature_def_map, nullptr)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::tflite::LogToStderr();
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/pytypes.h"
|
||||||
|
#include "pybind11/stl.h"
|
||||||
|
#include "tensorflow/lite/model_builder.h"
|
||||||
|
#include "tensorflow/lite/tools/signature/signature_def_util.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
|
|
||||||
|
py::bytes WrappedSetSignatureDefMap(
|
||||||
|
const std::vector<uint8_t>& model_buffer,
|
||||||
|
const std::map<std::string, std::string>& serialized_signature_def_map) {
|
||||||
|
auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer(
|
||||||
|
reinterpret_cast<const char*>(model_buffer.data()), model_buffer.size());
|
||||||
|
auto* model = flatbuffer_model->GetModel();
|
||||||
|
if (!model) {
|
||||||
|
throw std::invalid_argument("Invalid model");
|
||||||
|
}
|
||||||
|
std::string data;
|
||||||
|
std::map<std::string, tensorflow::SignatureDef> signature_def_map;
|
||||||
|
for (const auto& entry : serialized_signature_def_map) {
|
||||||
|
tensorflow::SignatureDef signature_def;
|
||||||
|
if (!signature_def.ParseFromString(entry.second)) {
|
||||||
|
throw std::invalid_argument("Cannot parse signature def");
|
||||||
|
}
|
||||||
|
signature_def_map[entry.first] = signature_def;
|
||||||
|
}
|
||||||
|
auto status = tflite::SetSignatureDefMap(model, signature_def_map, &data);
|
||||||
|
if (status != tensorflow::Status::OK()) {
|
||||||
|
throw std::invalid_argument(status.error_message());
|
||||||
|
}
|
||||||
|
return py::bytes(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, py::bytes> WrappedGetSignatureDefMap(
|
||||||
|
const std::vector<uint8_t>& model_buffer) {
|
||||||
|
auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer(
|
||||||
|
reinterpret_cast<const char*>(model_buffer.data()), model_buffer.size());
|
||||||
|
auto* model = flatbuffer_model->GetModel();
|
||||||
|
if (!model) {
|
||||||
|
throw std::invalid_argument("Invalid model");
|
||||||
|
}
|
||||||
|
std::string content;
|
||||||
|
std::map<std::string, tensorflow::SignatureDef> signature_def_map;
|
||||||
|
auto status = tflite::GetSignatureDefMap(model, &signature_def_map);
|
||||||
|
if (status != tensorflow::Status::OK()) {
|
||||||
|
throw std::invalid_argument("Cannot parse signature def");
|
||||||
|
}
|
||||||
|
std::map<std::string, py::bytes> serialized_signature_def_map;
|
||||||
|
for (const auto& entry : signature_def_map) {
|
||||||
|
serialized_signature_def_map[entry.first] =
|
||||||
|
py::bytes(entry.second.SerializeAsString());
|
||||||
|
}
|
||||||
|
return serialized_signature_def_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::bytes WrappedClearSignatureDefs(const std::vector<uint8_t>& model_buffer) {
|
||||||
|
auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer(
|
||||||
|
reinterpret_cast<const char*>(model_buffer.data()), model_buffer.size());
|
||||||
|
auto* model = flatbuffer_model->GetModel();
|
||||||
|
if (!model) {
|
||||||
|
throw std::invalid_argument("Invalid model");
|
||||||
|
}
|
||||||
|
std::string content;
|
||||||
|
auto status = tflite::ClearSignatureDefMap(model, &content);
|
||||||
|
if (status != tensorflow::Status::OK()) {
|
||||||
|
throw std::invalid_argument("An unknown error occurred");
|
||||||
|
}
|
||||||
|
return py::bytes(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_signature_def_util_wrapper, m) {
|
||||||
|
m.doc() = R"pbdoc(
|
||||||
|
_pywrap_signature_def_util_wrapper
|
||||||
|
-----
|
||||||
|
)pbdoc";
|
||||||
|
|
||||||
|
m.def("SetSignatureDefMap", &WrappedSetSignatureDefMap);
|
||||||
|
|
||||||
|
m.def("GetSignatureDefMap", &WrappedGetSignatureDefMap);
|
||||||
|
|
||||||
|
m.def("ClearSignatureDefs", &WrappedClearSignatureDefs);
|
||||||
|
}
|
95
tensorflow/lite/tools/signature/signature_def_utils.py
Normal file
95
tensorflow/lite/tools/signature/signature_def_utils.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utility functions related to SignatureDefs."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
|
from tensorflow.lite.tools.signature import _pywrap_signature_def_util_wrapper as signature_def_util
|
||||||
|
|
||||||
|
|
||||||
|
def set_signature_defs(tflite_model, signature_def_map):
|
||||||
|
"""Sets SignatureDefs to the Metadata of a TfLite flatbuffer buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tflite_model: Binary TFLite model (bytes or bytes-like object) to which to
|
||||||
|
add signature_def.
|
||||||
|
signature_def_map: dict containing SignatureDefs to store in metadata.
|
||||||
|
Returns:
|
||||||
|
buffer: A TFLite model binary identical to model buffer with
|
||||||
|
metadata field containing SignatureDef.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
tflite_model buffer does not contain a valid TFLite model.
|
||||||
|
signature_def_map is empty or does not contain a SignatureDef.
|
||||||
|
"""
|
||||||
|
model = tflite_model
|
||||||
|
if not isinstance(tflite_model, bytearray):
|
||||||
|
model = bytearray(tflite_model)
|
||||||
|
serialized_signature_def_map = {
|
||||||
|
k: v.SerializeToString() for k, v in signature_def_map.items()}
|
||||||
|
model_buffer = signature_def_util.SetSignatureDefMap(
|
||||||
|
model, serialized_signature_def_map)
|
||||||
|
return model_buffer
|
||||||
|
|
||||||
|
|
||||||
|
def get_signature_defs(tflite_model):
|
||||||
|
"""Get SignatureDef dict from the Metadata of a TfLite flatbuffer buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tflite_model: TFLite model buffer to get the signature_def.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict containing serving names to SignatureDefs if exists, otherwise, empty
|
||||||
|
dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
tflite_model buffer does not contain a valid TFLite model.
|
||||||
|
DecodeError:
|
||||||
|
SignatureDef cannot be parsed from TfLite SignatureDef metadata.
|
||||||
|
"""
|
||||||
|
model = tflite_model
|
||||||
|
if not isinstance(tflite_model, bytearray):
|
||||||
|
model = bytearray(tflite_model)
|
||||||
|
serialized_signature_def_map = signature_def_util.GetSignatureDefMap(model)
|
||||||
|
def _deserialize(serialized):
|
||||||
|
signature_def = meta_graph_pb2.SignatureDef()
|
||||||
|
signature_def.ParseFromString(serialized)
|
||||||
|
return signature_def
|
||||||
|
return {k: _deserialize(v) for k, v in serialized_signature_def_map.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def clear_signature_defs(tflite_model):
|
||||||
|
"""Clears SignatureDefs from the Metadata of a TfLite flatbuffer buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tflite_model: TFLite model buffer to remove signature_defs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
buffer: A TFLite model binary identical to model buffer with
|
||||||
|
no SignatureDef metadata.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
tflite_model buffer does not contain a valid TFLite model.
|
||||||
|
"""
|
||||||
|
model = tflite_model
|
||||||
|
if not isinstance(tflite_model, bytearray):
|
||||||
|
model = bytearray(tflite_model)
|
||||||
|
return signature_def_util.ClearSignatureDefs(model)
|
76
tensorflow/lite/tools/signature/signature_def_utils_test.py
Normal file
76
tensorflow/lite/tools/signature/signature_def_utils_test.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for signature_def_util.py.
|
||||||
|
|
||||||
|
- Tests adding a SignatureDef to TFLite metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
|
from tensorflow.lite.tools.signature import signature_def_utils
|
||||||
|
|
||||||
|
|
||||||
|
class SignatureDefUtilsTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testAddSignatureDefToFlatbufferMetadata(self):
|
||||||
|
"""Test a SavedModel conversion has correct Metadata."""
|
||||||
|
filename = tf.compat.v1.resource_loader.get_path_to_datafile(
|
||||||
|
'../../testdata/add.bin')
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
raise IOError('File "{0}" does not exist in {1}.'.format(
|
||||||
|
filename,
|
||||||
|
tf.compat.v1.resource_loader.get_root_dir_with_all_resources()))
|
||||||
|
|
||||||
|
with tf.io.gfile.GFile(filename, 'rb') as fp:
|
||||||
|
tflite_model = bytearray(fp.read())
|
||||||
|
|
||||||
|
self.assertIsNotNone(tflite_model, 'TFLite model is none')
|
||||||
|
sig_input_tensor = meta_graph_pb2.TensorInfo(
|
||||||
|
dtype=tf.as_dtype(tf.float32).as_datatype_enum,
|
||||||
|
tensor_shape=tf.TensorShape([1, 8, 8, 3]).as_proto())
|
||||||
|
sig_input_tensor_signature = {'x': sig_input_tensor}
|
||||||
|
sig_output_tensor = meta_graph_pb2.TensorInfo(
|
||||||
|
dtype=tf.as_dtype(tf.float32).as_datatype_enum,
|
||||||
|
tensor_shape=tf.TensorShape([1, 8, 8, 3]).as_proto())
|
||||||
|
sig_output_tensor_signature = {'y': sig_output_tensor}
|
||||||
|
predict_signature_def = (
|
||||||
|
tf.compat.v1.saved_model.build_signature_def(
|
||||||
|
sig_input_tensor_signature, sig_output_tensor_signature,
|
||||||
|
tf.saved_model.PREDICT_METHOD_NAME))
|
||||||
|
serving_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||||
|
signature_def_map = {serving_key: predict_signature_def}
|
||||||
|
tflite_model = signature_def_utils.set_signature_defs(
|
||||||
|
tflite_model, signature_def_map)
|
||||||
|
saved_signature_def_map = signature_def_utils.get_signature_defs(
|
||||||
|
tflite_model)
|
||||||
|
signature_def = saved_signature_def_map.get(serving_key)
|
||||||
|
self.assertIsNotNone(signature_def, 'SignatureDef not found')
|
||||||
|
self.assertEqual(signature_def.SerializeToString(),
|
||||||
|
predict_signature_def.SerializeToString())
|
||||||
|
remove_tflite_model = (
|
||||||
|
signature_def_utils.clear_signature_defs(tflite_model))
|
||||||
|
signature_def_map = signature_def_utils.get_signature_defs(
|
||||||
|
remove_tflite_model)
|
||||||
|
self.assertIsNone(signature_def_map.get(serving_key),
|
||||||
|
'SignatureDef found, but should be missing')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
Loading…
Reference in New Issue
Block a user