Adds utility methods for storing SignatureDefs in the metadata table in the flatbuffer

PiperOrigin-RevId: 311652937
Change-Id: I397c7ce6fad843cff789dedb583d6df44545db3f
This commit is contained in:
David Rim 2020-05-14 19:19:49 -07:00 committed by TensorFlower Gardener
parent efa3fb28d9
commit 37df93331e
7 changed files with 785 additions and 0 deletions

View File

@ -0,0 +1,106 @@
# Utilities for signature_defs in TFLite
load("//tensorflow:tensorflow.bzl", "pybind_extension")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
load("//tensorflow/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/lite/micro:build_def.bzl", "cc_library")
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
TFLITE_DEFAULT_COPTS = if_not_windows([
"-Wall",
"-Wno-comment",
"-Wno-extern-c-compat",
])
cc_library(
name = "signature_def_util",
srcs = ["signature_def_util.cc"],
hdrs = ["signature_def_util.h"],
copts = TFLITE_DEFAULT_COPTS + tflite_copts(),
deps = [
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:protos_all_cc_impl",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/memory",
"@com_google_protobuf//:protobuf",
"@flatbuffers",
],
)
cc_test(
name = "signature_def_util_test",
size = "small",
srcs = ["signature_def_util_test.cc"],
data = [
"//tensorflow/lite:testdata/add.bin",
],
tags = [
"tflite_not_portable",
],
deps = [
":signature_def_util",
"//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/core/platform:errors",
"//tensorflow/lite:framework_lib",
"//tensorflow/lite/c:c_api",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
pybind_extension(
name = "_pywrap_signature_def_util_wrapper",
srcs = [
"signature_def_util_wrapper_pybind11.cc",
],
module_name = "_pywrap_signature_def_util_wrapper",
deps = [
":signature_def_util",
"//tensorflow/lite:framework_lib",
"//tensorflow/python:pybind11_lib",
"@pybind11",
],
)
py_library(
name = "signature_def_utils",
srcs = ["signature_def_utils.py"],
srcs_version = "PY2AND3",
deps = [
":_pywrap_signature_def_util_wrapper",
"//tensorflow/core:protos_all_py",
],
)
py_test(
name = "signature_def_utils_test",
srcs = ["signature_def_utils_test.py"],
data = ["//tensorflow/lite:testdata/add.bin"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_mac",
],
visibility = ["//visibility:public"],
deps = [
":signature_def_utils",
"//tensorflow:tensorflow_py",
"//tensorflow/core:protos_all_py",
],
)
tflite_portable_test_suite()

View File

@ -0,0 +1,175 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/signature/signature_def_util.h"
#include <string>
#include "absl/memory/memory.h"
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace {
using tensorflow::Status;
using SerializedSignatureDefMap = std::map<std::string, std::string>;
using SignatureDefMap = std::map<std::string, tensorflow::SignatureDef>;
const Metadata* GetSignatureDefMetadata(const Model* model) {
if (!model || !model->metadata()) {
return nullptr;
}
for (int i = 0; i < model->metadata()->size(); ++i) {
const Metadata* metadata = model->metadata()->Get(i);
if (metadata->name()->str() == kSignatureDefsMetadataName) {
return metadata;
}
}
return nullptr;
}
Status ReadSignatureDefMap(const Model* model, const Metadata* metadata,
SerializedSignatureDefMap* map) {
if (!model || !metadata || !map) {
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
}
const flatbuffers::Vector<uint8_t>* flatbuffer_data =
model->buffers()->Get(metadata->buffer())->data();
const auto signature_defs =
flexbuffers::GetRoot(flatbuffer_data->data(), flatbuffer_data->size())
.AsMap();
for (int i = 0; i < signature_defs.Keys().size(); ++i) {
const std::string key = signature_defs.Keys()[i].AsString().c_str();
(*map)[key] = signature_defs[key].AsString().c_str();
}
return tensorflow::Status::OK();
}
} // namespace
Status SetSignatureDefMap(const Model* model,
const SignatureDefMap& signature_def_map,
std::string* model_data_with_signature_def) {
if (!model || !model_data_with_signature_def) {
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
}
if (signature_def_map.empty()) {
return tensorflow::errors::InvalidArgument(
"signature_def_map should not be empty");
}
flexbuffers::Builder fbb;
const size_t start_map = fbb.StartMap();
auto mutable_model = absl::make_unique<ModelT>();
model->UnPackTo(mutable_model.get(), nullptr);
int buffer_id = mutable_model->buffers.size();
const Metadata* metadata = GetSignatureDefMetadata(model);
if (metadata) {
buffer_id = metadata->buffer();
} else {
auto buffer = absl::make_unique<BufferT>();
mutable_model->buffers.emplace_back(std::move(buffer));
auto sigdef_metadata = absl::make_unique<MetadataT>();
sigdef_metadata->buffer = buffer_id;
sigdef_metadata->name = kSignatureDefsMetadataName;
mutable_model->metadata.emplace_back(std::move(sigdef_metadata));
}
for (const auto& entry : signature_def_map) {
fbb.String(entry.first.c_str(), entry.second.SerializeAsString());
}
fbb.EndMap(start_map);
fbb.Finish();
mutable_model->buffers[buffer_id]->data = fbb.GetBuffer();
flatbuffers::FlatBufferBuilder builder;
auto packed_model = Model::Pack(builder, mutable_model.get());
FinishModelBuffer(builder, packed_model);
*model_data_with_signature_def =
std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
return Status::OK();
}
bool HasSignatureDef(const Model* model, const std::string& signature_key) {
if (!model) {
return false;
}
const Metadata* metadata = GetSignatureDefMetadata(model);
if (!metadata) {
return false;
}
SerializedSignatureDefMap signature_defs;
if (ReadSignatureDefMap(model, metadata, &signature_defs) !=
tensorflow::Status::OK()) {
return false;
}
return (signature_defs.find(signature_key) != signature_defs.end());
}
Status GetSignatureDefMap(const Model* model,
SignatureDefMap* signature_def_map) {
if (!model || !signature_def_map) {
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
}
SignatureDefMap retrieved_signature_def_map;
const Metadata* metadata = GetSignatureDefMetadata(model);
if (metadata) {
SerializedSignatureDefMap signature_defs;
auto status = ReadSignatureDefMap(model, metadata, &signature_defs);
if (status != tensorflow::Status::OK()) {
return tensorflow::errors::Internal("Error reading signature def map: %s",
status.error_message());
}
for (const auto& entry : signature_defs) {
tensorflow::SignatureDef signature_def;
if (!signature_def.ParseFromString(entry.second)) {
return tensorflow::errors::Internal(
"Cannot parse signature def found in flatbuffer.");
}
retrieved_signature_def_map[entry.first] = signature_def;
}
*signature_def_map = retrieved_signature_def_map;
}
return Status::OK();
}
Status ClearSignatureDefMap(const Model* model, std::string* model_data) {
if (!model || !model_data) {
return tensorflow::errors::InvalidArgument("Arguments must not be nullptr");
}
auto mutable_model = absl::make_unique<ModelT>();
model->UnPackTo(mutable_model.get(), nullptr);
for (int id = 0; id < model->metadata()->size(); ++id) {
const Metadata* metadata = model->metadata()->Get(id);
if (metadata->name()->str() == kSignatureDefsMetadataName) {
auto* buffers = &(mutable_model->buffers);
buffers->erase(buffers->begin() + metadata->buffer());
mutable_model->metadata.erase(mutable_model->metadata.begin() + id);
break;
}
}
flatbuffers::FlatBufferBuilder builder;
auto packed_model = Model::Pack(builder, mutable_model.get());
FinishModelBuffer(builder, packed_model);
*model_data =
std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
return Status::OK();
}
} // namespace tflite

View File

@ -0,0 +1,71 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_SIGNATURE_DEF_UTIL_H_
#define TENSORFLOW_LITE_TOOLS_SIGNATURE_DEF_UTIL_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
// Constant for name of the Metadata entry associated with SignatureDefs.
constexpr char kSignatureDefsMetadataName[] = "signature_defs_metadata";
// The function `SetSignatureDefMap()` results in
// `model_data_with_signature_defs` containing a serialized TFLite model
// identical to `model` with a metadata and associated buffer containing
// a FlexBuffer::Map with `signature_def_map` keys and values serialized to
// String.
//
// If a Metadata entry containing a SignatureDef map exists, it will be
// overwritten.
//
// Returns error if `model_data_with_signature_defs` is null or
// `signature_def_map` is empty.
//
// On success, returns tensorflow::Status::OK() or error otherwise.
// On error, `model_data_with_signature_defs` is unchanged.
tensorflow::Status SetSignatureDefMap(
const Model* model,
const std::map<std::string, tensorflow::SignatureDef>& signature_def_map,
std::string* model_data_with_signature_defs);
// The function `HasSignatureDef()` returns true if `model` contains a Metadata
// table pointing to a buffer containing a FlexBuffer::Map and the map has
// `signature_key` as a key, or false otherwise.
bool HasSignatureDef(const Model* model, const std::string& signature_key);
// The function `GetSignatureDefMap()` results in `signature_def_map`
// pointing to a map<std::string, tensorflow::SignatureDef>
// parsed from `model`'s metadata buffer.
//
// If the Metadata entry does not exist, `signature_def_map` is unchanged.
// If the Metadata entry exists but cannot be parsed, returns an error.
tensorflow::Status GetSignatureDefMap(
const Model* model,
std::map<std::string, tensorflow::SignatureDef>* signature_def_map);
// The function `ClearSignatureDefs` results in `model_data`
// containing a serialized Model identical to `model` omitting any
// SignatureDef-related metadata or buffers.
tensorflow::Status ClearSignatureDefMap(const Model* model,
std::string* model_data);
} // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_SIGNATURE_DEF_UTIL_H_

