Move metadata, codegen and java lib into the new repo.
PiperOrigin-RevId: 319147352 Change-Id: I79ab15ccebe9d50c62952c535746c6639883fc3a
This commit is contained in:
parent
3be438aca2
commit
e9695a20ee
tensorflow/lite/experimental/support
README.md
codegen
BUILDREADME.mdandroid_java_generator.ccandroid_java_generator.hcode_generator.cccode_generator.hcode_generator_test.ccmetadata_helper.ccmetadata_helper.h
python
utils.ccutils.hutils_test.ccjava
AndroidManifest.xmlBUILDREADME.md
src/java/org/tensorflow/lite/support
metadata
BUILDREADME.mdbuild_defs.bzl
cc
flatbuffers_lib
java
metadata.pymetadata_parser.py.templatemetadata_parser_test.pymetadata_schema.fbsmetadata_test.pytestdata
5
tensorflow/lite/experimental/support/README.md
Normal file
5
tensorflow/lite/experimental/support/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# TensorFlow Lite Support
|
||||
|
||||
The TensorFlow Lite Support project has been migrated to its own repo. Please
|
||||
checkout [TFLite Support](https://github.com/tensorflow/tflite-support) for the
|
||||
latest updates.
|
@ -1,87 +0,0 @@
|
||||
# The tools for generating wrapper classes for a TFLite model with metadata.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
srcs = [
|
||||
"utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils.h",
|
||||
],
|
||||
deps = [
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "code_generator",
|
||||
srcs = [
|
||||
"code_generator.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"code_generator.h",
|
||||
],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "metadata_helper",
|
||||
srcs = [
|
||||
"metadata_helper.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"metadata_helper.h",
|
||||
],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "android_java_generator",
|
||||
srcs = [
|
||||
"android_java_generator.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"android_java_generator.h",
|
||||
],
|
||||
deps = [
|
||||
":code_generator",
|
||||
":metadata_helper",
|
||||
":utils",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "code_generator_test",
|
||||
size = "small",
|
||||
srcs = ["code_generator_test.cc"],
|
||||
data = ["//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs"],
|
||||
deps = [
|
||||
":code_generator",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "utils_test",
|
||||
srcs = ["utils_test.cc"],
|
||||
deps = [
|
||||
":utils",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
@ -1,13 +0,0 @@
|
||||
# TensorFlow Lite Android Wrapper Code Generator
|
||||
|
||||
For TensorFlow Lite model enhanced with [metadata](https://www.tensorflow.org/lite/convert/metadata.md),
|
||||
developers can use the TensorFlow Lite Android wrapper code generator to create
|
||||
platform specific wrapper code. The wrapper code removes the need to interact
|
||||
directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow
|
||||
Lite model with typed objects such as `Bitmap` and `Rect`.
|
||||
|
||||
The usefulness of the code generator depend on the completeness of the
|
||||
TensorFlow Lite model's metadata entry. Refer to the `<Codegen usage>` section
|
||||
under relevant fields in
|
||||
[metadata_schema.fbs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs),
|
||||
to see how the codegen tool parses each field.
|
File diff suppressed because it is too large
Load Diff
@ -1,116 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
namespace details_android_java {
|
||||
|
||||
/// The intermediate data structure for generating code from TensorMetadata.
|
||||
/// Should only be used as const reference when created.
|
||||
struct TensorInfo {
|
||||
std::string name;
|
||||
std::string upper_camel_name;
|
||||
std::string content_type;
|
||||
std::string wrapper_type;
|
||||
std::string processor_type;
|
||||
bool is_input;
|
||||
/// Optional. Set to -1 if not applicable.
|
||||
int normalization_unit;
|
||||
/// Optional. Set to -1 if associated_axis_label is empty.
|
||||
int associated_axis_label_index;
|
||||
/// Optional. Set to -1 if associated_value_label is empty.
|
||||
int associated_value_label_index;
|
||||
};
|
||||
|
||||
/// The intermediate data structure for generating code from ModelMetadata.
|
||||
/// Should only be used as const reference when created.
|
||||
struct ModelInfo {
|
||||
std::string package_name;
|
||||
std::string model_asset_path;
|
||||
std::string model_class_name;
|
||||
std::string model_versioned_name;
|
||||
std::vector<TensorInfo> inputs;
|
||||
std::vector<TensorInfo> outputs;
|
||||
// Extra helper fields. For models with inputs "a", "b" and outputs "x", "y":
|
||||
std::string input_type_param_list;
|
||||
// e.g. "TensorImage a, TensorBuffer b"
|
||||
std::string inputs_list;
|
||||
// e.g. "a, b"
|
||||
std::string postprocessor_type_param_list;
|
||||
// e.g. "ImageProcessor xPostprocessor, TensorProcessor yPostprocessor"
|
||||
std::string postprocessors_list;
|
||||
// e.g. "xPostprocessor, yPostprocessor"
|
||||
};
|
||||
|
||||
} // namespace details_android_java
|
||||
|
||||
constexpr char JAVA_EXT[] = ".java";
|
||||
|
||||
/// Generates Android supporting codes and modules (in Java) based on TFLite
|
||||
/// metadata.
|
||||
class AndroidJavaGenerator : public CodeGenerator {
|
||||
public:
|
||||
/// Creates an AndroidJavaGenerator.
|
||||
/// Args:
|
||||
/// - module_root: The root of destination Java module.
|
||||
explicit AndroidJavaGenerator(const std::string& module_root);
|
||||
|
||||
/// Generates files. Returns the file paths and contents.
|
||||
/// Args:
|
||||
/// - model: The TFLite model with Metadata filled.
|
||||
/// - package_name: The name of the Java package which generated classes
|
||||
/// belong to.
|
||||
/// - model_class_name: A readable name of the generated wrapper class, such
|
||||
/// as "ImageClassifier", "MobileNetV2" or "MyModel".
|
||||
/// - model_asset_path: The relevant path to the model file in the asset.
|
||||
// TODO(b/141225157): Automatically generate model_class_name.
|
||||
GenerationResult Generate(const Model* model, const std::string& package_name,
|
||||
const std::string& model_class_name,
|
||||
const std::string& model_asset_path);
|
||||
|
||||
/// Generates files and returns the file paths and contents.
|
||||
/// It's mostly identical with the previous one, but the model here is
|
||||
/// provided as binary flatbuffer content without parsing.
|
||||
GenerationResult Generate(const char* model_storage,
|
||||
const std::string& package_name,
|
||||
const std::string& model_class_name,
|
||||
const std::string& model_asset_path);
|
||||
|
||||
std::string GetErrorMessage();
|
||||
|
||||
private:
|
||||
const std::string module_root_;
|
||||
ErrorReporter err_;
|
||||
};
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
|
@ -1,179 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||
|
||||
#include <cctype>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
namespace {
|
||||
|
||||
void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) {
|
||||
auto& names = *names_ptr;
|
||||
std::unordered_map<std::string, int> indexes;
|
||||
std::unordered_map<std::string, int> first_appearance;
|
||||
for (int i = 0; i < names.size(); i++) {
|
||||
if (indexes.find(names[i]) == indexes.end()) {
|
||||
indexes[names[i]] = 1;
|
||||
first_appearance[names[i]] = i;
|
||||
} else {
|
||||
indexes[names[i]] += 1;
|
||||
names[i].append(std::to_string(indexes[names[i]]));
|
||||
}
|
||||
}
|
||||
for (const auto& it : first_appearance) {
|
||||
const auto& name = it.first;
|
||||
const auto i = it.second;
|
||||
if (indexes[name] > 1) {
|
||||
names[i].append("1");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CodeGenerator::CodeGenerator() {}
|
||||
|
||||
bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata,
|
||||
ErrorReporter* err) {
|
||||
if (metadata == nullptr) {
|
||||
err->Error("Loading nullptr is not allowed");
|
||||
return false;
|
||||
}
|
||||
if (metadata->subgraph_metadata()->size() != 1) {
|
||||
err->Error("Only exact 1 subgraph is supported");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<std::vector<std::string>, std::vector<std::string>>
|
||||
CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs,
|
||||
const TensorMetadataList* outputs) {
|
||||
std::vector<std::string> input_names;
|
||||
std::vector<std::string> output_names;
|
||||
if (inputs != nullptr) {
|
||||
input_names.reserve(inputs->size());
|
||||
for (const auto* tensor : *inputs) {
|
||||
input_names.push_back(NameTensor(*tensor, "input"));
|
||||
}
|
||||
}
|
||||
if (outputs != nullptr) {
|
||||
output_names.reserve(outputs->size());
|
||||
for (const auto* tensor : *outputs) {
|
||||
output_names.push_back(NameTensor(*tensor, "output"));
|
||||
}
|
||||
}
|
||||
// Solve conflict
|
||||
ResolveConflictedInputAndOutputNames(&input_names, &output_names);
|
||||
return std::make_pair(input_names, output_names);
|
||||
}
|
||||
|
||||
std::string CodeGenerator::ConvertToValidName(const std::string& name) {
|
||||
// lowercase all
|
||||
std::string result = name;
|
||||
for (int i = 0; i < result.size(); i++) {
|
||||
result[i] = std::tolower(result[i]);
|
||||
}
|
||||
// replace all non-alpha or non-numeric with underscores, except underscore
|
||||
// itself
|
||||
for (int i = 0; i < result.size(); i++) {
|
||||
if (result[i] != '_' && !std::isalnum(result[i])) {
|
||||
result[i] = '_';
|
||||
}
|
||||
}
|
||||
// remove leading underscores
|
||||
int leading_underscores = 0;
|
||||
while (leading_underscores < result.size() &&
|
||||
result[leading_underscores] == '_') {
|
||||
leading_underscores++;
|
||||
}
|
||||
result.erase(0, leading_underscores);
|
||||
if (result.empty()) {
|
||||
return "";
|
||||
}
|
||||
// first char should be alpha
|
||||
if (std::isalpha(result[0])) {
|
||||
return result;
|
||||
}
|
||||
return "tensor_" + result;
|
||||
}
|
||||
|
||||
std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
|
||||
const std::string& default_name) {
|
||||
if (tensor.name() != nullptr && tensor.name()->size() > 0) {
|
||||
// TODO(b/141225157) Validate tensor name. It should be in lower case.
|
||||
auto suggested_name = ConvertToValidName(tensor.name()->str());
|
||||
if (!suggested_name.empty()) {
|
||||
return suggested_name;
|
||||
}
|
||||
}
|
||||
auto* content = tensor.content();
|
||||
if (content == nullptr || content->content_properties() == nullptr) {
|
||||
return default_name;
|
||||
}
|
||||
switch (content->content_properties_type()) {
|
||||
case ContentProperties_ImageProperties:
|
||||
return "image";
|
||||
case ContentProperties_FeatureProperties:
|
||||
return "feature";
|
||||
default:
|
||||
return default_name;
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenerator::ResolveConflictedInputAndOutputNames(
|
||||
std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
|
||||
std::unordered_set<std::string> io_conflict;
|
||||
auto& input_names = *inputs;
|
||||
auto& output_names = *outputs;
|
||||
for (const auto& input : input_names) {
|
||||
if (io_conflict.find(input) != io_conflict.end()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& output : output_names) {
|
||||
if (input == output) {
|
||||
io_conflict.insert(input);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < input_names.size(); i++) {
|
||||
if (io_conflict.find(input_names[i]) != io_conflict.end()) {
|
||||
input_names[i] = "input_" + input_names[i];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < output_names.size(); i++) {
|
||||
if (io_conflict.find(output_names[i]) != io_conflict.end()) {
|
||||
output_names[i] = "output_" + output_names[i];
|
||||
}
|
||||
}
|
||||
// 2. Second, add index if input[i] == input[j]
|
||||
ResolveConflictedNamesByAddingIndex(&input_names);
|
||||
ResolveConflictedNamesByAddingIndex(&output_names);
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -1,80 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
struct GenerationResult {
|
||||
struct File {
|
||||
std::string path;
|
||||
std::string content;
|
||||
};
|
||||
std::vector<File> files;
|
||||
};
|
||||
|
||||
/// Defines language-independent codegen strategies, like class naming, .etc.
|
||||
/// Should not be used directly.
|
||||
class CodeGenerator {
|
||||
public:
|
||||
CodeGenerator();
|
||||
|
||||
using TensorMetadataList =
|
||||
typename flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>;
|
||||
|
||||
virtual ~CodeGenerator() {}
|
||||
|
||||
// Strategies.
|
||||
/// Names all the IO tensors. It's useful when they don't have names, or the
|
||||
/// names have conflicts. We have to name every tensor for code generation.
|
||||
// TODO(b/141225157): Add reserved keywords check.
|
||||
static std::pair<std::vector<std::string>, std::vector<std::string>>
|
||||
NameInputsAndOutputs(const TensorMetadataList* inputs,
|
||||
const TensorMetadataList* outputs);
|
||||
|
||||
/// Loads a metadata for code generation.
|
||||
/// Returns false if the metadata is not good for generation.
|
||||
static bool VerifyMetadata(const ModelMetadata* metadata, ErrorReporter* err);
|
||||
|
||||
protected:
|
||||
/// Converts a name into a valid form. Rules:
|
||||
/// - lower all letters.
|
||||
/// - replace all non alphabet nor numeric characters with underscores.
|
||||
/// - remove prefix underscores.
|
||||
/// - add prefix if the leading character is a number.
|
||||
/// Returns empty string if not possible.
|
||||
static std::string ConvertToValidName(const std::string& name);
|
||||
static std::string NameTensor(const TensorMetadata& tensor,
|
||||
const std::string& default_name);
|
||||
static void ResolveConflictedInputAndOutputNames(
|
||||
std::vector<std::string>* input, std::vector<std::string>* output);
|
||||
};
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
|
@ -1,126 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class CodeGeneratorTest : public ::testing::Test {
|
||||
public:
|
||||
class TestingCodeGenerator : public CodeGenerator {
|
||||
public:
|
||||
explicit TestingCodeGenerator() : CodeGenerator() {}
|
||||
|
||||
// Make tested method public.
|
||||
static std::string ConvertToValidName(const std::string& name) {
|
||||
return CodeGenerator::ConvertToValidName(name);
|
||||
}
|
||||
static void ResolveConflictedInputAndOutputNames(
|
||||
std::vector<std::string>* input, std::vector<std::string>* output) {
|
||||
CodeGenerator::ResolveConflictedInputAndOutputNames(input, output);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
TEST_F(CodeGeneratorTest, UpperCasesShouldLower) {
|
||||
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("AlphaBetCOOL"),
|
||||
"alphabetcool");
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, NonAlphaNumShouldReplace) {
|
||||
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("A+=B C\t"), "a__b_c_");
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, NoLeadingUnderscore) {
|
||||
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("+KAI Z"), "kai_z");
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, NoLeadingNumbers) {
|
||||
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("3000 Cool Tensors"),
|
||||
"tensor_3000_cool_tensors");
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestSimpleIONames) {
|
||||
std::vector<std::string> inputs = {"image"};
|
||||
std::vector<std::string> outputs = {"output"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs, ElementsAreArray({"image"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestIOConflict) {
|
||||
std::vector<std::string> inputs = {"image"};
|
||||
std::vector<std::string> outputs = {"image"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs, ElementsAreArray({"input_image"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestInternalConflict) {
|
||||
std::vector<std::string> inputs = {"image", "image"};
|
||||
std::vector<std::string> outputs = {"output"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs, ElementsAreArray({"image1", "image2"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestAllConflictNTo1) {
|
||||
std::vector<std::string> inputs = {"image", "image"};
|
||||
std::vector<std::string> outputs = {"image"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs, ElementsAreArray({"input_image1", "input_image2"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestAllConflict) {
|
||||
std::vector<std::string> inputs = {"image", "audio", "image", "audio",
|
||||
"audio"};
|
||||
std::vector<std::string> outputs = {"image", "image", "audio", "feature",
|
||||
"feature"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs,
|
||||
ElementsAreArray({"input_image1", "input_audio1", "input_image2",
|
||||
"input_audio2", "input_audio3"}));
|
||||
EXPECT_THAT(outputs,
|
||||
ElementsAreArray({"output_image1", "output_image2",
|
||||
"output_audio", "feature1", "feature2"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestAllConflictReversed) {
|
||||
std::vector<std::string> inputs = {"image", "image", "audio", "feature",
|
||||
"feature"};
|
||||
std::vector<std::string> outputs = {"image", "audio", "image", "audio",
|
||||
"audio"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs,
|
||||
ElementsAreArray({"input_image1", "input_image2", "input_audio",
|
||||
"feature1", "feature2"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output_image1", "output_audio1",
|
||||
"output_image2", "output_audio2",
|
||||
"output_audio3"}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -1,100 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h"
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
constexpr char BUFFER_KEY[] = "TFLITE_METADATA";
|
||||
const ModelMetadata* GetMetadataFromModel(const Model* model) {
|
||||
if (model == nullptr || model->metadata() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
for (auto i = 0; i < model->metadata()->size(); i++) {
|
||||
const auto* name = model->metadata()->Get(i)->name();
|
||||
if (name != nullptr && name->str() == BUFFER_KEY) {
|
||||
const auto buffer_index = model->metadata()->Get(i)->buffer();
|
||||
if (model->buffers() == nullptr ||
|
||||
model->buffers()->size() <= buffer_index) {
|
||||
continue;
|
||||
}
|
||||
const auto* buffer_vec = model->buffers()->Get(buffer_index)->data();
|
||||
if (buffer_vec == nullptr || buffer_vec->data() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
return GetModelMetadata(buffer_vec->data());
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int FindAssociatedFile(const TensorMetadata* metadata,
|
||||
const AssociatedFileType file_type,
|
||||
const std::string& tensor_identifier,
|
||||
ErrorReporter* err) {
|
||||
int result = -1;
|
||||
if (metadata->associated_files() == nullptr ||
|
||||
metadata->associated_files()->size() == 0) {
|
||||
return result;
|
||||
}
|
||||
for (int i = 0; i < metadata->associated_files()->size(); i++) {
|
||||
const auto* file_metadata = metadata->associated_files()->Get(i);
|
||||
if (file_metadata->type() == file_type) {
|
||||
if (result >= 0) {
|
||||
err->Warning(
|
||||
"Multiple associated file of type %d found on tensor %s. Only the "
|
||||
"first one will be used.",
|
||||
file_type, tensor_identifier.c_str());
|
||||
continue;
|
||||
}
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int FindNormalizationUnit(const TensorMetadata* metadata,
|
||||
const std::string& tensor_identifier,
|
||||
ErrorReporter* err) {
|
||||
int result = -1;
|
||||
if (metadata->process_units() == nullptr ||
|
||||
metadata->process_units()->size() == 0) {
|
||||
return result;
|
||||
}
|
||||
for (int i = 0; i < metadata->process_units()->size(); i++) {
|
||||
const auto* process_uint = metadata->process_units()->Get(i);
|
||||
if (process_uint->options_type() ==
|
||||
ProcessUnitOptions_NormalizationOptions) {
|
||||
if (result >= 0) {
|
||||
err->Warning(
|
||||
"Multiple normalization unit found in tensor %s. Only the first "
|
||||
"one will be effective.",
|
||||
tensor_identifier.c_str());
|
||||
continue;
|
||||
}
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -1,51 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
/// Parses a ModelMetadata out from a Model. The returned ModelMetadata's
|
||||
/// lifetime is scoped by the model.
|
||||
/// Returns nullptr if we cannot find any metadata.
|
||||
const ModelMetadata* GetMetadataFromModel(const Model* model);
|
||||
|
||||
/// Finds an associated file from a TensorMetadata of certain type. If there're
|
||||
/// multiple files meet the criteria, only the first one is used. If there's no
|
||||
/// file meets the criteria, -1 will be returned.
|
||||
int FindAssociatedFile(const TensorMetadata* metadata,
|
||||
const AssociatedFileType file_type,
|
||||
const std::string& tensor_identifier,
|
||||
ErrorReporter* err);
|
||||
|
||||
/// Find the first normalization unit. If none, return -1.
|
||||
int FindNormalizationUnit(const TensorMetadata* metadata,
|
||||
const std::string& tensor_identifier,
|
||||
ErrorReporter* err);
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
|
@ -1,38 +0,0 @@
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_codegen",
|
||||
srcs = [
|
||||
"codegen_lib.cc",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pywrap_codegen",
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/support/codegen:android_java_generator",
|
||||
"//tensorflow/lite/experimental/support/codegen:code_generator",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "codegen",
|
||||
srcs = [
|
||||
"codegen.py",
|
||||
],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":_pywrap_codegen",
|
||||
"@absl_py//absl:app",
|
||||
"@absl_py//absl/flags",
|
||||
"@absl_py//absl/logging",
|
||||
],
|
||||
)
|
@ -1,96 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Generates Android Java sources from a TFLite model with metadata."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.lite.experimental.support.codegen.python import _pywrap_codegen
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.')
|
||||
flags.DEFINE_string('destination', None, 'Path of destination of generation.')
|
||||
flags.DEFINE_string('package_name', 'org.tensorflow.lite.support',
|
||||
'Name of generated java package to put the wrapper class.')
|
||||
flags.DEFINE_string(
|
||||
'model_class_name', 'MyModel',
|
||||
'Name of generated wrapper class (should not contain package name).')
|
||||
flags.DEFINE_string(
|
||||
'model_asset_path', '',
|
||||
'(Optional) Path to the model in generated assets/ dir. If not set, '
|
||||
'generator will use base name of input model.'
|
||||
)
|
||||
|
||||
|
||||
def get_model_buffer(path):
|
||||
if not os.path.isfile(path):
|
||||
logging.error('Cannot find model at path %s.', path)
|
||||
with open(path, 'rb') as f:
|
||||
buf = f.read()
|
||||
return buf
|
||||
|
||||
|
||||
def prepare_directory_for_file(file_path):
|
||||
target_dir = os.path.dirname(file_path)
|
||||
if not os.path.exists(target_dir):
|
||||
os.makedirs(target_dir)
|
||||
return
|
||||
if not os.path.isdir(target_dir):
|
||||
logging.error('Cannot write to %s', target_dir)
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) > 1:
|
||||
logging.error('None flag arguments found: [%s]', ', '.join(argv[1:]))
|
||||
|
||||
codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination)
|
||||
model_buffer = get_model_buffer(FLAGS.model)
|
||||
model_asset_path = FLAGS.model_asset_path
|
||||
if not model_asset_path:
|
||||
model_asset_path = os.path.basename(FLAGS.model)
|
||||
result = codegen.generate(model_buffer, FLAGS.package_name,
|
||||
FLAGS.model_class_name, model_asset_path)
|
||||
error_message = codegen.get_error_message().strip()
|
||||
if error_message:
|
||||
logging.error(error_message)
|
||||
if not result.files:
|
||||
logging.error('Generation failed!')
|
||||
return
|
||||
|
||||
for each in result.files:
|
||||
prepare_directory_for_file(each.path)
|
||||
with open(each.path, 'w') as f:
|
||||
f.write(each.content)
|
||||
|
||||
logging.info('Generation succeeded!')
|
||||
model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets',
|
||||
model_asset_path)
|
||||
prepare_directory_for_file(model_asset_path)
|
||||
shutil.copy(FLAGS.model, model_asset_path)
|
||||
logging.info('Model copied into assets!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('model')
|
||||
flags.mark_flag_as_required('destination')
|
||||
app.run(main)
|
@ -1,49 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "pybind11/detail/common.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
|
||||
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
template <typename... Args>
|
||||
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_codegen, m) {
|
||||
pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator")
|
||||
.def(pybind11::init<const std::string &>())
|
||||
.def("generate",
|
||||
overload_cast_<const char *, const std::string &,
|
||||
const std::string &, const std::string &>()(
|
||||
&AndroidJavaGenerator::Generate))
|
||||
.def("get_error_message", &AndroidJavaGenerator::GetErrorMessage);
|
||||
pybind11::class_<GenerationResult>(m, "GenerationResult")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("files", &GenerationResult::files);
|
||||
pybind11::class_<GenerationResult::File>(m, "GenerationResultFile")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("path", &GenerationResult::File::path)
|
||||
.def_readwrite("content", &GenerationResult::File::content);
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -1,194 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
|
||||
#include <cstdarg>
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
int ErrorReporter::Warning(const char* format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
return Report("[WARN] ", format, args);
|
||||
}
|
||||
|
||||
int ErrorReporter::Error(const char* format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
return Report("[ERROR] ", format, args);
|
||||
}
|
||||
|
||||
int ErrorReporter::Report(const char* prefix, const char* format,
|
||||
va_list args) {
|
||||
char buf[1024];
|
||||
int formatted = vsnprintf(buf, sizeof(buf), format, args);
|
||||
buffer_ << prefix << buf << std::endl;
|
||||
return formatted;
|
||||
}
|
||||
|
||||
std::string ErrorReporter::GetMessage() {
|
||||
std::string value = buffer_.str();
|
||||
buffer_.str("");
|
||||
return value;
|
||||
}
|
||||
|
||||
CodeWriter::CodeWriter(ErrorReporter* err) : indent_(0), err_(err) {}
|
||||
|
||||
void CodeWriter::SetTokenValue(const std::string& token,
|
||||
const std::string& value) {
|
||||
value_map_[token] = value;
|
||||
}
|
||||
|
||||
const std::string CodeWriter::GetTokenValue(const std::string& token) const {
|
||||
auto iter = value_map_.find(token);
|
||||
if (iter == value_map_.end()) {
|
||||
// Typically only Code Generator's call this function (or `Append`). It's
|
||||
// their duty to make sure the token is valid, and requesting for an invalid
|
||||
// token implicits flaws in the code generation logic.
|
||||
err_->Error("Internal: Cannot find value with token '%s'", token.c_str());
|
||||
return "";
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void CodeWriter::SetIndentString(const std::string& indent_str) {
|
||||
indent_str_ = indent_str;
|
||||
}
|
||||
|
||||
void CodeWriter::Indent() { indent_++; }
|
||||
|
||||
void CodeWriter::Outdent() { indent_--; }
|
||||
|
||||
std::string CodeWriter::GenerateIndent() const {
|
||||
std::string res;
|
||||
res.reserve(indent_str_.size() * indent_);
|
||||
for (int i = 0; i < indent_; i++) {
|
||||
res.append(indent_str_);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); }
|
||||
|
||||
void CodeWriter::AppendNoNewLine(const std::string& text) {
|
||||
AppendInternal(text, false);
|
||||
}
|
||||
|
||||
void CodeWriter::AppendInternal(const std::string& text, bool newline) {
|
||||
// Prefix indent
|
||||
if ((buffer_.empty() // nothing in the buffer
|
||||
|| buffer_.back() == '\n') // is on new line
|
||||
// is writing on current line
|
||||
&& (!text.empty() && text[0] != '\n' && text[0] != '\r')) {
|
||||
buffer_.append(GenerateIndent());
|
||||
}
|
||||
// State machine variables
|
||||
bool in_token = false;
|
||||
int i = 0;
|
||||
// Rough memory reserve
|
||||
buffer_.reserve(buffer_.size() + text.size());
|
||||
std::string token_buffer;
|
||||
// A simple LL1 analysis
|
||||
while (i < text.size()) {
|
||||
char cur = text[i];
|
||||
char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian
|
||||
if (!in_token) {
|
||||
if (cur == '{' && cur_next == '{') { // Enter token
|
||||
in_token = true;
|
||||
i += 2;
|
||||
} else if (cur == '\n') { // We need to apply global indent here
|
||||
buffer_.push_back(cur);
|
||||
if (cur_next != '\0' && cur_next != '\n' && cur_next != '\r') {
|
||||
buffer_.append(GenerateIndent());
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
buffer_.push_back(cur);
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
if (cur == '}' && cur_next == '}') { // Close token
|
||||
in_token = false;
|
||||
const auto value = GetTokenValue(token_buffer);
|
||||
buffer_.append(value);
|
||||
token_buffer.clear();
|
||||
i += 2;
|
||||
} else {
|
||||
token_buffer.push_back(cur);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!token_buffer.empty()) {
|
||||
// Typically only Code Generator's call this function. It's
|
||||
// their duty to make sure the code (or template) has valid syntax, and
|
||||
// unclosed "{{...}}" implicits severe error in the template.
|
||||
err_->Error("Internal: Invalid template: {{token}} is not closed.");
|
||||
}
|
||||
if (newline) {
|
||||
buffer_.push_back('\n');
|
||||
}
|
||||
}
|
||||
|
||||
void CodeWriter::NewLine() { Append(""); }
|
||||
|
||||
void CodeWriter::Backspace(int n) {
|
||||
buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0);
|
||||
}
|
||||
|
||||
std::string CodeWriter::ToString() const { return buffer_; }
|
||||
|
||||
bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); }
|
||||
|
||||
void CodeWriter::Clear() {
|
||||
buffer_.clear();
|
||||
value_map_.clear();
|
||||
indent_ = 0;
|
||||
}
|
||||
|
||||
std::string SnakeCaseToCamelCase(const std::string& s) {
|
||||
std::string t;
|
||||
t.reserve(s.length());
|
||||
size_t i = 0;
|
||||
// Note: Use simple string += for simplicity.
|
||||
bool cap = false;
|
||||
while (i < s.size()) {
|
||||
const char c = s[i++];
|
||||
if (c == '_') {
|
||||
cap = true;
|
||||
} else if (cap) {
|
||||
t += toupper(c);
|
||||
cap = false;
|
||||
} else {
|
||||
t += c;
|
||||
}
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
std::string JoinPath(const std::string& a, const std::string& b) {
|
||||
if (a.empty()) return b;
|
||||
std::string a_fixed = a;
|
||||
if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back();
|
||||
std::string b_fixed = b;
|
||||
if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1);
|
||||
return a_fixed + "/" + b_fixed;
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -1,127 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
|
||||
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
/// Collects runtime error logs which could be showed later.
|
||||
// TODO(b/150538286): Consider a better mechanism to simplify callsite code.
|
||||
class ErrorReporter {
|
||||
public:
|
||||
int Warning(const char* format, ...);
|
||||
int Error(const char* format, ...);
|
||||
std::string GetMessage();
|
||||
|
||||
private:
|
||||
int Report(const char* prefix, const char* format, va_list args);
|
||||
std::stringstream buffer_;
|
||||
};
|
||||
|
||||
/// Implements basic code generating with text templates.
|
||||
///
|
||||
/// It could accept code templates and concatenate them into complete codes. A
|
||||
/// template could contain named values.
|
||||
///
|
||||
/// Example code:
|
||||
/// CodeWriter code;
|
||||
/// code.SetValue("NAME", "Foo");
|
||||
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
|
||||
/// code.SetValue("NAME", "Bar");
|
||||
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
|
||||
///
|
||||
/// Output:
|
||||
/// void Foo() { printf("%s", "Foo"); }
|
||||
/// void Bar() { printf("%s", "Bar"); }
|
||||
class CodeWriter {
|
||||
public:
|
||||
explicit CodeWriter(ErrorReporter* err);
|
||||
/// Sets value to a token. When generating code with template, a string in a
|
||||
/// pair of {{ and }} will be regarded as a token and replaced with the
|
||||
/// corresponding value in code generation.
|
||||
/// It rewrites if the token already has a value.
|
||||
void SetTokenValue(const std::string& token, const std::string& value);
|
||||
|
||||
/// Gets the current value set on the given token.
|
||||
const std::string GetTokenValue(const std::string& token) const;
|
||||
|
||||
/// Sets the unit indent string. For example, in Java it should be " ".
|
||||
void SetIndentString(const std::string& indent);
|
||||
|
||||
/// Increases the indent by a unit (the string set in SetIndentString).
|
||||
void Indent();
|
||||
|
||||
/// Decreases the indent by a unit (the string set in SetIndentString).
|
||||
void Outdent();
|
||||
|
||||
/// Generates the indentation string.
|
||||
std::string GenerateIndent() const;
|
||||
|
||||
/// Appends a piece of template codes to the stream. Every named value will be
|
||||
/// replaced via the real value. A new line will always be appended at the
|
||||
/// end.
|
||||
void Append(const std::string& text);
|
||||
|
||||
/// Appends a piece of template codes to the stream. Same with `Append`, but a
|
||||
/// new line will not be appended at the end.
|
||||
void AppendNoNewLine(const std::string& text);
|
||||
|
||||
/// Appends a new line to the stream.
|
||||
void NewLine();
|
||||
|
||||
/// Deletes the last N charaters in the stream. If the stream has less than N
|
||||
/// characters, deletes all.
|
||||
void Backspace(int n);
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
/// Checks if the internal string stream is empty. Note: This method has
|
||||
// overhead.
|
||||
bool IsStreamEmpty() const;
|
||||
|
||||
/// Clears all the internal string stream and value map.
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
void AppendInternal(const std::string& text, bool newline);
|
||||
|
||||
std::string indent_str_;
|
||||
int indent_;
|
||||
|
||||
std::map<std::string, std::string> value_map_;
|
||||
std::string buffer_;
|
||||
|
||||
ErrorReporter* err_;
|
||||
};
|
||||
|
||||
/// Converts foo_bar_name to fooBarName. It's callers duty to make sure given
|
||||
/// string "s" is already in snake case; or unexpected behavior may occur.
|
||||
std::string SnakeCaseToCamelCase(const std::string& s);
|
||||
|
||||
/// Joins 2 parts of file path into one, connected by unix path seperator '/'.
|
||||
/// It's callers duty to ensure the two parts are valid.
|
||||
std::string JoinPath(const std::string& a, const std::string& b);
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
|
@ -1,97 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
namespace {
|
||||
|
||||
TEST(ErrorReporterTest, TestReportError) {
|
||||
ErrorReporter err;
|
||||
err.Error("some text");
|
||||
EXPECT_EQ(err.GetMessage(), "[ERROR] some text\n");
|
||||
EXPECT_EQ(err.GetMessage(), "");
|
||||
}
|
||||
|
||||
TEST(CodeGeneratorTest, TestExample) {
|
||||
ErrorReporter err;
|
||||
CodeWriter writer(&err);
|
||||
writer.SetTokenValue("NAME", "Foo");
|
||||
const std::string text = R"(void {{NAME}}() { printf("%s", "{{NAME}}"); })";
|
||||
writer.Append(text);
|
||||
writer.SetTokenValue("NAME", "Bar");
|
||||
writer.Append(text);
|
||||
EXPECT_EQ(
|
||||
"void Foo() { printf(\"%s\", \"Foo\"); }\n"
|
||||
"void Bar() { printf(\"%s\", \"Bar\"); }\n",
|
||||
writer.ToString());
|
||||
}
|
||||
|
||||
TEST(CodeGeneratorTest, TestInexistentToken) {
|
||||
ErrorReporter err;
|
||||
CodeWriter writer(&err);
|
||||
writer.SetTokenValue("NAME", "Foo");
|
||||
const std::string text = R"(void {{name}}() {})";
|
||||
writer.Append(text);
|
||||
EXPECT_EQ(err.GetMessage(),
|
||||
"[ERROR] Internal: Cannot find value with token 'name'\n");
|
||||
}
|
||||
|
||||
TEST(CodeGeneratorTest, TestUnclosedToken) {
|
||||
ErrorReporter err;
|
||||
CodeWriter writer(&err);
|
||||
writer.SetTokenValue("NAME", "Foo");
|
||||
const std::string text = R"(void {{NAME}() {})";
|
||||
writer.Append(text);
|
||||
EXPECT_EQ(err.GetMessage(),
|
||||
"[ERROR] Internal: Invalid template: {{token}} is not closed.\n");
|
||||
}
|
||||
|
||||
TEST(CodeGeneratorTest, TestIndentControl) {
|
||||
ErrorReporter err;
|
||||
CodeWriter writer(&err);
|
||||
writer.SetIndentString(" ");
|
||||
writer.Indent();
|
||||
writer.AppendNoNewLine("abcde"); // Will indent
|
||||
EXPECT_EQ(" abcde", writer.ToString());
|
||||
writer.Clear();
|
||||
writer.Indent();
|
||||
writer.AppendNoNewLine("abc\n\nde");
|
||||
// The blank line will not indent
|
||||
EXPECT_EQ(" abc\n\n de", writer.ToString());
|
||||
writer.Clear();
|
||||
writer.Indent();
|
||||
writer.Append("abc");
|
||||
writer.Outdent();
|
||||
writer.AppendNoNewLine("def");
|
||||
EXPECT_EQ(" abc\ndef", writer.ToString());
|
||||
}
|
||||
|
||||
TEST(CaseConversionTest, TestSnakeToCamel) {
|
||||
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel"));
|
||||
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel_"));
|
||||
EXPECT_EQ("ImACamel", SnakeCaseToCamelCase("_im_a_camel"));
|
||||
EXPECT_EQ("", SnakeCaseToCamelCase("_"));
|
||||
EXPECT_EQ("camel", SnakeCaseToCamelCase("camel"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="org.tensorflow.lite.support">
|
||||
<uses-sdk android:minSdkVersion="19" />
|
||||
</manifest>
|
||||
|
@ -1,66 +0,0 @@
|
||||
# Description:
|
||||
# TensorFlow Lite Support API in Java.
|
||||
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# TODO(b/156482505): The NOGPU target is a temporary target. Internally, people
|
||||
# may already depend on "tensorflow-lite-support" so we shouldn't remove GPU
|
||||
# from its dependency. We will have CLs to help users migrate. After migration
|
||||
# is done, the "NOGPU" target will be removed.
|
||||
android_library(
|
||||
name = "tensorflow-lite-support-nogpu",
|
||||
srcs = glob(["src/java/org/tensorflow/lite/support/**/*.java"]),
|
||||
javacopts = JAVACOPTS,
|
||||
manifest = "AndroidManifest.xml",
|
||||
deps = [
|
||||
"//tensorflow/lite/java:tensorflowlite",
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(138904786): Split Java part and Android part to make the support library usable by pure Java.
|
||||
# For new users: Please use "tensorflow-lite-support-nogpu" if possible, and
|
||||
# additionally depends on "tensorflowlite_gpu" if needed.
|
||||
android_library(
|
||||
name = "tensorflow-lite-support",
|
||||
srcs = glob(["src/java/org/tensorflow/lite/support/**/*.java"]),
|
||||
javacopts = JAVACOPTS,
|
||||
manifest = "AndroidManifest.xml",
|
||||
deps = [
|
||||
"//tensorflow/lite/java:tensorflowlite",
|
||||
"//tensorflow/lite/java:tensorflowlite_gpu", # unuseddeps: keep
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
||||
|
||||
# This alias matches the style of lite/java naming for android_library targets. We keep the
|
||||
# `tensorflow-lite-support` variant to match the associated .aar library name output style.
|
||||
alias(
|
||||
name = "tensorflowlite_support",
|
||||
actual = ":tensorflow-lite-support",
|
||||
)
|
||||
|
||||
java_library(
|
||||
name = "tensorflow-lite-support-precondition",
|
||||
srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
deps = [
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "tensorflow-lite-support-precondition-lib-android",
|
||||
srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"],
|
||||
javacopts = JAVACOPTS,
|
||||
manifest = "AndroidManifest.xml",
|
||||
deps = [
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
@ -1,17 +0,0 @@
|
||||
# TensorFlow Lite Android Support Library
|
||||
|
||||
Mobile application developers typically interact with typed objects such as
|
||||
bitmaps or primitives such as integers. However, the TensorFlow Lite Interpreter
|
||||
that runs the on-device machine learning model uses tensors in the form of
|
||||
ByteBuffer, which can be difficult to debug and manipulate. The TensorFlow Lite
|
||||
Android Support Library is designed to help process the input and output of
|
||||
TensorFlow Lite models, and make the TensorFlow Lite interpreter easier to use.
|
||||
|
||||
We welcome feedback from the community as we develop this support library,
|
||||
especially around:
|
||||
|
||||
* Use-cases we should support including data types and operations
|
||||
* Ease of use - does the APIs make sense to the community
|
||||
|
||||
See the [documentation](https://www.tensorflow.org/lite/guide/lite_support) for
|
||||
instruction and examples.
|
@ -1,184 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common;
|
||||
|
||||
import android.content.Context;
|
||||
import android.content.res.AssetFileDescriptor;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.MappedByteBuffer;
|
||||
import java.nio.channels.FileChannel;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
|
||||
/** File I/O utilities. */
|
||||
public class FileUtil {
|
||||
private FileUtil() {}
|
||||
|
||||
/**
|
||||
* Loads labels from the label file into a list of strings.
|
||||
*
|
||||
* <p>A legal label file is the plain text file whose contents are split into lines, and each line
|
||||
* is an individual value. The file should be in assets of the context.
|
||||
*
|
||||
* @param context The context holds assets.
|
||||
* @param filePath The path of the label file, relative with assets directory.
|
||||
* @return a list of labels.
|
||||
* @throws IOException if error occurs to open or read the file.
|
||||
*/
|
||||
@NonNull
|
||||
public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
|
||||
throws IOException {
|
||||
return loadLabels(context, filePath, Charset.defaultCharset());
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads labels from the label file into a list of strings.
|
||||
*
|
||||
* <p>A legal label file is the plain text file whose contents are split into lines, and each line
|
||||
* is an individual value. The empty lines will be ignored. The file should be in assets of the
|
||||
* context.
|
||||
*
|
||||
* @param context The context holds assets.
|
||||
* @param filePath The path of the label file, relative with assets directory.
|
||||
* @param cs {@code Charset} to use when decoding content of label file.
|
||||
* @return a list of labels.
|
||||
* @throws IOException if error occurs to open or read the file.
|
||||
*/
|
||||
@NonNull
|
||||
public static List<String> loadLabels(
|
||||
@NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
|
||||
SupportPreconditions.checkNotNull(context, "Context cannot be null.");
|
||||
SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
|
||||
try (InputStream inputStream = context.getAssets().open(filePath)) {
|
||||
return loadLabels(inputStream, cs);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads labels from an input stream of an opened label file. See details for label files in
|
||||
* {@link FileUtil#loadLabels(Context, String)}.
|
||||
*
|
||||
* @param inputStream the input stream of an opened label file.
|
||||
* @return a list of labels.
|
||||
* @throws IOException if error occurs to open or read the file.
|
||||
*/
|
||||
@NonNull
|
||||
public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
|
||||
return loadLabels(inputStream, Charset.defaultCharset());
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads labels from an input stream of an opened label file. See details for label files in
|
||||
* {@link FileUtil#loadLabels(Context, String)}.
|
||||
*
|
||||
* @param inputStream the input stream of an opened label file.
|
||||
* @param cs {@code Charset} to use when decoding content of label file.
|
||||
* @return a list of labels.
|
||||
* @throws IOException if error occurs to open or read the file.
|
||||
*/
|
||||
@NonNull
|
||||
public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
|
||||
throws IOException {
|
||||
List<String> labels = new ArrayList<>();
|
||||
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (line.trim().length() > 0) {
|
||||
labels.add(line);
|
||||
}
|
||||
}
|
||||
return labels;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a vocabulary file (a single-column text file) into a list of strings.
|
||||
*
|
||||
* <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
|
||||
* and each line is an individual value. The file should be in assets of the context.
|
||||
*
|
||||
* @param context The context holds assets.
|
||||
* @param filePath The path of the vocabulary file, relative with assets directory.
|
||||
* @return a list of vocabulary words.
|
||||
* @throws IOException if error occurs to open or read the file.
|
||||
*/
|
||||
@NonNull
|
||||
public static List<String> loadSingleColumnTextFile(
|
||||
@NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
|
||||
return loadLabels(context, filePath, cs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
|
||||
* text file). See details for vocabulary files in {@link FileUtil#loadVocabularyFile(Context,
|
||||
* String)}.
|
||||
*
|
||||
* @param inputStream the input stream of an opened vocabulary file.
|
||||
* @return a list of vocabulary words.
|
||||
* @throws IOException if error occurs to open or read the file.
|
||||
*/
|
||||
@NonNull
|
||||
public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs)
|
||||
throws IOException {
|
||||
return loadLabels(inputStream, cs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a file from the asset folder through memory mapping.
|
||||
*
|
||||
* @param context Application context to access assets.
|
||||
* @param filePath Asset path of the file.
|
||||
* @return the loaded memory mapped file.
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
@NonNull
|
||||
public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath)
|
||||
throws IOException {
|
||||
SupportPreconditions.checkNotNull(context, "Context should not be null.");
|
||||
SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
|
||||
try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
|
||||
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
|
||||
FileChannel fileChannel = inputStream.getChannel();
|
||||
long startOffset = fileDescriptor.getStartOffset();
|
||||
long declaredLength = fileDescriptor.getDeclaredLength();
|
||||
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a binary file from the asset folder.
|
||||
*
|
||||
* @param context Application context to access assets.
|
||||
* @param filePath Asset path of the file.
|
||||
* @return the byte array for the binary file.
|
||||
* @throws IOException if an I/O error occurs when loading file.
|
||||
*/
|
||||
@NonNull
|
||||
public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath)
|
||||
throws IOException {
|
||||
ByteBuffer buffer = loadMappedFile(context, filePath);
|
||||
byte[] byteArray = new byte[buffer.remaining()];
|
||||
buffer.get(byteArray);
|
||||
return byteArray;
|
||||
}
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common;
|
||||
|
||||
/**
|
||||
* The common interface for classes that carries an "apply" method, which converts T to another one.
|
||||
* @param <T> The class which Operator handles.
|
||||
*/
|
||||
public interface Operator<T> {
|
||||
|
||||
/**
|
||||
* Applies an operation on a T object, returning a T object.
|
||||
*
|
||||
* <p>Note: The returned object could probably be the same one with given input, and given input
|
||||
* could probably be changed.
|
||||
*/
|
||||
T apply(T x);
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common;
|
||||
|
||||
/**
|
||||
* Processes T object with prepared {@link Operator<T>}.
|
||||
*/
|
||||
public interface Processor<T> {
|
||||
T process(T input);
|
||||
}
|
@ -1,82 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
|
||||
/**
|
||||
* A processor base class that chains a serial of {@link Operator<T>} and executes them.
|
||||
*
|
||||
* <p>Typically, users could use its subclasses, e.g. {@link
|
||||
* org.tensorflow.lite.support.image.ImageProcessor} rather than directly use this one.
|
||||
*
|
||||
* @param <T> The type that the Operator is handling.
|
||||
*/
|
||||
public class SequentialProcessor<T> implements Processor<T> {
|
||||
|
||||
/** List of operators added to this {@link SequentialProcessor}. */
|
||||
protected final List<Operator<T>> operatorList;
|
||||
/**
|
||||
* The {@link Map} between the operator name and the corresponding op indexes in {@code
|
||||
* operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}.
|
||||
*/
|
||||
protected final Map<String, List<Integer>> operatorIndex;
|
||||
|
||||
protected SequentialProcessor(Builder<T> builder) {
|
||||
operatorList = builder.operatorList;
|
||||
operatorIndex = Collections.unmodifiableMap(builder.operatorIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public T process(T x) {
|
||||
for (Operator<T> op : operatorList) {
|
||||
x = op.apply(x);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
/** The inner builder class to build a Sequential Processor. */
|
||||
protected static class Builder<T> {
|
||||
|
||||
private final List<Operator<T>> operatorList;
|
||||
private final Map<String, List<Integer>> operatorIndex;
|
||||
|
||||
protected Builder() {
|
||||
operatorList = new ArrayList<>();
|
||||
operatorIndex = new HashMap<>();
|
||||
}
|
||||
|
||||
public Builder<T> add(@NonNull Operator<T> op) {
|
||||
SupportPreconditions.checkNotNull(op, "Adding null Op is illegal.");
|
||||
operatorList.add(op);
|
||||
String operatorName = op.getClass().getName();
|
||||
if (!operatorIndex.containsKey(operatorName)) {
|
||||
operatorIndex.put(operatorName, new ArrayList<Integer>());
|
||||
}
|
||||
operatorIndex.get(operatorName).add(operatorList.size() - 1);
|
||||
return this;
|
||||
}
|
||||
|
||||
public SequentialProcessor<T> build() {
|
||||
return new SequentialProcessor<T>(this);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,184 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common;
|
||||
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
|
||||
/** Static error checking util methods. */
|
||||
public final class SupportPreconditions {
|
||||
/**
|
||||
* Ensures that an object reference passed as a parameter to the calling method is not null.
|
||||
*
|
||||
* @param reference an object reference
|
||||
* @return the non-null reference that was validated
|
||||
* @throws NullPointerException if {@code reference} is null
|
||||
*/
|
||||
public static <T extends Object> T checkNotNull(T reference) {
|
||||
if (reference == null) {
|
||||
throw new NullPointerException("The object reference is null.");
|
||||
}
|
||||
return reference;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that an object reference passed as a parameter to the calling method is not null.
|
||||
*
|
||||
* @param reference an object reference
|
||||
* @param errorMessage the exception message to use if the check fails; will be converted to a
|
||||
* string using {@link String#valueOf(Object)}
|
||||
* @return the non-null reference that was validated
|
||||
* @throws NullPointerException if {@code reference} is null
|
||||
*/
|
||||
public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
|
||||
if (reference == null) {
|
||||
throw new NullPointerException(String.valueOf(errorMessage));
|
||||
}
|
||||
return reference;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that the given String is not empty and not null.
|
||||
*
|
||||
* @param string the String to test
|
||||
* @return the non-null non-empty String that was validated
|
||||
* @throws IllegalArgumentException if {@code string} is null or empty
|
||||
*/
|
||||
public static String checkNotEmpty(String string) {
|
||||
if (string == null || string.length() == 0) {
|
||||
throw new IllegalArgumentException("Given String is empty or null.");
|
||||
}
|
||||
return string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that the given String is not empty and not null.
|
||||
*
|
||||
* @param string the String to test
|
||||
* @param errorMessage the exception message to use if the check fails; will be converted to a
|
||||
* string using {@link String#valueOf(Object)}
|
||||
* @return the non-null non-empty String that was validated
|
||||
* @throws IllegalArgumentException if {@code string} is null or empty
|
||||
*/
|
||||
public static String checkNotEmpty(String string, Object errorMessage) {
|
||||
if (string == null || string.length() == 0) {
|
||||
throw new IllegalArgumentException(String.valueOf(errorMessage));
|
||||
}
|
||||
return string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the truth of an expression involving one or more parameters to the calling method.
|
||||
*
|
||||
* @param expression a boolean expression.
|
||||
* @throws IllegalArgumentException if {@code expression} is false.
|
||||
*/
|
||||
public static void checkArgument(boolean expression) {
|
||||
if (!expression) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the truth of an expression involving one or more parameters to the calling method.
|
||||
*
|
||||
* @param expression a boolean expression.
|
||||
* @param errorMessage the exception message to use if the check fails; will be converted to a
|
||||
* string using {@link String#valueOf(Object)}.
|
||||
* @throws IllegalArgumentException if {@code expression} is false.
|
||||
*/
|
||||
public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
|
||||
if (!expression) {
|
||||
throw new IllegalArgumentException(String.valueOf(errorMessage));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
|
||||
* {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
|
||||
*
|
||||
* @param index a user-supplied index identifying an element of an array, list or string
|
||||
* @param size the size of that array, list or string
|
||||
* @return the value of {@code index}
|
||||
* @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
|
||||
* @throws IllegalArgumentException if {@code size} is negative
|
||||
*/
|
||||
public static int checkElementIndex(int index, int size) {
|
||||
return checkElementIndex(index, size, "index");
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
|
||||
* {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
|
||||
*
|
||||
* @param index a user-supplied index identifying an element of an array, list or string
|
||||
* @param size the size of that array, list or string
|
||||
* @param desc the text to use to describe this index in an error message
|
||||
* @return the value of {@code index}
|
||||
* @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
|
||||
* @throws IllegalArgumentException if {@code size} is negative
|
||||
*/
|
||||
public static int checkElementIndex(int index, int size, @Nullable String desc) {
|
||||
// Carefully optimized for execution by hotspot (explanatory comment above)
|
||||
if (index < 0 || index >= size) {
|
||||
throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the truth of an expression involving the state of the calling instance, but not
|
||||
* involving any parameters to the calling method.
|
||||
*
|
||||
* @param expression a boolean expression
|
||||
* @throws IllegalStateException if {@code expression} is false
|
||||
* @see Verify#verify Verify.verify()
|
||||
*/
|
||||
public static void checkState(boolean expression) {
|
||||
if (!expression) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the truth of an expression involving the state of the calling instance, but not
|
||||
* involving any parameters to the calling method.
|
||||
*
|
||||
* @param expression a boolean expression
|
||||
* @param errorMessage the exception message to use if the check fails; will be converted to a
|
||||
* string using {@link String#valueOf(Object)}
|
||||
* @throws IllegalStateException if {@code expression} is false
|
||||
* @see Verify#verify Verify.verify()
|
||||
*/
|
||||
public static void checkState(boolean expression, @Nullable Object errorMessage) {
|
||||
if (!expression) {
|
||||
throw new IllegalStateException(String.valueOf(errorMessage));
|
||||
}
|
||||
}
|
||||
|
||||
private static String badElementIndex(int index, int size, @Nullable String desc) {
|
||||
if (index < 0) {
|
||||
return String.format("%s (%s) must not be negative", desc, index);
|
||||
} else if (size < 0) {
|
||||
throw new IllegalArgumentException("negative size: " + size);
|
||||
} else { // index >= size
|
||||
return String.format("%s (%s) must be less than size (%s)", desc, index, size);
|
||||
}
|
||||
}
|
||||
|
||||
private SupportPreconditions() {
|
||||
throw new AssertionError("SupportPreconditions is Uninstantiable.");
|
||||
}
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common;
|
||||
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/**
|
||||
* Applies some operation on TensorBuffers.
|
||||
*/
|
||||
public interface TensorOperator extends Operator<TensorBuffer> {
|
||||
/** @see Operator#apply(Object) . */
|
||||
@Override
|
||||
TensorBuffer apply(TensorBuffer input);
|
||||
}
|
@ -1,68 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common;
|
||||
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/**
|
||||
* TensorProcessor is a helper class for preprocessing and postprocessing tensors. It could
|
||||
* transform a {@link TensorBuffer} to another by executing a chain of {@link TensorOperator}.
|
||||
*
|
||||
* <p>Example Usage:
|
||||
*
|
||||
* <pre>
|
||||
* TensorProcessor processor = new TensorProcessor.Builder().add(new NormalizeOp(1, 2)).build();
|
||||
* TensorBuffer anotherTensorBuffer = processor.process(tensorBuffer);
|
||||
* </pre>
|
||||
*
|
||||
* @see TensorProcessor.Builder to build a {@link TensorProcessor} instance.
|
||||
* @see TensorProcessor#process(TensorBuffer) to apply the processor on a {@link TensorBuffer}.
|
||||
*/
|
||||
public class TensorProcessor extends SequentialProcessor<TensorBuffer> {
|
||||
private TensorProcessor(Builder builder) {
|
||||
super(builder);
|
||||
}
|
||||
|
||||
/** The Builder to create an {@link TensorProcessor}, which could be executed later. */
|
||||
public static class Builder extends SequentialProcessor.Builder<TensorBuffer> {
|
||||
|
||||
/**
|
||||
* Creates a Builder to build {@link TensorProcessor}.
|
||||
*
|
||||
* @see #add(TensorOperator) to add an Op.
|
||||
* @see #build() to complete the building process and get a built Processor.
|
||||
*/
|
||||
public Builder() {
|
||||
super();
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an {@link TensorOperator} into the Operator chain.
|
||||
*
|
||||
* @param op the Operator instance to be executed then.
|
||||
*/
|
||||
public TensorProcessor.Builder add(TensorOperator op) {
|
||||
super.add(op);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Completes the building process and gets the {@link TensorProcessor} instance. */
|
||||
@Override
|
||||
public TensorProcessor build() {
|
||||
return new TensorProcessor(this);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common.ops;
|
||||
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
import org.tensorflow.lite.support.common.TensorOperator;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/** Casts a {@link TensorBuffer} to a specified data type. */
|
||||
public class CastOp implements TensorOperator {
|
||||
|
||||
private final DataType destinationType;
|
||||
|
||||
/**
|
||||
* Constructs a CastOp.
|
||||
*
|
||||
* <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than in
|
||||
* a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}.
|
||||
*
|
||||
* <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code
|
||||
* destinationType}, the original buffer will be directly returned.
|
||||
*
|
||||
* @param destinationType: The type of the casted {@link TensorBuffer}.
|
||||
* @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8}
|
||||
* nor {@link DataType#FLOAT32}.
|
||||
*/
|
||||
public CastOp(DataType destinationType) {
|
||||
SupportPreconditions.checkArgument(
|
||||
destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32,
|
||||
"Destination type " + destinationType + " is not supported.");
|
||||
this.destinationType = destinationType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TensorBuffer apply(TensorBuffer input) {
|
||||
if (input.getDataType() == destinationType) {
|
||||
return input;
|
||||
}
|
||||
return TensorBuffer.createFrom(input, destinationType);
|
||||
}
|
||||
}
|
@ -1,40 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common.ops;
|
||||
|
||||
import org.tensorflow.lite.support.common.TensorOperator;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/**
|
||||
* Dequantizes a {@link TensorBuffer} with given {@code zeroPoint} and {@code scale}.
|
||||
*
|
||||
* <p>Note: The data type of output tensor is always {@code FLOAT32} except when the DequantizeOp is
|
||||
* created effectively as an identity Op such as setting {@code zeroPoint} to 0 and {@code scale} to
|
||||
* 1 (in this case, the output tensor is the same instance as input).
|
||||
*
|
||||
* <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link DequantizeOp} will be bypassed,
|
||||
* which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful
|
||||
* when passing in the quantization parameters that are extracted directly from the TFLite model
|
||||
* flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read
|
||||
* as 0.
|
||||
*/
|
||||
public class DequantizeOp extends NormalizeOp implements TensorOperator {
|
||||
|
||||
public DequantizeOp(float zeroPoint, float scale) {
|
||||
// Quantization: f = (q - z) * s
|
||||
super(zeroPoint, 1 / scale);
|
||||
}
|
||||
}
|
@ -1,160 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common.ops;
|
||||
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
import org.tensorflow.lite.support.common.TensorOperator;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat;
|
||||
|
||||
/**
|
||||
* Normalizes a {@link TensorBuffer} with given mean and stddev: output = (input - mean) / stddev.
|
||||
*/
|
||||
public class NormalizeOp implements TensorOperator {
|
||||
|
||||
// mean.length should always be equal to stddev.length and always >= 1.
|
||||
private final float[] mean;
|
||||
private final float[] stddev;
|
||||
private final int numChannels;
|
||||
private final boolean isIdentityOp;
|
||||
|
||||
/**
|
||||
* Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
|
||||
* satisfies:
|
||||
*
|
||||
* <pre>
|
||||
* output = (input - mean) / stddev
|
||||
* </pre>
|
||||
*
|
||||
* <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
|
||||
* normalization. <br>
|
||||
* 1. Both {@code mean} and {code stddev} are 0. <br>
|
||||
* 2. {@code mean} is 0 and {stddev} is Infinity.
|
||||
*
|
||||
* <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
|
||||
* happen, and original input will be directly returned in execution.
|
||||
*
|
||||
* <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
|
||||
* present, except that the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 and
|
||||
* {@code stddev} is set to 1.
|
||||
*
|
||||
* @param mean the mean value to be subtracted first.
|
||||
* @param stddev the standard deviation value to divide then.
|
||||
* @throws IllegalArgumentException if {@code stddev} is zero.
|
||||
*/
|
||||
public NormalizeOp(float mean, float stddev) {
|
||||
// Make exceptions to the cases that
|
||||
// 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters
|
||||
// from a tensor which does not have the values populated in the metadata. The same situation
|
||||
// may also happen to the quantization parameters.
|
||||
// 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
|
||||
// parameters from a tensor which does not have the values populated in the metadata, and then
|
||||
// passing the parameters into the DequantizeOp.
|
||||
// Bypass both of the two cases, by reseting stddev to 1.0f.
|
||||
if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
|
||||
stddev = 1.0f;
|
||||
}
|
||||
|
||||
SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
|
||||
boolean meansIsZeroAndDevsIs1 = false;
|
||||
if (mean == 0.0f && stddev == 1.0f) {
|
||||
meansIsZeroAndDevsIs1 = true;
|
||||
}
|
||||
|
||||
this.isIdentityOp = meansIsZeroAndDevsIs1;
|
||||
this.mean = new float[] {mean};
|
||||
this.stddev = new float[] {stddev};
|
||||
this.numChannels = 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
|
||||
* satisfies:
|
||||
*
|
||||
* <pre>
|
||||
* // Pseudo code. [...][i] means a certain element whose channel id is i.
|
||||
* output[...][i] = (input[...][i] - mean[i]) / stddev[i]
|
||||
* </pre>
|
||||
*
|
||||
* <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no
|
||||
* computation will happen, and original input will be directly returned in execution.
|
||||
*
|
||||
* <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
|
||||
* present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set to
|
||||
* 0 and all {@code stddev} are set to 1.
|
||||
*
|
||||
* @param mean the mean values to be subtracted first for each channel.
|
||||
* @param stddev the standard deviation values to divide then for each channel.
|
||||
* @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different
|
||||
* number of elements with {@code stddev}, or any of them is empty.
|
||||
*/
|
||||
public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) {
|
||||
SupportPreconditions.checkNotNull(mean, "Mean cannot be null");
|
||||
SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null");
|
||||
SupportPreconditions.checkArgument(
|
||||
mean.length == stddev.length,
|
||||
"Per channel normalization requires same number of means and stddevs");
|
||||
SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty.");
|
||||
this.mean = mean.clone();
|
||||
this.stddev = stddev.clone();
|
||||
boolean allMeansAreZeroAndAllDevsAre1 = true;
|
||||
this.numChannels = mean.length;
|
||||
for (int i = 0; i < numChannels; i++) {
|
||||
SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero.");
|
||||
if (this.stddev[i] != 1 || this.mean[i] != 0) {
|
||||
allMeansAreZeroAndAllDevsAre1 = false;
|
||||
}
|
||||
}
|
||||
this.isIdentityOp = allMeansAreZeroAndAllDevsAre1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the defined normalization on given tensor and returns the result.
|
||||
*
|
||||
* <p>Note: {@code input} is possibly the same instance with the output.
|
||||
*
|
||||
* @param input input tensor. It may be the same instance with the output.
|
||||
* @return output tensor.
|
||||
*/
|
||||
@Override
|
||||
@NonNull
|
||||
public TensorBuffer apply(@NonNull TensorBuffer input) {
|
||||
if (isIdentityOp) {
|
||||
return input;
|
||||
}
|
||||
int[] shape = input.getShape();
|
||||
SupportPreconditions.checkArgument(
|
||||
numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels),
|
||||
"Number of means (stddevs) is not same with number of channels (size of last axis).");
|
||||
// TODO(136750944): Eliminate the array copy here.
|
||||
float[] values = input.getFloatArray();
|
||||
int j = 0;
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
values[i] = (values[i] - mean[j]) / stddev[j];
|
||||
j = (j + 1) % numChannels;
|
||||
}
|
||||
TensorBuffer output;
|
||||
if (input.isDynamic()) {
|
||||
output = TensorBufferFloat.createDynamic(DataType.FLOAT32);
|
||||
} else {
|
||||
output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32);
|
||||
}
|
||||
output.loadArray(values, shape);
|
||||
return output;
|
||||
}
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.common.ops;
|
||||
|
||||
import org.tensorflow.lite.support.common.TensorOperator;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/**
|
||||
* Quantizes a {@link TensorBuffer} with given {@code zeroPoint} and {@code scale}.
|
||||
*
|
||||
* <p>Note: {@link QuantizeOp} does not cast output to UINT8, but only performs the quantization
|
||||
* math on top of input. The data type of output tensor is always {@code FLOAT32} except that the Op
|
||||
* is effectively an identity Op (in this case, the output tensor is the same instance as the
|
||||
* input). To connect with quantized model, a {@link CastOp} is probably needed.
|
||||
*
|
||||
* <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link QuantizeOp} will be bypassed,
|
||||
* which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful
|
||||
* when passing in the quantization parameters that are extracted directly from the TFLite model
|
||||
* flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read
|
||||
* as 0.
|
||||
*/
|
||||
public class QuantizeOp extends NormalizeOp implements TensorOperator {
|
||||
|
||||
public QuantizeOp(float zeroPoint, float scale) {
|
||||
// Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s
|
||||
super(-zeroPoint * scale, scale);
|
||||
}
|
||||
}
|
@ -1,202 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.image;
|
||||
|
||||
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
|
||||
|
||||
import android.graphics.RectF;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/**
|
||||
* Helper class for converting values that represents bounding boxes into rectangles.
|
||||
*
|
||||
* <p>The class provides a static function to create bounding boxes as {@link RectF} from different
|
||||
* types of configurations.
|
||||
*
|
||||
* <p>Generally, a bounding box could be represented by 4 float values, but the values could be
|
||||
* interpreted in many ways. We now support 3 {@link Type} of configurations, and the order of
|
||||
* elements in each type is configurable as well.
|
||||
*/
|
||||
public final class BoundingBoxUtil {
|
||||
|
||||
/** Denotes how a bounding box is represented. */
|
||||
public enum Type {
|
||||
/**
|
||||
* Represents the bounding box by using the combination of boundaries, {left, top, right,
|
||||
* bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated by an
|
||||
* index array.
|
||||
*/
|
||||
BOUNDARIES,
|
||||
/**
|
||||
* Represents the bounding box by using the upper_left corner, width and height. The default
|
||||
* order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an
|
||||
* index array.
|
||||
*/
|
||||
UPPER_LEFT,
|
||||
/**
|
||||
* Represents the bounding box by using the center of the box, width and height. The default
|
||||
* order is {center_x, center_y, width, height}. Other orders can be indicated by an index
|
||||
* array.
|
||||
*/
|
||||
CENTER,
|
||||
}
|
||||
|
||||
/** Denotes if the coordinates are actual pixels or relative ratios. */
|
||||
public enum CoordinateType {
|
||||
/** The coordinates are relative ratios in range [0, 1]. */
|
||||
RATIO,
|
||||
/** The coordinates are actual pixel values. */
|
||||
PIXEL
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes.
|
||||
*
|
||||
* @param tensor holds the data representing some boxes.
|
||||
* @param valueIndex denotes the order of the elements defined in each bounding box type. An empty
|
||||
* index array represent the default order of each bounding box type. For example, to denote
|
||||
* the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1, 2,
|
||||
* 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}.
|
||||
* <p>The index array can be applied to all bounding box types to adjust the order of their
|
||||
* corresponding underlying elements.
|
||||
* @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The
|
||||
* size of that dimension is required to be 4. Index here starts from 0. For example, if the
|
||||
* tensor has shape 4x10, the axis for bounding boxes is likely to be 0. For shape 10x4, the
|
||||
* axis is likely to be 1 (or -1, equivalently).
|
||||
* @param type defines how values should be converted into boxes. See {@link Type}
|
||||
* @param coordinateType defines how values are interpreted to coordinates. See {@link
|
||||
* CoordinateType}
|
||||
* @param height the height of the image which the boxes belong to. Only has effects when {@code
|
||||
* coordinateType} is {@link CoordinateType#RATIO}
|
||||
* @param width the width of the image which the boxes belong to. Only has effects when {@code
|
||||
* coordinateType} is {@link CoordinateType#RATIO}
|
||||
* @return A list of bounding boxes that the {@code tensor} represents. All dimensions except
|
||||
* {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code
|
||||
* tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a list
|
||||
* of 20 bounding boxes.
|
||||
* @throws IllegalArgumentException if size of bounding box dimension (set by {@code
|
||||
* boundingBoxAxis}) is not 4.
|
||||
* @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)} where
|
||||
* {@code D} is the number of dimensions of the {@code tensor}.
|
||||
* @throws IllegalArgumentException if {@code tensor} has data type other than {@link
|
||||
* DataType#FLOAT32}.
|
||||
*/
|
||||
public static List<RectF> convert(
|
||||
TensorBuffer tensor,
|
||||
int[] valueIndex,
|
||||
int boundingBoxAxis,
|
||||
Type type,
|
||||
CoordinateType coordinateType,
|
||||
int height,
|
||||
int width) {
|
||||
int[] shape = tensor.getShape();
|
||||
checkArgument(
|
||||
boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length,
|
||||
String.format(
|
||||
"Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input"
|
||||
+ " tensor (shape=%s)",
|
||||
boundingBoxAxis, Arrays.toString(shape)));
|
||||
if (boundingBoxAxis < 0) {
|
||||
boundingBoxAxis = shape.length + boundingBoxAxis;
|
||||
}
|
||||
checkArgument(
|
||||
shape[boundingBoxAxis] == 4,
|
||||
String.format(
|
||||
"Size of bounding box dimension %d is not 4. Got %d in shape %s",
|
||||
boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape)));
|
||||
checkArgument(
|
||||
valueIndex.length == 4,
|
||||
String.format(
|
||||
"Bounding box index array length %d is not 4. Got index array %s",
|
||||
valueIndex.length, Arrays.toString(valueIndex)));
|
||||
checkArgument(
|
||||
tensor.getDataType() == DataType.FLOAT32,
|
||||
"Bounding Boxes only create from FLOAT32 buffers. Got: " + tensor.getDataType().name());
|
||||
List<RectF> boundingBoxList = new ArrayList<>();
|
||||
// Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and its
|
||||
// four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by
|
||||
// i * 4b + k * b + j.
|
||||
int a = 1;
|
||||
for (int i = 0; i < boundingBoxAxis; i++) {
|
||||
a *= shape[i];
|
||||
}
|
||||
int b = 1;
|
||||
for (int i = boundingBoxAxis + 1; i < shape.length; i++) {
|
||||
b *= shape[i];
|
||||
}
|
||||
float[] values = new float[4];
|
||||
ByteBuffer byteBuffer = tensor.getBuffer();
|
||||
byteBuffer.rewind();
|
||||
FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
|
||||
for (int i = 0; i < a; i++) {
|
||||
for (int j = 0; j < b; j++) {
|
||||
for (int k = 0; k < 4; k++) {
|
||||
values[k] = floatBuffer.get((i * 4 + k) * b + j);
|
||||
}
|
||||
boundingBoxList.add(
|
||||
convertOneBoundingBox(values, valueIndex, type, coordinateType, height, width));
|
||||
}
|
||||
}
|
||||
byteBuffer.rewind();
|
||||
return boundingBoxList;
|
||||
}
|
||||
|
||||
private static RectF convertOneBoundingBox(
|
||||
float[] values,
|
||||
int[] valueIndex,
|
||||
Type type,
|
||||
CoordinateType coordinateType,
|
||||
int height,
|
||||
int width) {
|
||||
float[] orderedValues = new float[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
orderedValues[i] = values[valueIndex[i]];
|
||||
}
|
||||
return convertOneBoundingBox(orderedValues, type, coordinateType, height, width);
|
||||
}
|
||||
|
||||
private static RectF convertOneBoundingBox(
|
||||
float[] values, Type type, CoordinateType coordinateType, int height, int width) {
|
||||
switch (type) {
|
||||
case BOUNDARIES:
|
||||
return convertFromBoundaries(values, coordinateType, height, width);
|
||||
case UPPER_LEFT:
|
||||
case CENTER:
|
||||
// TODO(b/150824448): convertFrom{UpperLeft, Center}
|
||||
throw new IllegalArgumentException("BoundingBox.Type " + type + " is not yet supported.");
|
||||
}
|
||||
throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type);
|
||||
}
|
||||
|
||||
private static RectF convertFromBoundaries(
|
||||
float[] values, CoordinateType coordinateType, int height, int width) {
|
||||
if (coordinateType == CoordinateType.RATIO) {
|
||||
return new RectF(
|
||||
values[0] * width, values[1] * height, values[2] * width, values[3] * height);
|
||||
} else {
|
||||
return new RectF(values[0], values[1], values[2], values[3]);
|
||||
}
|
||||
}
|
||||
|
||||
// Private constructor to prevent initialization.
|
||||
private BoundingBoxUtil() {}
|
||||
}
|
@ -1,108 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.image;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.Color;
|
||||
import java.util.Arrays;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/**
|
||||
* Implements some stateless image conversion methods.
|
||||
*
|
||||
* This class is an internal helper for {@link org.tensorflow.lite.support.image}.
|
||||
*/
|
||||
class ImageConversions {
|
||||
|
||||
/**
|
||||
* Converts an Image in a TensorBuffer to a Bitmap, whose memory is already allocated.
|
||||
*
|
||||
* <p>Notice: We only support ARGB_8888 at this point.
|
||||
*
|
||||
* @param buffer The TensorBuffer object representing the image. It should be an UInt8 buffer with
|
||||
* 3 dimensions: width, height, channel. Size of each dimension should be positive and the
|
||||
* size of channels should be 3 (representing R, G, B). An optional 4th dimension "batch" is
|
||||
* acceptable, and dimensions look like: batch, width, height, channel. In this case, size of
|
||||
* batches should be 1.
|
||||
* @param bitmap The destination of the conversion. Needs to be created in advance, needs to be
|
||||
* mutable, and needs to have the same width and height with the buffer.
|
||||
* @throws IllegalArgumentException 1) if the {@code buffer} is not uint8 (e.g. a float buffer),
|
||||
* or has an invalid shape. 2) if the {@code bitmap} is not mutable. 3) if the {@code bitmap}
|
||||
* has different height or width with the buffer.
|
||||
*/
|
||||
static void convertTensorBufferToBitmap(TensorBuffer buffer, Bitmap bitmap) {
|
||||
if (buffer.getDataType() != DataType.UINT8) {
|
||||
// We will add support to FLOAT format conversion in the future, as it may need other configs.
|
||||
throw new UnsupportedOperationException(
|
||||
String.format(
|
||||
"Converting TensorBuffer of type %s to ARGB_8888 Bitmap is not supported yet.",
|
||||
buffer.getDataType()));
|
||||
}
|
||||
int[] shape = buffer.getShape();
|
||||
TensorImage.checkImageTensorShape(shape);
|
||||
int h = shape[shape.length - 3];
|
||||
int w = shape[shape.length - 2];
|
||||
if (bitmap.getWidth() != w || bitmap.getHeight() != h) {
|
||||
throw new IllegalArgumentException(String.format(
|
||||
"Given bitmap has different width or height %s with the expected ones %s.",
|
||||
Arrays.toString(new int[]{bitmap.getWidth(), bitmap.getHeight()}),
|
||||
Arrays.toString(new int[]{w, h})));
|
||||
}
|
||||
if (!bitmap.isMutable()) {
|
||||
throw new IllegalArgumentException("Given bitmap is not mutable");
|
||||
}
|
||||
// TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
|
||||
int[] intValues = new int[w * h];
|
||||
int[] rgbValues = buffer.getIntArray();
|
||||
for (int i = 0, j = 0; i < intValues.length; i++) {
|
||||
int r = rgbValues[j++];
|
||||
int g = rgbValues[j++];
|
||||
int b = rgbValues[j++];
|
||||
intValues[i] = Color.rgb(r, g, b);
|
||||
}
|
||||
bitmap.setPixels(intValues, 0, w, 0, 0, w, h);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose memory
|
||||
* is already allocated, or could be dynamically allocated.
|
||||
*
|
||||
* @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888
|
||||
* config.
|
||||
* @param buffer The destination of the conversion. Needs to be created in advance. If it's
|
||||
* fixed-size, its flat size should be w*h*3.
|
||||
* @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match.
|
||||
*/
|
||||
static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) {
|
||||
int w = bitmap.getWidth();
|
||||
int h = bitmap.getHeight();
|
||||
int[] intValues = new int[w * h];
|
||||
bitmap.getPixels(intValues, 0, w, 0, 0, w, h);
|
||||
// TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
|
||||
int[] rgbValues = new int[w * h * 3];
|
||||
for (int i = 0, j = 0; i < intValues.length; i++) {
|
||||
rgbValues[j++] = ((intValues[i] >> 16) & 0xFF);
|
||||
rgbValues[j++] = ((intValues[i] >> 8) & 0xFF);
|
||||
rgbValues[j++] = (intValues[i] & 0xFF);
|
||||
}
|
||||
int[] shape = new int[] {h, w, 3};
|
||||
buffer.loadArray(rgbValues, shape);
|
||||
}
|
||||
|
||||
// Hide the constructor as the class is static.
|
||||
private ImageConversions() {}
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.image;
|
||||
|
||||
import android.graphics.PointF;
|
||||
import org.tensorflow.lite.support.common.Operator;
|
||||
|
||||
/** Operates a TensorImage object. Used in ImageProcessor. */
|
||||
public interface ImageOperator extends Operator<TensorImage> {
|
||||
/** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */
|
||||
@Override
|
||||
TensorImage apply(TensorImage image);
|
||||
|
||||
/** Computes the width of the expected output image when input image size is given. */
|
||||
int getOutputImageWidth(int inputImageHeight, int inputImageWidth);
|
||||
|
||||
/** Computes the height of the expected output image when input image size is given. */
|
||||
int getOutputImageHeight(int inputImageHeight, int inputImageWidth);
|
||||
|
||||
/**
|
||||
* Transforms a point from coordinates system of the result image back to the one of the input
|
||||
* image.
|
||||
*
|
||||
* @param point the point from the result coordinates system.
|
||||
* @param inputImageHeight the height of input image.
|
||||
* @param inputImageWidth the width of input image.
|
||||
* @return the point with the coordinates from the coordinates system of the input image.
|
||||
*/
|
||||
PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth);
|
||||
}
|
@ -1,198 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.image;
|
||||
|
||||
import android.graphics.PointF;
|
||||
import android.graphics.RectF;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.ListIterator;
|
||||
import org.tensorflow.lite.support.common.Operator;
|
||||
import org.tensorflow.lite.support.common.SequentialProcessor;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
import org.tensorflow.lite.support.common.TensorOperator;
|
||||
import org.tensorflow.lite.support.image.ops.Rot90Op;
|
||||
import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper;
|
||||
|
||||
/**
|
||||
* ImageProcessor is a helper class for preprocessing and postprocessing {@link TensorImage}. It
|
||||
* could transform a {@link TensorImage} to another by executing a chain of {@link ImageOperator}.
|
||||
*
|
||||
* <p>Example Usage:
|
||||
*
|
||||
* <pre>
|
||||
* ImageProcessor processor = new ImageProcessor.Builder()
|
||||
* .add(new ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOR)
|
||||
* .add(new Rot90Op())
|
||||
* .add(new NormalizeOp(127.5f, 127.5f))
|
||||
* .build();
|
||||
* TensorImage anotherTensorImage = processor.process(tensorImage);
|
||||
* </pre>
|
||||
*
|
||||
* <p><b>WARNING:</b> Instances of an {@code ImageProcessor} are <b>not</b> thread-safe with {@link
|
||||
* #updateNumberOfRotations}. Updating the number of rotations and then processing images (using
|
||||
* {@link #process}) must be protected from concurrent access. It is recommended to create separate
|
||||
* {@code ImageProcessor} instances for each thread. If multiple threads access a {@code
|
||||
* ImageProcessor} concurrently, it must be synchronized externally.
|
||||
*
|
||||
* @see ImageProcessor.Builder to build a {@link ImageProcessor} instance
|
||||
* @see ImageProcessor#process(TensorImage) to apply the processor on a {@link TensorImage}
|
||||
*/
|
||||
public class ImageProcessor extends SequentialProcessor<TensorImage> {
|
||||
private ImageProcessor(Builder builder) {
|
||||
super(builder);
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms a point from coordinates system of the result image back to the one of the input
|
||||
* image.
|
||||
*
|
||||
* @param point the point from the result coordinates system.
|
||||
* @param inputImageHeight the height of input image.
|
||||
* @param inputImageWidth the width of input image.
|
||||
* @return the point with the coordinates from the coordinates system of the input image.
|
||||
*/
|
||||
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
|
||||
List<Integer> widths = new ArrayList<>();
|
||||
List<Integer> heights = new ArrayList<>();
|
||||
int currentWidth = inputImageWidth;
|
||||
int currentHeight = inputImageHeight;
|
||||
for (Operator<TensorImage> op : operatorList) {
|
||||
widths.add(currentWidth);
|
||||
heights.add(currentHeight);
|
||||
ImageOperator imageOperator = (ImageOperator) op;
|
||||
int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth);
|
||||
int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth);
|
||||
currentHeight = newHeight;
|
||||
currentWidth = newWidth;
|
||||
}
|
||||
ListIterator<Operator<TensorImage>> opIterator = operatorList.listIterator(operatorList.size());
|
||||
ListIterator<Integer> widthIterator = widths.listIterator(widths.size());
|
||||
ListIterator<Integer> heightIterator = heights.listIterator(heights.size());
|
||||
while (opIterator.hasPrevious()) {
|
||||
ImageOperator imageOperator = (ImageOperator) opIterator.previous();
|
||||
int height = heightIterator.previous();
|
||||
int width = widthIterator.previous();
|
||||
point = imageOperator.inverseTransform(point, height, width);
|
||||
}
|
||||
return point;
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms a rectangle from coordinates system of the result image back to the one of the input
|
||||
* image.
|
||||
*
|
||||
* @param rect the rectangle from the result coordinates system.
|
||||
* @param inputImageHeight the height of input image.
|
||||
* @param inputImageWidth the width of input image.
|
||||
* @return the rectangle with the coordinates from the coordinates system of the input image.
|
||||
*/
|
||||
public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) {
|
||||
// when rotation is involved, corner order may change - top left changes to bottom right, .etc
|
||||
PointF p1 =
|
||||
inverseTransform(new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth);
|
||||
PointF p2 =
|
||||
inverseTransform(new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth);
|
||||
return new RectF(
|
||||
Math.min(p1.x, p2.x), Math.min(p1.y, p2.y), Math.max(p1.x, p2.x), Math.max(p1.y, p2.y));
|
||||
}
|
||||
|
||||
/**
|
||||
* The Builder to create an ImageProcessor, which could be executed later.
|
||||
*
|
||||
* @see #add(TensorOperator) to add a general TensorOperator
|
||||
* @see #add(ImageOperator) to add an ImageOperator
|
||||
* @see #build() complete the building process and get a built Processor
|
||||
*/
|
||||
public static class Builder extends SequentialProcessor.Builder<TensorImage> {
|
||||
public Builder() {
|
||||
super();
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an {@link ImageOperator} into the Operator chain.
|
||||
*
|
||||
* @param op the Operator instance to be executed then
|
||||
*/
|
||||
public Builder add(ImageOperator op) {
|
||||
super.add(op);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls
|
||||
* {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by transforming
|
||||
* the underlying {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
|
||||
*
|
||||
* @param op the Operator instance to be executed then
|
||||
*/
|
||||
public Builder add(TensorOperator op) {
|
||||
return add(new TensorOperatorWrapper(op));
|
||||
}
|
||||
|
||||
/** Completes the building process and gets the {@link ImageProcessor} instance. */
|
||||
@Override
|
||||
public ImageProcessor build() {
|
||||
return new ImageProcessor(this);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}.
|
||||
*
|
||||
* <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
|
||||
* then processing images (using {@link #process}) must be protected from concurrent access with
|
||||
* additional synchronization.
|
||||
*
|
||||
* @param k the number of rotations
|
||||
* @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
|
||||
* ImageProcessor}
|
||||
*/
|
||||
public void updateNumberOfRotations(int k) {
|
||||
updateNumberOfRotations(k, /*occurrence=*/ 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in this
|
||||
* {@link ImageProcessor}.
|
||||
*
|
||||
* <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
|
||||
* then processing images (using {@link #process}) must be protected from concurrent access with
|
||||
* additional synchronization.
|
||||
*
|
||||
* @param k the number of rotations
|
||||
* @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For
|
||||
* example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be
|
||||
* set to 1.
|
||||
* @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the
|
||||
* number of {@link Rot90Op} in this {@link ImageProcessor}
|
||||
* @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
|
||||
* ImageProcessor}
|
||||
*/
|
||||
public synchronized void updateNumberOfRotations(int k, int occurrence) {
|
||||
SupportPreconditions.checkState(
|
||||
operatorIndex.containsKey(Rot90Op.class.getName()),
|
||||
"The Rot90Op has not been added to the ImageProcessor.");
|
||||
|
||||
List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName());
|
||||
SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence");
|
||||
|
||||
// The index of the Rot90Op to be replaced in operatorList.
|
||||
int index = indexes.get(occurrence);
|
||||
Rot90Op newRot = new Rot90Op(k);
|
||||
operatorList.set(index, newRot);
|
||||
}
|
||||
}
|
@ -1,381 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.image;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.Bitmap.Config;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/**
|
||||
* TensorImage is the wrapper class for Image object. When using image processing utils in
|
||||
* TFLite.support library, it's common to convert image objects in variant types to TensorImage at
|
||||
* first.
|
||||
*
|
||||
* <p>At present, only RGB images are supported, and the A channel is always ignored.
|
||||
*
|
||||
* <p>Details of data storage: a {@link TensorImage} object may have 2 potential sources of truth: a
|
||||
* {@link Bitmap} or a {@link TensorBuffer}. {@link TensorImage} maintains the state and only
|
||||
* convert one to the other when needed.
|
||||
*
|
||||
* <p>IMPORTANT: The container doesn't own its data. Callers should not modify data objects those
|
||||
* are passed to {@link ImageContainer#set(Bitmap)} or {@link ImageContainer#set(TensorBuffer)}.
|
||||
*
|
||||
* <p>IMPORTANT: All methods are not proved thread-safe.
|
||||
*
|
||||
* @see ImageProcessor which is often used for transforming a {@link TensorImage}.
|
||||
*/
|
||||
// TODO(b/138906681): Support basic Image properties (ColorType, DataType)
|
||||
// TODO(b/138907116): Support loading images from TensorBuffer with properties.
|
||||
// TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary.
|
||||
public class TensorImage {
|
||||
|
||||
private final ImageContainer container;
|
||||
|
||||
/**
|
||||
* Initialize a TensorImage object.
|
||||
*
|
||||
* Note: The data type of this TensorImage is UINT8, which means it could naturally accept Bitmaps
|
||||
* whose pixel value range is [0, 255]. However, any image with float value pixels will not be
|
||||
* loaded correctly. In those cases, please use {@link TensorImage(DataType)}.
|
||||
*/
|
||||
public TensorImage() {
|
||||
this(DataType.UINT8);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a TensorImage object with data type specified.
|
||||
*
|
||||
* <p>Note: The shape of a TensorImage is not fixed. It is determined when {@code load} methods
|
||||
* called, and could be change later.
|
||||
*
|
||||
* @param dataType the expected internal data type of underlying tensor. The type is always fixed
|
||||
* during the lifetime of the {@link TensorImage}. To convert the data type, use {@link
|
||||
* TensorImage#createFrom(TensorImage, DataType)} to create a copy and convert data type at
|
||||
* the same time.
|
||||
* @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
|
||||
* {@link DataType#FLOAT32}.
|
||||
*/
|
||||
public TensorImage(DataType dataType) {
|
||||
SupportPreconditions.checkArgument(
|
||||
dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
|
||||
"Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
|
||||
container = new ImageContainer(dataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a {@link TensorImage} object with a {@link Bitmap}.
|
||||
*
|
||||
* @see TensorImage#load(Bitmap) for reusing the object when it's expensive to create objects
|
||||
* frequently, because every call of {@code fromBitmap} creates a new {@link TensorImage}.
|
||||
*/
|
||||
public static TensorImage fromBitmap(Bitmap bitmap) {
|
||||
TensorImage image = new TensorImage();
|
||||
image.load(bitmap);
|
||||
return image;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a deep-copy of a given {@link TensorImage} and converts internal tensor data type.
|
||||
*
|
||||
* <p>If the given {@code dataType} is different with {@code src.getDataType()}, an implicit data
|
||||
* conversion will be applied. Converting data from {@link DataType#FLOAT32} to {@link
|
||||
* DataType#UINT8} may involve default float->int conversion and value clamping, because {@link
|
||||
* DataType#UINT8} stores value from 0 to 255 (inclusively).
|
||||
*
|
||||
* @param src the TensorImage to copy from.
|
||||
* @param dataType the expected data type of newly created {@link TensorImage}.
|
||||
* @return a TensorImage whose data is copied from {@code src} and data type is {@code dataType}.
|
||||
*/
|
||||
@NonNull
|
||||
public static TensorImage createFrom(@NonNull TensorImage src, DataType dataType) {
|
||||
TensorImage dst = new TensorImage(dataType);
|
||||
if (src.container.isBufferUpdated) {
|
||||
dst.container.set(TensorBuffer.createFrom(src.getTensorBuffer(), dataType));
|
||||
} else if (src.container.isBitmapUpdated) {
|
||||
Bitmap srcBitmap = src.getBitmap();
|
||||
dst.container.set(srcBitmap.copy(srcBitmap.getConfig(), srcBitmap.isMutable()));
|
||||
}
|
||||
return dst;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a Bitmap image object into TensorImage.
|
||||
*
|
||||
* Important: When loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The
|
||||
* {@code TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well.
|
||||
* In this method, we perform a zero-copy approach for that bitmap, by simply holding its
|
||||
* reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
|
||||
*
|
||||
* Note: To get the best performance, please load images in the same shape to avoid memory
|
||||
* re-allocation.
|
||||
*
|
||||
* @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888.
|
||||
*/
|
||||
public void load(@NonNull Bitmap bitmap) {
|
||||
SupportPreconditions.checkNotNull(bitmap, "Cannot load null bitmap.");
|
||||
SupportPreconditions.checkArgument(
|
||||
bitmap.getConfig().equals(Config.ARGB_8888), "Only supports loading ARGB_8888 bitmaps.");
|
||||
container.set(bitmap);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a float array as RGB pixels into TensorImage, representing the pixels inside.
|
||||
*
|
||||
* <p>Note: If the TensorImage has data type {@link DataType#UINT8}, numeric casting and clamping
|
||||
* will be applied.
|
||||
*
|
||||
* @param pixels The RGB pixels representing the image.
|
||||
* @param shape The shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3).
|
||||
*/
|
||||
public void load(@NonNull float[] pixels, @NonNull int[] shape) {
|
||||
checkImageTensorShape(shape);
|
||||
TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
|
||||
buffer.loadArray(pixels, shape);
|
||||
load(buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads an uint8 array as RGB pixels into TensorImage, representing the pixels inside.
|
||||
*
|
||||
* <p>Note: If the TensorImage has data type {@link DataType#UINT8}, all pixel values will clamp
|
||||
* into [0, 255].
|
||||
*
|
||||
* @param pixels The RGB pixels representing the image.
|
||||
* @param shape The shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3).
|
||||
*/
|
||||
public void load(@NonNull int[] pixels, @NonNull int[] shape) {
|
||||
checkImageTensorShape(shape);
|
||||
TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
|
||||
buffer.loadArray(pixels, shape);
|
||||
load(buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a TensorBuffer containing pixel values. The color layout should be RGB.
|
||||
*
|
||||
* @param buffer The TensorBuffer to load. Its shape should be either (h, w, 3) or (1, h, w, 3).
|
||||
*/
|
||||
public void load(TensorBuffer buffer) {
|
||||
checkImageTensorShape(buffer.getShape());
|
||||
container.set(buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a bitmap representation of this TensorImage.
|
||||
*
|
||||
* <p>Important: It's only a reference. DO NOT MODIFY. We don't create a copy here for performance
|
||||
* concern, but if modification is necessary, please make a copy.
|
||||
*
|
||||
* @return a reference to a Bitmap in ARGB_8888 config. "A" channel is always opaque.
|
||||
* @throws IllegalStateException if the TensorImage never loads data, or if the TensorImage is
|
||||
* holding a float-value image in {@code TensorBuffer}.
|
||||
*/
|
||||
@NonNull
|
||||
public Bitmap getBitmap() {
|
||||
return container.getBitmap();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a ByteBuffer representation of this TensorImage.
|
||||
*
|
||||
* <p>Important: It's only a reference. DO NOT MODIFY. We don't create a copy here for performance
|
||||
* concern, but if modification is necessary, please make a copy.
|
||||
*
|
||||
* <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
|
||||
*
|
||||
* @return a reference to a ByteBuffer which holds the image data.
|
||||
* @throws IllegalStateException if the TensorImage never loads data.
|
||||
*/
|
||||
@NonNull
|
||||
public ByteBuffer getBuffer() {
|
||||
return container.getTensorBuffer().getBuffer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a ByteBuffer representation of this TensorImage.
|
||||
*
|
||||
* <p>Important: It's only a reference. DO NOT MODIFY. We don't create a copy here for performance
|
||||
* concern, but if modification is necessary, please make a copy.
|
||||
*
|
||||
* @return a reference to a TensorBuffer which holds the image data.
|
||||
* @throws IllegalStateException if the TensorImage never loads data.
|
||||
*/
|
||||
@NonNull
|
||||
public TensorBuffer getTensorBuffer() {
|
||||
return container.getTensorBuffer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current data type.
|
||||
*
|
||||
* @return a data type. Currently only UINT8 and FLOAT32 are possible.
|
||||
*/
|
||||
public DataType getDataType() {
|
||||
return container.getDataType();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the image width.
|
||||
*
|
||||
* @throws IllegalStateException if the TensorImage never loads data.
|
||||
* @throws IllegalArgumentException if the container data is corrupted.
|
||||
*/
|
||||
public int getWidth() {
|
||||
return container.getWidth();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the image height.
|
||||
*
|
||||
* @throws IllegalStateException if the TensorImage never loads data.
|
||||
* @throws IllegalArgumentException if the container data is corrupted.
|
||||
*/
|
||||
public int getHeight() {
|
||||
return container.getHeight();
|
||||
}
|
||||
|
||||
// Requires tensor shape [h, w, 3] or [1, h, w, 3].
|
||||
static void checkImageTensorShape(int[] shape) {
|
||||
SupportPreconditions.checkArgument(
|
||||
(shape.length == 3 || (shape.length == 4 && shape[0] == 1))
|
||||
&& shape[shape.length - 3] > 0
|
||||
&& shape[shape.length - 2] > 0
|
||||
&& shape[shape.length - 1] == 3,
|
||||
"Only supports image shape in (h, w, c) or (1, h, w, c), and channels representing R, G, B"
|
||||
+ " in order.");
|
||||
}
|
||||
|
||||
// Handles RGB image data storage strategy of TensorBuffer.
|
||||
private static class ImageContainer {
|
||||
|
||||
private TensorBuffer bufferImage;
|
||||
private boolean isBufferUpdated;
|
||||
private Bitmap bitmapImage;
|
||||
private boolean isBitmapUpdated;
|
||||
|
||||
private final DataType dataType;
|
||||
|
||||
private static final int ARGB_8888_ELEMENT_BYTES = 4;
|
||||
|
||||
ImageContainer(DataType dataType) {
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
||||
// Internal method to set the image source-of-truth with a bitmap. The bitmap has to be
|
||||
// ARGB_8888.
|
||||
void set(Bitmap bitmap) {
|
||||
bitmapImage = bitmap;
|
||||
isBufferUpdated = false;
|
||||
isBitmapUpdated = true;
|
||||
}
|
||||
|
||||
// Internal method to set the image source-of-truth with a TensorBuffer.
|
||||
void set(TensorBuffer buffer) {
|
||||
bufferImage = buffer;
|
||||
isBitmapUpdated = false;
|
||||
isBufferUpdated = true;
|
||||
}
|
||||
|
||||
int getWidth() {
|
||||
SupportPreconditions.checkState(
|
||||
isBitmapUpdated || isBufferUpdated,
|
||||
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
|
||||
if (isBitmapUpdated) {
|
||||
return bitmapImage.getWidth();
|
||||
}
|
||||
return getBufferDimensionSize(-2);
|
||||
}
|
||||
|
||||
int getHeight() {
|
||||
SupportPreconditions.checkState(
|
||||
isBitmapUpdated || isBufferUpdated,
|
||||
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
|
||||
if (isBitmapUpdated) {
|
||||
return bitmapImage.getHeight();
|
||||
}
|
||||
return getBufferDimensionSize(-3);
|
||||
}
|
||||
|
||||
// Internal helper method to get the size of one dimension in the shape of the `bufferImage`.
|
||||
// Requires `isBufferUpdated` is true.
|
||||
// Throws `IllegalArgumentException` if data is corrupted.
|
||||
private int getBufferDimensionSize(int dim) {
|
||||
int[] shape = bufferImage.getShape();
|
||||
// The defensive check is needed because bufferImage might be invalidly changed by user
|
||||
// (a.k.a internal data is corrupted)
|
||||
TensorImage.checkImageTensorShape(shape);
|
||||
dim = dim % shape.length;
|
||||
if (dim < 0) {
|
||||
dim += shape.length;
|
||||
}
|
||||
return shape[dim];
|
||||
}
|
||||
|
||||
public DataType getDataType() {
|
||||
return dataType;
|
||||
}
|
||||
|
||||
// Internal method to update the internal Bitmap data by TensorBuffer data.
|
||||
@NonNull
|
||||
Bitmap getBitmap() {
|
||||
if (isBitmapUpdated) {
|
||||
return bitmapImage;
|
||||
}
|
||||
if (!isBufferUpdated) {
|
||||
throw new IllegalStateException(
|
||||
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
|
||||
}
|
||||
if (bufferImage.getDataType() != DataType.UINT8) {
|
||||
throw new IllegalStateException(
|
||||
"TensorImage is holding a float-value image which is not able to convert a Bitmap.");
|
||||
}
|
||||
int requiredAllocation = bufferImage.getFlatSize() * ARGB_8888_ELEMENT_BYTES;
|
||||
// Create a new bitmap and reallocate memory for it.
|
||||
if (bitmapImage == null || bitmapImage.getAllocationByteCount() < requiredAllocation) {
|
||||
int[] shape = bufferImage.getShape();
|
||||
int h = shape[shape.length - 3];
|
||||
int w = shape[shape.length - 2];
|
||||
bitmapImage = Bitmap.createBitmap(w, h, Config.ARGB_8888);
|
||||
}
|
||||
ImageConversions.convertTensorBufferToBitmap(bufferImage, bitmapImage);
|
||||
isBitmapUpdated = true;
|
||||
return bitmapImage;
|
||||
}
|
||||
|
||||
// Internal method to update the internal TensorBuffer data by Bitmap data.
|
||||
@NonNull
|
||||
TensorBuffer getTensorBuffer() {
|
||||
if (isBufferUpdated) {
|
||||
return bufferImage;
|
||||
}
|
||||
SupportPreconditions.checkArgument(
|
||||
isBitmapUpdated,
|
||||
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
|
||||
int requiredFlatSize = bitmapImage.getWidth() * bitmapImage.getHeight() * 3;
|
||||
if (bufferImage == null
|
||||
|| (!bufferImage.isDynamic() && bufferImage.getFlatSize() != requiredFlatSize)) {
|
||||
bufferImage = TensorBuffer.createDynamic(dataType);
|
||||
}
|
||||
ImageConversions.convertBitmapToTensorBuffer(bitmapImage, bufferImage);
|
||||
isBufferUpdated = true;
|
||||
return bufferImage;
|
||||
}
|
||||
}
|
||||
}
|
@ -1,89 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.image.ops;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.PointF;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.support.image.ImageOperator;
|
||||
import org.tensorflow.lite.support.image.TensorImage;
|
||||
|
||||
/**
|
||||
* As a computation unit for processing images, it can resize an image to user-specified size.
|
||||
*
|
||||
* <p>It interpolates pixels when image is stretched, and discards pixels when image is compressed.
|
||||
*
|
||||
* @see ResizeWithCropOrPadOp for resizing without content distortion.
|
||||
*/
|
||||
public class ResizeOp implements ImageOperator {
|
||||
|
||||
/** Algorithms for resizing. */
|
||||
public enum ResizeMethod {
|
||||
BILINEAR,
|
||||
NEAREST_NEIGHBOR
|
||||
}
|
||||
|
||||
private final int targetHeight;
|
||||
private final int targetWidth;
|
||||
private final boolean useBilinear;
|
||||
|
||||
/**
|
||||
* Creates a ResizeOp which can resize images to specified size in specified method.
|
||||
*
|
||||
* @param targetHeight: The expected height of resized image.
|
||||
* @param targetWidth: The expected width of resized image.
|
||||
* @param resizeMethod: The algorithm to use for resizing. Options: {@link ResizeMethod}
|
||||
*/
|
||||
public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) {
|
||||
this.targetHeight = targetHeight;
|
||||
this.targetWidth = targetWidth;
|
||||
useBilinear = (resizeMethod == ResizeMethod.BILINEAR);
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the defined resizing on given image and returns the result.
|
||||
*
|
||||
* <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
|
||||
* with the output.
|
||||
*
|
||||
* @param image input image.
|
||||
* @return output image.
|
||||
*/
|
||||
@Override
|
||||
@NonNull
|
||||
public TensorImage apply(@NonNull TensorImage image) {
|
||||
Bitmap scaled =
|
||||
Bitmap.createScaledBitmap(image.getBitmap(), targetWidth, targetHeight, useBilinear);
|
||||
image.load(scaled);
|
||||
return image;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
|
||||
return targetHeight;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
|
||||
return targetWidth;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
|
||||
return new PointF(
|
||||
point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight);
|
||||
}
|
||||
}
|
@ -1,125 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.image.ops;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.Bitmap.Config;
|
||||
import android.graphics.Canvas;
|
||||
import android.graphics.PointF;
|
||||
import android.graphics.Rect;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.support.image.ImageOperator;
|
||||
import org.tensorflow.lite.support.image.TensorImage;
|
||||
|
||||
/**
|
||||
* As a computation unit for processing images, it could resize image to predefined size.
|
||||
*
|
||||
* <p>It will not stretch or compress the content of image. However, to fit the new size, it crops
|
||||
* or pads pixels. When it crops image, it performs a center-crop; when it pads pixels, it performs
|
||||
* a zero-padding.
|
||||
*
|
||||
* @see ResizeOp for reszing images while stretching / compressing the content.
|
||||
*/
|
||||
public class ResizeWithCropOrPadOp implements ImageOperator {
|
||||
private final int targetHeight;
|
||||
private final int targetWidth;
|
||||
private final Bitmap output;
|
||||
|
||||
/**
|
||||
* Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts
|
||||
* center-crop and zero-padding.
|
||||
*
|
||||
* @param targetHeight: The expected height of cropped/padded image.
|
||||
* @param targetWidth: The expected width of cropped/padded image.
|
||||
*/
|
||||
public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) {
|
||||
this.targetHeight = targetHeight;
|
||||
this.targetWidth = targetWidth;
|
||||
output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888);
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the defined resizing with cropping or/and padding on given image and returns the
|
||||
* result.
|
||||
*
|
||||
* <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
|
||||
* with the output.
|
||||
*
|
||||
* @param image input image.
|
||||
* @return output image.
|
||||
*/
|
||||
@Override
|
||||
@NonNull
|
||||
public TensorImage apply(@NonNull TensorImage image) {
|
||||
Bitmap input = image.getBitmap();
|
||||
int srcL;
|
||||
int srcR;
|
||||
int srcT;
|
||||
int srcB;
|
||||
int dstL;
|
||||
int dstR;
|
||||
int dstT;
|
||||
int dstB;
|
||||
int w = input.getWidth();
|
||||
int h = input.getHeight();
|
||||
if (targetWidth > w) { // padding
|
||||
srcL = 0;
|
||||
srcR = w;
|
||||
dstL = (targetWidth - w) / 2;
|
||||
dstR = dstL + w;
|
||||
} else { // cropping
|
||||
dstL = 0;
|
||||
dstR = targetWidth;
|
||||
srcL = (w - targetWidth) / 2;
|
||||
srcR = srcL + targetWidth;
|
||||
}
|
||||
if (targetHeight > h) { // padding
|
||||
srcT = 0;
|
||||
srcB = h;
|
||||
dstT = (targetHeight - h) / 2;
|
||||
dstB = dstT + h;
|
||||
} else { // cropping
|
||||
dstT = 0;
|
||||
dstB = targetHeight;
|
||||
srcT = (h - targetHeight) / 2;
|
||||
srcB = srcT + targetHeight;
|
||||
}
|
||||
Rect src = new Rect(srcL, srcT, srcR, srcB);
|
||||
Rect dst = new Rect(dstL, dstT, dstR, dstB);
|
||||
new Canvas(output).drawBitmap(input, src, dst, null);
|
||||
image.load(output);
|
||||
return image;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
|
||||
return targetHeight;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
|
||||
return targetWidth;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
|
||||
return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth);
|
||||
}
|
||||
|
||||
private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) {
|
||||
return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2);
|
||||
}
|
||||
}
|
@ -1,103 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.image.ops;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.Matrix;
|
||||
import android.graphics.PointF;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.support.image.ImageOperator;
|
||||
import org.tensorflow.lite.support.image.TensorImage;
|
||||
|
||||
/** Rotates image counter-clockwise. */
|
||||
public class Rot90Op implements ImageOperator {
|
||||
|
||||
private final int numRotation;
|
||||
|
||||
/** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */
|
||||
public Rot90Op() {
|
||||
this(1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times counter-clockwise.
|
||||
*
|
||||
* @param k: The number of times the image is rotated by 90 degrees. If it's positive, the image
|
||||
* will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise.
|
||||
*/
|
||||
public Rot90Op(int k) {
|
||||
numRotation = k % 4;
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the defined rotation on given image and returns the result.
|
||||
*
|
||||
* <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
|
||||
* with the output.
|
||||
*
|
||||
* @param image input image.
|
||||
* @return output image.
|
||||
*/
|
||||
@NonNull
|
||||
@Override
|
||||
public TensorImage apply(@NonNull TensorImage image) {
|
||||
Bitmap input = image.getBitmap();
|
||||
if (numRotation == 0) {
|
||||
return image;
|
||||
}
|
||||
int w = input.getWidth();
|
||||
int h = input.getHeight();
|
||||
Matrix matrix = new Matrix();
|
||||
matrix.postTranslate(w * 0.5f, h * 0.5f);
|
||||
matrix.postRotate(-90 * numRotation);
|
||||
int newW = (numRotation % 2 == 0) ? w : h;
|
||||
int newH = (numRotation % 2 == 0) ? h : w;
|
||||
matrix.postTranslate(newW * 0.5f, newH * 0.5f);
|
||||
Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false);
|
||||
image.load(output);
|
||||
return image;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
|
||||
return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
|
||||
return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
|
||||
int inverseNumRotation = (4 - numRotation) % 4;
|
||||
int height = getOutputImageHeight(inputImageHeight, inputImageWidth);
|
||||
int width = getOutputImageWidth(inputImageHeight, inputImageWidth);
|
||||
return transformImpl(point, height, width, inverseNumRotation);
|
||||
}
|
||||
|
||||
private static PointF transformImpl(PointF point, int height, int width, int numRotation) {
|
||||
if (numRotation == 0) {
|
||||
return point;
|
||||
} else if (numRotation == 1) {
|
||||
return new PointF(point.y, width - point.x);
|
||||
} else if (numRotation == 2) {
|
||||
return new PointF(width - point.x, height - point.y);
|
||||
} else { // numRotation == 3
|
||||
return new PointF(height - point.y, point.x);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,70 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.image.ops;
|
||||
|
||||
import android.graphics.PointF;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
import org.tensorflow.lite.support.common.TensorOperator;
|
||||
import org.tensorflow.lite.support.image.ImageOperator;
|
||||
import org.tensorflow.lite.support.image.TensorImage;
|
||||
|
||||
/**
|
||||
* The adapter that makes a TensorOperator able to run with TensorImage.
|
||||
*
|
||||
* @see org.tensorflow.lite.support.common.TensorOperator
|
||||
* @see org.tensorflow.lite.support.image.TensorImage
|
||||
*/
|
||||
public class TensorOperatorWrapper implements ImageOperator {
|
||||
|
||||
private final TensorOperator tensorOp;
|
||||
|
||||
/**
|
||||
* Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link
|
||||
* TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link
|
||||
* org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
|
||||
*
|
||||
* <p>Requirement: The {@code op} should not change coordinate system when applied on an image.
|
||||
*
|
||||
* @param op The created operator.
|
||||
*/
|
||||
public TensorOperatorWrapper(TensorOperator op) {
|
||||
tensorOp = op;
|
||||
}
|
||||
|
||||
@Override
|
||||
@NonNull
|
||||
public TensorImage apply(@NonNull TensorImage image) {
|
||||
SupportPreconditions.checkNotNull(image, "Op cannot apply on null image.");
|
||||
image.load(tensorOp.apply(image.getTensorBuffer()));
|
||||
return image;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
|
||||
return inputImageHeight;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
|
||||
return inputImageWidth;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
|
||||
return point;
|
||||
}
|
||||
}
|
@ -1,62 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.label;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Category is a util class, contains a label and a float value. Typically it's used as result of
|
||||
* classification tasks.
|
||||
*/
|
||||
public final class Category {
|
||||
private final String label;
|
||||
private final float score;
|
||||
|
||||
/** Constructs a Category. */
|
||||
public Category(String label, float score) {
|
||||
this.label = label;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
/** Gets the reference of category's label. */
|
||||
public String getLabel() {
|
||||
return label;
|
||||
}
|
||||
|
||||
/** Gets the score of the category. */
|
||||
public float getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (o instanceof Category) {
|
||||
Category other = (Category) o;
|
||||
return (other.getLabel().equals(this.label) && other.getScore() == this.score);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(label, score);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "<Category \"" + label + "\" (score=" + score + ")>";
|
||||
}
|
||||
}
|
@ -1,64 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.label;
|
||||
|
||||
import android.util.Log;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/** Label operation utils. */
|
||||
public class LabelUtil {
|
||||
/**
|
||||
* Maps an int value tensor to a list of string labels. It takes an array of strings as the
|
||||
* dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background",
|
||||
* "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"].
|
||||
*
|
||||
* @param tensorBuffer: A tensor with index values. The values should be non-negative integers,
|
||||
* and each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is
|
||||
* given as a float {@link TensorBuffer}, values will be cast to integers. All values that are
|
||||
* out of bound will map to empty string.
|
||||
* @param labels: A list of strings, used as a dictionary to look up. The index of the array
|
||||
* element will be used as the key. To get better performance, use an object that implements
|
||||
* RandomAccess, such as {@link ArrayList}.
|
||||
* @param offset: The offset value when look up int values in the {@code labels}.
|
||||
* @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}.
|
||||
* @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null.
|
||||
*/
|
||||
public static List<String> mapValueToLabels(
|
||||
@NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) {
|
||||
SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null");
|
||||
SupportPreconditions.checkNotNull(labels, "Given labels should not be null");
|
||||
int[] values = tensorBuffer.getIntArray();
|
||||
Log.d("values", Arrays.toString(values));
|
||||
List<String> result = new ArrayList<>();
|
||||
for (int v : values) {
|
||||
int index = v + offset;
|
||||
if (index < 0 || index >= labels.size()) {
|
||||
result.add("");
|
||||
} else {
|
||||
result.add(labels.get(index));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Private constructor to prevent initialization.
|
||||
private LabelUtil() {}
|
||||
}
|
@ -1,224 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.label;
|
||||
|
||||
import android.content.Context;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/**
|
||||
* TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.
|
||||
*
|
||||
* <p>For example, an image classification model may have an output tensor with shape as {1, 10},
|
||||
* where 1 is the batch size and 10 is the number of categories. In fact, on the 2nd axis, we could
|
||||
* label each sub-tensor with the name or description of each corresponding category. {@link
|
||||
* TensorLabel} could help converting the plain Tensor in {@link TensorBuffer} into a map from
|
||||
* predefined labels to sub-tensors. In this case, if provided 10 labels for the 2nd axis, {@link
|
||||
* TensorLabel} could convert the original {1, 10} Tensor to a 10 element map, each value of which
|
||||
* is Tensor in shape {} (scalar). Usage example:
|
||||
*
|
||||
* <pre>
|
||||
* TensorBuffer outputTensor = ...;
|
||||
* {@literal List<String>} labels = FileUtil.loadLabels(context, labelFilePath);
|
||||
* // labels the first axis with size greater than one
|
||||
* TensorLabel labeled = new TensorLabel(labels, outputTensor);
|
||||
* // If each sub-tensor has effectively size 1, we can directly get a float value
|
||||
* {@literal Map<String, Float>} probabilities = labeled.getMapWithFloatValue();
|
||||
* // Or get sub-tensors, when each sub-tensor has elements more than 1
|
||||
* {@literal Map<String, TensorBuffer>} subTensors = labeled.getMapWithTensorBuffer();
|
||||
* </pre>
|
||||
*
|
||||
* <p>Note: currently we only support tensor-to-map conversion for the first label with size greater
|
||||
* than 1.
|
||||
*
|
||||
* @see org.tensorflow.lite.support.common.FileUtil#loadLabels(Context, String) to load labels from
|
||||
* a label file (plain text file whose each line is a label) in assets simply.
|
||||
*/
|
||||
public class TensorLabel {
|
||||
private final Map<Integer, List<String>> axisLabels;
|
||||
private final TensorBuffer tensorBuffer;
|
||||
private final int[] shape;
|
||||
|
||||
/**
|
||||
* Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
|
||||
*
|
||||
* @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
|
||||
* labels. Note: The size of labels should be same with the size of the tensor on that axis.
|
||||
* @param tensorBuffer The TensorBuffer to be labeled.
|
||||
* @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
|
||||
* value in {@code axisLabels} is null.
|
||||
* @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to
|
||||
* the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code
|
||||
* tensorBuffer} on the given dimension.
|
||||
*/
|
||||
public TensorLabel(
|
||||
@NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
|
||||
SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
|
||||
SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
|
||||
this.axisLabels = axisLabels;
|
||||
this.tensorBuffer = tensorBuffer;
|
||||
this.shape = tensorBuffer.getShape();
|
||||
for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
|
||||
int axis = entry.getKey();
|
||||
SupportPreconditions.checkArgument(
|
||||
axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
|
||||
SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis);
|
||||
SupportPreconditions.checkArgument(
|
||||
shape[axis] == entry.getValue().size(),
|
||||
"Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
|
||||
*
|
||||
* <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if
|
||||
* the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from
|
||||
* 0), and size of {@code axisLabels} should be 10 as well.
|
||||
*
|
||||
* @param axisLabels A list of labels, whose size should be same with the size of the tensor on
|
||||
* the to-be-labeled axis.
|
||||
* @param tensorBuffer The TensorBuffer to be labeled.
|
||||
*/
|
||||
public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
|
||||
this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
|
||||
* mapping on the first axis with size greater than 1 currently.
|
||||
*/
|
||||
@NonNull
|
||||
public Map<String, TensorBuffer> getMapWithTensorBuffer() {
|
||||
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
|
||||
|
||||
Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
|
||||
SupportPreconditions.checkArgument(
|
||||
axisLabels.containsKey(labeledAxis),
|
||||
"get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
|
||||
List<String> labels = axisLabels.get(labeledAxis);
|
||||
|
||||
DataType dataType = tensorBuffer.getDataType();
|
||||
int typeSize = tensorBuffer.getTypeSize();
|
||||
int flatSize = tensorBuffer.getFlatSize();
|
||||
|
||||
// Gets the underlying bytes that could be used to generate the sub-array later.
|
||||
ByteBuffer byteBuffer = tensorBuffer.getBuffer();
|
||||
byteBuffer.rewind();
|
||||
|
||||
// Note: computation below is only correct when labeledAxis is the first axis with size greater
|
||||
// than 1.
|
||||
int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
|
||||
int i = 0;
|
||||
SupportPreconditions.checkNotNull(labels, "Label list should never be null");
|
||||
for (String label : labels) {
|
||||
// Gets the corresponding TensorBuffer.
|
||||
byteBuffer.position(i * subArrayLength);
|
||||
ByteBuffer subBuffer = byteBuffer.slice();
|
||||
// ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
|
||||
subBuffer.order(byteBuffer.order()).limit(subArrayLength);
|
||||
TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
|
||||
labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
|
||||
labelToTensorMap.put(label, labelBuffer);
|
||||
i += 1;
|
||||
}
|
||||
return labelToTensorMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a map that maps label to float. Only allow the mapping on the first axis with size greater
|
||||
* than 1, and the axis should be effectively the last axis (which means every sub tensor
|
||||
* specified by this axis should have a flat size of 1).
|
||||
*
|
||||
* <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
|
||||
*
|
||||
* @throws IllegalStateException if size of a sub tensor on each label is not 1.
|
||||
*/
|
||||
@NonNull
|
||||
public Map<String, Float> getMapWithFloatValue() {
|
||||
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
|
||||
SupportPreconditions.checkState(
|
||||
labeledAxis == shape.length - 1,
|
||||
"get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
|
||||
List<String> labels = axisLabels.get(labeledAxis);
|
||||
float[] data = tensorBuffer.getFloatArray();
|
||||
SupportPreconditions.checkState(labels.size() == data.length);
|
||||
Map<String, Float> result = new LinkedHashMap<>();
|
||||
int i = 0;
|
||||
for (String label : labels) {
|
||||
result.put(label, data[i]);
|
||||
i += 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a list of {@link Category} from the {@link TensorLabel} object.
|
||||
*
|
||||
* <p>The axis of label should be effectively the last axis (which means every sub tensor
|
||||
* specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
|
||||
* converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}}
|
||||
* and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}.
|
||||
*
|
||||
* <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
|
||||
* the result.
|
||||
*
|
||||
* @throws IllegalStateException if size of a sub tensor on each label is not 1.
|
||||
*/
|
||||
@NonNull
|
||||
public List<Category> getCategoryList() {
|
||||
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
|
||||
SupportPreconditions.checkState(
|
||||
labeledAxis == shape.length - 1,
|
||||
"get a Category list is only valid when the only labeled axis is the last one.");
|
||||
List<String> labels = axisLabels.get(labeledAxis);
|
||||
float[] data = tensorBuffer.getFloatArray();
|
||||
SupportPreconditions.checkState(labels.size() == data.length);
|
||||
List<Category> result = new ArrayList<>();
|
||||
int i = 0;
|
||||
for (String label : labels) {
|
||||
result.add(new Category(label, data[i]));
|
||||
i += 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
|
||||
int[] shape = tensorBuffer.getShape();
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
if (shape[i] > 1) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"Cannot find an axis to label. A valid axis to label should have size larger than 1.");
|
||||
}
|
||||
|
||||
// Helper function to wrap the List<String> to a one-entry map.
|
||||
private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
|
||||
Map<Integer, List<String>> map = new LinkedHashMap<>();
|
||||
map.put(axis, labels);
|
||||
return map;
|
||||
}
|
||||
}
|
@ -1,74 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.label.ops;
|
||||
|
||||
import android.content.Context;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.support.common.FileUtil;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
import org.tensorflow.lite.support.label.TensorLabel;
|
||||
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
|
||||
|
||||
/**
|
||||
* Labels TensorBuffer with axisLabels for outputs.
|
||||
*
|
||||
* <p>Apply on a {@code TensorBuffer} to get a {@code TensorLabel} that could output a Map, which is
|
||||
* a pair of the label name and the corresponding TensorBuffer value.
|
||||
*/
|
||||
public class LabelAxisOp {
|
||||
// Axis and its corresponding label names.
|
||||
private final Map<Integer, List<String>> axisLabels;
|
||||
|
||||
protected LabelAxisOp(Builder builder) {
|
||||
axisLabels = builder.axisLabels;
|
||||
}
|
||||
|
||||
public TensorLabel apply(@NonNull TensorBuffer buffer) {
|
||||
SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
|
||||
return new TensorLabel(axisLabels, buffer);
|
||||
}
|
||||
|
||||
/** The inner builder class to build a LabelTensor Operator. */
|
||||
public static class Builder {
|
||||
private final Map<Integer, List<String>> axisLabels;
|
||||
|
||||
protected Builder() {
|
||||
axisLabels = new HashMap<>();
|
||||
}
|
||||
|
||||
public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
|
||||
throws IOException {
|
||||
SupportPreconditions.checkNotNull(context, "Context cannot be null.");
|
||||
SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
|
||||
List<String> labels = FileUtil.loadLabels(context, filePath);
|
||||
axisLabels.put(axis, labels);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
|
||||
axisLabels.put(axis, labels);
|
||||
return this;
|
||||
}
|
||||
|
||||
public LabelAxisOp build() {
|
||||
return new LabelAxisOp(this);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,69 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.model;
|
||||
|
||||
import android.util.Log;
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import org.tensorflow.lite.Delegate;
|
||||
|
||||
/**
|
||||
* Helper class to create and call necessary methods of {@code GpuDelegate} which is not a strict
|
||||
* dependency.
|
||||
*/
|
||||
class GpuDelegateProxy implements Delegate, Closeable {
|
||||
|
||||
private static final String TAG = "GpuDelegateProxy";
|
||||
|
||||
private final Delegate proxiedDelegate;
|
||||
private final Closeable proxiedCloseable;
|
||||
|
||||
@Nullable
|
||||
public static GpuDelegateProxy maybeNewInstance() {
|
||||
try {
|
||||
Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate");
|
||||
Object instance = clazz.getDeclaredConstructor().newInstance();
|
||||
return new GpuDelegateProxy(instance);
|
||||
} catch (ReflectiveOperationException e) {
|
||||
Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/** Calls {@code close()} method of the delegate. */
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
proxiedCloseable.close();
|
||||
} catch (IOException e) {
|
||||
// Should not trigger, because GpuDelegate#close never throws. The catch is required because
|
||||
// of Closeable#close.
|
||||
Log.e(TAG, "Failed to close the GpuDelegate.", e);
|
||||
}
|
||||
}
|
||||
|
||||
/** Calls {@code getNativeHandle()} method of the delegate. */
|
||||
@Override
|
||||
public long getNativeHandle() {
|
||||
return proxiedDelegate.getNativeHandle();
|
||||
}
|
||||
|
||||
private GpuDelegateProxy(Object instance) {
|
||||
this.proxiedCloseable = (Closeable) instance;
|
||||
this.proxiedDelegate = (Delegate) instance;
|
||||
}
|
||||
}
|
@ -1,285 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.model;
|
||||
|
||||
import android.content.Context;
|
||||
import java.io.IOException;
|
||||
import java.nio.MappedByteBuffer;
|
||||
import java.util.Map;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import org.tensorflow.lite.Interpreter;
|
||||
import org.tensorflow.lite.Tensor;
|
||||
import org.tensorflow.lite.support.common.FileUtil;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
|
||||
/**
|
||||
* The wrapper class for a TFLite model and a TFLite interpreter.
|
||||
*
|
||||
* <p>Note: A {@link Model} can only holds 1 TFLite model at a time, and always holds a TFLite
|
||||
* interpreter instance to run it.
|
||||
*/
|
||||
public class Model {
|
||||
|
||||
/** The runtime device type used for executing classification. */
|
||||
public enum Device {
|
||||
CPU,
|
||||
NNAPI,
|
||||
GPU
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for running the model. Configurable parameters includes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the model.
|
||||
* The default value is {@link Device#CPU}.
|
||||
* <li>{@code numThreads} {@link Builder#setNumThreads(int)} specifies the number of threads
|
||||
* used by TFLite inference. It's only effective when device is set to {@link Device#CPU}
|
||||
* and default value is 1.
|
||||
* </ul>
|
||||
*/
|
||||
public static class Options {
|
||||
private final Device device;
|
||||
private final int numThreads;
|
||||
|
||||
/** Builder of {@link Options}. See its doc for details. */
|
||||
public static class Builder {
|
||||
private Device device = Device.CPU;
|
||||
private int numThreads = 1;
|
||||
|
||||
public Builder setDevice(Device device) {
|
||||
this.device = device;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumThreads(int numThreads) {
|
||||
this.numThreads = numThreads;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Options build() {
|
||||
return new Options(this);
|
||||
}
|
||||
}
|
||||
|
||||
private Options(Builder builder) {
|
||||
device = builder.device;
|
||||
numThreads = builder.numThreads;
|
||||
}
|
||||
}
|
||||
|
||||
/** An instance of the driver class to run model inference with Tensorflow Lite. */
|
||||
private final Interpreter interpreter;
|
||||
|
||||
/** Path to tflite model file in asset folder. */
|
||||
private final String modelPath;
|
||||
|
||||
/** The memory-mapped model data. */
|
||||
private final MappedByteBuffer byteModel;
|
||||
|
||||
private final GpuDelegateProxy gpuDelegateProxy;
|
||||
|
||||
/**
|
||||
* Builder for {@link Model}.
|
||||
*
|
||||
* @deprecated Please use {@link Model#createModel(Context, String, Options)}.
|
||||
*/
|
||||
@Deprecated
|
||||
public static class Builder {
|
||||
private Device device = Device.CPU;
|
||||
private int numThreads = 1;
|
||||
private final String modelPath;
|
||||
private final MappedByteBuffer byteModel;
|
||||
|
||||
/**
|
||||
* Creates a builder which loads tflite model from asset folder using memory-mapped files.
|
||||
*
|
||||
* @param context: Application context to access assets.
|
||||
* @param modelPath: Asset path of the model (.tflite file).
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
@NonNull
|
||||
public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException {
|
||||
this.modelPath = modelPath;
|
||||
byteModel = FileUtil.loadMappedFile(context, modelPath);
|
||||
}
|
||||
|
||||
/** Sets running device. By default, TFLite will run on CPU. */
|
||||
@NonNull
|
||||
public Builder setDevice(Device device) {
|
||||
this.device = device;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Sets number of threads. By default it's 1. */
|
||||
@NonNull
|
||||
public Builder setNumThreads(int numThreads) {
|
||||
this.numThreads = numThreads;
|
||||
return this;
|
||||
}
|
||||
|
||||
// Note: The implementation is copied from `Model#createModel`. As the builder is going to be
|
||||
// deprecated, this function is also to be removed.
|
||||
@NonNull
|
||||
public Model build() {
|
||||
Options options = new Options.Builder().setNumThreads(numThreads).setDevice(device).build();
|
||||
return createModel(byteModel, modelPath, options);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a model from assets and initialize TFLite interpreter.
|
||||
*
|
||||
* <p>The default options are: (1) CPU device; (2) one thread.
|
||||
*
|
||||
* @param context The App Context.
|
||||
* @param modelPath The path of the model file.
|
||||
* @throws IOException if any exception occurs when open the model file.
|
||||
*/
|
||||
public static Model createModel(@NonNull Context context, @NonNull String modelPath)
|
||||
throws IOException {
|
||||
return createModel(context, modelPath, new Options.Builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a model from assets and initialize TFLite interpreter with given options.
|
||||
*
|
||||
* @see Options for details.
|
||||
* @param context The App Context.
|
||||
* @param modelPath The path of the model file.
|
||||
* @param options The options for running the model.
|
||||
* @throws IOException if any exception occurs when open the model file.
|
||||
*/
|
||||
public static Model createModel(
|
||||
@NonNull Context context, @NonNull String modelPath, @NonNull Options options)
|
||||
throws IOException {
|
||||
SupportPreconditions.checkNotEmpty(
|
||||
modelPath, "Model path in the asset folder cannot be empty.");
|
||||
MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath);
|
||||
return createModel(byteModel, modelPath, options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a model with loaded {@link MappedByteBuffer}.
|
||||
*
|
||||
* @see Options for details.
|
||||
* @param byteModel The loaded TFLite model.
|
||||
* @param modelPath The original path of the model. It can be fetched later by {@link
|
||||
* Model#getPath()}.
|
||||
* @param options The options for running the model.
|
||||
* @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but
|
||||
* "tensorflow-lite-gpu" is not linked to the project.
|
||||
*/
|
||||
public static Model createModel(
|
||||
@NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) {
|
||||
Interpreter.Options interpreterOptions = new Interpreter.Options();
|
||||
GpuDelegateProxy gpuDelegateProxy = null;
|
||||
switch (options.device) {
|
||||
case NNAPI:
|
||||
interpreterOptions.setUseNNAPI(true);
|
||||
break;
|
||||
case GPU:
|
||||
gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance();
|
||||
SupportPreconditions.checkArgument(
|
||||
gpuDelegateProxy != null,
|
||||
"Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?");
|
||||
interpreterOptions.addDelegate(gpuDelegateProxy);
|
||||
break;
|
||||
case CPU:
|
||||
break;
|
||||
}
|
||||
interpreterOptions.setNumThreads(options.numThreads);
|
||||
Interpreter interpreter = new Interpreter(byteModel, interpreterOptions);
|
||||
return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy);
|
||||
}
|
||||
|
||||
/** Returns the memory-mapped model data. */
|
||||
@NonNull
|
||||
public MappedByteBuffer getData() {
|
||||
return byteModel;
|
||||
}
|
||||
|
||||
/** Returns the path of the model file stored in Assets. */
|
||||
@NonNull
|
||||
public String getPath() {
|
||||
return modelPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the Tensor associated with the provdied input index.
|
||||
*
|
||||
* @throws IllegalStateException if the interpreter is closed.
|
||||
*/
|
||||
public Tensor getInputTensor(int inputIndex) {
|
||||
return interpreter.getInputTensor(inputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the Tensor associated with the provdied output index.
|
||||
*
|
||||
* @throws IllegalStateException if the interpreter is closed.
|
||||
*/
|
||||
public Tensor getOutputTensor(int outputIndex) {
|
||||
return interpreter.getOutputTensor(outputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the output shape. Useful if output shape is only determined when graph is created.
|
||||
*
|
||||
* @throws IllegalStateException if the interpreter is closed.
|
||||
*/
|
||||
public int[] getOutputTensorShape(int outputIndex) {
|
||||
return interpreter.getOutputTensor(outputIndex).shape();
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs model inference on multiple inputs, and returns multiple outputs.
|
||||
*
|
||||
* @param inputs an array of input data. The inputs should be in the same order as inputs of the
|
||||
* model. Each input can be an array or multidimensional array, or a {@link
|
||||
* java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link
|
||||
* java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types
|
||||
* require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer} is
|
||||
* used, its content should remain unchanged until model inference is done.
|
||||
* @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
|
||||
* java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only
|
||||
* needs to keep entries for the outputs to be used.
|
||||
*/
|
||||
public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
|
||||
interpreter.runForMultipleInputsOutputs(inputs, outputs);
|
||||
}
|
||||
|
||||
public void close() {
|
||||
if (interpreter != null) {
|
||||
interpreter.close();
|
||||
}
|
||||
if (gpuDelegateProxy != null) {
|
||||
gpuDelegateProxy.close();
|
||||
}
|
||||
}
|
||||
|
||||
private Model(
|
||||
@NonNull String modelPath,
|
||||
@NonNull MappedByteBuffer byteModel,
|
||||
@NonNull Interpreter interpreter,
|
||||
@Nullable GpuDelegateProxy gpuDelegateProxy) {
|
||||
this.modelPath = modelPath;
|
||||
this.byteModel = byteModel;
|
||||
this.interpreter = interpreter;
|
||||
this.gpuDelegateProxy = gpuDelegateProxy;
|
||||
}
|
||||
}
|
@ -1,412 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.tensorbuffer;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Arrays;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
|
||||
/** Represents the data buffer for either a model's input or its output. */
|
||||
public abstract class TensorBuffer {
|
||||
/** Where the data is stored. */
|
||||
protected ByteBuffer buffer;
|
||||
|
||||
/** Shape of the tensor stored in this buffer. */
|
||||
protected int[] shape;
|
||||
|
||||
/** Number of elements in the buffer. It will be changed to a proper value in the constructor. */
|
||||
protected int flatSize = -1;
|
||||
|
||||
/**
|
||||
* Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
|
||||
* pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
|
||||
*/
|
||||
protected final boolean isDynamic;
|
||||
|
||||
/**
|
||||
* Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some
|
||||
* examples:
|
||||
*
|
||||
* <pre>
|
||||
* Creating a float TensorBuffer with shape {2, 3}:
|
||||
* int[] shape = new int[] {2, 3};
|
||||
* TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
|
||||
* </pre>
|
||||
*
|
||||
* <pre>
|
||||
* Creating an uint8 TensorBuffer of a scalar:
|
||||
* int[] shape = new int[] {};
|
||||
* TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
|
||||
* </pre>
|
||||
*
|
||||
* <pre>
|
||||
* Creating an empty uint8 TensorBuffer:
|
||||
* int[] shape = new int[] {0};
|
||||
* TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
|
||||
* </pre>
|
||||
*
|
||||
* <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
|
||||
*
|
||||
* @param shape The shape of the {@link TensorBuffer} to be created.
|
||||
* @param dataType The dataType of the {@link TensorBuffer} to be created.
|
||||
* @throws NullPointerException if {@code shape} is null.
|
||||
* @throws IllegalArgumentException if {@code shape} has non-positive elements.
|
||||
*/
|
||||
@NonNull
|
||||
public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
|
||||
switch (dataType) {
|
||||
case FLOAT32:
|
||||
return new TensorBufferFloat(shape);
|
||||
case UINT8:
|
||||
return new TensorBufferUint8(shape);
|
||||
default:
|
||||
throw new AssertionError("TensorBuffer does not support data type: " + dataType);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the
|
||||
* created {@link TensorBuffer} is {0}.
|
||||
*
|
||||
* <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
|
||||
* different buffer sizes.
|
||||
*
|
||||
* @param dataType The dataType of the {@link TensorBuffer} to be created.
|
||||
*/
|
||||
@NonNull
|
||||
public static TensorBuffer createDynamic(DataType dataType) {
|
||||
switch (dataType) {
|
||||
case FLOAT32:
|
||||
return new TensorBufferFloat();
|
||||
case UINT8:
|
||||
return new TensorBufferUint8();
|
||||
default:
|
||||
throw new AssertionError("TensorBuffer does not support data type: " + dataType);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}.
|
||||
*
|
||||
* @param buffer the source {@link TensorBuffer} to copy from.
|
||||
* @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
|
||||
* @throws NullPointerException if {@code buffer} is null.
|
||||
*/
|
||||
@NonNull
|
||||
public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
|
||||
SupportPreconditions.checkNotNull(buffer, "Cannot create a buffer from null");
|
||||
TensorBuffer result;
|
||||
if (buffer.isDynamic()) {
|
||||
result = createDynamic(dataType);
|
||||
} else {
|
||||
result = createFixedSize(buffer.shape, dataType);
|
||||
}
|
||||
// The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
|
||||
// intermediate container.
|
||||
// The assumption is not true when we support other data types.
|
||||
if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
|
||||
float[] data = buffer.getFloatArray();
|
||||
result.loadArray(data, buffer.shape);
|
||||
} else {
|
||||
int[] data = buffer.getIntArray();
|
||||
result.loadArray(data, buffer.shape);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Returns the data buffer. */
|
||||
@NonNull
|
||||
public ByteBuffer getBuffer() {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/** Gets the {@link TensorBuffer#flatSize} of the buffer. */
|
||||
public int getFlatSize() {
|
||||
return flatSize;
|
||||
}
|
||||
|
||||
/** Gets the current shape. (returning a copy here to avoid unexpected modification.) */
|
||||
@NonNull
|
||||
public int[] getShape() {
|
||||
return Arrays.copyOf(shape, shape.length);
|
||||
}
|
||||
|
||||
/** Returns the data type of this buffer. */
|
||||
public abstract DataType getDataType();
|
||||
|
||||
/**
|
||||
* Returns a float array of the values stored in this buffer. If the buffer is of different types
|
||||
* than float, the values will be converted into float. For example, values in {@link
|
||||
* TensorBufferUint8} will be converted from uint8 to float.
|
||||
*/
|
||||
@NonNull
|
||||
public abstract float[] getFloatArray();
|
||||
|
||||
/**
|
||||
* Returns a float value at a given index. If the buffer is of different types than float, the
|
||||
* value will be converted into float. For example, when reading a value from {@link
|
||||
* TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from
|
||||
* uint8 to float.
|
||||
*
|
||||
* <pre>
|
||||
* For example, a TensorBuffer with shape {2, 3} that represents the following array,
|
||||
* [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
|
||||
*
|
||||
* The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by:
|
||||
* float v = tensorBuffer.getFloatValue(3);
|
||||
* </pre>
|
||||
*
|
||||
* @param absIndex The absolute index of the value to be read.
|
||||
*/
|
||||
public abstract float getFloatValue(int absIndex);
|
||||
|
||||
/**
|
||||
* Returns an int array of the values stored in this buffer. If the buffer is of different type
|
||||
* than int, the values will be converted into int, and loss of precision may apply. For example,
|
||||
* getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output
|
||||
* is {400, 23}.
|
||||
*/
|
||||
@NonNull
|
||||
public abstract int[] getIntArray();
|
||||
|
||||
/**
|
||||
* Returns an int value at a given index. If the buffer is of different types than int, the value
|
||||
* will be converted into int. For example, when reading a value from {@link TensorBufferFloat},
|
||||
* the value will be first read out as float, and then will be converted from float to int. Loss
|
||||
* of precision may apply.
|
||||
*
|
||||
* <pre>
|
||||
* For example, a TensorBuffer with shape {2, 3} that represents the following array,
|
||||
* [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
|
||||
*
|
||||
* The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by:
|
||||
* int v = tensorBuffer.getIntValue(3);
|
||||
* Note that v is converted from 3.0f to 3 as a result of type conversion.
|
||||
* </pre>
|
||||
*
|
||||
* @param absIndex The absolute index of the value to be read.
|
||||
*/
|
||||
public abstract int getIntValue(int absIndex);
|
||||
|
||||
/**
|
||||
* Returns the number of bytes of a single element in the array. For example, a float buffer will
|
||||
* return 4, and a byte buffer will return 1.
|
||||
*/
|
||||
public abstract int getTypeSize();
|
||||
|
||||
/** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
|
||||
public boolean isDynamic() {
|
||||
return isDynamic;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads an int array into this buffer with specific shape. If the buffer is of different types
|
||||
* than int, the values will be converted into the buffer's type before being loaded into the
|
||||
* buffer, and loss of precision may apply. For example, loading an int array with values {400,
|
||||
* -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
|
||||
* casted to uint8 by {255, 0}.
|
||||
*
|
||||
* @param src The source array to be loaded.
|
||||
* @param shape Shape of the tensor that {@code src} represents.
|
||||
* @throws NullPointerException if {@code src} is null.
|
||||
* @throws NullPointerException if {@code shape} is null.
|
||||
* @throws IllegalArgumentException if the size of the array to be loaded does not match the
|
||||
* specified shape.
|
||||
*/
|
||||
public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
|
||||
|
||||
/**
|
||||
* Loads an int array into this buffer. If the buffer is of different types than int, the values
|
||||
* will be converted into the buffer's type before being loaded into the buffer, and loss of
|
||||
* precision may apply. For example, loading an int array with values {400, -23} into a {@link
|
||||
* TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
|
||||
* {255, 0}.
|
||||
*
|
||||
* <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both
|
||||
* fixed-size and dynamic {@link TensorBuffer}.
|
||||
*
|
||||
* @param src The source array to be loaded.
|
||||
*/
|
||||
public void loadArray(@NonNull int[] src) {
|
||||
loadArray(src, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a float array into this buffer with specific shape. If the buffer is of different types
|
||||
* than float, the values will be converted into the buffer's type before being loaded into the
|
||||
* buffer, and loss of precision may apply. For example, loading a float array into a {@link
|
||||
* TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
|
||||
* then be casted to uint8 by {255, 0}.
|
||||
*
|
||||
* @param src The source array to be loaded.
|
||||
* @param shape Shape of the tensor that {@code src} represents.
|
||||
* @throws NullPointerException if {@code src} is null.
|
||||
* @throws NullPointerException if {@code shape} is null.
|
||||
* @throws IllegalArgumentException if the size of the array to be loaded does not match the
|
||||
* specified shape.
|
||||
*/
|
||||
public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
|
||||
|
||||
/**
|
||||
* Loads a float array into this buffer. If the buffer is of different types than float, the
|
||||
* values will be converted into the buffer's type before being loaded into the buffer, and loss
|
||||
* of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
|
||||
* with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
|
||||
* uint8 by {255, 0}.
|
||||
*
|
||||
* <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both
|
||||
* fixed-size and dynamic {@link TensorBuffer}.
|
||||
*
|
||||
* @param src The source array to be loaded.
|
||||
*/
|
||||
public void loadArray(@NonNull float[] src) {
|
||||
loadArray(src, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a byte buffer into this {@link TensorBuffer} with specific shape.
|
||||
*
|
||||
* <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
|
||||
* performance concern, but if modification is necessary, please make a copy.
|
||||
*
|
||||
* @param buffer The byte buffer to load.
|
||||
* @throws NullPointerException if {@code buffer} is null.
|
||||
* @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
|
||||
* match or the size of {@code buffer} and {@code flatSize} do not match.
|
||||
*/
|
||||
public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
|
||||
SupportPreconditions.checkNotNull(buffer, "Byte buffer cannot be null.");
|
||||
int flatSize = computeFlatSize(shape);
|
||||
SupportPreconditions.checkArgument(
|
||||
(buffer.limit() == getTypeSize() * flatSize),
|
||||
"The size of byte buffer and the shape do not match.");
|
||||
|
||||
if (!isDynamic) {
|
||||
SupportPreconditions.checkArgument(
|
||||
flatSize == this.flatSize,
|
||||
"The size of byte buffer and the size of the tensor buffer do not match.");
|
||||
} else {
|
||||
this.flatSize = flatSize;
|
||||
}
|
||||
|
||||
this.shape = shape.clone();
|
||||
buffer.rewind();
|
||||
this.buffer = buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
|
||||
* this {@link TensorBuffer}.
|
||||
*
|
||||
* <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
|
||||
* performance concern, but if modification is necessary, please make a copy.
|
||||
*
|
||||
* @param buffer The byte buffer to load.
|
||||
*/
|
||||
public void loadBuffer(@NonNull ByteBuffer buffer) {
|
||||
loadBuffer(buffer, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
|
||||
*
|
||||
* @throws NullPointerException if {@code shape} is null.
|
||||
* @throws IllegalArgumentException if {@code shape} has non-positive elements.
|
||||
*/
|
||||
protected TensorBuffer(@NonNull int[] shape) {
|
||||
isDynamic = false;
|
||||
allocateMemory(shape);
|
||||
}
|
||||
|
||||
/** Constructs a dynamic {@link TensorBuffer} which can be resized. */
|
||||
protected TensorBuffer() {
|
||||
isDynamic = true;
|
||||
// Initialize the dynamic TensorBuffer with an empty ByteBuffer.
|
||||
allocateMemory(new int[] {0});
|
||||
}
|
||||
|
||||
/** Calculates number of elements in the buffer. */
|
||||
protected static int computeFlatSize(@NonNull int[] shape) {
|
||||
SupportPreconditions.checkNotNull(shape, "Shape cannot be null.");
|
||||
int prod = 1;
|
||||
for (int s : shape) {
|
||||
prod = prod * s;
|
||||
}
|
||||
return prod;
|
||||
}
|
||||
|
||||
/**
|
||||
* For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
|
||||
* shape} of src fits the buffer size.
|
||||
*/
|
||||
protected void resize(@NonNull int[] shape) {
|
||||
if (isDynamic) {
|
||||
allocateMemory(shape);
|
||||
} else {
|
||||
// Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
|
||||
SupportPreconditions.checkArgument(Arrays.equals(shape, this.shape));
|
||||
this.shape = shape.clone();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this
|
||||
* {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
|
||||
*
|
||||
* @throws NullPointerException if {@code shape} is null.
|
||||
* @throws IllegalArgumentException if {@code shape} has negative elements.
|
||||
*/
|
||||
private void allocateMemory(@NonNull int[] shape) {
|
||||
SupportPreconditions.checkNotNull(shape, "TensorBuffer shape cannot be null.");
|
||||
SupportPreconditions.checkArgument(
|
||||
isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
|
||||
|
||||
// Check if the new shape is the same as current shape.
|
||||
int newFlatSize = computeFlatSize(shape);
|
||||
this.shape = shape.clone();
|
||||
if (flatSize == newFlatSize) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update to the new shape.
|
||||
flatSize = newFlatSize;
|
||||
buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
|
||||
buffer.order(ByteOrder.nativeOrder());
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
|
||||
* are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar.
|
||||
*/
|
||||
private static boolean isShapeValid(@NonNull int[] shape) {
|
||||
if (shape.length == 0) {
|
||||
// This shape refers to a scalar.
|
||||
return true;
|
||||
}
|
||||
|
||||
// This shape refers to a multidimensional array.
|
||||
for (int s : shape) {
|
||||
// All elements in shape should be non-negative.
|
||||
if (s < 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
@ -1,110 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.tensorbuffer;
|
||||
|
||||
import java.nio.FloatBuffer;
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
|
||||
/** Represents data buffer with float values. */
|
||||
public final class TensorBufferFloat extends TensorBuffer {
|
||||
private static final DataType DATA_TYPE = DataType.FLOAT32;
|
||||
|
||||
/**
|
||||
* Creates a {@link TensorBufferFloat} with specified {@code shape}.
|
||||
*
|
||||
* @throws NullPointerException if {@code shape} is null.
|
||||
* @throws IllegalArgumentException if {@code shape} has non-positive elements.
|
||||
*/
|
||||
TensorBufferFloat(@NonNull int[] shape) {
|
||||
super(shape);
|
||||
}
|
||||
|
||||
TensorBufferFloat() {
|
||||
super();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return DATA_TYPE;
|
||||
}
|
||||
|
||||
@Override
|
||||
@NonNull
|
||||
public float[] getFloatArray() {
|
||||
buffer.rewind();
|
||||
float[] arr = new float[flatSize];
|
||||
|
||||
FloatBuffer floatBuffer = buffer.asFloatBuffer();
|
||||
floatBuffer.get(arr);
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getFloatValue(int absIndex) {
|
||||
return buffer.getFloat(absIndex << 2);
|
||||
}
|
||||
|
||||
@Override
|
||||
@NonNull
|
||||
public int[] getIntArray() {
|
||||
buffer.rewind();
|
||||
int[] arr = new int[flatSize];
|
||||
|
||||
for (int i = 0; i < flatSize; i++) {
|
||||
arr[i] = (int) buffer.getFloat();
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getIntValue(int absIndex) {
|
||||
return (int) buffer.getFloat(absIndex << 2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getTypeSize() {
|
||||
return DATA_TYPE.byteSize();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
|
||||
SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
|
||||
SupportPreconditions.checkArgument(
|
||||
src.length == computeFlatSize(shape),
|
||||
"The size of the array to be loaded does not match the specified shape.");
|
||||
resize(shape);
|
||||
buffer.rewind();
|
||||
|
||||
FloatBuffer floatBuffer = buffer.asFloatBuffer();
|
||||
floatBuffer.put(src);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
|
||||
SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
|
||||
SupportPreconditions.checkArgument(
|
||||
src.length == computeFlatSize(shape),
|
||||
"The size of the array to be loaded does not match the specified shape.");
|
||||
resize(shape);
|
||||
buffer.rewind();
|
||||
|
||||
for (int a : src) {
|
||||
buffer.putFloat((float) a);
|
||||
}
|
||||
}
|
||||
}
|
@ -1,111 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.tensorbuffer;
|
||||
|
||||
import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
|
||||
/** Represents data buffer with 8-bit unsigned integer values. */
|
||||
public final class TensorBufferUint8 extends TensorBuffer {
|
||||
private static final DataType DATA_TYPE = DataType.UINT8;
|
||||
|
||||
/**
|
||||
* Creates a {@link TensorBufferUint8} with specified {@code shape}.
|
||||
*
|
||||
* @throws NullPointerException if {@code shape} is null.
|
||||
* @throws IllegalArgumentException if {@code shape} has non-positive elements.
|
||||
*/
|
||||
TensorBufferUint8(@NonNull int[] shape) {
|
||||
super(shape);
|
||||
}
|
||||
|
||||
TensorBufferUint8() {
|
||||
super();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return DATA_TYPE;
|
||||
}
|
||||
|
||||
@Override
|
||||
@NonNull
|
||||
public float[] getFloatArray() {
|
||||
buffer.rewind();
|
||||
float[] arr = new float[flatSize];
|
||||
|
||||
for (int i = 0; i < flatSize; i++) {
|
||||
arr[i] = (float) (buffer.get() & 0xff);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getFloatValue(int index) {
|
||||
return (float) (buffer.get(index) & 0xff);
|
||||
}
|
||||
|
||||
@Override
|
||||
@NonNull
|
||||
public int[] getIntArray() {
|
||||
buffer.rewind();
|
||||
int[] arr = new int[flatSize];
|
||||
|
||||
for (int i = 0; i < flatSize; i++) {
|
||||
arr[i] = buffer.get() & 0xff;
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getIntValue(int index) {
|
||||
return buffer.get(index) & 0xff;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getTypeSize() {
|
||||
return DATA_TYPE.byteSize();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
|
||||
SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
|
||||
SupportPreconditions.checkArgument(
|
||||
src.length == computeFlatSize(shape),
|
||||
"The size of the array to be loaded does not match the specified shape.");
|
||||
resize(shape);
|
||||
buffer.rewind();
|
||||
|
||||
for (float a : src) {
|
||||
buffer.put((byte) Math.max(Math.min(a, 255.0), 0.0));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
|
||||
SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
|
||||
SupportPreconditions.checkArgument(
|
||||
src.length == computeFlatSize(shape),
|
||||
"The size of the array to be loaded does not match the specified shape.");
|
||||
resize(shape);
|
||||
buffer.rewind();
|
||||
|
||||
for (int a : src) {
|
||||
buffer.put((byte) Math.max(Math.min(a, 255), 0));
|
||||
}
|
||||
}
|
||||
}
|
@ -1,113 +0,0 @@
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library", "flatbuffer_py_library")
|
||||
load("//tensorflow/lite/experimental/support/metadata:build_defs.bzl", "stamp_metadata_parser_version")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["metadata_schema.fbs"])
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "schema_py",
|
||||
srcs = ["//tensorflow/lite/schema:schema.fbs"],
|
||||
)
|
||||
|
||||
# Generic schema for inference on device.
|
||||
flatbuffer_android_library(
|
||||
name = "schema_fbs_android",
|
||||
srcs = ["//tensorflow/lite/schema:schema.fbs"],
|
||||
custom_package = "org.tensorflow.lite.schema",
|
||||
)
|
||||
|
||||
flatbuffer_java_library(
|
||||
name = "schema_fbs_java",
|
||||
srcs = ["//tensorflow/lite/schema:schema.fbs"],
|
||||
custom_package = "org.tensorflow.lite.schema",
|
||||
)
|
||||
|
||||
# Generic schema for model metadata.
|
||||
flatbuffer_cc_library(
|
||||
name = "metadata_schema_cc",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "metadata_schema_py",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_java_library(
|
||||
name = "metadata_schema_java",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
custom_package = "org.tensorflow.lite.support.metadata.schema",
|
||||
)
|
||||
|
||||
flatbuffer_android_library(
|
||||
name = "metadata_schema_fbs_android",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
custom_package = "org.tensorflow.lite.support.metadata.schema",
|
||||
)
|
||||
|
||||
# TODO(b/157813075): move the metadata python library to metadata/python/ when migrating to the new repo.
|
||||
stamp_metadata_parser_version(
|
||||
name = "metadata_parser_py",
|
||||
srcs = ["metadata_parser.py.template"],
|
||||
outs = ["metadata_parser.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "metadata",
|
||||
srcs = [
|
||||
"metadata.py",
|
||||
":metadata_parser_py",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":metadata_schema_py",
|
||||
":schema_py",
|
||||
"//tensorflow/lite/experimental/support/metadata/cc/python:_pywrap_metadata_version",
|
||||
"//tensorflow/lite/experimental/support/metadata/flatbuffers_lib:_pywrap_flatbuffers",
|
||||
"//tensorflow/python:platform",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metadata_test",
|
||||
srcs = ["metadata_test.py"],
|
||||
data = ["testdata/golden_json.json"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS.
|
||||
],
|
||||
deps = [
|
||||
":metadata",
|
||||
":metadata_schema_py",
|
||||
":schema_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@flatbuffers//:runtime_py",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metadata_parser_test",
|
||||
srcs = ["metadata_parser_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":metadata",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
@ -1,15 +0,0 @@
|
||||
# TensorFlow Lite Metadata and Android wrapper code generator
|
||||
|
||||
Note: Both TensorFlow Lite Metadata and the Android wrapper code generator are
|
||||
in experimental (beta) phase.
|
||||
|
||||
TensorFlow Lite metadata provides a structured framework for storing metadata
|
||||
to convey information for both the developer that will utilitised the model and
|
||||
code generators which can create wrapper around the model. For information on
|
||||
how to populate model metadata, please refer to the [TensorFlow Lite Metadata
|
||||
documentation](https://www.tensorflow.org/lite/convert/metadata).
|
||||
|
||||
The first code generator which takes advantage of this metadata format is the
|
||||
TensorFlow Lite Android Code Generator. For more information on how to use this
|
||||
generator, please refer to the [TensorFlow Lite Android wrapper code generator
|
||||
documentation](https://www.tensorflow.org/lite/guide/codegen).
|
@ -1,43 +0,0 @@
|
||||
"""Build rules to generate metadata schema versions."""
|
||||
|
||||
METADATA_SCHEMA_FILE = "//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs"
|
||||
|
||||
def stamp_metadata_parser_version(
|
||||
name,
|
||||
srcs,
|
||||
outs):
|
||||
"""Stamps the latest metadata parser version into the srcs files.
|
||||
|
||||
Replaces all the occurrences of "{LATEST_METADATA_PARSER_VERSION}" in the
|
||||
srcs files with the metadata schema version extracted from
|
||||
METADATA_SCHEMA_FILE and then outputs the generated file into outs,
|
||||
respectively. The number of srcs files needs to match the number of outs
|
||||
files.
|
||||
|
||||
Args:
|
||||
name: Rule name. (required)
|
||||
srcs: List of source files. (required)
|
||||
outs: List of output files. (required)
|
||||
"""
|
||||
if len(srcs) != len(outs):
|
||||
fail(("The number of srcs files (%d) does not match that of the outs" +
|
||||
" files (%d).") %
|
||||
(len(srcs), len(outs)))
|
||||
|
||||
for i in range(0, len(srcs)):
|
||||
native.genrule(
|
||||
name = "%s_file%d" % (name, i),
|
||||
srcs = [srcs[i]],
|
||||
outs = [outs[i]],
|
||||
tools = [METADATA_SCHEMA_FILE],
|
||||
# Gets the metadata schema version from the file, and stamps it
|
||||
# into the srcs file.
|
||||
cmd = "version=$$(sed -n -e '/Schema Semantic version/ s/.*\\: *//p' $(location %s));" %
|
||||
METADATA_SCHEMA_FILE +
|
||||
'sed "s/{LATEST_METADATA_PARSER_VERSION}/$$version/" $< > $@',
|
||||
)
|
||||
|
||||
native.filegroup(
|
||||
name = name,
|
||||
srcs = outs,
|
||||
)
|
@ -1,29 +0,0 @@
|
||||
load("//tensorflow/lite/experimental/support/metadata:build_defs.bzl", "stamp_metadata_parser_version")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/lite/experimental/support:users"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
stamp_metadata_parser_version(
|
||||
name = "metadata_parser_h",
|
||||
srcs = ["metadata_parser.h.template"],
|
||||
outs = ["metadata_parser.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "metadata_version",
|
||||
srcs = ["metadata_version.cc"],
|
||||
hdrs = [
|
||||
"metadata_version.h",
|
||||
":metadata_parser_h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
@ -1,28 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_PARSER_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_PARSER_H_
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
|
||||
// The version of the metadata parser that this metadata versioning library is
|
||||
// depending on.
|
||||
inline constexpr char kMatadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}";
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_PARSER_H_
|
@ -1,214 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <array>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
namespace {
|
||||
|
||||
// Members that are added to the metadata schema after the initial version
|
||||
// of 1.0.0.
|
||||
enum class SchemaMembers {
|
||||
kAssociatedFileTypeVocabulary = 0,
|
||||
};
|
||||
|
||||
// Helper class to compare semantic versions in terms of three integers, major,
|
||||
// minor, and patch.
|
||||
class Version {
|
||||
public:
|
||||
explicit Version(int major, int minor = 0, int patch = 0)
|
||||
: version_({major, minor, patch}) {}
|
||||
|
||||
explicit Version(const std::string& version) {
|
||||
const std::vector<std::string> vec = absl::StrSplit(version, '.');
|
||||
// The version string should always be less than four numbers.
|
||||
TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty());
|
||||
version_[0] = std::stoi(vec[0]);
|
||||
version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0;
|
||||
version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0;
|
||||
}
|
||||
|
||||
// Compares two semantic version numbers.
|
||||
//
|
||||
// Example results when comparing two versions strings:
|
||||
// "1.9" precedes "1.14";
|
||||
// "1.14" precedes "1.14.1";
|
||||
// "1.14" and "1.14.0" are equal.
|
||||
//
|
||||
// Returns the value 0 if the two versions are equal; a value less than 0 if
|
||||
// *this precedes v; a value greater than 0 if v precedes *this.
|
||||
int Compare(const Version& v) {
|
||||
for (int i = 0; i < kElementNumber; ++i) {
|
||||
if (version_[i] != v.version_[i]) {
|
||||
return version_[i] < v.version_[i] ? -1 : 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Converts version_ into a version string.
|
||||
std::string ToString() { return absl::StrJoin(version_, "."); }
|
||||
|
||||
private:
|
||||
static constexpr int kElementNumber = 3;
|
||||
std::array<int, kElementNumber> version_;
|
||||
};
|
||||
|
||||
Version GetMemberVersion(SchemaMembers member) {
|
||||
switch (member) {
|
||||
case SchemaMembers::kAssociatedFileTypeVocabulary:
|
||||
return Version(1, 0, 1);
|
||||
default:
|
||||
TFLITE_LOG(FATAL) << "Unsupported schema member: "
|
||||
<< static_cast<int>(member);
|
||||
}
|
||||
}
|
||||
|
||||
// Updates min_version if it precedes the new_version.
|
||||
inline void UpdateMinimumVersion(const Version& new_version,
|
||||
Version* min_version) {
|
||||
if (min_version->Compare(new_version) < 0) {
|
||||
*min_version = new_version;
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForAssociatedFile(
|
||||
const tflite::AssociatedFile* associated_file, Version* min_version) {
|
||||
if (associated_file == nullptr) return;
|
||||
|
||||
if (associated_file->type() == AssociatedFileType_VOCABULARY) {
|
||||
UpdateMinimumVersion(
|
||||
GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
|
||||
min_version);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForAssociatedFileArray(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
|
||||
associated_files,
|
||||
Version* min_version) {
|
||||
if (associated_files == nullptr) return;
|
||||
|
||||
for (int i = 0; i < associated_files->size(); ++i) {
|
||||
UpdateMinimumVersionForAssociatedFile(associated_files->Get(i),
|
||||
min_version);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForTensorMetadata(
|
||||
const tflite::TensorMetadata* tensor_metadata, Version* min_version) {
|
||||
if (tensor_metadata == nullptr) return;
|
||||
|
||||
// Checks the associated_files field.
|
||||
UpdateMinimumVersionForAssociatedFileArray(
|
||||
tensor_metadata->associated_files(), min_version);
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForTensorMetadataArray(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
|
||||
tensor_metadata_array,
|
||||
Version* min_version) {
|
||||
if (tensor_metadata_array == nullptr) return;
|
||||
|
||||
for (int i = 0; i < tensor_metadata_array->size(); ++i) {
|
||||
UpdateMinimumVersionForTensorMetadata(tensor_metadata_array->Get(i),
|
||||
min_version);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForSubGraphMetadata(
|
||||
const tflite::SubGraphMetadata* subgraph_metadata, Version* min_version) {
|
||||
if (subgraph_metadata == nullptr) return;
|
||||
|
||||
// Checks in the input/output metadata arrays.
|
||||
UpdateMinimumVersionForTensorMetadataArray(
|
||||
subgraph_metadata->input_tensor_metadata(), min_version);
|
||||
UpdateMinimumVersionForTensorMetadataArray(
|
||||
subgraph_metadata->output_tensor_metadata(), min_version);
|
||||
|
||||
// Checks the associated_files field.
|
||||
UpdateMinimumVersionForAssociatedFileArray(
|
||||
subgraph_metadata->associated_files(), min_version);
|
||||
}
|
||||
|
||||
void UpdateMinimumVersionForModelMetadata(
|
||||
const tflite::ModelMetadata& model_metadata, Version* min_version) {
|
||||
// Checks the subgraph_metadata field.
|
||||
if (model_metadata.subgraph_metadata() != nullptr) {
|
||||
for (int i = 0; i < model_metadata.subgraph_metadata()->size(); ++i) {
|
||||
UpdateMinimumVersionForSubGraphMetadata(
|
||||
model_metadata.subgraph_metadata()->Get(i), min_version);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks the associated_files field.
|
||||
UpdateMinimumVersionForAssociatedFileArray(model_metadata.associated_files(),
|
||||
min_version);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
|
||||
size_t buffer_size,
|
||||
std::string* min_version_str) {
|
||||
flatbuffers::Verifier verifier =
|
||||
flatbuffers::Verifier(buffer_data, buffer_size);
|
||||
if (!tflite::VerifyModelMetadataBuffer(verifier)) {
|
||||
TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer.";
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
static constexpr char kDefaultVersion[] = "1.0.0";
|
||||
Version min_version = Version(kDefaultVersion);
|
||||
|
||||
// Checks if any member declared after 1.0.0 (such as those in
|
||||
// SchemaMembers) exists, and updates min_version accordingly. The minimum
|
||||
// metadata parser version will be the largest version number of all fields
|
||||
// that has been added to a metadata flatbuffer
|
||||
const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);
|
||||
|
||||
// All tables in the metadata schema should have their dedicated
|
||||
// UpdateMinimumVersionFor**() methods, respectively. We'll gradually add
|
||||
// these methods when new fields show up in later schema versions.
|
||||
//
|
||||
// UpdateMinimumVersionFor<Foo>() takes a const pointer of Foo. The pointer
|
||||
// can be a nullptr if Foo is not populated into the corresponding table of
|
||||
// the Flatbuffer object. In this case, UpdateMinimumVersionFor<Foo>() will be
|
||||
// skipped. An exception is UpdateMinimumVersionForModelMetadata(), where
|
||||
// ModelMetadata is the root table, and it won't be null.
|
||||
UpdateMinimumVersionForModelMetadata(*model_metadata, &min_version);
|
||||
|
||||
*min_version_str = min_version.ToString();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
@ -1,38 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
|
||||
// Gets the minimum metadata parser version that can fully understand all fields
|
||||
// in a given metadata flatbuffer. TFLite Metadata follows Semantic Versioning
|
||||
// 2.0. Each release version has the form MAJOR.MINOR.PATCH.
|
||||
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
|
||||
size_t buffer_size,
|
||||
std::string* min_version);
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
|
@ -1,22 +0,0 @@
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/lite/experimental/support/metadata:__pkg__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_metadata_version",
|
||||
srcs = [
|
||||
"metadata_version.cc",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pywrap_metadata_version",
|
||||
deps = [
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
@ -1,55 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
|
||||
PYBIND11_MODULE(_pywrap_metadata_version, m) {
|
||||
m.doc() = R"pbdoc(
|
||||
_pywrap_metadata_version
|
||||
A module that returns the minimum metadata parser version of a given
|
||||
metadata flatbuffer.
|
||||
)pbdoc";
|
||||
|
||||
// Using pybind11 type conversions to convert between Python and native
|
||||
// C++ types. There are other options to provide access to native Python types
|
||||
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details.
|
||||
// Type converstions is recommended by pybind11, though the main downside
|
||||
// is that a copy of the data must be made on every Python to C++ transition:
|
||||
// this is needed since the C++ and Python versions of the same type generally
|
||||
// won’t have the same memory layout.
|
||||
//
|
||||
// [1]: https://pybind11.readthedocs.io/en/stable/advanced/cast/index.html
|
||||
m.def("GetMinimumMetadataParserVersion",
|
||||
[](const std::string& buffer_data) -> std::string {
|
||||
std::string min_version;
|
||||
if (GetMinimumMetadataParserVersion(
|
||||
reinterpret_cast<const uint8_t*>(buffer_data.c_str()),
|
||||
buffer_data.length(), &min_version) != kTfLiteOk) {
|
||||
pybind11::value_error(
|
||||
"Error occurred when getting the minimum metadata parser "
|
||||
"version of the metadata flatbuffer.");
|
||||
}
|
||||
return min_version;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
@ -1,24 +0,0 @@
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "metadata_version_test",
|
||||
srcs = ["metadata_version_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "metadata_parser_test",
|
||||
srcs = ["metadata_parser_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
@ -1,33 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_parser.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
namespace {
|
||||
|
||||
using ::testing::MatchesRegex;
|
||||
|
||||
TEST(MetadataParserTest, MatadataParserVersionIsWellFormed) {
|
||||
// Validates that the version is well-formed (x.y.z).
|
||||
EXPECT_THAT(kMatadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
@ -1,187 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace metadata {
|
||||
namespace {
|
||||
|
||||
using ::testing::MatchesRegex;
|
||||
using ::testing::StrEq;
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionSucceedsWithValidMetadata) {
|
||||
// Creates a dummy metadata flatbuffer for test.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto name = builder.CreateString("Foo");
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_name(name);
|
||||
auto metadata = metadata_builder.Finish();
|
||||
FinishModelMetadataBuffer(builder, metadata);
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is well-formed (x.y.z).
|
||||
EXPECT_THAT(min_version, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionFailsWithInvalidIdentifier) {
|
||||
// Creates a dummy metadata flatbuffer without identifier.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
auto metadata = metadata_builder.Finish();
|
||||
builder.Finish(metadata);
|
||||
|
||||
// Gets the mimimum metadata parser version and triggers error.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteError);
|
||||
EXPECT_TRUE(min_version.empty());
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionForModelMetadataVocabAssociatedFiles) {
|
||||
// Creates a metadata flatbuffer with the field,
|
||||
// ModelMetadata.associated_fiels, populated with the vocabulary file type.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
AssociatedFileBuilder associated_file_builder(builder);
|
||||
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
|
||||
auto associated_files =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
|
||||
associated_file_builder.Finish()});
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_associated_files(associated_files);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is exactly 1.0.1.
|
||||
EXPECT_THAT(min_version, StrEq("1.0.1"));
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionForSubGraphMetadataVocabAssociatedFiles) {
|
||||
// Creates a metadata flatbuffer with the field,
|
||||
// SubGraphMetadata.associated_fiels, populated with the vocabulary file type.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
AssociatedFileBuilder associated_file_builder(builder);
|
||||
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
|
||||
auto associated_files =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
|
||||
associated_file_builder.Finish()});
|
||||
SubGraphMetadataBuilder subgraph_builder(builder);
|
||||
subgraph_builder.add_associated_files(associated_files);
|
||||
auto subgraphs =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
|
||||
subgraph_builder.Finish()});
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is exactly 1.0.1.
|
||||
EXPECT_THAT(min_version, StrEq("1.0.1"));
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionForInputMetadataVocabAssociatedFiles) {
|
||||
// Creates a metadata flatbuffer with the field,
|
||||
// SubGraphMetadata.input_tensor_metadata.associated_fiels, populated with the
|
||||
// vocabulary file type.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
AssociatedFileBuilder associated_file_builder(builder);
|
||||
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
|
||||
auto associated_files =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
|
||||
associated_file_builder.Finish()});
|
||||
TensorMetadataBuilder tensor_builder(builder);
|
||||
tensor_builder.add_associated_files(associated_files);
|
||||
auto tensors =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<TensorMetadata>>{
|
||||
tensor_builder.Finish()});
|
||||
SubGraphMetadataBuilder subgraph_builder(builder);
|
||||
subgraph_builder.add_input_tensor_metadata(tensors);
|
||||
auto subgraphs =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
|
||||
subgraph_builder.Finish()});
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is exactly 1.0.1.
|
||||
EXPECT_THAT(min_version, StrEq("1.0.1"));
|
||||
}
|
||||
|
||||
TEST(MetadataVersionTest,
|
||||
GetMinimumMetadataParserVersionForOutputMetadataVocabAssociatedFiles) {
|
||||
// Creates a metadata flatbuffer with the field,
|
||||
// SubGraphMetadata.output_tensor_metadata.associated_fiels, populated with
|
||||
// the vocabulary file type.
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
AssociatedFileBuilder associated_file_builder(builder);
|
||||
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
|
||||
auto associated_files =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
|
||||
associated_file_builder.Finish()});
|
||||
TensorMetadataBuilder tensor_builder(builder);
|
||||
tensor_builder.add_associated_files(associated_files);
|
||||
auto tensors =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<TensorMetadata>>{
|
||||
tensor_builder.Finish()});
|
||||
SubGraphMetadataBuilder subgraph_builder(builder);
|
||||
subgraph_builder.add_output_tensor_metadata(tensors);
|
||||
auto subgraphs =
|
||||
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
|
||||
subgraph_builder.Finish()});
|
||||
ModelMetadataBuilder metadata_builder(builder);
|
||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||
|
||||
// Gets the mimimum metadata parser version.
|
||||
std::string min_version;
|
||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||
builder.GetSize(), &min_version),
|
||||
kTfLiteOk);
|
||||
// Validates that the version is exactly 1.0.1.
|
||||
EXPECT_EQ(min_version, "1.0.1");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace metadata
|
||||
} // namespace tflite
|
@ -1,23 +0,0 @@
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_flatbuffers",
|
||||
srcs = [
|
||||
"flatbuffers_lib.cc",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pywrap_flatbuffers",
|
||||
deps = [
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@flatbuffers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
@ -1,59 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "flatbuffers/idl.h" // from @flatbuffers
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
|
||||
PYBIND11_MODULE(_pywrap_flatbuffers, m) {
|
||||
pybind11::class_<flatbuffers::IDLOptions>(m, "IDLOptions")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json);
|
||||
pybind11::class_<flatbuffers::Parser>(m, "Parser")
|
||||
.def(pybind11::init<const flatbuffers::IDLOptions&>())
|
||||
.def("parse",
|
||||
[](flatbuffers::Parser* self, const std::string& source) {
|
||||
return self->Parse(source.c_str());
|
||||
})
|
||||
.def_readonly("builder", &flatbuffers::Parser::builder_)
|
||||
.def_readonly("error", &flatbuffers::Parser::error_);
|
||||
pybind11::class_<flatbuffers::FlatBufferBuilder>(m, "FlatBufferBuilder")
|
||||
.def("clear", &flatbuffers::FlatBufferBuilder::Clear)
|
||||
.def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self,
|
||||
const std::string& contents) {
|
||||
self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
|
||||
contents.length());
|
||||
});
|
||||
m.def("generate_text_file", &flatbuffers::GenerateTextFile);
|
||||
m.def(
|
||||
"generate_text",
|
||||
[](const flatbuffers::Parser& parser,
|
||||
const std::string& buffer) -> std::string {
|
||||
std::string text;
|
||||
if (!flatbuffers::GenerateText(
|
||||
parser, reinterpret_cast<const void*>(buffer.c_str()), &text)) {
|
||||
return "";
|
||||
}
|
||||
return text;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="org.tensorflow.lite.support">
|
||||
<uses-sdk android:minSdkVersion="19" />
|
||||
</manifest>
|
||||
|
@ -1,40 +0,0 @@
|
||||
# Description:
|
||||
# TensorFlow Lite Support API in Java for metadata.
|
||||
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
METADATA_SRCS = glob(
|
||||
["src/java/org/tensorflow/lite/support/metadata/**/*.java"],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "tensorflow-lite-support-metadata",
|
||||
srcs = METADATA_SRCS,
|
||||
manifest = "AndroidManifest.xml",
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_fbs_android",
|
||||
"//tensorflow/lite/experimental/support/metadata:schema_fbs_android",
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
||||
|
||||
java_library(
|
||||
name = "tensorflow-lite-support-metadata-lib",
|
||||
srcs = METADATA_SRCS,
|
||||
javacopts = JAVACOPTS,
|
||||
resource_jars = [
|
||||
"//tensorflow/lite/experimental/support/metadata:libmetadata_schema_java.jar",
|
||||
"//tensorflow/lite/experimental/support/metadata:libschema_fbs_java.jar",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_java",
|
||||
"//tensorflow/lite/experimental/support/metadata:schema_fbs_java",
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
@ -1,116 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkElementIndex;
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
* An {@link InputStream} that wraps a section of a {@link SeekableByteChannelCompat}.
|
||||
*
|
||||
* <p><b>WARNING:</b> Similar as {@link InputStream}, instances of an {@link BoundedInputStream} are
|
||||
* <b>not</b> thread-safe. If multiple threads concurrently reading from the same {@link
|
||||
* BoundedInputStream}, it must be synchronized externally. Also, if multiple instances of {@link
|
||||
* BoundedInputStream} are created on the same {@link SeekableByteChannelCompat}, it must be
|
||||
* synchronized as well.
|
||||
*/
|
||||
final class BoundedInputStream extends InputStream {
|
||||
private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
|
||||
private final long end; // The valid data for the stream is between [start, end).
|
||||
private long position;
|
||||
private final SeekableByteChannelCompat channel;
|
||||
|
||||
/**
|
||||
* Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
|
||||
*
|
||||
* @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
|
||||
* BoundedInputStream}
|
||||
* @param start the starting position of this {@link BoundedInputStream} in the given {@link
|
||||
* SeekableByteChannelCompat}
|
||||
* @param remaining the length of this {@link BoundedInputStream}
|
||||
* @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
|
||||
*/
|
||||
BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
|
||||
checkArgument(
|
||||
remaining >= 0 && start >= 0,
|
||||
String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
|
||||
|
||||
end = start + remaining;
|
||||
this.channel = channel;
|
||||
position = start;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int available() throws IOException {
|
||||
return (int) (Math.min(end, channel.size()) - position);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int read() throws IOException {
|
||||
if (position >= end) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
singleByteBuffer.rewind();
|
||||
int count = read(position, singleByteBuffer);
|
||||
if (count < 0) {
|
||||
return count;
|
||||
}
|
||||
|
||||
position++;
|
||||
return singleByteBuffer.get() & 0xff;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int read(byte[] b, int off, int len) throws IOException {
|
||||
checkNotNull(b);
|
||||
checkElementIndex(off, b.length, "The start offset");
|
||||
checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
|
||||
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (len > end - position) {
|
||||
if (position >= end) {
|
||||
return -1;
|
||||
}
|
||||
len = (int) (end - position);
|
||||
}
|
||||
|
||||
ByteBuffer buf = ByteBuffer.wrap(b, off, len);
|
||||
int count = read(position, buf);
|
||||
if (count > 0) {
|
||||
position += count;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
private int read(long position, ByteBuffer buf) throws IOException {
|
||||
int count;
|
||||
synchronized (channel) {
|
||||
channel.position(position);
|
||||
count = channel.read(buf);
|
||||
}
|
||||
buf.flip();
|
||||
return count;
|
||||
}
|
||||
}
|
@ -1,130 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static java.lang.Math.min;
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.NonWritableChannelException;
|
||||
|
||||
/** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */
|
||||
final class ByteBufferChannel implements SeekableByteChannelCompat {
|
||||
|
||||
/** The ByteBuffer that holds the data. */
|
||||
private final ByteBuffer buffer;
|
||||
|
||||
/**
|
||||
* Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
|
||||
*
|
||||
* @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
|
||||
* @throws NullPointerException if {@code buffer} is null
|
||||
*/
|
||||
public ByteBufferChannel(ByteBuffer buffer) {
|
||||
checkNotNull(buffer, "The ByteBuffer cannot be null.");
|
||||
this.buffer = buffer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
|
||||
@Override
|
||||
public boolean isOpen() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long position() {
|
||||
return buffer.position();
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets this channel's position.
|
||||
*
|
||||
* @param newPosition the new position, a non-negative integer counting the number of bytes from
|
||||
* the beginning of the entity
|
||||
* @return this channel
|
||||
* @throws IllegalArgumentException if the new position is negative, or greater than the size of
|
||||
* the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
|
||||
*/
|
||||
@Override
|
||||
public synchronized ByteBufferChannel position(long newPosition) {
|
||||
checkArgument(
|
||||
(newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
|
||||
"The new position should be non-negative and be less than Integer.MAX_VALUE.");
|
||||
buffer.position((int) newPosition);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>Bytes are read starting at this channel's current position, and then the position is updated
|
||||
* with the number of bytes actually read. Otherwise this method behaves exactly as specified in
|
||||
* the {@link ReadableByteChannel} interface.
|
||||
*/
|
||||
@Override
|
||||
public synchronized int read(ByteBuffer dst) {
|
||||
if (buffer.remaining() == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
int count = min(dst.remaining(), buffer.remaining());
|
||||
if (count > 0) {
|
||||
ByteBuffer tempBuffer = buffer.slice();
|
||||
tempBuffer.order(buffer.order()).limit(count);
|
||||
dst.put(tempBuffer);
|
||||
buffer.position(buffer.position() + count);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long size() {
|
||||
return buffer.limit();
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized ByteBufferChannel truncate(long size) {
|
||||
checkArgument(
|
||||
(size >= 0 && size <= Integer.MAX_VALUE),
|
||||
"The new size should be non-negative and be less than Integer.MAX_VALUE.");
|
||||
|
||||
if (size < buffer.limit()) {
|
||||
buffer.limit((int) size);
|
||||
if (buffer.position() > size) {
|
||||
buffer.position((int) size);
|
||||
}
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized int write(ByteBuffer src) {
|
||||
if (buffer.isReadOnly()) {
|
||||
throw new NonWritableChannelException();
|
||||
}
|
||||
|
||||
int count = min(src.remaining(), buffer.remaining());
|
||||
if (count > 0) {
|
||||
ByteBuffer tempBuffer = src.slice();
|
||||
tempBuffer.order(buffer.order()).limit(count);
|
||||
buffer.put(tempBuffer);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
}
|
@ -1,368 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.zip.ZipException;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import org.tensorflow.lite.schema.Tensor;
|
||||
import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
|
||||
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
|
||||
|
||||
/**
|
||||
* Loads metadata from TFLite Model FlatBuffer.
|
||||
*
|
||||
* <p>TFLite Model FlatBuffer can be generated using the <a
|
||||
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
|
||||
* Model schema file.</a>
|
||||
*
|
||||
* <p>Some models contain a TFLite Metadata Flatbuffer, which records more information about what
|
||||
* the model does and how to interprete the model. TFLite Metadata Flatbuffer can be generated using
|
||||
* the <a
|
||||
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/metadata_schema.fbs">TFLite
|
||||
* Metadata schema file.</a>
|
||||
*
|
||||
* <p>It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking methods
|
||||
* that read from TFLite metadata will cause runtime errors.
|
||||
*
|
||||
* <p>Similarly, it is allowed to pass in a model FlatBuffer without associated files. However,
|
||||
* invoking methods that read the associated files will cause runtime errors.
|
||||
*
|
||||
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports a
|
||||
* single subgraph so far. See the <a
|
||||
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
|
||||
* of how to specify subgraph during convertion for more information.</a> Therefore, {@link
|
||||
* MetadataExtractor} omits subgraph index as an input in its methods.
|
||||
*/
|
||||
public class MetadataExtractor {
|
||||
|
||||
/** The helper class to load metadata from TFLite model FlatBuffer. */
|
||||
private final ModelInfo modelInfo;
|
||||
|
||||
/** The helper class to load metadata from TFLite metadata FlatBuffer. */
|
||||
@Nullable private final ModelMetadataInfo metadataInfo;
|
||||
|
||||
/** The handler to load associated files through zip. */
|
||||
@Nullable private final ZipFile zipFile;
|
||||
|
||||
/**
|
||||
* Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
|
||||
*
|
||||
* @param buffer the TFLite model FlatBuffer
|
||||
* @throws IllegalArgumentException if the number of input or output tensors in the model does not
|
||||
* match that in the metadata
|
||||
* @throws IOException if an error occurs while reading the model as a Zip file
|
||||
*/
|
||||
public MetadataExtractor(ByteBuffer buffer) throws IOException {
|
||||
modelInfo = new ModelInfo(buffer);
|
||||
ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
|
||||
if (metadataBuffer != null) {
|
||||
metadataInfo = new ModelMetadataInfo(metadataBuffer);
|
||||
|
||||
// Prints warning message if the minimum parser version is not satisfied.
|
||||
if (!isMinimumParserVersionSatisfied()) {
|
||||
System.err.printf(
|
||||
"<Warning> Some fields in the metadata belong to a future schema. The minimum parser"
|
||||
+ " version required is %s, but the version of the current metadata parser is %s",
|
||||
metadataInfo.getMininumParserVersion(), MetadataParser.VERSION);
|
||||
}
|
||||
|
||||
checkArgument(
|
||||
modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
|
||||
String.format(
|
||||
"The number of input tensors in the model is %d. The number of input tensors that"
|
||||
+ " recorded in the metadata is %d. These two values does not match.",
|
||||
modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
|
||||
checkArgument(
|
||||
modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
|
||||
String.format(
|
||||
"The number of output tensors in the model is %d. The number of output tensors that"
|
||||
+ " recorded in the metadata is %d. These two values does not match.",
|
||||
modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
|
||||
} else {
|
||||
// It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking
|
||||
// methods that read from TFLite metadata will cause runtime errors.
|
||||
metadataInfo = null;
|
||||
}
|
||||
|
||||
zipFile = createZipFile(buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
|
||||
* <a
|
||||
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
|
||||
* Model schema file.</a>
|
||||
*
|
||||
* <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and
|
||||
* {@code zero_point} are both single values instead of arrays.
|
||||
*
|
||||
* <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
|
||||
*
|
||||
* <p>Given a quantized value q, the corresponding float value f should be: <br>
|
||||
* f = scale * (q - zero_point) <br>
|
||||
*/
|
||||
public static class QuantizationParams {
|
||||
/** The scale value used in quantization. */
|
||||
private final float scale;
|
||||
/** The zero point value used in quantization. */
|
||||
private final int zeroPoint;
|
||||
|
||||
/**
|
||||
* Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
|
||||
*
|
||||
* @param scale The scale value used in quantization.
|
||||
* @param zeroPoint The zero point value used in quantization.
|
||||
*/
|
||||
public QuantizationParams(final float scale, final int zeroPoint) {
|
||||
this.scale = scale;
|
||||
this.zeroPoint = zeroPoint;
|
||||
}
|
||||
|
||||
/** Returns the scale value. */
|
||||
public float getScale() {
|
||||
return scale;
|
||||
}
|
||||
|
||||
/** Returns the zero point value. */
|
||||
public int getZeroPoint() {
|
||||
return zeroPoint;
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
|
||||
public boolean hasMetadata() {
|
||||
return metadataInfo != null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the packed associated file with the specified {@code fileName}.
|
||||
*
|
||||
* @param fileName the name of the associated file
|
||||
* @return the raw input stream containing specified file
|
||||
* @throws IllegalStateException if the model is not a zip file
|
||||
* @throws IllegalArgumentException if the specified file does not exist in the model
|
||||
*/
|
||||
public InputStream getAssociatedFile(String fileName) {
|
||||
assertZipFile();
|
||||
return zipFile.getRawInputStream(fileName);
|
||||
}
|
||||
|
||||
/** Gets the count of input tensors in the model. */
|
||||
public int getInputTensorCount() {
|
||||
return modelInfo.getInputTensorCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the metadata for the input tensor specified by {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex the index of the desired input tensor
|
||||
* @throws IllegalStateException if this model does not contain model metadata
|
||||
*/
|
||||
@Nullable
|
||||
public TensorMetadata getInputTensorMetadata(int inputIndex) {
|
||||
assertMetadataInfo();
|
||||
return metadataInfo.getInputTensorMetadata(inputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex the index of the desired input tensor
|
||||
*/
|
||||
public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
|
||||
Tensor tensor = modelInfo.getInputTensor(inputIndex);
|
||||
return modelInfo.getQuantizationParams(tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex the index of the desired input tensor
|
||||
*/
|
||||
public int[] getInputTensorShape(int inputIndex) {
|
||||
return modelInfo.getInputTensorShape(inputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the {@link TensorType} of the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex the index of the desired input tensor
|
||||
*/
|
||||
public byte getInputTensorType(int inputIndex) {
|
||||
return modelInfo.getInputTensorType(inputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the root handler for the model metadata.
|
||||
*
|
||||
* @throws IllegalStateException if this model does not contain model metadata
|
||||
*/
|
||||
public ModelMetadata getModelMetadata() {
|
||||
assertMetadataInfo();
|
||||
return metadataInfo.getModelMetadata();
|
||||
}
|
||||
|
||||
/** Gets the count of output tensors in the model. */
|
||||
public int getOutputTensorCount() {
|
||||
return modelInfo.getOutputTensorCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the metadata for the output tensor specified by {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex the index of the desired output tensor
|
||||
* @throws IllegalStateException if this model does not contain model metadata
|
||||
*/
|
||||
@Nullable
|
||||
public TensorMetadata getOutputTensorMetadata(int outputIndex) {
|
||||
assertMetadataInfo();
|
||||
return metadataInfo.getOutputTensorMetadata(outputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex the index of the desired output tensor
|
||||
*/
|
||||
public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
|
||||
Tensor tensor = modelInfo.getOutputTensor(outputIndex);
|
||||
return modelInfo.getQuantizationParams(tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the output tensor with {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex the index of the desired output tensor
|
||||
*/
|
||||
public int[] getOutputTensorShape(int outputIndex) {
|
||||
return modelInfo.getOutputTensorShape(outputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the {@link TensorType} of the output tensor with {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex the index of the desired output tensor
|
||||
*/
|
||||
public byte getOutputTensorType(int outputIndex) {
|
||||
return modelInfo.getOutputTensorType(outputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns {@code true} if the minimum parser version required by the given metadata flatbuffer
|
||||
* precedes or equals to the version of the metadata parser that this MetadataExtractor library is
|
||||
* relying on. All fields in the metadata can be parsed correctly with this metadata extractor
|
||||
* library in this case. Otherwise, it returns {@code false}.
|
||||
*
|
||||
* <p>For example, assume the underlying metadata parser version is {@code 1.14.1},
|
||||
*
|
||||
* <ul>
|
||||
* <li>it returns {@code true}, if the required minimum parser version is the same or older,
|
||||
* such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions,
|
||||
* because some metadata flatbuffers are generated before the first versioned release; <br>
|
||||
* <li>it returns {@code false}, if the required minimum parser version is newer, such as {@code
|
||||
* 1.14.2}.
|
||||
* </ul>
|
||||
*/
|
||||
public final boolean isMinimumParserVersionSatisfied() {
|
||||
String minVersion = metadataInfo.getMininumParserVersion();
|
||||
if (minVersion == null) {
|
||||
return true;
|
||||
}
|
||||
return compareVersions(minVersion, MetadataParser.VERSION) <= 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and this
|
||||
* is allowed. However, invoking methods that reads the metadata is not allowed.
|
||||
*
|
||||
* @throws IllegalStateException if this model does not contain model metadata
|
||||
*/
|
||||
private void assertMetadataInfo() {
|
||||
if (metadataInfo == null) {
|
||||
throw new IllegalStateException("This model does not contain model metadata.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus
|
||||
* are not Zip files. This is allowed. However, invoking methods that reads those associated files
|
||||
* is not allowed.
|
||||
*
|
||||
* @throws IllegalStateException if this model is not a Zip file
|
||||
*/
|
||||
private void assertZipFile() {
|
||||
if (zipFile == null) {
|
||||
throw new IllegalStateException(
|
||||
"This model does not contain associated files, and is not a Zip file.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
|
||||
* it does not have associated files, return a null handler.
|
||||
*
|
||||
* @param buffer the TFLite model FlatBuffer
|
||||
* @throws IOException if an error occurs while reading the model as a Zip file
|
||||
*/
|
||||
@Nullable
|
||||
private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
|
||||
try {
|
||||
// Creates the handler to hold the associated files through the Zip.
|
||||
ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
|
||||
return ZipFile.createFrom(byteBufferChannel);
|
||||
} catch (ZipException e) {
|
||||
// Some models may not have associate files. Therefore, Those models are not zip files.
|
||||
// However, invoking methods that read associated files later will lead into errors.
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compares two semantic version numbers.
|
||||
*
|
||||
* <p>Examples of comparing two versions: <br>
|
||||
* {@code 1.9} precedes {@code 1.14}; <br>
|
||||
* {@code 1.14} precedes {@code 1.14.1}; <br>
|
||||
* {@code 1.14} and {@code 1.14.0} are euqal;
|
||||
*
|
||||
* @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if
|
||||
* {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code
|
||||
* version2} precedes {@code version1}.
|
||||
*/
|
||||
private static int compareVersions(String version1, String version2) {
|
||||
// Using String.split instead of the recommanded Guava Splitter because we've been avoiding
|
||||
// depending on other third party libraries in this project.
|
||||
String[] levels1 = version1.split("\\.", 0);
|
||||
String[] levels2 = version2.split("\\.", 0);
|
||||
|
||||
int length = Math.max(levels1.length, levels2.length);
|
||||
for (int i = 0; i < length; i++) {
|
||||
Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0;
|
||||
Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0;
|
||||
int compare = v1.compareTo(v2);
|
||||
if (compare != 0) {
|
||||
return compare;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
/** Information about the metadata parser that this metadata extractor library is depending on. */
|
||||
public final class MetadataParser {
|
||||
/**
|
||||
* The version of the metadata parser that this metadata extractor library is depending on. The
|
||||
* value should match the value of "Schema Semantic version" in metadata_schema.fbs.
|
||||
*/
|
||||
public static final String VERSION = "1.0.1";
|
||||
|
||||
private MetadataParser() {}
|
||||
}
|
@ -1,266 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import org.tensorflow.lite.schema.Buffer;
|
||||
import org.tensorflow.lite.schema.Metadata;
|
||||
import org.tensorflow.lite.schema.Model;
|
||||
import org.tensorflow.lite.schema.QuantizationParameters;
|
||||
import org.tensorflow.lite.schema.SubGraph;
|
||||
import org.tensorflow.lite.schema.Tensor;
|
||||
import org.tensorflow.lite.schema.TensorType;
|
||||
import org.tensorflow.lite.support.metadata.MetadataExtractor.QuantizationParams;
|
||||
|
||||
/** Extracts model information out of TFLite model FLatBuffer. */
|
||||
final class ModelInfo {
|
||||
/** The model that is loaded from TFLite model FlatBuffer. */
|
||||
private final Model model;
|
||||
|
||||
/** A list of input tensors. */
|
||||
private final List</* @Nullable */ Tensor> inputTensors;
|
||||
|
||||
/** A list of output tensors. */
|
||||
private final List</* @Nullable */ Tensor> outputTensors;
|
||||
|
||||
/** Identifier of the TFLite model metadata in the Metadata array. */
|
||||
static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
|
||||
|
||||
/**
|
||||
* Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
|
||||
*
|
||||
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports
|
||||
* single subgraph so far. See the <a
|
||||
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
|
||||
* of how to specify subgraph during convertion for more information.</a> Therefore, all methods
|
||||
* in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
|
||||
*
|
||||
* @param buffer the TFLite model FlatBuffer
|
||||
* @throws NullPointerException if {@code buffer} is null
|
||||
* @throws IllegalArgumentException if the model does not contain any subgraph, or the model does
|
||||
* not contain the expected identifier
|
||||
*/
|
||||
ModelInfo(ByteBuffer buffer) {
|
||||
assertTFLiteModel(buffer);
|
||||
|
||||
model = Model.getRootAsModel(buffer);
|
||||
checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph.");
|
||||
|
||||
inputTensors = getInputTensors(model);
|
||||
outputTensors = getOutputTensors(model);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex The index of the desired input tensor.
|
||||
* @throws IllegalArgumentException if the inputIndex specified is invalid.
|
||||
*/
|
||||
@Nullable
|
||||
Tensor getInputTensor(int inputIndex) {
|
||||
checkArgument(
|
||||
inputIndex >= 0 && inputIndex < inputTensors.size(),
|
||||
"The inputIndex specified is invalid.");
|
||||
return inputTensors.get(inputIndex);
|
||||
}
|
||||
|
||||
int getInputTensorCount() {
|
||||
return inputTensors.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets shape of the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex The index of the desired intput tensor.
|
||||
*/
|
||||
int[] getInputTensorShape(int inputIndex) {
|
||||
Tensor tensor = getInputTensor(inputIndex);
|
||||
return getShape(tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex The index of the desired intput tensor.
|
||||
*/
|
||||
byte getInputTensorType(int inputIndex) {
|
||||
return getInputTensor(inputIndex).type();
|
||||
}
|
||||
|
||||
/** Gets the metadata FlatBuffer from the model FlatBuffer. */
|
||||
@Nullable
|
||||
ByteBuffer getMetadataBuffer() {
|
||||
// Some models may not have metadata, and this is allowed.
|
||||
if (model.metadataLength() == 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
for (int i = 0; i < model.metadataLength(); i++) {
|
||||
Metadata meta = model.metadata(i);
|
||||
if (METADATA_FIELD_NAME.equals(meta.name())) {
|
||||
long bufferIndex = meta.buffer();
|
||||
Buffer metadataBuf = model.buffers((int) bufferIndex);
|
||||
return metadataBuf.dataAsByteBuffer();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the output tensor with {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex The index of the desired outtput tensor.
|
||||
* @throws IllegalArgumentException if the outputIndex specified is invalid.
|
||||
*/
|
||||
@Nullable
|
||||
Tensor getOutputTensor(int outputIndex) {
|
||||
checkArgument(
|
||||
outputIndex >= 0 && outputIndex < outputTensors.size(),
|
||||
"The outputIndex specified is invalid.");
|
||||
return outputTensors.get(outputIndex);
|
||||
}
|
||||
|
||||
int getOutputTensorCount() {
|
||||
return outputTensors.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets shape of the output tensor with {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex The index of the desired outtput tensor.
|
||||
*/
|
||||
int[] getOutputTensorShape(int outputIndex) {
|
||||
Tensor tensor = getOutputTensor(outputIndex);
|
||||
return getShape(tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex The index of the desired outtput tensor.
|
||||
*/
|
||||
byte getOutputTensorType(int outputIndex) {
|
||||
return getOutputTensor(outputIndex).type();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the quantization parameters of a tensor.
|
||||
*
|
||||
* <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
|
||||
* quantized, the values of scale and zero_point are both 0.
|
||||
*
|
||||
* @param tensor The tensor whoes quantization parameters is desired.
|
||||
* @throws NullPointerException if the tensor is null.
|
||||
* @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link
|
||||
* QuantizationParameters} are not single values.
|
||||
*/
|
||||
QuantizationParams getQuantizationParams(Tensor tensor) {
|
||||
checkNotNull(tensor, "Tensor cannot be null.");
|
||||
|
||||
float scale;
|
||||
int zeroPoint;
|
||||
QuantizationParameters quantization = tensor.quantization();
|
||||
|
||||
// Tensors that are not quantized do not have quantization parameters, which can be null when
|
||||
// being extracted from the flatbuffer.
|
||||
if (quantization == null) {
|
||||
scale = 0.0f;
|
||||
zeroPoint = 0;
|
||||
return new QuantizationParams(scale, zeroPoint);
|
||||
}
|
||||
|
||||
// Tensors that are not quantized do not have quantization parameters.
|
||||
// quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
|
||||
checkArgument(
|
||||
quantization.scaleLength() <= 1,
|
||||
"Input and output tensors do not support per-channel quantization.");
|
||||
checkArgument(
|
||||
quantization.zeroPointLength() <= 1,
|
||||
"Input and output tensors do not support per-channel quantization.");
|
||||
|
||||
// For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will
|
||||
// both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++
|
||||
// runtime.
|
||||
scale = quantization.scale(0);
|
||||
// zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it
|
||||
// consistent with the C++ runtime.
|
||||
zeroPoint = (int) quantization.zeroPoint(0);
|
||||
|
||||
return new QuantizationParams(scale, zeroPoint);
|
||||
}
|
||||
|
||||
/**
|
||||
* Verifies if the buffer is a valid TFLite model.
|
||||
*
|
||||
* @param buffer the TFLite model flatbuffer
|
||||
* @throws NullPointerException if {@code buffer} is null.
|
||||
* @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
|
||||
*/
|
||||
private static void assertTFLiteModel(ByteBuffer buffer) {
|
||||
checkNotNull(buffer, "Model flatbuffer cannot be null.");
|
||||
checkArgument(
|
||||
Model.ModelBufferHasIdentifier(buffer),
|
||||
"The identifier of the model is invalid. The buffer may not be a valid TFLite model"
|
||||
+ " flatbuffer.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of a tensor.
|
||||
*
|
||||
* @param tensor The tensor whoes shape is desired.
|
||||
* @throws NullPointerException if the tensor is null.
|
||||
*/
|
||||
private static int[] getShape(Tensor tensor) {
|
||||
checkNotNull(tensor, "Tensor cannot be null.");
|
||||
int shapeDim = tensor.shapeLength();
|
||||
int[] tensorShape = new int[shapeDim];
|
||||
for (int i = 0; i < shapeDim; i++) {
|
||||
tensorShape[i] = tensor.shape(i);
|
||||
}
|
||||
return tensorShape;
|
||||
}
|
||||
|
||||
/** Gets input tensors from a model. */
|
||||
private static List<Tensor> getInputTensors(Model model) {
|
||||
// TFLite only support one subgraph currently.
|
||||
SubGraph subgraph = model.subgraphs(0);
|
||||
int tensorNum = subgraph.inputsLength();
|
||||
ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
|
||||
for (int i = 0; i < tensorNum; i++) {
|
||||
inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
|
||||
}
|
||||
return Collections.unmodifiableList(inputTensors);
|
||||
}
|
||||
|
||||
/** Gets output tensors from a model. */
|
||||
private static List<Tensor> getOutputTensors(Model model) {
|
||||
// TFLite only support one subgraph currently.
|
||||
SubGraph subgraph = model.subgraphs(0);
|
||||
int tensorNum = subgraph.outputsLength();
|
||||
ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
|
||||
for (int i = 0; i < tensorNum; i++) {
|
||||
outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
|
||||
}
|
||||
return Collections.unmodifiableList(outputTensors);
|
||||
}
|
||||
}
|
@ -1,153 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
|
||||
import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
|
||||
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
|
||||
|
||||
/** Extracts model metadata information out of TFLite metadata FlatBuffer. */
|
||||
final class ModelMetadataInfo {
|
||||
/** The root handler for the model metadata. */
|
||||
private final ModelMetadata modelMetadata;
|
||||
|
||||
/** Metadata array of input tensors. */
|
||||
private final List</* @Nullable */ TensorMetadata> inputsMetadata;
|
||||
|
||||
/** Metadata array of output tensors. */
|
||||
private final List</* @Nullable */ TensorMetadata> outputsMetadata;
|
||||
|
||||
/** The minimum parser version required to fully understand the metadata flatbuffer. */
|
||||
private final String /* @Nullable */ minVersion;
|
||||
|
||||
/**
|
||||
* Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
|
||||
*
|
||||
* @param buffer the TFLite metadata FlatBuffer
|
||||
* @throws NullPointerException if {@code buffer} is null
|
||||
* @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or
|
||||
* it does not contain the expected identifier
|
||||
*/
|
||||
ModelMetadataInfo(ByteBuffer buffer) {
|
||||
assertTFLiteMetadata(buffer);
|
||||
|
||||
modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
|
||||
checkArgument(
|
||||
modelMetadata.subgraphMetadataLength() > 0,
|
||||
"The metadata flatbuffer does not contain any subgraph metadata.");
|
||||
|
||||
inputsMetadata = getInputsMetadata(modelMetadata);
|
||||
outputsMetadata = getOutputsMetadata(modelMetadata);
|
||||
minVersion = modelMetadata.minParserVersion();
|
||||
}
|
||||
|
||||
/** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
|
||||
int getInputTensorCount() {
|
||||
return inputsMetadata.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the metadata for the input tensor specified by {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex The index of the desired intput tensor.
|
||||
* @throws IllegalArgumentException if the inputIndex specified is invalid.
|
||||
*/
|
||||
@Nullable
|
||||
TensorMetadata getInputTensorMetadata(int inputIndex) {
|
||||
checkArgument(
|
||||
inputIndex >= 0 && inputIndex < inputsMetadata.size(),
|
||||
"The inputIndex specified is invalid.");
|
||||
return inputsMetadata.get(inputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the minimum parser version of the metadata. It can be {@code null} if the version is not
|
||||
* populated.
|
||||
*/
|
||||
@Nullable
|
||||
String getMininumParserVersion() {
|
||||
return minVersion;
|
||||
}
|
||||
|
||||
/** Gets the root handler for the model metadata. */
|
||||
ModelMetadata getModelMetadata() {
|
||||
return modelMetadata;
|
||||
}
|
||||
|
||||
/** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
|
||||
int getOutputTensorCount() {
|
||||
return outputsMetadata.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the metadata for the output tensor specified by {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex The index of the desired output tensor.
|
||||
* @throws IllegalArgumentException if the outputIndex specified is invalid.
|
||||
*/
|
||||
@Nullable
|
||||
TensorMetadata getOutputTensorMetadata(int outputIndex) {
|
||||
checkArgument(
|
||||
outputIndex >= 0 && outputIndex < outputsMetadata.size(),
|
||||
"The outputIndex specified is invalid.");
|
||||
return outputsMetadata.get(outputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Verifies if the buffer is a valid TFLite metadata flatbuffer.
|
||||
*
|
||||
* @param buffer the TFLite metadata flatbuffer
|
||||
* @throws NullPointerException if {@code buffer} is null.
|
||||
* @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
|
||||
*/
|
||||
private static void assertTFLiteMetadata(ByteBuffer buffer) {
|
||||
checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
|
||||
checkArgument(
|
||||
ModelMetadata.ModelMetadataBufferHasIdentifier(buffer),
|
||||
"The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata"
|
||||
+ " flatbuffer.");
|
||||
}
|
||||
|
||||
/** Gets metadata for all input tensors. */
|
||||
private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
|
||||
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
|
||||
int tensorNum = subgraphMetadata.inputTensorMetadataLength();
|
||||
ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
|
||||
for (int i = 0; i < tensorNum; i++) {
|
||||
inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
|
||||
}
|
||||
return Collections.unmodifiableList(inputsMetadata);
|
||||
}
|
||||
|
||||
/** Gets metadata for all output tensors. */
|
||||
private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
|
||||
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
|
||||
int tensorNum = subgraphMetadata.outputTensorMetadataLength();
|
||||
ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
|
||||
for (int i = 0; i < tensorNum; i++) {
|
||||
outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
|
||||
}
|
||||
return Collections.unmodifiableList(outputsMetadata);
|
||||
}
|
||||
}
|
@ -1,184 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
|
||||
/** Static error checking util methods. */
|
||||
final class Preconditions {
|
||||
/**
|
||||
* Ensures that an object reference passed as a parameter to the calling method is not null.
|
||||
*
|
||||
* @param reference an object reference
|
||||
* @return the non-null reference that was validated
|
||||
* @throws NullPointerException if {@code reference} is null
|
||||
*/
|
||||
public static <T extends Object> T checkNotNull(T reference) {
|
||||
if (reference == null) {
|
||||
throw new NullPointerException("The object reference is null.");
|
||||
}
|
||||
return reference;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that an object reference passed as a parameter to the calling method is not null.
|
||||
*
|
||||
* @param reference an object reference
|
||||
* @param errorMessage the exception message to use if the check fails; will be converted to a
|
||||
* string using {@link String#valueOf(Object)}
|
||||
* @return the non-null reference that was validated
|
||||
* @throws NullPointerException if {@code reference} is null
|
||||
*/
|
||||
public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
|
||||
if (reference == null) {
|
||||
throw new NullPointerException(String.valueOf(errorMessage));
|
||||
}
|
||||
return reference;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that the given String is not empty and not null.
|
||||
*
|
||||
* @param string the String to test
|
||||
* @return the non-null non-empty String that was validated
|
||||
* @throws IllegalArgumentException if {@code string} is null or empty
|
||||
*/
|
||||
public static String checkNotEmpty(String string) {
|
||||
if (string == null || string.length() == 0) {
|
||||
throw new IllegalArgumentException("Given String is empty or null.");
|
||||
}
|
||||
return string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that the given String is not empty and not null.
|
||||
*
|
||||
* @param string the String to test
|
||||
* @param errorMessage the exception message to use if the check fails; will be converted to a
|
||||
* string using {@link String#valueOf(Object)}
|
||||
* @return the non-null non-empty String that was validated
|
||||
* @throws IllegalArgumentException if {@code string} is null or empty
|
||||
*/
|
||||
public static String checkNotEmpty(String string, Object errorMessage) {
|
||||
if (string == null || string.length() == 0) {
|
||||
throw new IllegalArgumentException(String.valueOf(errorMessage));
|
||||
}
|
||||
return string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the truth of an expression involving one or more parameters to the calling method.
|
||||
*
|
||||
* @param expression a boolean expression.
|
||||
* @throws IllegalArgumentException if {@code expression} is false.
|
||||
*/
|
||||
public static void checkArgument(boolean expression) {
|
||||
if (!expression) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the truth of an expression involving one or more parameters to the calling method.
|
||||
*
|
||||
* @param expression a boolean expression.
|
||||
* @param errorMessage the exception message to use if the check fails; will be converted to a
|
||||
* string using {@link String#valueOf(Object)}.
|
||||
* @throws IllegalArgumentException if {@code expression} is false.
|
||||
*/
|
||||
public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
|
||||
if (!expression) {
|
||||
throw new IllegalArgumentException(String.valueOf(errorMessage));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
|
||||
* {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
|
||||
*
|
||||
* @param index a user-supplied index identifying an element of an array, list or string
|
||||
* @param size the size of that array, list or string
|
||||
* @return the value of {@code index}
|
||||
* @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
|
||||
* @throws IllegalArgumentException if {@code size} is negative
|
||||
*/
|
||||
public static int checkElementIndex(int index, int size) {
|
||||
return checkElementIndex(index, size, "index");
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
|
||||
* {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
|
||||
*
|
||||
* @param index a user-supplied index identifying an element of an array, list or string
|
||||
* @param size the size of that array, list or string
|
||||
* @param desc the text to use to describe this index in an error message
|
||||
* @return the value of {@code index}
|
||||
* @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
|
||||
* @throws IllegalArgumentException if {@code size} is negative
|
||||
*/
|
||||
public static int checkElementIndex(int index, int size, @Nullable String desc) {
|
||||
// Carefully optimized for execution by hotspot (explanatory comment above)
|
||||
if (index < 0 || index >= size) {
|
||||
throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the truth of an expression involving the state of the calling instance, but not
|
||||
* involving any parameters to the calling method.
|
||||
*
|
||||
* @param expression a boolean expression
|
||||
* @throws IllegalStateException if {@code expression} is false
|
||||
* @see Verify#verify Verify.verify()
|
||||
*/
|
||||
public static void checkState(boolean expression) {
|
||||
if (!expression) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the truth of an expression involving the state of the calling instance, but not
|
||||
* involving any parameters to the calling method.
|
||||
*
|
||||
* @param expression a boolean expression
|
||||
* @param errorMessage the exception message to use if the check fails; will be converted to a
|
||||
* string using {@link String#valueOf(Object)}
|
||||
* @throws IllegalStateException if {@code expression} is false
|
||||
* @see Verify#verify Verify.verify()
|
||||
*/
|
||||
public static void checkState(boolean expression, @Nullable Object errorMessage) {
|
||||
if (!expression) {
|
||||
throw new IllegalStateException(String.valueOf(errorMessage));
|
||||
}
|
||||
}
|
||||
|
||||
private static String badElementIndex(int index, int size, @Nullable String desc) {
|
||||
if (index < 0) {
|
||||
return String.format("%s (%s) must not be negative", desc, index);
|
||||
} else if (size < 0) {
|
||||
throw new IllegalArgumentException("negative size: " + size);
|
||||
} else { // index >= size
|
||||
return String.format("%s (%s) must be less than size (%s)", desc, index, size);
|
||||
}
|
||||
}
|
||||
|
||||
private Preconditions() {
|
||||
throw new AssertionError("Preconditions is Uninstantiable.");
|
||||
}
|
||||
}
|
@ -1,107 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.Channel;
|
||||
|
||||
/**
|
||||
* A byte channel that maintains a current <i>position</i> and allows the position to be changed.
|
||||
* {@link SeekableByteChannelCompat} is compatible with {@link
|
||||
* java.nio.channels.SeekableByteChannel}.
|
||||
*
|
||||
* <p>{@link java.nio.channels.SeekableByteChannel} is not available in Android API 23 and under.
|
||||
* Therefore, {@link SeekableByteChannelCompat} is introduced here to make the interfaces used in
|
||||
* the MetadtaExtractor library consistent with the common used Java libraries.
|
||||
*/
|
||||
interface SeekableByteChannelCompat extends Channel {
|
||||
/**
|
||||
* Reads a sequence of bytes from this channel into the given buffer.
|
||||
*
|
||||
* @param dst The buffer into which bytes are to be transferred
|
||||
* @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
|
||||
* end-of-stream
|
||||
* @throws NonReadableChannelException If this channel was not opened for reading
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws AsynchronousCloseException If another thread closes this channel while the read
|
||||
* operation is in progress
|
||||
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
|
||||
* read operation is in progress, thereby closing the channel and setting the current thread's
|
||||
* interrupt status
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
int read(ByteBuffer dst) throws IOException;
|
||||
|
||||
/**
|
||||
* Writes a sequence of bytes to this channel from the given buffer.
|
||||
*
|
||||
* @param src The buffer from which bytes are to be retrieved
|
||||
* @return The number of bytes written, possibly zero
|
||||
* @throws NonWritableChannelException If this channel was not opened for writing
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws AsynchronousCloseException If another thread closes this channel while the write
|
||||
* operation is in progress
|
||||
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
|
||||
* write operation is in progress, thereby closing the channel and setting the current
|
||||
* thread's interrupt status
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
int write(ByteBuffer src) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns this channel's position.
|
||||
*
|
||||
* @return This channel's position, a non-negative integer counting the number of bytes from the
|
||||
* beginning of the entity to the current position
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
long position() throws IOException;
|
||||
|
||||
/**
|
||||
* Sets this channel's position.
|
||||
*
|
||||
* @param newPosition The new position, a non-negative integer counting the number of bytes from
|
||||
* the beginning of the entity
|
||||
* @return This channel
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws IllegalArgumentException If the new position is negative
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
SeekableByteChannelCompat position(long newPosition) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns the current size of entity to which this channel is connected.
|
||||
*
|
||||
* @return The current size, measured in bytes
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
long size() throws IOException;
|
||||
|
||||
/**
|
||||
* Truncates the entity, to which this channel is connected, to the given size.
|
||||
*
|
||||
* @param size The new size, a non-negative byte count
|
||||
* @return This channel
|
||||
* @throws NonWritableChannelException If this channel was not opened for writing
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws IllegalArgumentException If the new size is negative
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
SeekableByteChannelCompat truncate(long size) throws IOException;
|
||||
}
|
@ -1,427 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
|
||||
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.EOFException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.zip.ZipException;
|
||||
|
||||
/**
|
||||
* Reads uncompressed files from the TFLite model, a zip file.
|
||||
*
|
||||
* <p>TODO(b/150237111): add a link to the webpage of MetadataPopulator once it's available.
|
||||
*
|
||||
* <p>A TFLite model file becomes a zip file when it contains associated files. The associated files
|
||||
* can be packed to a TFLite model file using the MetadataPopulator. The associated files are not
|
||||
* compressed when being added to the model file.
|
||||
*
|
||||
* <p>{@link ZipFile} does not support Zip64 format, because TFLite models are much smaller than the
|
||||
* size limit for Zip64, which is 4GB.
|
||||
*/
|
||||
final class ZipFile implements Closeable {
|
||||
/** Maps String to list of ZipEntrys, name -> actual entries. */
|
||||
private final Map<String, List<ZipEntry>> nameMap;
|
||||
|
||||
/** The actual data source. */
|
||||
private final ByteBufferChannel archive;
|
||||
|
||||
/**
|
||||
* Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
|
||||
* ZipFile} does not synchronized over the buffer that is passed into it.
|
||||
*
|
||||
* @param channel the archive
|
||||
* @throws IOException if an error occurs while creating this {@link ZipFile}
|
||||
* @throws ZipException if the channel is not a zip archive
|
||||
* @throws NullPointerException if the archive is null
|
||||
*/
|
||||
public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
|
||||
checkNotNull(channel);
|
||||
ZipParser zipParser = new ZipParser(channel);
|
||||
Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
|
||||
return new ZipFile(channel, nameMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
archive.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Exposes the raw stream of the archive entry.
|
||||
*
|
||||
* <p>Since the associated files will not be compressed when being packed to the zip file, the raw
|
||||
* stream represents the non-compressed files.
|
||||
*
|
||||
* <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
|
||||
* threads concurrently reading from the returned {@link InputStream}, it must be synchronized
|
||||
* externally.
|
||||
*
|
||||
* @param name name of the entry to get the stream for
|
||||
* @return the raw input stream containing data
|
||||
* @throws IllegalArgumentException if the specified file does not exist in the zip file
|
||||
*/
|
||||
public InputStream getRawInputStream(String name) {
|
||||
checkArgument(
|
||||
nameMap.containsKey(name),
|
||||
String.format("The file, %s, does not exist in the zip file.", name));
|
||||
|
||||
List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
|
||||
ZipEntry entry = entriesWithTheSameName.get(0);
|
||||
long start = entry.getDataOffset();
|
||||
long remaining = entry.getSize();
|
||||
return new BoundedInputStream(archive, start, remaining);
|
||||
}
|
||||
|
||||
private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
|
||||
archive = channel;
|
||||
this.nameMap = nameMap;
|
||||
}
|
||||
|
||||
/* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
|
||||
private static class ZipParser {
|
||||
private final ByteBufferChannel archive;
|
||||
|
||||
// Cached buffers that will only be used locally in the class to reduce garbage collection.
|
||||
private final ByteBuffer longBuffer =
|
||||
ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
|
||||
private final ByteBuffer intBuffer =
|
||||
ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
|
||||
private final ByteBuffer shortBuffer =
|
||||
ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
private ZipParser(ByteBufferChannel archive) {
|
||||
this.archive = archive;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses the underlying {@code archive} and returns the information as a list of {@link
|
||||
* ZipEntry}.
|
||||
*/
|
||||
private Map<String, List<ZipEntry>> parseEntries() throws IOException {
|
||||
List<ZipEntry> entries = parseCentralDirectory();
|
||||
return parseLocalFileHeaderData(entries);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the current position contains a central file header signature, {@link
|
||||
* ZipConstants#CENSIG}.
|
||||
*/
|
||||
private boolean foundCentralFileheaderSignature() {
|
||||
long signature = (long) getInt();
|
||||
return signature == ZipConstants.CENSIG;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the value as a Java int from two bytes starting at the current position of the archive.
|
||||
*/
|
||||
private int getShort() {
|
||||
shortBuffer.rewind();
|
||||
archive.read(shortBuffer);
|
||||
shortBuffer.flip();
|
||||
return (int) shortBuffer.getShort();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the value as a Java long from four bytes starting at the current position of the
|
||||
* archive.
|
||||
*/
|
||||
private int getInt() {
|
||||
intBuffer.rewind();
|
||||
archive.read(intBuffer);
|
||||
intBuffer.flip();
|
||||
return intBuffer.getInt();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the value as a Java long from four bytes starting at the current position of the
|
||||
* archive.
|
||||
*/
|
||||
private long getLong() {
|
||||
longBuffer.rewind();
|
||||
archive.read(longBuffer);
|
||||
longBuffer.flip();
|
||||
return longBuffer.getLong();
|
||||
}
|
||||
|
||||
/**
|
||||
* Positions the archive at the start of the central directory.
|
||||
*
|
||||
* <p>First, it searches for the signature of the "end of central directory record", {@link
|
||||
* ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
|
||||
* record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG}
|
||||
* should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file.
|
||||
*
|
||||
* <p>Then, parse the "end of central dir record" and position the archive at the start of the
|
||||
* central directory.
|
||||
*/
|
||||
private void locateCentralDirectory() throws IOException {
|
||||
if (archive.size() < ZipConstants.ENDHDR) {
|
||||
throw new ZipException("The archive is not a ZIP archive.");
|
||||
}
|
||||
|
||||
// Positions the archive at the start of the "end of central directory record".
|
||||
long offsetRecord = archive.size() - ZipConstants.ENDHDR;
|
||||
archive.position(offsetRecord);
|
||||
|
||||
// Checks for the signature, {@link ZipConstants#ENDSIG}.
|
||||
long endSig = getLong();
|
||||
if (endSig != ZipConstants.ENDSIG) {
|
||||
throw new ZipException("The archive is not a ZIP archive.");
|
||||
}
|
||||
|
||||
// Positions the archive at the “offset of central directory”.
|
||||
skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
|
||||
// Gets the offset to central directory
|
||||
long offsetDirectory = getInt();
|
||||
// Goes to the central directory.
|
||||
archive.position(offsetDirectory);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads the central directory of the given archive and populates the internal tables with
|
||||
* {@link ZipEntry} instances.
|
||||
*/
|
||||
private List<ZipEntry> parseCentralDirectory() throws IOException {
|
||||
/** List of entries in the order they appear inside the central directory. */
|
||||
List<ZipEntry> entries = new ArrayList<>();
|
||||
locateCentralDirectory();
|
||||
|
||||
while (foundCentralFileheaderSignature()) {
|
||||
ZipEntry entry = parseCentralDirectoryEntry();
|
||||
entries.add(entry);
|
||||
}
|
||||
|
||||
return entries;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to
|
||||
* the global maps.
|
||||
*/
|
||||
private ZipEntry parseCentralDirectoryEntry() throws IOException {
|
||||
// Positions the archive at the "compressed size" and read the value.
|
||||
skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
|
||||
long compressSize = getInt();
|
||||
|
||||
// Positions the archive at the "filename length" and read the value.
|
||||
skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
|
||||
int fileNameLen = getShort();
|
||||
|
||||
// Reads the extra field length and the comment length.
|
||||
int extraLen = getShort();
|
||||
int commentLen = getShort();
|
||||
|
||||
// Positions the archive at the "local file header offset" and read the value.
|
||||
skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
|
||||
long localHeaderOffset = getInt();
|
||||
|
||||
// Reads the file name.
|
||||
byte[] fileNameBuf = new byte[fileNameLen];
|
||||
archive.read(ByteBuffer.wrap(fileNameBuf));
|
||||
String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
|
||||
|
||||
// Skips the extra field and the comment.
|
||||
skipBytes(extraLen + commentLen);
|
||||
|
||||
ZipEntry entry = new ZipEntry();
|
||||
entry.setSize(compressSize);
|
||||
entry.setLocalHeaderOffset(localHeaderOffset);
|
||||
entry.setName(fileName);
|
||||
|
||||
return entry;
|
||||
}
|
||||
|
||||
/** Walks through all recorded entries and records the offsets for the entry data. */
|
||||
private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
|
||||
/** Maps String to list of ZipEntrys, name -> actual entries. */
|
||||
Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
|
||||
|
||||
for (ZipEntry entry : entries) {
|
||||
long offset = entry.getLocalHeaderOffset();
|
||||
archive.position(offset + ZipConstants.LOCNAM);
|
||||
|
||||
// Gets the data offset of this entry.
|
||||
int fileNameLen = getShort();
|
||||
int extraFieldLen = getShort();
|
||||
long dataOffset =
|
||||
offset
|
||||
+ ZipConstants.LOCEXT
|
||||
+ ZipConstants.SHORT_BYTE_SIZE
|
||||
+ fileNameLen
|
||||
+ extraFieldLen;
|
||||
entry.setDataOffset(dataOffset);
|
||||
|
||||
// Puts the entry into the nameMap.
|
||||
String name = entry.getName();
|
||||
List<ZipEntry> entriesWithTheSameName;
|
||||
if (nameMap.containsKey(name)) {
|
||||
entriesWithTheSameName = nameMap.get(name);
|
||||
} else {
|
||||
entriesWithTheSameName = new ArrayList<>();
|
||||
nameMap.put(name, entriesWithTheSameName);
|
||||
}
|
||||
entriesWithTheSameName.add(entry);
|
||||
}
|
||||
|
||||
return nameMap;
|
||||
}
|
||||
|
||||
/** Skips the given number of bytes or throws an EOFException if skipping failed. */
|
||||
private void skipBytes(int count) throws IOException {
|
||||
long currentPosition = archive.position();
|
||||
long newPosition = currentPosition + count;
|
||||
if (newPosition > archive.size()) {
|
||||
throw new EOFException();
|
||||
}
|
||||
archive.position(newPosition);
|
||||
}
|
||||
}
|
||||
|
||||
/** Stores the data offset and the size of an entry in the archive. */
|
||||
private static class ZipEntry {
|
||||
|
||||
private String name;
|
||||
private long dataOffset = -1;
|
||||
private long size = -1;
|
||||
private long localHeaderOffset = -1;
|
||||
|
||||
public long getSize() {
|
||||
return size;
|
||||
}
|
||||
|
||||
public long getDataOffset() {
|
||||
return dataOffset;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public long getLocalHeaderOffset() {
|
||||
return localHeaderOffset;
|
||||
}
|
||||
|
||||
public void setSize(long size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
public void setDataOffset(long dataOffset) {
|
||||
this.dataOffset = dataOffset;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public void setLocalHeaderOffset(long localHeaderOffset) {
|
||||
this.localHeaderOffset = localHeaderOffset;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Various constants for this {@link ZipFile}.
|
||||
*
|
||||
* <p>Referenced from {@link java.util.zip.ZipConstants}.
|
||||
*/
|
||||
private static class ZipConstants {
|
||||
/** length of Java short in bytes. */
|
||||
static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
|
||||
|
||||
/** length of Java int in bytes. */
|
||||
static final int INT_BYTE_SIZE = Integer.SIZE / 8;
|
||||
|
||||
/** length of Java long in bytes. */
|
||||
static final int LONG_BYTE_SIZE = Long.SIZE / 8;
|
||||
|
||||
/*
|
||||
* Header signatures
|
||||
*/
|
||||
static final long LOCSIG = 0x04034b50L; // "PK\003\004"
|
||||
static final long EXTSIG = 0x08074b50L; // "PK\007\008"
|
||||
static final long CENSIG = 0x02014b50L; // "PK\001\002"
|
||||
static final long ENDSIG = 0x06054b50L; // "PK\005\006"
|
||||
|
||||
/*
|
||||
* Header sizes in bytes (including signatures)
|
||||
*/
|
||||
static final int LOCHDR = 30; // LOC header size
|
||||
static final int EXTHDR = 16; // EXT header size
|
||||
static final int CENHDR = 46; // CEN header size
|
||||
static final int ENDHDR = 22; // END header size
|
||||
|
||||
/*
|
||||
* Local file (LOC) header field offsets
|
||||
*/
|
||||
static final int LOCVER = 4; // version needed to extract
|
||||
static final int LOCFLG = 6; // general purpose bit flag
|
||||
static final int LOCHOW = 8; // compression method
|
||||
static final int LOCTIM = 10; // modification time
|
||||
static final int LOCCRC = 14; // uncompressed file crc-32 value
|
||||
static final int LOCSIZ = 18; // compressed size
|
||||
static final int LOCLEN = 22; // uncompressed size
|
||||
static final int LOCNAM = 26; // filename length
|
||||
static final int LOCEXT = 28; // extra field length
|
||||
|
||||
/*
|
||||
* Extra local (EXT) header field offsets
|
||||
*/
|
||||
static final int EXTCRC = 4; // uncompressed file crc-32 value
|
||||
static final int EXTSIZ = 8; // compressed size
|
||||
static final int EXTLEN = 12; // uncompressed size
|
||||
|
||||
/*
|
||||
* Central directory (CEN) header field offsets
|
||||
*/
|
||||
static final int CENVEM = 4; // version made by
|
||||
static final int CENVER = 6; // version needed to extract
|
||||
static final int CENFLG = 8; // encrypt, decrypt flags
|
||||
static final int CENHOW = 10; // compression method
|
||||
static final int CENTIM = 12; // modification time
|
||||
static final int CENCRC = 16; // uncompressed file crc-32 value
|
||||
static final int CENSIZ = 20; // compressed size
|
||||
static final int CENLEN = 24; // uncompressed size
|
||||
static final int CENNAM = 28; // filename length
|
||||
static final int CENEXT = 30; // extra field length
|
||||
static final int CENCOM = 32; // comment length
|
||||
static final int CENDSK = 34; // disk number start
|
||||
static final int CENATT = 36; // internal file attributes
|
||||
static final int CENATX = 38; // external file attributes
|
||||
static final int CENOFF = 42; // LOC header offset
|
||||
|
||||
/*
|
||||
* End of central directory (END) header field offsets
|
||||
*/
|
||||
static final int ENDSUB = 8; // number of entries on this disk
|
||||
static final int ENDTOT = 10; // total number of entries
|
||||
static final int ENDSIZ = 12; // central directory size in bytes
|
||||
static final int ENDOFF = 16; // offset of first CEN header
|
||||
static final int ENDCOM = 20; // zip file comment length
|
||||
|
||||
private ZipConstants() {}
|
||||
}
|
||||
}
|
@ -1,615 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorFlow Lite metadata tools."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
from flatbuffers.python import flatbuffers
|
||||
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
|
||||
from tensorflow.lite.experimental.support.metadata.cc.python import _pywrap_metadata_version
|
||||
from tensorflow.lite.experimental.support.metadata.flatbuffers_lib import _pywrap_flatbuffers
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_FLATC_TFLITE_METADATA_SCHEMA_FILE = resource_loader.get_path_to_datafile(
|
||||
"metadata_schema.fbs")
|
||||
|
||||
|
||||
# TODO(b/141467403): add delete method for associated files.
|
||||
class MetadataPopulator(object):
|
||||
"""Packs metadata and associated files into TensorFlow Lite model file.
|
||||
|
||||
MetadataPopulator can be used to populate metadata and model associated files
|
||||
into a model file or a model buffer (in bytearray). It can also help to
|
||||
inspect list of files that have been packed into the model or are supposed to
|
||||
be packed into the model.
|
||||
|
||||
The metadata file (or buffer) should be generated based on the metadata
|
||||
schema:
|
||||
third_party/tensorflow/lite/schema/metadata_schema.fbs
|
||||
|
||||
Example usage:
|
||||
Populate matadata and label file into an image classifier model.
|
||||
|
||||
First, based on metadata_schema.fbs, generate the metadata for this image
|
||||
classifer model using Flatbuffers API. Attach the label file onto the ouput
|
||||
tensor (the tensor of probabilities) in the metadata.
|
||||
|
||||
Then, pack the metadata and label file into the model as follows.
|
||||
|
||||
```python
|
||||
# Populating a metadata file (or a metadta buffer) and associated files to
|
||||
a model file:
|
||||
populator = MetadataPopulator.with_model_file(model_file)
|
||||
# For metadata buffer (bytearray read from the metadata file), use:
|
||||
# populator.load_metadata_buffer(metadata_buf)
|
||||
populator.load_metadata_file(metadata_file)
|
||||
populator.load_associated_files([label.txt])
|
||||
populator.populate()
|
||||
|
||||
# Populating a metadata file (or a metadta buffer) and associated files to
|
||||
a model buffer:
|
||||
populator = MetadataPopulator.with_model_buffer(model_buf)
|
||||
populator.load_metadata_file(metadata_file)
|
||||
populator.load_associated_files([label.txt])
|
||||
populator.populate()
|
||||
# Writing the updated model buffer into a file.
|
||||
updated_model_buf = populator.get_model_buffer()
|
||||
with open("updated_model.tflite", "wb") as f:
|
||||
f.write(updated_model_buf)
|
||||
```
|
||||
|
||||
Note that existing metadata buffer (if applied) will be overridden by the new
|
||||
metadata buffer.
|
||||
"""
|
||||
# As Zip API is used to concatenate associated files after tflite model file,
|
||||
# the populating operation is developed based on a model file. For in-memory
|
||||
# model buffer, we create a tempfile to serve the populating operation.
|
||||
# Creating the deleting such a tempfile is handled by the class,
|
||||
# _MetadataPopulatorWithBuffer.
|
||||
|
||||
METADATA_FIELD_NAME = "TFLITE_METADATA"
|
||||
TFLITE_FILE_IDENTIFIER = b"TFL3"
|
||||
METADATA_FILE_IDENTIFIER = b"M001"
|
||||
|
||||
def __init__(self, model_file):
|
||||
"""Constructor for MetadataPopulator.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
_assert_model_file_identifier(model_file)
|
||||
self._model_file = model_file
|
||||
self._metadata_buf = None
|
||||
self._associated_files = set()
|
||||
|
||||
@classmethod
|
||||
def with_model_file(cls, model_file):
|
||||
"""Creates a MetadataPopulator object that populates data to a model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Returns:
|
||||
MetadataPopulator object.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
return cls(model_file)
|
||||
|
||||
# TODO(b/141468993): investigate if type check can be applied to model_buf for
|
||||
# FB.
|
||||
@classmethod
|
||||
def with_model_buffer(cls, model_buf):
|
||||
"""Creates a MetadataPopulator object that populates data to a model buffer.
|
||||
|
||||
Args:
|
||||
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Returns:
|
||||
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
|
||||
|
||||
Raises:
|
||||
ValueError: the model does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
return _MetadataPopulatorWithBuffer(model_buf)
|
||||
|
||||
def get_model_buffer(self):
|
||||
"""Gets the buffer of the model with packed metadata and associated files.
|
||||
|
||||
Returns:
|
||||
Model buffer (in bytearray).
|
||||
"""
|
||||
with open(self._model_file, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def get_packed_associated_file_list(self):
|
||||
"""Gets a list of associated files packed to the model file.
|
||||
|
||||
Returns:
|
||||
List of packed associated files.
|
||||
"""
|
||||
if not zipfile.is_zipfile(self._model_file):
|
||||
return []
|
||||
|
||||
with zipfile.ZipFile(self._model_file, "r") as zf:
|
||||
return zf.namelist()
|
||||
|
||||
def get_recorded_associated_file_list(self):
|
||||
"""Gets a list of associated files recorded in metadata of the model file.
|
||||
|
||||
Associated files may be attached to a model, a subgraph, or an input/output
|
||||
tensor.
|
||||
|
||||
Returns:
|
||||
List of recorded associated files.
|
||||
"""
|
||||
recorded_files = []
|
||||
|
||||
if not self._metadata_buf:
|
||||
return recorded_files
|
||||
|
||||
metadata = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
|
||||
self._metadata_buf, 0)
|
||||
|
||||
# Add associated files attached to ModelMetadata
|
||||
self._get_associated_files_from_metadata_struct(metadata, recorded_files)
|
||||
|
||||
# Add associated files attached to each SubgraphMetadata
|
||||
for j in range(metadata.SubgraphMetadataLength()):
|
||||
subgraph = metadata.SubgraphMetadata(j)
|
||||
self._get_associated_files_from_metadata_struct(subgraph, recorded_files)
|
||||
|
||||
# Add associated files attached to each input tensor
|
||||
for k in range(subgraph.InputTensorMetadataLength()):
|
||||
tensor = subgraph.InputTensorMetadata(k)
|
||||
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
|
||||
|
||||
# Add associated files attached to each output tensor
|
||||
for k in range(subgraph.OutputTensorMetadataLength()):
|
||||
tensor = subgraph.OutputTensorMetadata(k)
|
||||
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
|
||||
|
||||
return recorded_files
|
||||
|
||||
def load_associated_files(self, associated_files):
|
||||
"""Loads associated files that to be concatenated after the model file.
|
||||
|
||||
Args:
|
||||
associated_files: list of file paths.
|
||||
|
||||
Raises:
|
||||
IOError:
|
||||
File not found.
|
||||
"""
|
||||
for af in associated_files:
|
||||
_assert_exist(af)
|
||||
self._associated_files.add(af)
|
||||
|
||||
def load_metadata_buffer(self, metadata_buf):
|
||||
"""Loads the metadata buffer (in bytearray) to be populated.
|
||||
|
||||
Args:
|
||||
metadata_buf: metadata buffer (in bytearray) to be populated.
|
||||
|
||||
Raises:
|
||||
ValueError: The metadata to be populated is empty.
|
||||
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||
ValueError: Error occurs when getting the minimum metadata parser version.
|
||||
"""
|
||||
if not metadata_buf:
|
||||
raise ValueError("The metadata to be populated is empty.")
|
||||
|
||||
_assert_metadata_buffer_identifier(metadata_buf)
|
||||
|
||||
# Gets the minimum metadata parser version of the metadata_buf.
|
||||
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
|
||||
bytes(metadata_buf))
|
||||
|
||||
# Inserts in the minimum metadata parser version into the metadata_buf.
|
||||
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
|
||||
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
|
||||
metadata.minParserVersion = min_version
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf_with_version = b.Output()
|
||||
|
||||
self._metadata_buf = metadata_buf_with_version
|
||||
|
||||
def load_metadata_file(self, metadata_file):
|
||||
"""Loads the metadata file to be populated.
|
||||
|
||||
Args:
|
||||
metadata_file: path to the metadata file to be populated.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: The metadata does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
_assert_exist(metadata_file)
|
||||
with open(metadata_file, "rb") as f:
|
||||
metadata_buf = f.read()
|
||||
self.load_metadata_buffer(bytearray(metadata_buf))
|
||||
|
||||
def populate(self):
|
||||
"""Populates loaded metadata and associated files into the model file."""
|
||||
self._assert_validate()
|
||||
self._populate_metadata_buffer()
|
||||
self._populate_associated_files()
|
||||
|
||||
def _assert_validate(self):
|
||||
"""Validates the metadata and associated files to be populated.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
File is recorded in the metadata, but is not going to be populated.
|
||||
File has already been packed.
|
||||
"""
|
||||
# Gets files that are recorded in metadata.
|
||||
recorded_files = self.get_recorded_associated_file_list()
|
||||
|
||||
# Gets files that have been packed to self._model_file.
|
||||
packed_files = self.get_packed_associated_file_list()
|
||||
|
||||
# Gets the file name of those associated files to be populated.
|
||||
to_be_populated_files = []
|
||||
for af in self._associated_files:
|
||||
to_be_populated_files.append(os.path.basename(af))
|
||||
|
||||
# Checks all files recorded in the metadata will be populated.
|
||||
for rf in recorded_files:
|
||||
if rf not in to_be_populated_files and rf not in packed_files:
|
||||
raise ValueError("File, '{0}', is recorded in the metadata, but has "
|
||||
"not been loaded into the populator.".format(rf))
|
||||
|
||||
for f in to_be_populated_files:
|
||||
if f in packed_files:
|
||||
raise ValueError("File, '{0}', has already been packed.".format(f))
|
||||
|
||||
if f not in recorded_files:
|
||||
warnings.warn(
|
||||
"File, '{0}', does not exsit in the metadata. But packing it to "
|
||||
"tflite model is still allowed.".format(f))
|
||||
|
||||
def _copy_archived_files(self, src_zip, dst_zip, file_list):
|
||||
"""Copy archieved files in file_list from src_zip ro dst_zip."""
|
||||
|
||||
if not zipfile.is_zipfile(src_zip):
|
||||
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
|
||||
|
||||
with zipfile.ZipFile(src_zip,
|
||||
"r") as src_zf, zipfile.ZipFile(dst_zip,
|
||||
"a") as dst_zf:
|
||||
src_list = src_zf.namelist()
|
||||
for f in file_list:
|
||||
if f not in src_list:
|
||||
raise ValueError(
|
||||
"File, '{0}', does not exist in the zipfile, {1}.".format(
|
||||
f, src_zip))
|
||||
file_buffer = src_zf.read(f)
|
||||
dst_zf.writestr(f, file_buffer)
|
||||
|
||||
def _get_associated_files_from_metadata_struct(self, file_holder, file_list):
|
||||
for j in range(file_holder.AssociatedFilesLength()):
|
||||
file_list.append(file_holder.AssociatedFiles(j).Name().decode("utf-8"))
|
||||
|
||||
def _populate_associated_files(self):
|
||||
"""Concatenates associated files after TensorFlow Lite model file.
|
||||
|
||||
If the MetadataPopulator object is created using the method,
|
||||
with_model_file(model_file), the model file will be updated.
|
||||
"""
|
||||
# Opens up the model file in "appending" mode.
|
||||
# If self._model_file already has pack files, zipfile will concatenate
|
||||
# addition files after self._model_file. For example, suppose we have
|
||||
# self._model_file = old_tflite_file | label1.txt | label2.txt
|
||||
# Then after trigger populate() to add label3.txt, self._model_file becomes
|
||||
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
|
||||
with zipfile.ZipFile(self._model_file, "a") as zf:
|
||||
for af in self._associated_files:
|
||||
filename = os.path.basename(af)
|
||||
zf.write(af, filename)
|
||||
|
||||
def _populate_metadata_buffer(self):
|
||||
"""Populates the metadata buffer (in bytearray) into the model file.
|
||||
|
||||
Inserts metadata_buf into the metadata field of schema.Model. If the
|
||||
MetadataPopulator object is created using the method,
|
||||
with_model_file(model_file), the model file will be updated.
|
||||
|
||||
Existing metadata buffer (if applied) will be overridden by the new metadata
|
||||
buffer.
|
||||
"""
|
||||
|
||||
with open(self._model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
|
||||
model = _schema_fb.ModelT.InitFromObj(
|
||||
_schema_fb.Model.GetRootAsModel(model_buf, 0))
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
buffer_field.data = self._metadata_buf
|
||||
|
||||
is_populated = False
|
||||
if not model.metadata:
|
||||
model.metadata = []
|
||||
else:
|
||||
# Check if metadata has already been populated.
|
||||
for meta in model.metadata:
|
||||
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
|
||||
is_populated = True
|
||||
model.buffers[meta.buffer] = buffer_field
|
||||
|
||||
if not is_populated:
|
||||
if not model.buffers:
|
||||
model.buffers = []
|
||||
model.buffers.append(buffer_field)
|
||||
# Creates a new metadata field.
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
metadata_field.name = self.METADATA_FIELD_NAME
|
||||
metadata_field.buffer = len(model.buffers) - 1
|
||||
model.metadata.append(metadata_field)
|
||||
|
||||
# Packs model back to a flatbuffer binaray file.
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
|
||||
model_buf = b.Output()
|
||||
|
||||
# Saves the updated model buffer to model file.
|
||||
# Gets files that have been packed to self._model_file.
|
||||
packed_files = self.get_packed_associated_file_list()
|
||||
if packed_files:
|
||||
# Writes the updated model buffer and associated files into a new model
|
||||
# file. Then overwrites the original model file.
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
new_file = temp.name
|
||||
with open(new_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
self._copy_archived_files(self._model_file, new_file, packed_files)
|
||||
shutil.copy(new_file, self._model_file)
|
||||
os.remove(new_file)
|
||||
else:
|
||||
with open(self._model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
|
||||
class _MetadataPopulatorWithBuffer(MetadataPopulator):
|
||||
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
|
||||
|
||||
This class is used to populate metadata into a in-memory model buffer. As we
|
||||
use Zip API to concatenate associated files after tflite model file, the
|
||||
populating operation is developed based on a model file. For in-memory model
|
||||
buffer, we create a tempfile to serve the populating operation. This class is
|
||||
then used to generate this tempfile, and delete the file when the
|
||||
MetadataPopulator object is deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, model_buf):
|
||||
"""Constructor for _MetadataPopulatorWithBuffer.
|
||||
|
||||
Args:
|
||||
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Raises:
|
||||
ValueError: model_buf is empty.
|
||||
ValueError: model_buf does not have the expected flatbuffer identifer.
|
||||
"""
|
||||
if not model_buf:
|
||||
raise ValueError("model_buf cannot be empty.")
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
model_file = temp.name
|
||||
|
||||
with open(model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
MetadataPopulator.__init__(self, model_file)
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor of _MetadataPopulatorWithBuffer.
|
||||
|
||||
Deletes the tempfile.
|
||||
"""
|
||||
if os.path.exists(self._model_file):
|
||||
os.remove(self._model_file)
|
||||
|
||||
|
||||
class MetadataDisplayer(object):
|
||||
"""Displays metadata and associated file info in human-readable format."""
|
||||
|
||||
def __init__(self, model_file, metadata_file, associated_file_list):
|
||||
"""Constructor for MetadataDisplayer.
|
||||
|
||||
Args:
|
||||
model_file: valid path to the model file.
|
||||
metadata_file: valid path to the metadata file.
|
||||
associated_file_list: list of associate files in the model file.
|
||||
"""
|
||||
_assert_model_file_identifier(model_file)
|
||||
_assert_metadata_file_identifier(metadata_file)
|
||||
self._model_file = model_file
|
||||
self._metadata_file = metadata_file
|
||||
self._associated_file_list = associated_file_list
|
||||
|
||||
@classmethod
|
||||
def with_model_file(cls, model_file):
|
||||
"""Creates a MetadataDisplayer object for the model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Returns:
|
||||
MetadataDisplayer object.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: The model does not have metadata.
|
||||
"""
|
||||
_assert_exist(model_file)
|
||||
metadata_file = cls._save_temporary_metadata_file(model_file)
|
||||
associated_file_list = cls._parse_packed_associted_file_list(model_file)
|
||||
return cls(model_file, metadata_file, associated_file_list)
|
||||
|
||||
@classmethod
|
||||
def with_model_buffer(cls, model_buffer):
|
||||
"""Creates a MetadataDisplayer object for a file buffer.
|
||||
|
||||
Args:
|
||||
model_buffer: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Returns:
|
||||
MetadataDisplayer object.
|
||||
"""
|
||||
if not model_buffer:
|
||||
raise ValueError("model_buffer cannot be empty.")
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
model_file = temp.name
|
||||
|
||||
with open(model_file, "wb") as f:
|
||||
f.write(model_buffer)
|
||||
return cls.with_model_file(model_file)
|
||||
|
||||
def get_metadata_json(self):
|
||||
"""Converts the metadata into a json string."""
|
||||
opt = _pywrap_flatbuffers.IDLOptions()
|
||||
opt.strict_json = True
|
||||
parser = _pywrap_flatbuffers.Parser(opt)
|
||||
with open(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f:
|
||||
metadata_schema_content = f.read()
|
||||
with open(self._metadata_file, "rb") as f:
|
||||
metadata_file_content = f.read()
|
||||
if not parser.parse(metadata_schema_content):
|
||||
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
|
||||
with open(self._metadata_file, "rb") as f:
|
||||
metadata_file_content = f.read()
|
||||
return _pywrap_flatbuffers.generate_text(parser, metadata_file_content)
|
||||
|
||||
def get_packed_associated_file_list(self):
|
||||
"""Returns a list of associated files that are packed in the model.
|
||||
|
||||
Returns:
|
||||
A name list of associated files.
|
||||
"""
|
||||
return copy.deepcopy(self._associated_file_list)
|
||||
|
||||
@staticmethod
|
||||
def _save_temporary_metadata_file(model_file):
|
||||
"""Saves the metadata in the model file to a temporary file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to the model file.
|
||||
|
||||
Returns:
|
||||
Path to the metadata temporary file.
|
||||
|
||||
Raises:
|
||||
ValueError: The model does not have metadata.
|
||||
"""
|
||||
with open(model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
|
||||
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
|
||||
|
||||
# Gets metadata from the model file.
|
||||
for i in range(tflite_model.MetadataLength()):
|
||||
meta = tflite_model.Metadata(i)
|
||||
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
|
||||
buffer_index = meta.Buffer()
|
||||
metadata = tflite_model.Buffers(buffer_index)
|
||||
metadata_buf = metadata.DataAsNumpy().tobytes()
|
||||
# Creates a temporary file to store the metadata.
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
metadata_file = temp.name
|
||||
# Saves the metadata into the temporary file.
|
||||
with open(metadata_file, "wb") as f:
|
||||
f.write(metadata_buf)
|
||||
return metadata_file
|
||||
|
||||
raise ValueError("The model does not have metadata.")
|
||||
|
||||
@staticmethod
|
||||
def _parse_packed_associted_file_list(model_file):
|
||||
"""Gets a list of associated files packed to the model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to the model file.
|
||||
|
||||
Returns:
|
||||
List of packed associated files.
|
||||
"""
|
||||
if not zipfile.is_zipfile(model_file):
|
||||
return []
|
||||
|
||||
with zipfile.ZipFile(model_file, "r") as zf:
|
||||
return zf.namelist()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor of MetadataDisplayer.
|
||||
|
||||
Deletes the tempfile.
|
||||
"""
|
||||
if os.path.exists(self._metadata_file):
|
||||
os.remove(self._metadata_file)
|
||||
|
||||
|
||||
def _assert_exist(filename):
|
||||
"""Checks if a file exists."""
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("File, '{0}', does not exist.".format(filename))
|
||||
|
||||
|
||||
def _assert_model_file_identifier(model_file):
|
||||
"""Checks if a model file has the expected TFLite schema identifier."""
|
||||
_assert_exist(model_file)
|
||||
with open(model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
|
||||
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
|
||||
raise ValueError(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.")
|
||||
|
||||
|
||||
def _assert_metadata_file_identifier(metadata_file):
|
||||
"""Checks if a metadata file has the expected Metadata schema identifier."""
|
||||
_assert_exist(metadata_file)
|
||||
with open(metadata_file, "rb") as f:
|
||||
metadata_buf = f.read()
|
||||
_assert_metadata_buffer_identifier(metadata_buf)
|
||||
|
||||
|
||||
def _assert_metadata_buffer_identifier(metadata_buf):
|
||||
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
|
||||
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
|
||||
metadata_buf, 0):
|
||||
raise ValueError(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.")
|
@ -1,26 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Information about the metadata parser that this python library depends on."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class MetadataParser(object):
|
||||
"""Information about the metadata parser."""
|
||||
|
||||
# The version of the metadata parser.
|
||||
VERSION = "{LATEST_METADATA_PARSER_VERSION}"
|
@ -1,38 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.lite.experimental.support.metadata.metadata_parser."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
from tensorflow.lite.experimental.support.metadata import metadata_parser
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MetadataParserTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_version_wellFormedSemanticVersion(self):
|
||||
# Validates that the version is well-formed (x.y.z).
|
||||
self.assertTrue(
|
||||
re.match('[0-9]+\\.[0-9]+\\.[0-9]+',
|
||||
metadata_parser.MetadataParser.VERSION))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,570 +0,0 @@
|
||||
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
namespace tflite;
|
||||
|
||||
// TFLite metadata contains both human readable and machine readable information
|
||||
// about what the model does and how to use the model. It can be used as a
|
||||
// README file, which elaborates the details of the model, each input/ouput
|
||||
// tensor, and each associated file.
|
||||
//
|
||||
// An important use case of TFLite metadata is the TFLite codegen tool, which
|
||||
// automatically generates the model interface based on the properties of the
|
||||
// model and the tensors. The model interface provides high-level APIs to
|
||||
// interact with the model, such as preprocessing the input data and running
|
||||
// inferences.
|
||||
//
|
||||
// Entries marked with "<Codegen usage>" are used in TFLite codegen tool to
|
||||
// generate the model interface. It is recommended to fill in at least those
|
||||
// enties to boost the codegen performance.
|
||||
|
||||
// The Metadata schema is versioned by the Semantic versioning number, such as
|
||||
// MAJOR.MINOR.PATCH. It tracks the schema changes according to the rules below:
|
||||
// * Bump up the MAJOR number when making potentially backwards incompatible
|
||||
// changes. It must be incremented if the new changes break the backwards
|
||||
// compatibility. It may also include minor and patch level changes as
|
||||
// needed. The true backwards compatibility is indicated by the file
|
||||
// identifier.
|
||||
// * Bump up the MINOR number when making backwards compatible updates for
|
||||
// major features, such as supporting new content types or adding new
|
||||
// processing units.
|
||||
// * Bump up the PATCH number when making small backwards compatible changes,
|
||||
// such as adding a new fields or deprecating certain fields (not deleting
|
||||
// them).
|
||||
//
|
||||
// ModelMetadata.min_parser_version indicates the minimum necessary metadata
|
||||
// parser version to fully understand all fields in a given metadata flatbuffer.
|
||||
//
|
||||
// New fields and types will have associated comments with the schema version
|
||||
// for which they were added.
|
||||
//
|
||||
// LINT.IfChange
|
||||
// Schema Semantic version: 1.0.1
|
||||
// LINT.ThenChange(//tensorflow/lite/experimental/\
|
||||
//. support/metadata/java/src/java/org/tensorflow/lite/support/metadata/\
|
||||
//. MetadataParser.java)
|
||||
|
||||
// This indicates the flatbuffer compatibility. The number will bump up when a
|
||||
// break change is applied to the schema, such as removing fields or adding new
|
||||
// fields to the middle of a table.
|
||||
file_identifier "M001";
|
||||
|
||||
// History:
|
||||
// 1.0.1 - Added VOCABULARY type to AssociatedFileType.
|
||||
|
||||
// File extension of any written files.
|
||||
file_extension "tflitemeta";
|
||||
|
||||
// LINT.IfChange
|
||||
enum AssociatedFileType : byte {
|
||||
UNKNOWN = 0,
|
||||
|
||||
// Files such as readme.txt.
|
||||
DESCRIPTIONS = 1,
|
||||
|
||||
// Contains labels that annotate certain axis of the tensor. For example,
|
||||
// the label file in image classification. Those labels annotate the
|
||||
// the output tensor, such that each value in the output tensor is the
|
||||
// probability of that corresponding category specified by the label.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// If an output tensor has an associated file as TENSOR_AXIS_LABELS, return
|
||||
// the output as a mapping between the labels and probability in the model
|
||||
// interface.
|
||||
// If multiple files of the same type are present, the first one is used by
|
||||
// default; additional ones are to be distinguished from one another by their
|
||||
// specified locale.
|
||||
TENSOR_AXIS_LABELS = 2,
|
||||
|
||||
// Contains labels that tensor values correspond to. For example, in
|
||||
// the object detection model, one of the output tensors is the detected
|
||||
// classes. And each value in the tensor refers to the index of label in the
|
||||
// category label file.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// If an output tensor has an associated file as TENSOR_VALUE_LABELS, convert
|
||||
// the tensor values into labels, and return a list of string as the output.
|
||||
// If multiple files of the same type are present, the first one is used by
|
||||
// default; additional ones are to be distinguished from one another by their
|
||||
// specified locale.
|
||||
TENSOR_VALUE_LABELS = 3,
|
||||
|
||||
// Contains sigmoid-based score calibration parameters, formatted as CSV.
|
||||
// Lines contain for each index of an output tensor the scale, slope, offset
|
||||
// and (optional) min_score parameters to be used for sigmoid fitting (in this
|
||||
// order and in `strtof`-compatible [1] format).
|
||||
// A line may be left empty to default calibrated scores for this index to
|
||||
// default_score.
|
||||
// In summary, each line should thus contain 0, 3 or 4 comma-separated values.
|
||||
//
|
||||
// See documentation for ScoreCalibrationOptions for details.
|
||||
//
|
||||
// [1]: https://en.cppreference.com/w/c/string/byte/strtof
|
||||
TENSOR_AXIS_SCORE_CALIBRATION = 4,
|
||||
|
||||
// Contains a list of unique words (characters separated by "\n" or in lines)
|
||||
// that help to convert natural language words to embedding vectors.
|
||||
// Added in: 1.0.1
|
||||
VOCABULARY = 5,
|
||||
}
|
||||
|
||||
table AssociatedFile {
|
||||
// Name of this file. Need to be exact the same as the name of the actual file
|
||||
// packed into the TFLite model as a zip file.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Locates to the actual file in the TFLite model.
|
||||
name:string;
|
||||
|
||||
// A description of what the file is.
|
||||
description:string;
|
||||
|
||||
// Type of the associated file. There may be special pre/post processing for
|
||||
// some types. For example in image classification, a label file of the output
|
||||
// will be used to convert object index into string.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to process the corresponding tensor.
|
||||
type:AssociatedFileType;
|
||||
|
||||
// An optional locale for this associated file (if applicable). It is
|
||||
// recommended to use an ISO 639-1 letter code (e.g. "en" for English),
|
||||
// optionally completed by a two letter region code (e.g. "en-US" for US
|
||||
// English and "en-CA" for Canadian English).
|
||||
// Leverage this in order to specify e.g multiple label files translated in
|
||||
// different languages.
|
||||
locale:string;
|
||||
}
|
||||
|
||||
// The basic content type for all tensors.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Input feature tensors:
|
||||
// 1. Generates the method to load data from a TensorBuffer.
|
||||
// 2. Creates the preprocessing logic. The default processing pipeline is:
|
||||
// [NormalizeOp, QuantizeOp].
|
||||
// Output feature tensors:
|
||||
// 1. Generates the method to return the output data to a TensorBuffer.
|
||||
// 2. Creates the post-processing logic. The default processing pipeline is:
|
||||
// [DeQuantizeOp].
|
||||
table FeatureProperties {
|
||||
}
|
||||
|
||||
// The type of color space of an image.
|
||||
enum ColorSpaceType : byte {
|
||||
UNKNOWN = 0,
|
||||
RGB = 1,
|
||||
GRAYSCALE = 2,
|
||||
}
|
||||
|
||||
table ImageSize {
|
||||
width:uint;
|
||||
height:uint;
|
||||
}
|
||||
|
||||
// The properties for image tensors.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Input image tensors:
|
||||
// 1. Generates the method to load an image from a TensorImage.
|
||||
// 2. Creates the preprocessing logic. The default processing pipeline is:
|
||||
// [ResizeOp, NormalizeOp, QuantizeOp].
|
||||
// Output image tensors:
|
||||
// 1. Generates the method to return the output data to a TensorImage.
|
||||
// 2. Creates the post-processing logic. The default processing pipeline is:
|
||||
// [DeQuantizeOp].
|
||||
table ImageProperties {
|
||||
// The color space of the image.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to convert the color space of a given image from users.
|
||||
color_space:ColorSpaceType;
|
||||
|
||||
// Indicates the default value of image width and height if the tensor shape
|
||||
// is dynamic. For fixed-size tensor, this size will be consistent with the
|
||||
// expected size.
|
||||
default_size:ImageSize;
|
||||
}
|
||||
|
||||
// The properties for tensors representing bounding boxes.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Input image tensors: NA.
|
||||
// Output image tensors: parses the values into a data stucture that represents
|
||||
// bounding boxes. For example, in the generated wrapper for Android, it returns
|
||||
// the output as android.graphics.Rect objects.
|
||||
enum BoundingBoxType : byte {
|
||||
UNKNOWN = 0,
|
||||
// Represents the bounding box by using the combination of boundaries,
|
||||
// {left, top, right, bottom}.
|
||||
// The default order is {left, top, right, bottom}. Other orders can be
|
||||
// indicated by BoundingBoxProperties.index.
|
||||
BOUNDARIES = 1,
|
||||
|
||||
// Represents the bounding box by using the upper_left corner, width and
|
||||
// height.
|
||||
// The default order is {upper_left_x, upper_left_y, width, height}. Other
|
||||
// orders can be indicated by BoundingBoxProperties.index.
|
||||
UPPER_LEFT = 2,
|
||||
|
||||
// Represents the bounding box by using the center of the box, width and
|
||||
// height. The default order is {center_x, center_y, width, height}. Other
|
||||
// orders can be indicated by BoundingBoxProperties.index.
|
||||
CENTER = 3,
|
||||
|
||||
}
|
||||
|
||||
enum CoordinateType : byte {
|
||||
// The coordinates are float values from 0 to 1.
|
||||
RATIO = 0,
|
||||
// The coordinates are integers.
|
||||
PIXEL = 1,
|
||||
}
|
||||
|
||||
table BoundingBoxProperties {
|
||||
// Denotes the order of the elements defined in each bounding box type. An
|
||||
// empty index array represent the default order of each bounding box type.
|
||||
// For example, to denote the default order of BOUNDARIES, {left, top, right,
|
||||
// bottom}, the index should be {0, 1, 2, 3}. To denote the order {left,
|
||||
// right, top, bottom}, the order should be {0, 2, 1, 3}.
|
||||
//
|
||||
// The index array can be applied to all bounding box types to adjust the
|
||||
// order of their corresponding underlying elements.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Indicates how to parse the bounding box values.
|
||||
index:[uint];
|
||||
|
||||
// <Codegen usage>:
|
||||
// Indicates how to parse the bounding box values.
|
||||
type:BoundingBoxType;
|
||||
|
||||
// <Codegen usage>:
|
||||
// Indicates how to convert the bounding box back to the original image in
|
||||
// pixels.
|
||||
coordinate_type:CoordinateType;
|
||||
}
|
||||
|
||||
union ContentProperties {
|
||||
FeatureProperties,
|
||||
ImageProperties,
|
||||
BoundingBoxProperties,
|
||||
}
|
||||
|
||||
table ValueRange {
|
||||
min:int;
|
||||
max:int;
|
||||
}
|
||||
|
||||
table Content {
|
||||
// The properties that the content may have, indicating the type of the
|
||||
// Content.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Indicates how to process the tensor.
|
||||
content_properties:ContentProperties;
|
||||
|
||||
// The range of dimensions that the content corresponds to. A NULL
|
||||
// "range" indicates that the content uses up all dimensions,
|
||||
// except the batch axis if applied.
|
||||
//
|
||||
// Here are all the possible situations of how a tensor is composed.
|
||||
// Case 1: The tensor is a single object, such as an image.
|
||||
// For example, the input of an image classifier
|
||||
// (https://www.tensorflow.org/lite/models/image_classification/overview),
|
||||
// a tensor of shape [1, 224, 224, 3]. Dimensions 1 to 3 correspond to the
|
||||
// image. Since dimension 0 is a batch axis, which can be ignored,
|
||||
// "range" can be left as NULL.
|
||||
//
|
||||
// Case 2: The tensor contains multiple instances of the same object.
|
||||
// For example, the output tensor of detected bounding boxes of an object
|
||||
// detection model
|
||||
// (https://www.tensorflow.org/lite/models/object_detection/overview).
|
||||
// The tensor shape is [1, 10, 4]. Here is the what the three dimensions
|
||||
// represent for:
|
||||
// dimension 0: the batch axis.
|
||||
// dimension 1: the 10 objects detected with the highest confidence.
|
||||
// dimension 2: the bounding boxes of the 10 detected objects.
|
||||
// The tensor is essentially 10 bounding boxes. In this case,
|
||||
// "range" should be {min=2; max=2;}.
|
||||
// Another example is the pose estimation model
|
||||
// (https://www.tensorflow.org/lite/models/pose_estimation/overview).
|
||||
// The output tensor of heatmaps is in the shape of [1, 9, 9, 17].
|
||||
// Here is the what the four dimensions represent for:
|
||||
// dimension 0: the batch axis.
|
||||
// dimension 1/2: the heatmap image.
|
||||
// dimension 3: 17 body parts of a person.
|
||||
// Even though the last axis is body part, the real content of this tensor is
|
||||
// the heatmap. "range" should be [min=1; max=2].
|
||||
//
|
||||
// Case 3: The tensor contains multiple different objects. (Not supported by
|
||||
// Content at this point).
|
||||
// Sometimes a tensor may contain multiple different objects, thus different
|
||||
// contents. It is very common for regression models. For example, a model
|
||||
// to predict the fuel efficiency
|
||||
// (https://www.tensorflow.org/tutorials/keras/regression).
|
||||
// The input tensor has shape [1, 9], consisting of 9 features, such as
|
||||
// "Cylinders", "Displacement", "Weight", etc. In this case, dimension 1
|
||||
// contains 9 different contents. However, since these sub-dimension objects
|
||||
// barely need to be specifically processed, their contents are not recorded
|
||||
// in the metadata. Through, the name of each dimension can be set through
|
||||
// TensorMetadata.dimension_names.
|
||||
//
|
||||
// Note that if it is not case 3, a tensor can only have one content type.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Case 1: return a processed single object of certain content type.
|
||||
// Case 2: return a list of processed objects of certain content type. The
|
||||
// generated model interface have API to random access those objects from
|
||||
// the output.
|
||||
range:ValueRange;
|
||||
}
|
||||
|
||||
// Parameters that are used when normalizing the tensor.
|
||||
table NormalizationOptions{
|
||||
// mean and std are normalization parameters. Tensor values are normalized
|
||||
// on a per-channel basis, by the formula
|
||||
// (x - mean) / std.
|
||||
// If there is only one value in mean or std, we'll propogate the value to
|
||||
// all channels.
|
||||
//
|
||||
// Quantized models share the same normalization parameters as their
|
||||
// corresponding float models. For example, an image input tensor may have
|
||||
// the normalization parameter of
|
||||
// mean = 127.5f and std = 127.5f.
|
||||
// The image value will be normalized from [0, 255] to [-1, 1].
|
||||
// Then, for quantized models, the image data should be further quantized
|
||||
// according to the quantization parameters. In the case of uint8, the image
|
||||
// data will be scaled back to [0, 255], while for int8, the image data will
|
||||
// be scaled to [-128, 127].
|
||||
//
|
||||
// Both the normalization parameters and quantization parameters can be
|
||||
// retrieved through the metadata extractor library.
|
||||
// TODO(b/156644598): add link for the metadata extractor library.
|
||||
|
||||
// Per-channel mean of the possible values used in normalization.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Apply normalization to input tensors accordingly.
|
||||
mean:[float];
|
||||
|
||||
// Per-channel standard dev. of the possible values used in normalization.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Apply normalization to input tensors accordingly.
|
||||
std:[float];
|
||||
}
|
||||
|
||||
// The different possible score transforms to apply to uncalibrated scores
|
||||
// before applying score calibration.
|
||||
enum ScoreTransformationType : byte {
|
||||
// Identity function: g(x) = x.
|
||||
IDENTITY = 0,
|
||||
// Log function: g(x) = log(x).
|
||||
LOG = 1,
|
||||
// Inverse logistic function: g(x) = log(x) - log(1-x).
|
||||
INVERSE_LOGISTIC = 2,
|
||||
}
|
||||
|
||||
// Options to perform score calibration on an output tensor through sigmoid
|
||||
// functions. One of the main purposes of score calibration is to make scores
|
||||
// across classes comparable, so that a common threshold can be used for all
|
||||
// output classes. This is meant for models producing class predictions as
|
||||
// output, e.g. image classification or detection models.
|
||||
//
|
||||
// For each index in the output tensor, this applies:
|
||||
// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score` or if no
|
||||
// `min_score` has been specified,
|
||||
// * `f(x) = default_score` otherwise or if no scale, slope and offset have been
|
||||
// specified.
|
||||
// Where:
|
||||
// * scale, slope, offset and (optional) min_score are index-specific parameters
|
||||
// * g(x) is an index-independent transform among those defined in
|
||||
// ScoreTransformationType
|
||||
// * default_score is an index-independent parameter.
|
||||
// An AssociatedFile with type TANSOR_AXIS_SCORE_CALIBRATION specifying the
|
||||
// index-specific parameters must be associated with the corresponding
|
||||
// TensorMetadata for score calibration be applied.
|
||||
table ScoreCalibrationOptions {
|
||||
// The function to use for transforming the uncalibrated score before
|
||||
// applying score calibration.
|
||||
score_transformation:ScoreTransformationType;
|
||||
|
||||
// The default calibrated score to apply if the uncalibrated score is
|
||||
// below min_score or if no parameters were specified for a given index.
|
||||
default_score:float;
|
||||
}
|
||||
|
||||
// Performs thresholding on output tensor values, in order to filter out
|
||||
// low-confidence results.
|
||||
table ScoreThresholdingOptions {
|
||||
// The recommended global threshold below which results are considered
|
||||
// low-confidence and should be filtered out.
|
||||
global_score_threshold:float;
|
||||
}
|
||||
|
||||
// Options that are used when processing the tensor.
|
||||
union ProcessUnitOptions {
|
||||
NormalizationOptions,
|
||||
ScoreCalibrationOptions,
|
||||
ScoreThresholdingOptions,
|
||||
}
|
||||
|
||||
// A process unit that is used to process the tensor out-of-graph.
|
||||
table ProcessUnit {
|
||||
options:ProcessUnitOptions;
|
||||
}
|
||||
|
||||
|
||||
// Statistics to describe a tensor.
|
||||
table Stats {
|
||||
// Max and min are not currently used in tflite.support codegen. They mainly
|
||||
// serve as references for users to better understand the model. They can also
|
||||
// be used to validate model pre/post processing results.
|
||||
// If there is only one value in max or min, we'll propogate the value to
|
||||
// all channels.
|
||||
|
||||
// Per-channel maximum value of the tensor.
|
||||
max:[float];
|
||||
|
||||
// Per-channel minimum value of the tensor.
|
||||
min:[float];
|
||||
}
|
||||
|
||||
// Detailed information of an input or output tensor.
|
||||
table TensorMetadata {
|
||||
// Name of the tensor.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// The name of this tensor in the generated model interface.
|
||||
name:string;
|
||||
|
||||
// A description of the tensor.
|
||||
description:string;
|
||||
|
||||
// A list of names of the dimensions in this tensor. The length of
|
||||
// dimension_names need to match the number of dimensions in this tensor.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// The name of each dimension in the generated model interface. See "Case 2"
|
||||
// in the comments of Content.range.
|
||||
dimension_names:[string];
|
||||
|
||||
// The content that represents this tensor.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to process this tensor. See each item in ContentProperties
|
||||
// for the default process units that will be applied to the tensor.
|
||||
content:Content;
|
||||
|
||||
// The process units that are used to process the tensor out-of-graph.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Contains the parameters of the default processing pipeline for each content
|
||||
// type, such as the normalization parameters in all content types. See the
|
||||
// items under ContentProperties for the details of the default processing
|
||||
// pipeline.
|
||||
process_units:[ProcessUnit];
|
||||
|
||||
// The statistics of the tensor values.
|
||||
stats:Stats;
|
||||
|
||||
// A list of associated files of this tensor.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Contains processing parameters of this tensor, such as normalization.
|
||||
associated_files:[AssociatedFile];
|
||||
}
|
||||
|
||||
table SubGraphMetadata {
|
||||
// Name of the subgraph.
|
||||
//
|
||||
// Note that, since TFLite only support one subgraph at this moment, the
|
||||
// Codegen tool will use the name in ModelMetadata in the generated model
|
||||
// interface.
|
||||
name:string;
|
||||
|
||||
// A description explains details about what the subgraph does.
|
||||
description:string;
|
||||
|
||||
// Metadata of all input tensors used in this subgraph. It matches extactly
|
||||
// the input tensors specified by `SubGraph.inputs` in the TFLite
|
||||
// schema.fbs file[2]. The number of `TensorMetadata` in the array should
|
||||
// equal to the number of indices in `SubGraph.inputs`.
|
||||
//
|
||||
// [2]: tensorflow/lite/schema/schema.fbs
|
||||
// <Codegen usage>:
|
||||
// Determines how to process the inputs.
|
||||
input_tensor_metadata:[TensorMetadata];
|
||||
|
||||
// Metadata of all output tensors used in this subgraph. It matches extactly
|
||||
// the output tensors specified by `SubGraph.outputs` in the TFLite
|
||||
// schema.fbs file[2]. The number of `TensorMetadata` in the array should
|
||||
// equal to the number of indices in `SubGraph.outputs`.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to process the outputs.
|
||||
output_tensor_metadata:[TensorMetadata];
|
||||
|
||||
// A list of associated files of this subgraph.
|
||||
associated_files:[AssociatedFile];
|
||||
}
|
||||
|
||||
table ModelMetadata {
|
||||
// Name of the model.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// The name of the model in the generated model interface.
|
||||
name:string;
|
||||
|
||||
// Model description in schema.
|
||||
description:string;
|
||||
|
||||
// Version of the model that specified by model creators.
|
||||
version:string;
|
||||
|
||||
// Noted that, the minimum required TFLite runtime version that the model is
|
||||
// compatible with, has already been added as a metadata entry in tflite
|
||||
// schema. We'll decide later if we want to move it here, and keep it with
|
||||
// other metadata entries.
|
||||
|
||||
// Metadata of all the subgraphs of the model. The 0th is assumed to be the
|
||||
// main subgraph.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to process the inputs and outputs.
|
||||
subgraph_metadata:[SubGraphMetadata];
|
||||
|
||||
// The person who creates this model.
|
||||
author:string;
|
||||
|
||||
// Licenses that may apply to this model.
|
||||
license:string;
|
||||
|
||||
// A list of associated files of this model.
|
||||
associated_files:[AssociatedFile];
|
||||
|
||||
// The minimum metadata parser version that can fully understand the fields in
|
||||
// the metadata flatbuffer. The version is effectively the largest version
|
||||
// number among the versions of all the fields populated and the smallest
|
||||
// compatible version indicated by the file identifier.
|
||||
//
|
||||
// This field is automaticaly populated by the MetadataPopulator when
|
||||
// the metadata is populated into a TFLite model.
|
||||
min_parser_version:string;
|
||||
}
|
||||
// LINT.ThenChange(//tensorflow/lite/experimental/\
|
||||
// support/metadata/cc/metadata_version.cc)
|
||||
|
||||
root_type ModelMetadata;
|
@ -1,484 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.lite.experimental.support.metadata.metadata."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import six
|
||||
|
||||
from flatbuffers.python import flatbuffers
|
||||
from tensorflow.lite.experimental.support.metadata import metadata as _metadata
|
||||
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MetadataTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MetadataTest, self).setUp()
|
||||
self._invalid_model_buf = None
|
||||
self._invalid_file = "not_existed_file"
|
||||
self._empty_model_buf = self._create_empty_model_buf()
|
||||
self._empty_model_file = self.create_tempfile().full_path
|
||||
with open(self._empty_model_file, "wb") as f:
|
||||
f.write(self._empty_model_buf)
|
||||
self._model_file = self._create_model_file_with_metadata_and_buf_fields()
|
||||
self._metadata_file = self._create_metadata_file()
|
||||
self._metadata_file_with_version = self._create_metadata_file_with_version(
|
||||
self._metadata_file, "1.0.0")
|
||||
self._file1 = self.create_tempfile("file1").full_path
|
||||
self._file2 = self.create_tempfile("file2").full_path
|
||||
self._file3 = self.create_tempfile("file3").full_path
|
||||
|
||||
def _create_empty_model_buf(self):
|
||||
model = _schema_fb.ModelT()
|
||||
model_builder = flatbuffers.Builder(0)
|
||||
model_builder.Finish(
|
||||
model.Pack(model_builder),
|
||||
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||
return model_builder.Output()
|
||||
|
||||
def _create_model_file_with_metadata_and_buf_fields(self):
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
metadata_field.name = "meta"
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
model = _schema_fb.ModelT()
|
||||
model.metadata = [metadata_field, metadata_field]
|
||||
model.buffers = [buffer_field, buffer_field, buffer_field]
|
||||
model_builder = flatbuffers.Builder(0)
|
||||
model_builder.Finish(
|
||||
model.Pack(model_builder),
|
||||
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||
|
||||
mnodel_file = self.create_tempfile().full_path
|
||||
with open(mnodel_file, "wb") as f:
|
||||
f.write(model_builder.Output())
|
||||
|
||||
return mnodel_file
|
||||
|
||||
def _create_metadata_file(self):
|
||||
associated_file1 = _metadata_fb.AssociatedFileT()
|
||||
associated_file1.name = b"file1"
|
||||
associated_file2 = _metadata_fb.AssociatedFileT()
|
||||
associated_file2.name = b"file2"
|
||||
self.expected_recorded_files = [
|
||||
six.ensure_str(associated_file1.name),
|
||||
six.ensure_str(associated_file2.name)
|
||||
]
|
||||
|
||||
output_meta = _metadata_fb.TensorMetadataT()
|
||||
output_meta.associatedFiles = [associated_file2]
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
subgraph.outputTensorMetadata = [output_meta]
|
||||
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.name = "Mobilenet_quantized"
|
||||
model_meta.associatedFiles = [associated_file1]
|
||||
model_meta.subgraphMetadata = [subgraph]
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
model_meta.Pack(b),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
|
||||
metadata_file = self.create_tempfile().full_path
|
||||
with open(metadata_file, "wb") as f:
|
||||
f.write(b.Output())
|
||||
return metadata_file
|
||||
|
||||
def _create_model_buffer_with_wrong_identifier(self):
|
||||
wrong_identifier = b"widn"
|
||||
model = _schema_fb.ModelT()
|
||||
model_builder = flatbuffers.Builder(0)
|
||||
model_builder.Finish(model.Pack(model_builder), wrong_identifier)
|
||||
return model_builder.Output()
|
||||
|
||||
def _create_metadata_buffer_with_wrong_identifier(self):
|
||||
# Creates a metadata with wrong identifier
|
||||
wrong_identifier = b"widn"
|
||||
metadata = _metadata_fb.ModelMetadataT()
|
||||
metadata_builder = flatbuffers.Builder(0)
|
||||
metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier)
|
||||
return metadata_builder.Output()
|
||||
|
||||
def _populate_metadata_with_identifier(self, model_buf, metadata_buf,
|
||||
identifier):
|
||||
# For testing purposes only. MetadataPopulator cannot populate metadata with
|
||||
# wrong identifiers.
|
||||
model = _schema_fb.ModelT.InitFromObj(
|
||||
_schema_fb.Model.GetRootAsModel(model_buf, 0))
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
buffer_field.data = metadata_buf
|
||||
model.buffers = [buffer_field]
|
||||
# Creates a new metadata field.
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME
|
||||
metadata_field.buffer = len(model.buffers) - 1
|
||||
model.metadata = [metadata_field]
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model.Pack(b), identifier)
|
||||
return b.Output()
|
||||
|
||||
def _create_metadata_file_with_version(self, metadata_file, min_version):
|
||||
# Creates a new metadata file with the specified min_version for testing
|
||||
# purposes.
|
||||
with open(metadata_file, "rb") as f:
|
||||
metadata_buf = bytearray(f.read())
|
||||
|
||||
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
|
||||
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
|
||||
metadata.minParserVersion = min_version
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
|
||||
metadata_file_with_version = self.create_tempfile().full_path
|
||||
with open(metadata_file_with_version, "wb") as f:
|
||||
f.write(b.Output())
|
||||
return metadata_file_with_version
|
||||
|
||||
|
||||
class MetadataPopulatorTest(MetadataTest):
|
||||
|
||||
def testToValidModelFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(
|
||||
self._empty_model_file)
|
||||
self.assertIsInstance(populator, _metadata.MetadataPopulator)
|
||||
|
||||
def testToInvalidModelFile(self):
|
||||
with self.assertRaises(IOError) as error:
|
||||
_metadata.MetadataPopulator.with_model_file(self._invalid_file)
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testToValidModelBuffer(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
self.assertIsInstance(populator, _metadata.MetadataPopulator)
|
||||
|
||||
def testToInvalidModelBuffer(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
|
||||
self.assertEqual("model_buf cannot be empty.", str(error.exception))
|
||||
|
||||
def testToModelBufferWithWrongIdentifier(self):
|
||||
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataPopulator.with_model_buffer(model_buf)
|
||||
self.assertEqual(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.", str(error.exception))
|
||||
|
||||
def testSinglePopulateAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
expected_packed_files = [os.path.basename(self._file1)]
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
def testRepeatedPopulateAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(
|
||||
self._empty_model_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
# Loads file2 multiple times.
|
||||
populator.load_associated_files([self._file2])
|
||||
populator.populate()
|
||||
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
expected_packed_files = [
|
||||
os.path.basename(self._file1),
|
||||
os.path.basename(self._file2)
|
||||
]
|
||||
self.assertEqual(len(packed_files), 2)
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
# Check if the model buffer read from file is the same as that read from
|
||||
# get_model_buffer().
|
||||
with open(self._empty_model_file, "rb") as f:
|
||||
model_buf_from_file = f.read()
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testPopulateInvalidAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
with self.assertRaises(IOError) as error:
|
||||
populator.load_associated_files([self._invalid_file])
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testPopulatePackedAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
self.assertEqual(
|
||||
"File, '{0}', has already been packed.".format(
|
||||
os.path.basename(self._file1)), str(error.exception))
|
||||
|
||||
def testGetPackedAssociatedFileList(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
self.assertEqual(packed_files, [])
|
||||
|
||||
def testPopulateMetadataFileToEmptyModelFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(
|
||||
self._empty_model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
populator.populate()
|
||||
|
||||
with open(self._empty_model_file, "rb") as f:
|
||||
model_buf_from_file = f.read()
|
||||
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
|
||||
metadata_field = model.Metadata(0)
|
||||
self.assertEqual(
|
||||
six.ensure_str(metadata_field.Name()),
|
||||
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
|
||||
|
||||
buffer_index = metadata_field.Buffer()
|
||||
buffer_data = model.Buffers(buffer_index)
|
||||
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||
metadata_buf = metadata_buf_np.tobytes()
|
||||
with open(self._metadata_file_with_version, "rb") as f:
|
||||
expected_metadata_buf = bytearray(f.read())
|
||||
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||
|
||||
recorded_files = populator.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||
|
||||
# Up to now, we've proved the correctness of the model buffer that read from
|
||||
# file. Then we'll test if get_model_buffer() gives the same model buffer.
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testPopulateMetadataFileWithoutAssociatedFiles(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(
|
||||
self._empty_model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1])
|
||||
# Suppose to populate self._file2, because it is recorded in the metadta.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.populate()
|
||||
self.assertEqual(("File, '{0}', is recorded in the metadata, but has "
|
||||
"not been loaded into the populator.").format(
|
||||
os.path.basename(self._file2)), str(error.exception))
|
||||
|
||||
def testPopulateMetadataBufferWithWrongIdentifier(self):
|
||||
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_metadata_buffer(metadata_buf)
|
||||
self.assertEqual(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.", str(error.exception))
|
||||
|
||||
def _assert_golden_metadata(self, model_file):
|
||||
with open(model_file, "rb") as f:
|
||||
model_buf_from_file = f.read()
|
||||
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
|
||||
# There are two elements in model.Metadata array before the population.
|
||||
# Metadata should be packed to the third element in the array.
|
||||
metadata_field = model.Metadata(2)
|
||||
self.assertEqual(
|
||||
six.ensure_str(metadata_field.Name()),
|
||||
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
|
||||
|
||||
buffer_index = metadata_field.Buffer()
|
||||
buffer_data = model.Buffers(buffer_index)
|
||||
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||
metadata_buf = metadata_buf_np.tobytes()
|
||||
with open(self._metadata_file_with_version, "rb") as f:
|
||||
expected_metadata_buf = bytearray(f.read())
|
||||
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||
|
||||
def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self):
|
||||
# First, creates a dummy metadata. Populates it and the associated files
|
||||
# into the model.
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.name = "Mobilenet_quantized"
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
model_meta.Pack(b),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf = b.Output()
|
||||
|
||||
populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator1.load_metadata_buffer(metadata_buf)
|
||||
populator1.load_associated_files([self._file1, self._file2])
|
||||
populator1.populate()
|
||||
|
||||
# Then, populates the metadata again.
|
||||
populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator2.load_metadata_file(self._metadata_file)
|
||||
populator2.populate()
|
||||
|
||||
# Tests if the metadata is populated correctly.
|
||||
self._assert_golden_metadata(self._model_file)
|
||||
|
||||
def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
populator.populate()
|
||||
|
||||
# Tests if the metadata is populated correctly.
|
||||
self._assert_golden_metadata(self._model_file)
|
||||
|
||||
recorded_files = populator.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||
|
||||
# Up to now, we've proved the correctness of the model buffer that read from
|
||||
# file. Then we'll test if get_model_buffer() gives the same model buffer.
|
||||
with open(self._model_file, "rb") as f:
|
||||
model_buf_from_file = f.read()
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testPopulateInvalidMetadataFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
with self.assertRaises(IOError) as error:
|
||||
populator.load_metadata_file(self._invalid_file)
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testPopulateInvalidMetadataBuffer(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_metadata_buffer([])
|
||||
self.assertEqual("The metadata to be populated is empty.",
|
||||
str(error.exception))
|
||||
|
||||
def testGetModelBufferBeforePopulatingData(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
model_buf = populator.get_model_buffer()
|
||||
expected_model_buf = self._empty_model_buf
|
||||
self.assertEqual(model_buf, expected_model_buf)
|
||||
|
||||
|
||||
class MetadataDisplayerTest(MetadataTest):
|
||||
|
||||
def setUp(self):
|
||||
super(MetadataDisplayerTest, self).setUp()
|
||||
self._model_file = self._create_model_with_metadata_and_associated_files()
|
||||
|
||||
def _create_model_with_metadata_and_associated_files(self):
|
||||
model_buf = self._create_empty_model_buf()
|
||||
model_file = self.create_tempfile().full_path
|
||||
with open(model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_file(model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
populator.populate()
|
||||
return model_file
|
||||
|
||||
def test_load_model_buffer_metadataBufferWithWrongIdentifier_throwsException(
|
||||
self):
|
||||
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
|
||||
model_buf = self._populate_metadata_with_identifier(
|
||||
model_buf, metadata_buf,
|
||||
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
|
||||
self.assertEqual(
|
||||
"The metadata buffer does not have the expected identifier, and may not"
|
||||
" be a valid TFLite Metadata.", str(error.exception))
|
||||
|
||||
def test_load_model_buffer_modelBufferWithWrongIdentifier_throwsException(
|
||||
self):
|
||||
model_buf = self._create_model_buffer_with_wrong_identifier()
|
||||
metadata_file = self._create_metadata_file()
|
||||
wrong_identifier = b"widn"
|
||||
with open(metadata_file, "rb") as f:
|
||||
metadata_buf = bytearray(f.read())
|
||||
model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf,
|
||||
wrong_identifier)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
|
||||
self.assertEqual(
|
||||
"The model provided does not have the expected identifier, and "
|
||||
"may not be a valid TFLite model.", str(error.exception))
|
||||
|
||||
def test_load_model_file_invalidModelFile_throwsException(self):
|
||||
with self.assertRaises(IOError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def test_load_model_file_modelWithoutMetadata_throwsException(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_file(self._empty_model_file)
|
||||
self.assertEqual("The model does not have metadata.", str(error.exception))
|
||||
|
||||
def test_load_model_file_modelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
|
||||
|
||||
def test_load_model_buffer_modelWithOutMetadata_throwsException(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_buffer(
|
||||
self._create_empty_model_buf())
|
||||
self.assertEqual("The model does not have metadata.", str(error.exception))
|
||||
|
||||
def test_load_model_buffer_modelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_buffer(
|
||||
open(self._model_file, "rb").read())
|
||||
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
|
||||
|
||||
def test_get_metadata_json_modelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||
actual = displayer.get_metadata_json()
|
||||
|
||||
# Verifies the generated json file.
|
||||
golden_json_file_path = resource_loader.get_path_to_datafile(
|
||||
"testdata/golden_json.json")
|
||||
with open(golden_json_file_path, "r") as f:
|
||||
expected = f.read()
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_get_packed_associated_file_list_modelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||
packed_files = displayer.get_packed_associated_file_list()
|
||||
|
||||
expected_packed_files = [
|
||||
os.path.basename(self._file1),
|
||||
os.path.basename(self._file2)
|
||||
]
|
||||
self.assertEqual(len(packed_files), 2)
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,22 +0,0 @@
|
||||
{
|
||||
"name": "Mobilenet_quantized",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "file2"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "file1"
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
Loading…
Reference in New Issue
Block a user