Adds utility methods for storing SignatureDefs in the metadata table in the flatbuffer
PiperOrigin-RevId: 311652937 Change-Id: I397c7ce6fad843cff789dedb583d6df44545db3f
This commit is contained in:
parent
efa3fb28d9
commit
37df93331e
106
tensorflow/lite/tools/signature/BUILD
Normal file
106
tensorflow/lite/tools/signature/BUILD
Normal file
@ -0,0 +1,106 @@
|
||||
# Utilities for signature_defs in TFLite
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||
load("//tensorflow/lite/micro:build_def.bzl", "cc_library")
|
||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
TFLITE_DEFAULT_COPTS = if_not_windows([
|
||||
"-Wall",
|
||||
"-Wno-comment",
|
||||
"-Wno-extern-c-compat",
|
||||
])
|
||||
|
||||
cc_library(
|
||||
name = "signature_def_util",
|
||||
srcs = ["signature_def_util.cc"],
|
||||
hdrs = ["signature_def_util.h"],
|
||||
copts = TFLITE_DEFAULT_COPTS + tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:protos_all_cc_impl",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_protobuf//:protobuf",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "signature_def_util_test",
|
||||
size = "small",
|
||||
srcs = ["signature_def_util_test.cc"],
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/add.bin",
|
||||
],
|
||||
tags = [
|
||||
"tflite_not_portable",
|
||||
],
|
||||
deps = [
|
||||
":signature_def_util",
|
||||
"//tensorflow/cc/saved_model:signature_constants",
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/lite:framework_lib",
|
||||
"//tensorflow/lite/c:c_api",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_signature_def_util_wrapper",
|
||||
srcs = [
|
||||
"signature_def_util_wrapper_pybind11.cc",
|
||||
],
|
||||
module_name = "_pywrap_signature_def_util_wrapper",
|
||||
deps = [
|
||||
":signature_def_util",
|
||||
"//tensorflow/lite:framework_lib",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "signature_def_utils",
|
||||
srcs = ["signature_def_utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_pywrap_signature_def_util_wrapper",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "signature_def_utils_test",
|
||||
srcs = ["signature_def_utils_test.py"],
|
||||
data = ["//tensorflow/lite:testdata/add.bin"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_mac",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":signature_def_utils",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite()
|
175
tensorflow/lite/tools/signature/signature_def_util.cc
Normal file
175
tensorflow/lite/tools/signature/signature_def_util.cc
Normal file
@ -0,0 +1,175 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/signature/signature_def_util.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using tensorflow::Status;
|
||||
using SerializedSignatureDefMap = std::map<std::string, std::string>;
|
||||
using SignatureDefMap = std::map<std::string, tensorflow::SignatureDef>;
|
||||
|
||||
const Metadata* GetSignatureDefMetadata(const Model* model) {
|
||||
if (!model || !model->metadata()) {
|
||||
return nullptr;
|
||||
}
|
||||
for (int i = 0; i < model->metadata()->size(); ++i) {
|
||||
const Metadata* metadata = model->metadata()->Get(i);
|
||||
if (metadata->name()->str() == kSignatureDefsMetadataName) {
|
||||
return metadata;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status ReadSignatureDefMap(const Model* model, const Metadata* metadata,
|
||||
SerializedSignatureDefMap* map) {
|
||||
if (!model || !metadata || !map) {
|
||||
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t>* flatbuffer_data =
|
||||
model->buffers()->Get(metadata->buffer())->data();
|
||||
const auto signature_defs =
|
||||
flexbuffers::GetRoot(flatbuffer_data->data(), flatbuffer_data->size())
|
||||
.AsMap();
|
||||
for (int i = 0; i < signature_defs.Keys().size(); ++i) {
|
||||
const std::string key = signature_defs.Keys()[i].AsString().c_str();
|
||||
(*map)[key] = signature_defs[key].AsString().c_str();
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SetSignatureDefMap(const Model* model,
|
||||
const SignatureDefMap& signature_def_map,
|
||||
std::string* model_data_with_signature_def) {
|
||||
if (!model || !model_data_with_signature_def) {
|
||||
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
|
||||
}
|
||||
if (signature_def_map.empty()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"signature_def_map should not be empty");
|
||||
}
|
||||
flexbuffers::Builder fbb;
|
||||
const size_t start_map = fbb.StartMap();
|
||||
auto mutable_model = absl::make_unique<ModelT>();
|
||||
model->UnPackTo(mutable_model.get(), nullptr);
|
||||
int buffer_id = mutable_model->buffers.size();
|
||||
const Metadata* metadata = GetSignatureDefMetadata(model);
|
||||
if (metadata) {
|
||||
buffer_id = metadata->buffer();
|
||||
} else {
|
||||
auto buffer = absl::make_unique<BufferT>();
|
||||
mutable_model->buffers.emplace_back(std::move(buffer));
|
||||
auto sigdef_metadata = absl::make_unique<MetadataT>();
|
||||
sigdef_metadata->buffer = buffer_id;
|
||||
sigdef_metadata->name = kSignatureDefsMetadataName;
|
||||
mutable_model->metadata.emplace_back(std::move(sigdef_metadata));
|
||||
}
|
||||
for (const auto& entry : signature_def_map) {
|
||||
fbb.String(entry.first.c_str(), entry.second.SerializeAsString());
|
||||
}
|
||||
fbb.EndMap(start_map);
|
||||
fbb.Finish();
|
||||
mutable_model->buffers[buffer_id]->data = fbb.GetBuffer();
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto packed_model = Model::Pack(builder, mutable_model.get());
|
||||
FinishModelBuffer(builder, packed_model);
|
||||
*model_data_with_signature_def =
|
||||
std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
||||
builder.GetSize());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool HasSignatureDef(const Model* model, const std::string& signature_key) {
|
||||
if (!model) {
|
||||
return false;
|
||||
}
|
||||
const Metadata* metadata = GetSignatureDefMetadata(model);
|
||||
if (!metadata) {
|
||||
return false;
|
||||
}
|
||||
SerializedSignatureDefMap signature_defs;
|
||||
if (ReadSignatureDefMap(model, metadata, &signature_defs) !=
|
||||
tensorflow::Status::OK()) {
|
||||
return false;
|
||||
}
|
||||
return (signature_defs.find(signature_key) != signature_defs.end());
|
||||
}
|
||||
|
||||
Status GetSignatureDefMap(const Model* model,
|
||||
SignatureDefMap* signature_def_map) {
|
||||
if (!model || !signature_def_map) {
|
||||
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
|
||||
}
|
||||
SignatureDefMap retrieved_signature_def_map;
|
||||
const Metadata* metadata = GetSignatureDefMetadata(model);
|
||||
if (metadata) {
|
||||
SerializedSignatureDefMap signature_defs;
|
||||
auto status = ReadSignatureDefMap(model, metadata, &signature_defs);
|
||||
if (status != tensorflow::Status::OK()) {
|
||||
return tensorflow::errors::Internal("Error reading signature def map: %s",
|
||||
status.error_message());
|
||||
}
|
||||
for (const auto& entry : signature_defs) {
|
||||
tensorflow::SignatureDef signature_def;
|
||||
if (!signature_def.ParseFromString(entry.second)) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Cannot parse signature def found in flatbuffer.");
|
||||
}
|
||||
retrieved_signature_def_map[entry.first] = signature_def;
|
||||
}
|
||||
*signature_def_map = retrieved_signature_def_map;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClearSignatureDefMap(const Model* model, std::string* model_data) {
|
||||
if (!model || !model_data) {
|
||||
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
|
||||
}
|
||||
auto mutable_model = absl::make_unique<ModelT>();
|
||||
model->UnPackTo(mutable_model.get(), nullptr);
|
||||
for (int id = 0; id < model->metadata()->size(); ++id) {
|
||||
const Metadata* metadata = model->metadata()->Get(id);
|
||||
if (metadata->name()->str() == kSignatureDefsMetadataName) {
|
||||
auto* buffers = &(mutable_model->buffers);
|
||||
buffers->erase(buffers->begin() + metadata->buffer());
|
||||
mutable_model->metadata.erase(mutable_model->metadata.begin() + id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto packed_model = Model::Pack(builder, mutable_model.get());
|
||||
FinishModelBuffer(builder, packed_model);
|
||||
*model_data =
|
||||
std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
||||
builder.GetSize());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tflite
|
71
tensorflow/lite/tools/signature/signature_def_util.h
Normal file
71
tensorflow/lite/tools/signature/signature_def_util.h
Normal file
@ -0,0 +1,71 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_SIGNATURE_DEF_UTIL_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_SIGNATURE_DEF_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Constant for name of the Metadata entry associated with SignatureDefs.
|
||||
constexpr char kSignatureDefsMetadataName[] = "signature_defs_metadata";
|
||||
|
||||
// The function `SetSignatureDefMap()` results in
|
||||
// `model_data_with_signature_defs` containing a serialized TFLite model
|
||||
// identical to `model` with a metadata and associated buffer containing
|
||||
// a FlexBuffer::Map with `signature_def_map` keys and values serialized to
|
||||
// String.
|
||||
//
|
||||
// If a Metadata entry containing a SignatureDef map exists, it will be
|
||||
// overwritten.
|
||||
//
|
||||
// Returns error if `model_data_with_signature_defs` is null or
|
||||
// `signature_def_map` is empty.
|
||||
//
|
||||
// On success, returns tensorflow::Status::OK() or error otherwise.
|
||||
// On error, `model_data_with_signature_defs` is unchanged.
|
||||
tensorflow::Status SetSignatureDefMap(
|
||||
const Model* model,
|
||||
const std::map<std::string, tensorflow::SignatureDef>& signature_def_map,
|
||||
std::string* model_data_with_signature_defs);
|
||||
|
||||
// The function `HasSignatureDef()` returns true if `model` contains a Metadata
|
||||
// table pointing to a buffer containing a FlexBuffer::Map and the map has
|
||||
// `signature_key` as a key, or false otherwise.
|
||||
bool HasSignatureDef(const Model* model, const std::string& signature_key);
|
||||
|
||||
// The function `GetSignatureDefMap()` results in `signature_def_map`
|
||||
// pointing to a map<std::string, tensorflow::SignatureDef>
|
||||
// parsed from `model`'s metadata buffer.
|
||||
//
|
||||
// If the Metadata entry does not exist, `signature_def_map` is unchanged.
|
||||
// If the Metadata entry exists but cannot be parsed, returns an error.
|
||||
tensorflow::Status GetSignatureDefMap(
|
||||
const Model* model,
|
||||
std::map<std::string, tensorflow::SignatureDef>* signature_def_map);
|
||||
|
||||
// The function `ClearSignatureDefs` results in `model_data`
|
||||
// containing a serialized Model identical to `model` omitting any
|
||||
// SignatureDef-related metadata or buffers.
|
||||
tensorflow::Status ClearSignatureDefMap(const Model* model,
|
||||
std::string* model_data);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_SIGNATURE_DEF_UTIL_H_
|
167
tensorflow/lite/tools/signature/signature_def_util_test.cc
Normal file
167
tensorflow/lite/tools/signature/signature_def_util_test.cc
Normal file
@ -0,0 +1,167 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/signature/signature_def_util.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/cc/saved_model/signature_constants.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/c/c_api.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using tensorflow::kClassifyMethodName;
|
||||
using tensorflow::kDefaultServingSignatureDefKey;
|
||||
using tensorflow::kPredictMethodName;
|
||||
using tensorflow::SignatureDef;
|
||||
using tensorflow::Status;
|
||||
|
||||
constexpr char kSignatureInput[] = "input";
|
||||
constexpr char kSignatureOutput[] = "output";
|
||||
constexpr char kTestFilePath[] = "tensorflow/lite/testdata/add.bin";
|
||||
|
||||
class SimpleSignatureDefUtilTest : public testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
flatbuffer_model_ = FlatBufferModel::BuildFromFile(kTestFilePath);
|
||||
ASSERT_NE(flatbuffer_model_, nullptr);
|
||||
model_ = flatbuffer_model_->GetModel();
|
||||
ASSERT_NE(model_, nullptr);
|
||||
}
|
||||
|
||||
SignatureDef GetTestSignatureDef() {
|
||||
auto signature_def = SignatureDef();
|
||||
tensorflow::TensorInfo input_tensor;
|
||||
tensorflow::TensorInfo output_tensor;
|
||||
*input_tensor.mutable_name() = kSignatureInput;
|
||||
*output_tensor.mutable_name() = kSignatureOutput;
|
||||
*signature_def.mutable_method_name() = kClassifyMethodName;
|
||||
(*signature_def.mutable_inputs())[kSignatureInput] = input_tensor;
|
||||
(*signature_def.mutable_outputs())[kSignatureOutput] = output_tensor;
|
||||
return signature_def;
|
||||
}
|
||||
std::unique_ptr<FlatBufferModel> flatbuffer_model_;
|
||||
const Model* model_;
|
||||
};
|
||||
|
||||
TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefTest) {
|
||||
SignatureDef expected_signature_def = GetTestSignatureDef();
|
||||
std::string model_output;
|
||||
const std::map<string, SignatureDef> expected_signature_def_map = {
|
||||
{kDefaultServingSignatureDefKey, expected_signature_def}};
|
||||
EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
|
||||
&model_output));
|
||||
const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||
EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
|
||||
std::map<string, SignatureDef> test_signature_def_map;
|
||||
EXPECT_EQ(Status::OK(),
|
||||
GetSignatureDefMap(add_model, &test_signature_def_map));
|
||||
SignatureDef test_signature_def =
|
||||
test_signature_def_map[kDefaultServingSignatureDefKey];
|
||||
EXPECT_EQ(expected_signature_def.SerializeAsString(),
|
||||
test_signature_def.SerializeAsString());
|
||||
}
|
||||
|
||||
TEST_F(SimpleSignatureDefUtilTest, OverwriteSignatureDefTest) {
|
||||
auto expected_signature_def = GetTestSignatureDef();
|
||||
std::string model_output;
|
||||
std::map<string, SignatureDef> expected_signature_def_map = {
|
||||
{kDefaultServingSignatureDefKey, expected_signature_def}};
|
||||
EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
|
||||
&model_output));
|
||||
const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||
EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
|
||||
std::map<string, SignatureDef> test_signature_def_map;
|
||||
EXPECT_EQ(Status::OK(),
|
||||
GetSignatureDefMap(add_model, &test_signature_def_map));
|
||||
SignatureDef test_signature_def =
|
||||
test_signature_def_map[kDefaultServingSignatureDefKey];
|
||||
EXPECT_EQ(expected_signature_def.SerializeAsString(),
|
||||
test_signature_def.SerializeAsString());
|
||||
*expected_signature_def.mutable_method_name() = kPredictMethodName;
|
||||
expected_signature_def_map.erase(
|
||||
expected_signature_def_map.find(kDefaultServingSignatureDefKey));
|
||||
constexpr char kTestSignatureDefKey[] = "ServingTest";
|
||||
expected_signature_def_map[kTestSignatureDefKey] = expected_signature_def;
|
||||
EXPECT_EQ(
|
||||
Status::OK(),
|
||||
SetSignatureDefMap(add_model, expected_signature_def_map, &model_output));
|
||||
const Model* final_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||
EXPECT_FALSE(HasSignatureDef(final_model, kDefaultServingSignatureDefKey));
|
||||
EXPECT_EQ(Status::OK(),
|
||||
GetSignatureDefMap(final_model, &test_signature_def_map));
|
||||
EXPECT_NE(expected_signature_def.SerializeAsString(),
|
||||
test_signature_def.SerializeAsString());
|
||||
EXPECT_TRUE(HasSignatureDef(final_model, kTestSignatureDefKey));
|
||||
EXPECT_EQ(Status::OK(),
|
||||
GetSignatureDefMap(final_model, &test_signature_def_map));
|
||||
test_signature_def = test_signature_def_map[kTestSignatureDefKey];
|
||||
EXPECT_EQ(expected_signature_def.SerializeAsString(),
|
||||
test_signature_def.SerializeAsString());
|
||||
}
|
||||
|
||||
TEST_F(SimpleSignatureDefUtilTest, GetSignatureDefTest) {
|
||||
std::map<string, SignatureDef> test_signature_def_map;
|
||||
EXPECT_EQ(Status::OK(), GetSignatureDefMap(model_, &test_signature_def_map));
|
||||
EXPECT_FALSE(HasSignatureDef(model_, kDefaultServingSignatureDefKey));
|
||||
}
|
||||
|
||||
TEST_F(SimpleSignatureDefUtilTest, ClearSignatureDefTest) {
|
||||
const int expected_num_buffers = model_->buffers()->size();
|
||||
auto expected_signature_def = GetTestSignatureDef();
|
||||
std::string model_output;
|
||||
std::map<string, SignatureDef> expected_signature_def_map = {
|
||||
{kDefaultServingSignatureDefKey, expected_signature_def}};
|
||||
EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
|
||||
&model_output));
|
||||
const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||
EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
|
||||
SignatureDef test_signature_def;
|
||||
std::map<string, SignatureDef> test_signature_def_map;
|
||||
EXPECT_EQ(Status::OK(),
|
||||
GetSignatureDefMap(add_model, &test_signature_def_map));
|
||||
test_signature_def = test_signature_def_map[kDefaultServingSignatureDefKey];
|
||||
EXPECT_EQ(expected_signature_def.SerializeAsString(),
|
||||
test_signature_def.SerializeAsString());
|
||||
EXPECT_EQ(Status::OK(), ClearSignatureDefMap(add_model, &model_output));
|
||||
const Model* clear_model = flatbuffers::GetRoot<Model>(model_output.data());
|
||||
EXPECT_FALSE(HasSignatureDef(clear_model, kDefaultServingSignatureDefKey));
|
||||
EXPECT_EQ(expected_num_buffers, clear_model->buffers()->size());
|
||||
}
|
||||
|
||||
TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefErrorsTest) {
|
||||
std::map<string, SignatureDef> test_signature_def_map;
|
||||
std::string model_output;
|
||||
EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
|
||||
SetSignatureDefMap(model_, test_signature_def_map, &model_output)));
|
||||
SignatureDef test_signature_def;
|
||||
test_signature_def_map[kDefaultServingSignatureDefKey] = test_signature_def;
|
||||
EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
|
||||
SetSignatureDefMap(model_, test_signature_def_map, nullptr)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
#include "tensorflow/lite/tools/signature/signature_def_util.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
|
||||
py::bytes WrappedSetSignatureDefMap(
|
||||
const std::vector<uint8_t>& model_buffer,
|
||||
const std::map<std::string, std::string>& serialized_signature_def_map) {
|
||||
auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer(
|
||||
reinterpret_cast<const char*>(model_buffer.data()), model_buffer.size());
|
||||
auto* model = flatbuffer_model->GetModel();
|
||||
if (!model) {
|
||||
throw std::invalid_argument("Invalid model");
|
||||
}
|
||||
std::string data;
|
||||
std::map<std::string, tensorflow::SignatureDef> signature_def_map;
|
||||
for (const auto& entry : serialized_signature_def_map) {
|
||||
tensorflow::SignatureDef signature_def;
|
||||
if (!signature_def.ParseFromString(entry.second)) {
|
||||
throw std::invalid_argument("Cannot parse signature def");
|
||||
}
|
||||
signature_def_map[entry.first] = signature_def;
|
||||
}
|
||||
auto status = tflite::SetSignatureDefMap(model, signature_def_map, &data);
|
||||
if (status != tensorflow::Status::OK()) {
|
||||
throw std::invalid_argument(status.error_message());
|
||||
}
|
||||
return py::bytes(data);
|
||||
}
|
||||
|
||||
std::map<std::string, py::bytes> WrappedGetSignatureDefMap(
|
||||
const std::vector<uint8_t>& model_buffer) {
|
||||
auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer(
|
||||
reinterpret_cast<const char*>(model_buffer.data()), model_buffer.size());
|
||||
auto* model = flatbuffer_model->GetModel();
|
||||
if (!model) {
|
||||
throw std::invalid_argument("Invalid model");
|
||||
}
|
||||
std::string content;
|
||||
std::map<std::string, tensorflow::SignatureDef> signature_def_map;
|
||||
auto status = tflite::GetSignatureDefMap(model, &signature_def_map);
|
||||
if (status != tensorflow::Status::OK()) {
|
||||
throw std::invalid_argument("Cannot parse signature def");
|
||||
}
|
||||
std::map<std::string, py::bytes> serialized_signature_def_map;
|
||||
for (const auto& entry : signature_def_map) {
|
||||
serialized_signature_def_map[entry.first] =
|
||||
py::bytes(entry.second.SerializeAsString());
|
||||
}
|
||||
return serialized_signature_def_map;
|
||||
}
|
||||
|
||||
py::bytes WrappedClearSignatureDefs(const std::vector<uint8_t>& model_buffer) {
|
||||
auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer(
|
||||
reinterpret_cast<const char*>(model_buffer.data()), model_buffer.size());
|
||||
auto* model = flatbuffer_model->GetModel();
|
||||
if (!model) {
|
||||
throw std::invalid_argument("Invalid model");
|
||||
}
|
||||
std::string content;
|
||||
auto status = tflite::ClearSignatureDefMap(model, &content);
|
||||
if (status != tensorflow::Status::OK()) {
|
||||
throw std::invalid_argument("An unknown error occurred");
|
||||
}
|
||||
return py::bytes(content);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_pywrap_signature_def_util_wrapper, m) {
|
||||
m.doc() = R"pbdoc(
|
||||
_pywrap_signature_def_util_wrapper
|
||||
-----
|
||||
)pbdoc";
|
||||
|
||||
m.def("SetSignatureDefMap", &WrappedSetSignatureDefMap);
|
||||
|
||||
m.def("GetSignatureDefMap", &WrappedGetSignatureDefMap);
|
||||
|
||||
m.def("ClearSignatureDefs", &WrappedClearSignatureDefs);
|
||||
}
|
95
tensorflow/lite/tools/signature/signature_def_utils.py
Normal file
95
tensorflow/lite/tools/signature/signature_def_utils.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility functions related to SignatureDefs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.lite.tools.signature import _pywrap_signature_def_util_wrapper as signature_def_util
|
||||
|
||||
|
||||
def set_signature_defs(tflite_model, signature_def_map):
|
||||
"""Sets SignatureDefs to the Metadata of a TfLite flatbuffer buffer.
|
||||
|
||||
Args:
|
||||
tflite_model: Binary TFLite model (bytes or bytes-like object) to which to
|
||||
add signature_def.
|
||||
signature_def_map: dict containing SignatureDefs to store in metadata.
|
||||
Returns:
|
||||
buffer: A TFLite model binary identical to model buffer with
|
||||
metadata field containing SignatureDef.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
tflite_model buffer does not contain a valid TFLite model.
|
||||
signature_def_map is empty or does not contain a SignatureDef.
|
||||
"""
|
||||
model = tflite_model
|
||||
if not isinstance(tflite_model, bytearray):
|
||||
model = bytearray(tflite_model)
|
||||
serialized_signature_def_map = {
|
||||
k: v.SerializeToString() for k, v in signature_def_map.items()}
|
||||
model_buffer = signature_def_util.SetSignatureDefMap(
|
||||
model, serialized_signature_def_map)
|
||||
return model_buffer
|
||||
|
||||
|
||||
def get_signature_defs(tflite_model):
|
||||
"""Get SignatureDef dict from the Metadata of a TfLite flatbuffer buffer.
|
||||
|
||||
Args:
|
||||
tflite_model: TFLite model buffer to get the signature_def.
|
||||
|
||||
Returns:
|
||||
dict containing serving names to SignatureDefs if exists, otherwise, empty
|
||||
dict.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
tflite_model buffer does not contain a valid TFLite model.
|
||||
DecodeError:
|
||||
SignatureDef cannot be parsed from TfLite SignatureDef metadata.
|
||||
"""
|
||||
model = tflite_model
|
||||
if not isinstance(tflite_model, bytearray):
|
||||
model = bytearray(tflite_model)
|
||||
serialized_signature_def_map = signature_def_util.GetSignatureDefMap(model)
|
||||
def _deserialize(serialized):
|
||||
signature_def = meta_graph_pb2.SignatureDef()
|
||||
signature_def.ParseFromString(serialized)
|
||||
return signature_def
|
||||
return {k: _deserialize(v) for k, v in serialized_signature_def_map.items()}
|
||||
|
||||
|
||||
def clear_signature_defs(tflite_model):
|
||||
"""Clears SignatureDefs from the Metadata of a TfLite flatbuffer buffer.
|
||||
|
||||
Args:
|
||||
tflite_model: TFLite model buffer to remove signature_defs.
|
||||
|
||||
Returns:
|
||||
buffer: A TFLite model binary identical to model buffer with
|
||||
no SignatureDef metadata.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
tflite_model buffer does not contain a valid TFLite model.
|
||||
"""
|
||||
model = tflite_model
|
||||
if not isinstance(tflite_model, bytearray):
|
||||
model = bytearray(tflite_model)
|
||||
return signature_def_util.ClearSignatureDefs(model)
|
76
tensorflow/lite/tools/signature/signature_def_utils_test.py
Normal file
76
tensorflow/lite/tools/signature/signature_def_utils_test.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for signature_def_util.py.
|
||||
|
||||
- Tests adding a SignatureDef to TFLite metadata.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.lite.tools.signature import signature_def_utils
|
||||
|
||||
|
||||
class SignatureDefUtilsTest(tf.test.TestCase):
|
||||
|
||||
def testAddSignatureDefToFlatbufferMetadata(self):
|
||||
"""Test a SavedModel conversion has correct Metadata."""
|
||||
filename = tf.compat.v1.resource_loader.get_path_to_datafile(
|
||||
'../../testdata/add.bin')
|
||||
if not os.path.exists(filename):
|
||||
raise IOError('File "{0}" does not exist in {1}.'.format(
|
||||
filename,
|
||||
tf.compat.v1.resource_loader.get_root_dir_with_all_resources()))
|
||||
|
||||
with tf.io.gfile.GFile(filename, 'rb') as fp:
|
||||
tflite_model = bytearray(fp.read())
|
||||
|
||||
self.assertIsNotNone(tflite_model, 'TFLite model is none')
|
||||
sig_input_tensor = meta_graph_pb2.TensorInfo(
|
||||
dtype=tf.as_dtype(tf.float32).as_datatype_enum,
|
||||
tensor_shape=tf.TensorShape([1, 8, 8, 3]).as_proto())
|
||||
sig_input_tensor_signature = {'x': sig_input_tensor}
|
||||
sig_output_tensor = meta_graph_pb2.TensorInfo(
|
||||
dtype=tf.as_dtype(tf.float32).as_datatype_enum,
|
||||
tensor_shape=tf.TensorShape([1, 8, 8, 3]).as_proto())
|
||||
sig_output_tensor_signature = {'y': sig_output_tensor}
|
||||
predict_signature_def = (
|
||||
tf.compat.v1.saved_model.build_signature_def(
|
||||
sig_input_tensor_signature, sig_output_tensor_signature,
|
||||
tf.saved_model.PREDICT_METHOD_NAME))
|
||||
serving_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
signature_def_map = {serving_key: predict_signature_def}
|
||||
tflite_model = signature_def_utils.set_signature_defs(
|
||||
tflite_model, signature_def_map)
|
||||
saved_signature_def_map = signature_def_utils.get_signature_defs(
|
||||
tflite_model)
|
||||
signature_def = saved_signature_def_map.get(serving_key)
|
||||
self.assertIsNotNone(signature_def, 'SignatureDef not found')
|
||||
self.assertEqual(signature_def.SerializeToString(),
|
||||
predict_signature_def.SerializeToString())
|
||||
remove_tflite_model = (
|
||||
signature_def_utils.clear_signature_defs(tflite_model))
|
||||
signature_def_map = signature_def_utils.get_signature_defs(
|
||||
remove_tflite_model)
|
||||
self.assertIsNone(signature_def_map.get(serving_key),
|
||||
'SignatureDef found, but should be missing')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in New Issue
Block a user