View File

@ -0,0 +1,167 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/signature/signature_def_util.h"
#include <gtest/gtest.h>
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace {
using tensorflow::kClassifyMethodName;
using tensorflow::kDefaultServingSignatureDefKey;
using tensorflow::kPredictMethodName;
using tensorflow::SignatureDef;
using tensorflow::Status;
constexpr char kSignatureInput[] = "input";
constexpr char kSignatureOutput[] = "output";
constexpr char kTestFilePath[] = "tensorflow/lite/testdata/add.bin";
class SimpleSignatureDefUtilTest : public testing::Test {
protected:
void SetUp() override {
flatbuffer_model_ = FlatBufferModel::BuildFromFile(kTestFilePath);
ASSERT_NE(flatbuffer_model_, nullptr);
model_ = flatbuffer_model_->GetModel();
ASSERT_NE(model_, nullptr);
}
SignatureDef GetTestSignatureDef() {
auto signature_def = SignatureDef();
tensorflow::TensorInfo input_tensor;
tensorflow::TensorInfo output_tensor;
*input_tensor.mutable_name() = kSignatureInput;
*output_tensor.mutable_name() = kSignatureOutput;
*signature_def.mutable_method_name() = kClassifyMethodName;
(*signature_def.mutable_inputs())[kSignatureInput] = input_tensor;
(*signature_def.mutable_outputs())[kSignatureOutput] = output_tensor;
return signature_def;
}
std::unique_ptr<FlatBufferModel> flatbuffer_model_;
const Model* model_;
};
TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefTest) {
SignatureDef expected_signature_def = GetTestSignatureDef();
std::string model_output;
const std::map<string, SignatureDef> expected_signature_def_map = {
{kDefaultServingSignatureDefKey, expected_signature_def}};
EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
&model_output));
const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
std::map<string, SignatureDef> test_signature_def_map;
EXPECT_EQ(Status::OK(),
GetSignatureDefMap(add_model, &test_signature_def_map));
SignatureDef test_signature_def =
test_signature_def_map[kDefaultServingSignatureDefKey];
EXPECT_EQ(expected_signature_def.SerializeAsString(),
test_signature_def.SerializeAsString());
}
TEST_F(SimpleSignatureDefUtilTest, OverwriteSignatureDefTest) {
auto expected_signature_def = GetTestSignatureDef();
std::string model_output;
std::map<string, SignatureDef> expected_signature_def_map = {
{kDefaultServingSignatureDefKey, expected_signature_def}};
EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
&model_output));
const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
std::map<string, SignatureDef> test_signature_def_map;
EXPECT_EQ(Status::OK(),
GetSignatureDefMap(add_model, &test_signature_def_map));
SignatureDef test_signature_def =
test_signature_def_map[kDefaultServingSignatureDefKey];
EXPECT_EQ(expected_signature_def.SerializeAsString(),
test_signature_def.SerializeAsString());
*expected_signature_def.mutable_method_name() = kPredictMethodName;
expected_signature_def_map.erase(
expected_signature_def_map.find(kDefaultServingSignatureDefKey));
constexpr char kTestSignatureDefKey[] = "ServingTest";
expected_signature_def_map[kTestSignatureDefKey] = expected_signature_def;
EXPECT_EQ(
Status::OK(),
SetSignatureDefMap(add_model, expected_signature_def_map, &model_output));
const Model* final_model = flatbuffers::GetRoot<Model>(model_output.data());
EXPECT_FALSE(HasSignatureDef(final_model, kDefaultServingSignatureDefKey));
EXPECT_EQ(Status::OK(),
GetSignatureDefMap(final_model, &test_signature_def_map));
EXPECT_NE(expected_signature_def.SerializeAsString(),
test_signature_def.SerializeAsString());
EXPECT_TRUE(HasSignatureDef(final_model, kTestSignatureDefKey));
EXPECT_EQ(Status::OK(),
GetSignatureDefMap(final_model, &test_signature_def_map));
test_signature_def = test_signature_def_map[kTestSignatureDefKey];
EXPECT_EQ(expected_signature_def.SerializeAsString(),
test_signature_def.SerializeAsString());
}
TEST_F(SimpleSignatureDefUtilTest, GetSignatureDefTest) {
std::map<string, SignatureDef> test_signature_def_map;
EXPECT_EQ(Status::OK(), GetSignatureDefMap(model_, &test_signature_def_map));
EXPECT_FALSE(HasSignatureDef(model_, kDefaultServingSignatureDefKey));
}
TEST_F(SimpleSignatureDefUtilTest, ClearSignatureDefTest) {
const int expected_num_buffers = model_->buffers()->size();
auto expected_signature_def = GetTestSignatureDef();
std::string model_output;
std::map<string, SignatureDef> expected_signature_def_map = {
{kDefaultServingSignatureDefKey, expected_signature_def}};
EXPECT_EQ(Status::OK(), SetSignatureDefMap(model_, expected_signature_def_map,
&model_output));
const Model* add_model = flatbuffers::GetRoot<Model>(model_output.data());
EXPECT_TRUE(HasSignatureDef(add_model, kDefaultServingSignatureDefKey));
SignatureDef test_signature_def;
std::map<string, SignatureDef> test_signature_def_map;
EXPECT_EQ(Status::OK(),
GetSignatureDefMap(add_model, &test_signature_def_map));
test_signature_def = test_signature_def_map[kDefaultServingSignatureDefKey];
EXPECT_EQ(expected_signature_def.SerializeAsString(),
test_signature_def.SerializeAsString());
EXPECT_EQ(Status::OK(), ClearSignatureDefMap(add_model, &model_output));
const Model* clear_model = flatbuffers::GetRoot<Model>(model_output.data());
EXPECT_FALSE(HasSignatureDef(clear_model, kDefaultServingSignatureDefKey));
EXPECT_EQ(expected_num_buffers, clear_model->buffers()->size());
}
TEST_F(SimpleSignatureDefUtilTest, SetSignatureDefErrorsTest) {
std::map<string, SignatureDef> test_signature_def_map;
std::string model_output;
EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
SetSignatureDefMap(model_, test_signature_def_map, &model_output)));
SignatureDef test_signature_def;
test_signature_def_map[kDefaultServingSignatureDefKey] = test_signature_def;
EXPECT_TRUE(tensorflow::errors::IsInvalidArgument(
SetSignatureDefMap(model_, test_signature_def_map, nullptr)));
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,95 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/tools/signature/signature_def_util.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
py::bytes WrappedSetSignatureDefMap(
const std::vector<uint8_t>& model_buffer,
const std::map<std::string, std::string>& serialized_signature_def_map) {
auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer(
reinterpret_cast<const char*>(model_buffer.data()), model_buffer.size());
auto* model = flatbuffer_model->GetModel();
if (!model) {
throw std::invalid_argument("Invalid model");
}
std::string data;
std::map<std::string, tensorflow::SignatureDef> signature_def_map;
for (const auto& entry : serialized_signature_def_map) {
tensorflow::SignatureDef signature_def;
if (!signature_def.ParseFromString(entry.second)) {
throw std::invalid_argument("Cannot parse signature def");
}
signature_def_map[entry.first] = signature_def;
}
auto status = tflite::SetSignatureDefMap(model, signature_def_map, &data);
if (status != tensorflow::Status::OK()) {
throw std::invalid_argument(status.error_message());
}
return py::bytes(data);
}
std::map<std::string, py::bytes> WrappedGetSignatureDefMap(
const std::vector<uint8_t>& model_buffer) {
auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer(
reinterpret_cast<const char*>(model_buffer.data()), model_buffer.size());
auto* model = flatbuffer_model->GetModel();
if (!model) {
throw std::invalid_argument("Invalid model");
}
std::string content;
std::map<std::string, tensorflow::SignatureDef> signature_def_map;
auto status = tflite::GetSignatureDefMap(model, &signature_def_map);
if (status != tensorflow::Status::OK()) {
throw std::invalid_argument("Cannot parse signature def");
}
std::map<std::string, py::bytes> serialized_signature_def_map;
for (const auto& entry : signature_def_map) {
serialized_signature_def_map[entry.first] =
py::bytes(entry.second.SerializeAsString());
}
return serialized_signature_def_map;
}
py::bytes WrappedClearSignatureDefs(const std::vector<uint8_t>& model_buffer) {
auto flatbuffer_model = tflite::FlatBufferModel::BuildFromBuffer(
reinterpret_cast<const char*>(model_buffer.data()), model_buffer.size());
auto* model = flatbuffer_model->GetModel();
if (!model) {
throw std::invalid_argument("Invalid model");
}
std::string content;
auto status = tflite::ClearSignatureDefMap(model, &content);
if (status != tensorflow::Status::OK()) {
throw std::invalid_argument("An unknown error occurred");
}
return py::bytes(content);
}
PYBIND11_MODULE(_pywrap_signature_def_util_wrapper, m) {
m.doc() = R"pbdoc(
_pywrap_signature_def_util_wrapper
-----
)pbdoc";
m.def("SetSignatureDefMap", &WrappedSetSignatureDefMap);
m.def("GetSignatureDefMap", &WrappedGetSignatureDefMap);
m.def("ClearSignatureDefs", &WrappedClearSignatureDefs);
}

View File

@ -0,0 +1,95 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions related to SignatureDefs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.lite.tools.signature import _pywrap_signature_def_util_wrapper as signature_def_util
def set_signature_defs(tflite_model, signature_def_map):
"""Sets SignatureDefs to the Metadata of a TfLite flatbuffer buffer.
Args:
tflite_model: Binary TFLite model (bytes or bytes-like object) to which to
add signature_def.
signature_def_map: dict containing SignatureDefs to store in metadata.
Returns:
buffer: A TFLite model binary identical to model buffer with
metadata field containing SignatureDef.
Raises:
ValueError:
tflite_model buffer does not contain a valid TFLite model.
signature_def_map is empty or does not contain a SignatureDef.
"""
model = tflite_model
if not isinstance(tflite_model, bytearray):
model = bytearray(tflite_model)
serialized_signature_def_map = {
k: v.SerializeToString() for k, v in signature_def_map.items()}
model_buffer = signature_def_util.SetSignatureDefMap(
model, serialized_signature_def_map)
return model_buffer
def get_signature_defs(tflite_model):
"""Get SignatureDef dict from the Metadata of a TfLite flatbuffer buffer.
Args:
tflite_model: TFLite model buffer to get the signature_def.
Returns:
dict containing serving names to SignatureDefs if exists, otherwise, empty
dict.
Raises:
ValueError:
tflite_model buffer does not contain a valid TFLite model.
DecodeError:
SignatureDef cannot be parsed from TfLite SignatureDef metadata.
"""
model = tflite_model
if not isinstance(tflite_model, bytearray):
model = bytearray(tflite_model)
serialized_signature_def_map = signature_def_util.GetSignatureDefMap(model)
def _deserialize(serialized):
signature_def = meta_graph_pb2.SignatureDef()
signature_def.ParseFromString(serialized)
return signature_def
return {k: _deserialize(v) for k, v in serialized_signature_def_map.items()}
def clear_signature_defs(tflite_model):
"""Clears SignatureDefs from the Metadata of a TfLite flatbuffer buffer.
Args:
tflite_model: TFLite model buffer to remove signature_defs.
Returns:
buffer: A TFLite model binary identical to model buffer with
no SignatureDef metadata.
Raises:
ValueError:
tflite_model buffer does not contain a valid TFLite model.
"""
model = tflite_model
if not isinstance(tflite_model, bytearray):
model = bytearray(tflite_model)
return signature_def_util.ClearSignatureDefs(model)

View File

@ -0,0 +1,76 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for signature_def_util.py.
- Tests adding a SignatureDef to TFLite metadata.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.lite.tools.signature import signature_def_utils
class SignatureDefUtilsTest(tf.test.TestCase):
def testAddSignatureDefToFlatbufferMetadata(self):
"""Test a SavedModel conversion has correct Metadata."""
filename = tf.compat.v1.resource_loader.get_path_to_datafile(
'../../testdata/add.bin')
if not os.path.exists(filename):
raise IOError('File "{0}" does not exist in {1}.'.format(
filename,
tf.compat.v1.resource_loader.get_root_dir_with_all_resources()))
with tf.io.gfile.GFile(filename, 'rb') as fp:
tflite_model = bytearray(fp.read())
self.assertIsNotNone(tflite_model, 'TFLite model is none')
sig_input_tensor = meta_graph_pb2.TensorInfo(
dtype=tf.as_dtype(tf.float32).as_datatype_enum,
tensor_shape=tf.TensorShape([1, 8, 8, 3]).as_proto())
sig_input_tensor_signature = {'x': sig_input_tensor}
sig_output_tensor = meta_graph_pb2.TensorInfo(
dtype=tf.as_dtype(tf.float32).as_datatype_enum,
tensor_shape=tf.TensorShape([1, 8, 8, 3]).as_proto())
sig_output_tensor_signature = {'y': sig_output_tensor}
predict_signature_def = (
tf.compat.v1.saved_model.build_signature_def(
sig_input_tensor_signature, sig_output_tensor_signature,
tf.saved_model.PREDICT_METHOD_NAME))
serving_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
signature_def_map = {serving_key: predict_signature_def}
tflite_model = signature_def_utils.set_signature_defs(
tflite_model, signature_def_map)
saved_signature_def_map = signature_def_utils.get_signature_defs(
tflite_model)
signature_def = saved_signature_def_map.get(serving_key)
self.assertIsNotNone(signature_def, 'SignatureDef not found')
self.assertEqual(signature_def.SerializeToString(),
predict_signature_def.SerializeToString())
remove_tflite_model = (
signature_def_utils.clear_signature_defs(tflite_model))
signature_def_map = signature_def_utils.get_signature_defs(
remove_tflite_model)
self.assertIsNone(signature_def_map.get(serving_key),
'SignatureDef found, but should be missing')
if __name__ == '__main__':
tf.test.main()