Move metadata, codegen and java lib into the new repo.

PiperOrigin-RevId: 319147352
Change-Id: I79ab15ccebe9d50c62952c535746c6639883fc3a
This commit is contained in:
Xunkai Zhang 2020-06-30 19:42:37 -07:00 committed by TensorFlower Gardener
parent 3be438aca2
commit e9695a20ee
79 changed files with 5 additions and 10546 deletions

View File

@ -0,0 +1,5 @@
# TensorFlow Lite Support
The TensorFlow Lite Support project has been migrated to its own repo. Please
checkout [TFLite Support](https://github.com/tensorflow/tflite-support) for the
latest updates.

View File

@ -1,87 +0,0 @@
# The tools for generating wrapper classes for a TFLite model with metadata.
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "utils",
srcs = [
"utils.cc",
],
hdrs = [
"utils.h",
],
deps = [
],
)
cc_library(
name = "code_generator",
srcs = [
"code_generator.cc",
],
hdrs = [
"code_generator.h",
],
deps = [
":utils",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
],
)
cc_library(
name = "metadata_helper",
srcs = [
"metadata_helper.cc",
],
hdrs = [
"metadata_helper.h",
],
deps = [
":utils",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
"//tensorflow/lite/schema:schema_fbs",
],
)
cc_library(
name = "android_java_generator",
srcs = [
"android_java_generator.cc",
],
hdrs = [
"android_java_generator.h",
],
deps = [
":code_generator",
":metadata_helper",
":utils",
"//tensorflow/core/platform:logging",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
"//tensorflow/lite/schema:schema_fbs",
],
)
cc_test(
name = "code_generator_test",
size = "small",
srcs = ["code_generator_test.cc"],
data = ["//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs"],
deps = [
":code_generator",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "utils_test",
srcs = ["utils_test.cc"],
deps = [
":utils",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -1,13 +0,0 @@
# TensorFlow Lite Android Wrapper Code Generator
For TensorFlow Lite model enhanced with [metadata](https://www.tensorflow.org/lite/convert/metadata.md),
developers can use the TensorFlow Lite Android wrapper code generator to create
platform specific wrapper code. The wrapper code removes the need to interact
directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow
Lite model with typed objects such as `Bitmap` and `Rect`.
The usefulness of the code generator depend on the completeness of the
TensorFlow Lite model's metadata entry. Refer to the `<Codegen usage>` section
under relevant fields in
[metadata_schema.fbs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs),
to see how the codegen tool parses each field.

View File

@ -1,116 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
namespace details_android_java {
/// The intermediate data structure for generating code from TensorMetadata.
/// Should only be used as const reference when created.
struct TensorInfo {
std::string name;
std::string upper_camel_name;
std::string content_type;
std::string wrapper_type;
std::string processor_type;
bool is_input;
/// Optional. Set to -1 if not applicable.
int normalization_unit;
/// Optional. Set to -1 if associated_axis_label is empty.
int associated_axis_label_index;
/// Optional. Set to -1 if associated_value_label is empty.
int associated_value_label_index;
};
/// The intermediate data structure for generating code from ModelMetadata.
/// Should only be used as const reference when created.
struct ModelInfo {
std::string package_name;
std::string model_asset_path;
std::string model_class_name;
std::string model_versioned_name;
std::vector<TensorInfo> inputs;
std::vector<TensorInfo> outputs;
// Extra helper fields. For models with inputs "a", "b" and outputs "x", "y":
std::string input_type_param_list;
// e.g. "TensorImage a, TensorBuffer b"
std::string inputs_list;
// e.g. "a, b"
std::string postprocessor_type_param_list;
// e.g. "ImageProcessor xPostprocessor, TensorProcessor yPostprocessor"
std::string postprocessors_list;
// e.g. "xPostprocessor, yPostprocessor"
};
} // namespace details_android_java
constexpr char JAVA_EXT[] = ".java";
/// Generates Android supporting codes and modules (in Java) based on TFLite
/// metadata.
class AndroidJavaGenerator : public CodeGenerator {
public:
/// Creates an AndroidJavaGenerator.
/// Args:
/// - module_root: The root of destination Java module.
explicit AndroidJavaGenerator(const std::string& module_root);
/// Generates files. Returns the file paths and contents.
/// Args:
/// - model: The TFLite model with Metadata filled.
/// - package_name: The name of the Java package which generated classes
/// belong to.
/// - model_class_name: A readable name of the generated wrapper class, such
/// as "ImageClassifier", "MobileNetV2" or "MyModel".
/// - model_asset_path: The relevant path to the model file in the asset.
// TODO(b/141225157): Automatically generate model_class_name.
GenerationResult Generate(const Model* model, const std::string& package_name,
const std::string& model_class_name,
const std::string& model_asset_path);
/// Generates files and returns the file paths and contents.
/// It's mostly identical with the previous one, but the model here is
/// provided as binary flatbuffer content without parsing.
GenerationResult Generate(const char* model_storage,
const std::string& package_name,
const std::string& model_class_name,
const std::string& model_asset_path);
std::string GetErrorMessage();
private:
const std::string module_root_;
ErrorReporter err_;
};
} // namespace codegen
} // namespace support
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_

View File

@ -1,179 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
#include <cctype>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
namespace {
void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) {
auto& names = *names_ptr;
std::unordered_map<std::string, int> indexes;
std::unordered_map<std::string, int> first_appearance;
for (int i = 0; i < names.size(); i++) {
if (indexes.find(names[i]) == indexes.end()) {
indexes[names[i]] = 1;
first_appearance[names[i]] = i;
} else {
indexes[names[i]] += 1;
names[i].append(std::to_string(indexes[names[i]]));
}
}
for (const auto& it : first_appearance) {
const auto& name = it.first;
const auto i = it.second;
if (indexes[name] > 1) {
names[i].append("1");
}
}
}
} // namespace
CodeGenerator::CodeGenerator() {}
bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata,
ErrorReporter* err) {
if (metadata == nullptr) {
err->Error("Loading nullptr is not allowed");
return false;
}
if (metadata->subgraph_metadata()->size() != 1) {
err->Error("Only exact 1 subgraph is supported");
return false;
}
return true;
}
std::pair<std::vector<std::string>, std::vector<std::string>>
CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs,
const TensorMetadataList* outputs) {
std::vector<std::string> input_names;
std::vector<std::string> output_names;
if (inputs != nullptr) {
input_names.reserve(inputs->size());
for (const auto* tensor : *inputs) {
input_names.push_back(NameTensor(*tensor, "input"));
}
}
if (outputs != nullptr) {
output_names.reserve(outputs->size());
for (const auto* tensor : *outputs) {
output_names.push_back(NameTensor(*tensor, "output"));
}
}
// Solve conflict
ResolveConflictedInputAndOutputNames(&input_names, &output_names);
return std::make_pair(input_names, output_names);
}
std::string CodeGenerator::ConvertToValidName(const std::string& name) {
// lowercase all
std::string result = name;
for (int i = 0; i < result.size(); i++) {
result[i] = std::tolower(result[i]);
}
// replace all non-alpha or non-numeric with underscores, except underscore
// itself
for (int i = 0; i < result.size(); i++) {
if (result[i] != '_' && !std::isalnum(result[i])) {
result[i] = '_';
}
}
// remove leading underscores
int leading_underscores = 0;
while (leading_underscores < result.size() &&
result[leading_underscores] == '_') {
leading_underscores++;
}
result.erase(0, leading_underscores);
if (result.empty()) {
return "";
}
// first char should be alpha
if (std::isalpha(result[0])) {
return result;
}
return "tensor_" + result;
}
std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
const std::string& default_name) {
if (tensor.name() != nullptr && tensor.name()->size() > 0) {
// TODO(b/141225157) Validate tensor name. It should be in lower case.
auto suggested_name = ConvertToValidName(tensor.name()->str());
if (!suggested_name.empty()) {
return suggested_name;
}
}
auto* content = tensor.content();
if (content == nullptr || content->content_properties() == nullptr) {
return default_name;
}
switch (content->content_properties_type()) {
case ContentProperties_ImageProperties:
return "image";
case ContentProperties_FeatureProperties:
return "feature";
default:
return default_name;
}
}
void CodeGenerator::ResolveConflictedInputAndOutputNames(
std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
std::unordered_set<std::string> io_conflict;
auto& input_names = *inputs;
auto& output_names = *outputs;
for (const auto& input : input_names) {
if (io_conflict.find(input) != io_conflict.end()) {
continue;
}
for (const auto& output : output_names) {
if (input == output) {
io_conflict.insert(input);
break;
}
}
}
for (int i = 0; i < input_names.size(); i++) {
if (io_conflict.find(input_names[i]) != io_conflict.end()) {
input_names[i] = "input_" + input_names[i];
}
}
for (int i = 0; i < output_names.size(); i++) {
if (io_conflict.find(output_names[i]) != io_conflict.end()) {
output_names[i] = "output_" + output_names[i];
}
}
// 2. Second, add index if input[i] == input[j]
ResolveConflictedNamesByAddingIndex(&input_names);
ResolveConflictedNamesByAddingIndex(&output_names);
}
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -1,80 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
struct GenerationResult {
struct File {
std::string path;
std::string content;
};
std::vector<File> files;
};
/// Defines language-independent codegen strategies, like class naming, .etc.
/// Should not be used directly.
class CodeGenerator {
public:
CodeGenerator();
using TensorMetadataList =
typename flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>;
virtual ~CodeGenerator() {}
// Strategies.
/// Names all the IO tensors. It's useful when they don't have names, or the
/// names have conflicts. We have to name every tensor for code generation.
// TODO(b/141225157): Add reserved keywords check.
static std::pair<std::vector<std::string>, std::vector<std::string>>
NameInputsAndOutputs(const TensorMetadataList* inputs,
const TensorMetadataList* outputs);
/// Loads a metadata for code generation.
/// Returns false if the metadata is not good for generation.
static bool VerifyMetadata(const ModelMetadata* metadata, ErrorReporter* err);
protected:
/// Converts a name into a valid form. Rules:
/// - lower all letters.
/// - replace all non alphabet nor numeric characters with underscores.
/// - remove prefix underscores.
/// - add prefix if the leading character is a number.
/// Returns empty string if not possible.
static std::string ConvertToValidName(const std::string& name);
static std::string NameTensor(const TensorMetadata& tensor,
const std::string& default_name);
static void ResolveConflictedInputAndOutputNames(
std::vector<std::string>* input, std::vector<std::string>* output);
};
} // namespace codegen
} // namespace support
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_

View File

@ -1,126 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace tflite {
namespace support {
namespace codegen {
namespace {
using ::testing::ElementsAreArray;
class CodeGeneratorTest : public ::testing::Test {
public:
class TestingCodeGenerator : public CodeGenerator {
public:
explicit TestingCodeGenerator() : CodeGenerator() {}
// Make tested method public.
static std::string ConvertToValidName(const std::string& name) {
return CodeGenerator::ConvertToValidName(name);
}
static void ResolveConflictedInputAndOutputNames(
std::vector<std::string>* input, std::vector<std::string>* output) {
CodeGenerator::ResolveConflictedInputAndOutputNames(input, output);
}
};
};
TEST_F(CodeGeneratorTest, UpperCasesShouldLower) {
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("AlphaBetCOOL"),
"alphabetcool");
}
TEST_F(CodeGeneratorTest, NonAlphaNumShouldReplace) {
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("A+=B C\t"), "a__b_c_");
}
TEST_F(CodeGeneratorTest, NoLeadingUnderscore) {
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("+KAI Z"), "kai_z");
}
TEST_F(CodeGeneratorTest, NoLeadingNumbers) {
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("3000 Cool Tensors"),
"tensor_3000_cool_tensors");
}
TEST_F(CodeGeneratorTest, TestSimpleIONames) {
std::vector<std::string> inputs = {"image"};
std::vector<std::string> outputs = {"output"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs, ElementsAreArray({"image"}));
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
}
TEST_F(CodeGeneratorTest, TestIOConflict) {
std::vector<std::string> inputs = {"image"};
std::vector<std::string> outputs = {"image"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs, ElementsAreArray({"input_image"}));
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
}
TEST_F(CodeGeneratorTest, TestInternalConflict) {
std::vector<std::string> inputs = {"image", "image"};
std::vector<std::string> outputs = {"output"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs, ElementsAreArray({"image1", "image2"}));
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
}
TEST_F(CodeGeneratorTest, TestAllConflictNTo1) {
std::vector<std::string> inputs = {"image", "image"};
std::vector<std::string> outputs = {"image"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs, ElementsAreArray({"input_image1", "input_image2"}));
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
}
TEST_F(CodeGeneratorTest, TestAllConflict) {
std::vector<std::string> inputs = {"image", "audio", "image", "audio",
"audio"};
std::vector<std::string> outputs = {"image", "image", "audio", "feature",
"feature"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs,
ElementsAreArray({"input_image1", "input_audio1", "input_image2",
"input_audio2", "input_audio3"}));
EXPECT_THAT(outputs,
ElementsAreArray({"output_image1", "output_image2",
"output_audio", "feature1", "feature2"}));
}
TEST_F(CodeGeneratorTest, TestAllConflictReversed) {
std::vector<std::string> inputs = {"image", "image", "audio", "feature",
"feature"};
std::vector<std::string> outputs = {"image", "audio", "image", "audio",
"audio"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs,
ElementsAreArray({"input_image1", "input_image2", "input_audio",
"feature1", "feature2"}));
EXPECT_THAT(outputs, ElementsAreArray({"output_image1", "output_audio1",
"output_image2", "output_audio2",
"output_audio3"}));
}
} // namespace
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -1,100 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h"
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
constexpr char BUFFER_KEY[] = "TFLITE_METADATA";
const ModelMetadata* GetMetadataFromModel(const Model* model) {
if (model == nullptr || model->metadata() == nullptr) {
return nullptr;
}
for (auto i = 0; i < model->metadata()->size(); i++) {
const auto* name = model->metadata()->Get(i)->name();
if (name != nullptr && name->str() == BUFFER_KEY) {
const auto buffer_index = model->metadata()->Get(i)->buffer();
if (model->buffers() == nullptr ||
model->buffers()->size() <= buffer_index) {
continue;
}
const auto* buffer_vec = model->buffers()->Get(buffer_index)->data();
if (buffer_vec == nullptr || buffer_vec->data() == nullptr) {
continue;
}
return GetModelMetadata(buffer_vec->data());
}
}
return nullptr;
}
int FindAssociatedFile(const TensorMetadata* metadata,
const AssociatedFileType file_type,
const std::string& tensor_identifier,
ErrorReporter* err) {
int result = -1;
if (metadata->associated_files() == nullptr ||
metadata->associated_files()->size() == 0) {
return result;
}
for (int i = 0; i < metadata->associated_files()->size(); i++) {
const auto* file_metadata = metadata->associated_files()->Get(i);
if (file_metadata->type() == file_type) {
if (result >= 0) {
err->Warning(
"Multiple associated file of type %d found on tensor %s. Only the "
"first one will be used.",
file_type, tensor_identifier.c_str());
continue;
}
result = i;
}
}
return result;
}
int FindNormalizationUnit(const TensorMetadata* metadata,
const std::string& tensor_identifier,
ErrorReporter* err) {
int result = -1;
if (metadata->process_units() == nullptr ||
metadata->process_units()->size() == 0) {
return result;
}
for (int i = 0; i < metadata->process_units()->size(); i++) {
const auto* process_uint = metadata->process_units()->Get(i);
if (process_uint->options_type() ==
ProcessUnitOptions_NormalizationOptions) {
if (result >= 0) {
err->Warning(
"Multiple normalization unit found in tensor %s. Only the first "
"one will be effective.",
tensor_identifier.c_str());
continue;
}
result = i;
}
}
return result;
}
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -1,51 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
#include <string>
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
/// Parses a ModelMetadata out from a Model. The returned ModelMetadata's
/// lifetime is scoped by the model.
/// Returns nullptr if we cannot find any metadata.
const ModelMetadata* GetMetadataFromModel(const Model* model);
/// Finds an associated file from a TensorMetadata of certain type. If there're
/// multiple files meet the criteria, only the first one is used. If there's no
/// file meets the criteria, -1 will be returned.
int FindAssociatedFile(const TensorMetadata* metadata,
const AssociatedFileType file_type,
const std::string& tensor_identifier,
ErrorReporter* err);
/// Find the first normalization unit. If none, return -1.
int FindNormalizationUnit(const TensorMetadata* metadata,
const std::string& tensor_identifier,
ErrorReporter* err);
} // namespace codegen
} // namespace support
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_

View File

@ -1,38 +0,0 @@
load("//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
pybind_extension(
name = "_pywrap_codegen",
srcs = [
"codegen_lib.cc",
],
features = ["-use_header_modules"],
module_name = "_pywrap_codegen",
deps = [
"//tensorflow/lite/experimental/support/codegen:android_java_generator",
"//tensorflow/lite/experimental/support/codegen:code_generator",
"//tensorflow/python:pybind11_lib",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
py_binary(
name = "codegen",
srcs = [
"codegen.py",
],
python_version = "PY3",
deps = [
":_pywrap_codegen",
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@absl_py//absl/logging",
],
)

View File

@ -1,96 +0,0 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates Android Java sources from a TFLite model with metadata."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
from absl import app
from absl import flags
from absl import logging
from tensorflow.lite.experimental.support.codegen.python import _pywrap_codegen
FLAGS = flags.FLAGS
flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.')
flags.DEFINE_string('destination', None, 'Path of destination of generation.')
flags.DEFINE_string('package_name', 'org.tensorflow.lite.support',
'Name of generated java package to put the wrapper class.')
flags.DEFINE_string(
'model_class_name', 'MyModel',
'Name of generated wrapper class (should not contain package name).')
flags.DEFINE_string(
'model_asset_path', '',
'(Optional) Path to the model in generated assets/ dir. If not set, '
'generator will use base name of input model.'
)
def get_model_buffer(path):
if not os.path.isfile(path):
logging.error('Cannot find model at path %s.', path)
with open(path, 'rb') as f:
buf = f.read()
return buf
def prepare_directory_for_file(file_path):
target_dir = os.path.dirname(file_path)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
return
if not os.path.isdir(target_dir):
logging.error('Cannot write to %s', target_dir)
def main(argv):
if len(argv) > 1:
logging.error('None flag arguments found: [%s]', ', '.join(argv[1:]))
codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination)
model_buffer = get_model_buffer(FLAGS.model)
model_asset_path = FLAGS.model_asset_path
if not model_asset_path:
model_asset_path = os.path.basename(FLAGS.model)
result = codegen.generate(model_buffer, FLAGS.package_name,
FLAGS.model_class_name, model_asset_path)
error_message = codegen.get_error_message().strip()
if error_message:
logging.error(error_message)
if not result.files:
logging.error('Generation failed!')
return
for each in result.files:
prepare_directory_for_file(each.path)
with open(each.path, 'w') as f:
f.write(each.content)
logging.info('Generation succeeded!')
model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets',
model_asset_path)
prepare_directory_for_file(model_asset_path)
shutil.copy(FLAGS.model, model_asset_path)
logging.info('Model copied into assets!')
if __name__ == '__main__':
flags.mark_flag_as_required('model')
flags.mark_flag_as_required('destination')
app.run(main)

View File

@ -1,49 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/detail/common.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
namespace tflite {
namespace support {
namespace codegen {
template <typename... Args>
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
PYBIND11_MODULE(_pywrap_codegen, m) {
pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator")
.def(pybind11::init<const std::string &>())
.def("generate",
overload_cast_<const char *, const std::string &,
const std::string &, const std::string &>()(
&AndroidJavaGenerator::Generate))
.def("get_error_message", &AndroidJavaGenerator::GetErrorMessage);
pybind11::class_<GenerationResult>(m, "GenerationResult")
.def(pybind11::init<>())
.def_readwrite("files", &GenerationResult::files);
pybind11::class_<GenerationResult::File>(m, "GenerationResultFile")
.def(pybind11::init<>())
.def_readwrite("path", &GenerationResult::File::path)
.def_readwrite("content", &GenerationResult::File::content);
}
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -1,194 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include <cstdarg>
namespace tflite {
namespace support {
namespace codegen {
int ErrorReporter::Warning(const char* format, ...) {
va_list args;
va_start(args, format);
return Report("[WARN] ", format, args);
}
int ErrorReporter::Error(const char* format, ...) {
va_list args;
va_start(args, format);
return Report("[ERROR] ", format, args);
}
int ErrorReporter::Report(const char* prefix, const char* format,
va_list args) {
char buf[1024];
int formatted = vsnprintf(buf, sizeof(buf), format, args);
buffer_ << prefix << buf << std::endl;
return formatted;
}
std::string ErrorReporter::GetMessage() {
std::string value = buffer_.str();
buffer_.str("");
return value;
}
CodeWriter::CodeWriter(ErrorReporter* err) : indent_(0), err_(err) {}
void CodeWriter::SetTokenValue(const std::string& token,
const std::string& value) {
value_map_[token] = value;
}
const std::string CodeWriter::GetTokenValue(const std::string& token) const {
auto iter = value_map_.find(token);
if (iter == value_map_.end()) {
// Typically only Code Generator's call this function (or `Append`). It's
// their duty to make sure the token is valid, and requesting for an invalid
// token implicits flaws in the code generation logic.
err_->Error("Internal: Cannot find value with token '%s'", token.c_str());
return "";
}
return iter->second;
}
void CodeWriter::SetIndentString(const std::string& indent_str) {
indent_str_ = indent_str;
}
void CodeWriter::Indent() { indent_++; }
void CodeWriter::Outdent() { indent_--; }
std::string CodeWriter::GenerateIndent() const {
std::string res;
res.reserve(indent_str_.size() * indent_);
for (int i = 0; i < indent_; i++) {
res.append(indent_str_);
}
return res;
}
void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); }
void CodeWriter::AppendNoNewLine(const std::string& text) {
AppendInternal(text, false);
}
void CodeWriter::AppendInternal(const std::string& text, bool newline) {
// Prefix indent
if ((buffer_.empty() // nothing in the buffer
|| buffer_.back() == '\n') // is on new line
// is writing on current line
&& (!text.empty() && text[0] != '\n' && text[0] != '\r')) {
buffer_.append(GenerateIndent());
}
// State machine variables
bool in_token = false;
int i = 0;
// Rough memory reserve
buffer_.reserve(buffer_.size() + text.size());
std::string token_buffer;
// A simple LL1 analysis
while (i < text.size()) {
char cur = text[i];
char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian
if (!in_token) {
if (cur == '{' && cur_next == '{') { // Enter token
in_token = true;
i += 2;
} else if (cur == '\n') { // We need to apply global indent here
buffer_.push_back(cur);
if (cur_next != '\0' && cur_next != '\n' && cur_next != '\r') {
buffer_.append(GenerateIndent());
}
i += 1;
} else {
buffer_.push_back(cur);
i += 1;
}
} else {
if (cur == '}' && cur_next == '}') { // Close token
in_token = false;
const auto value = GetTokenValue(token_buffer);
buffer_.append(value);
token_buffer.clear();
i += 2;
} else {
token_buffer.push_back(cur);
i += 1;
}
}
}
if (!token_buffer.empty()) {
// Typically only Code Generator's call this function. It's
// their duty to make sure the code (or template) has valid syntax, and
// unclosed "{{...}}" implicits severe error in the template.
err_->Error("Internal: Invalid template: {{token}} is not closed.");
}
if (newline) {
buffer_.push_back('\n');
}
}
void CodeWriter::NewLine() { Append(""); }
void CodeWriter::Backspace(int n) {
buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0);
}
std::string CodeWriter::ToString() const { return buffer_; }
bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); }
void CodeWriter::Clear() {
buffer_.clear();
value_map_.clear();
indent_ = 0;
}
std::string SnakeCaseToCamelCase(const std::string& s) {
std::string t;
t.reserve(s.length());
size_t i = 0;
// Note: Use simple string += for simplicity.
bool cap = false;
while (i < s.size()) {
const char c = s[i++];
if (c == '_') {
cap = true;
} else if (cap) {
t += toupper(c);
cap = false;
} else {
t += c;
}
}
return t;
}
std::string JoinPath(const std::string& a, const std::string& b) {
if (a.empty()) return b;
std::string a_fixed = a;
if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back();
std::string b_fixed = b;
if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1);
return a_fixed + "/" + b_fixed;
}
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -1,127 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
#include <map>
#include <sstream>
#include <string>
namespace tflite {
namespace support {
namespace codegen {
/// Collects runtime error logs which could be showed later.
// TODO(b/150538286): Consider a better mechanism to simplify callsite code.
class ErrorReporter {
public:
int Warning(const char* format, ...);
int Error(const char* format, ...);
std::string GetMessage();
private:
int Report(const char* prefix, const char* format, va_list args);
std::stringstream buffer_;
};
/// Implements basic code generating with text templates.
///
/// It could accept code templates and concatenate them into complete codes. A
/// template could contain named values.
///
/// Example code:
/// CodeWriter code;
/// code.SetValue("NAME", "Foo");
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
/// code.SetValue("NAME", "Bar");
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
///
/// Output:
/// void Foo() { printf("%s", "Foo"); }
/// void Bar() { printf("%s", "Bar"); }
class CodeWriter {
public:
explicit CodeWriter(ErrorReporter* err);
/// Sets value to a token. When generating code with template, a string in a
/// pair of {{ and }} will be regarded as a token and replaced with the
/// corresponding value in code generation.
/// It rewrites if the token already has a value.
void SetTokenValue(const std::string& token, const std::string& value);
/// Gets the current value set on the given token.
const std::string GetTokenValue(const std::string& token) const;
/// Sets the unit indent string. For example, in Java it should be " ".
void SetIndentString(const std::string& indent);
/// Increases the indent by a unit (the string set in SetIndentString).
void Indent();
/// Decreases the indent by a unit (the string set in SetIndentString).
void Outdent();
/// Generates the indentation string.
std::string GenerateIndent() const;
/// Appends a piece of template codes to the stream. Every named value will be
/// replaced via the real value. A new line will always be appended at the
/// end.
void Append(const std::string& text);
/// Appends a piece of template codes to the stream. Same with `Append`, but a
/// new line will not be appended at the end.
void AppendNoNewLine(const std::string& text);
/// Appends a new line to the stream.
void NewLine();
/// Deletes the last N charaters in the stream. If the stream has less than N
/// characters, deletes all.
void Backspace(int n);
std::string ToString() const;
/// Checks if the internal string stream is empty. Note: This method has
// overhead.
bool IsStreamEmpty() const;
/// Clears all the internal string stream and value map.
void Clear();
private:
void AppendInternal(const std::string& text, bool newline);
std::string indent_str_;
int indent_;
std::map<std::string, std::string> value_map_;
std::string buffer_;
ErrorReporter* err_;
};
/// Converts foo_bar_name to fooBarName. It's callers duty to make sure given
/// string "s" is already in snake case; or unexpected behavior may occur.
std::string SnakeCaseToCamelCase(const std::string& s);
/// Joins 2 parts of file path into one, connected by unix path seperator '/'.
/// It's callers duty to ensure the two parts are valid.
std::string JoinPath(const std::string& a, const std::string& b);
} // namespace codegen
} // namespace support
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_

View File

@ -1,97 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include <gtest/gtest.h>
namespace tflite {
namespace support {
namespace codegen {
namespace {
TEST(ErrorReporterTest, TestReportError) {
ErrorReporter err;
err.Error("some text");
EXPECT_EQ(err.GetMessage(), "[ERROR] some text\n");
EXPECT_EQ(err.GetMessage(), "");
}
TEST(CodeGeneratorTest, TestExample) {
ErrorReporter err;
CodeWriter writer(&err);
writer.SetTokenValue("NAME", "Foo");
const std::string text = R"(void {{NAME}}() { printf("%s", "{{NAME}}"); })";
writer.Append(text);
writer.SetTokenValue("NAME", "Bar");
writer.Append(text);
EXPECT_EQ(
"void Foo() { printf(\"%s\", \"Foo\"); }\n"
"void Bar() { printf(\"%s\", \"Bar\"); }\n",
writer.ToString());
}
TEST(CodeGeneratorTest, TestInexistentToken) {
ErrorReporter err;
CodeWriter writer(&err);
writer.SetTokenValue("NAME", "Foo");
const std::string text = R"(void {{name}}() {})";
writer.Append(text);
EXPECT_EQ(err.GetMessage(),
"[ERROR] Internal: Cannot find value with token 'name'\n");
}
TEST(CodeGeneratorTest, TestUnclosedToken) {
ErrorReporter err;
CodeWriter writer(&err);
writer.SetTokenValue("NAME", "Foo");
const std::string text = R"(void {{NAME}() {})";
writer.Append(text);
EXPECT_EQ(err.GetMessage(),
"[ERROR] Internal: Invalid template: {{token}} is not closed.\n");
}
TEST(CodeGeneratorTest, TestIndentControl) {
ErrorReporter err;
CodeWriter writer(&err);
writer.SetIndentString(" ");
writer.Indent();
writer.AppendNoNewLine("abcde"); // Will indent
EXPECT_EQ(" abcde", writer.ToString());
writer.Clear();
writer.Indent();
writer.AppendNoNewLine("abc\n\nde");
// The blank line will not indent
EXPECT_EQ(" abc\n\n de", writer.ToString());
writer.Clear();
writer.Indent();
writer.Append("abc");
writer.Outdent();
writer.AppendNoNewLine("def");
EXPECT_EQ(" abc\ndef", writer.ToString());
}
TEST(CaseConversionTest, TestSnakeToCamel) {
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel"));
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel_"));
EXPECT_EQ("ImACamel", SnakeCaseToCamelCase("_im_a_camel"));
EXPECT_EQ("", SnakeCaseToCamelCase("_"));
EXPECT_EQ("camel", SnakeCaseToCamelCase("camel"));
}
} // namespace
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.support">
<uses-sdk android:minSdkVersion="19" />
</manifest>

View File

@ -1,66 +0,0 @@
# Description:
# TensorFlow Lite Support API in Java.
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
# TODO(b/156482505): The NOGPU target is a temporary target. Internally, people
# may already depend on "tensorflow-lite-support" so we shouldn't remove GPU
# from its dependency. We will have CLs to help users migrate. After migration
# is done, the "NOGPU" target will be removed.
android_library(
name = "tensorflow-lite-support-nogpu",
srcs = glob(["src/java/org/tensorflow/lite/support/**/*.java"]),
javacopts = JAVACOPTS,
manifest = "AndroidManifest.xml",
deps = [
"//tensorflow/lite/java:tensorflowlite",
"@org_checkerframework_qual",
],
)
# TODO(138904786): Split Java part and Android part to make the support library usable by pure Java.
# For new users: Please use "tensorflow-lite-support-nogpu" if possible, and
# additionally depends on "tensorflowlite_gpu" if needed.
android_library(
name = "tensorflow-lite-support",
srcs = glob(["src/java/org/tensorflow/lite/support/**/*.java"]),
javacopts = JAVACOPTS,
manifest = "AndroidManifest.xml",
deps = [
"//tensorflow/lite/java:tensorflowlite",
"//tensorflow/lite/java:tensorflowlite_gpu", # unuseddeps: keep
"@org_checkerframework_qual",
],
)
# This alias matches the style of lite/java naming for android_library targets. We keep the
# `tensorflow-lite-support` variant to match the associated .aar library name output style.
alias(
name = "tensorflowlite_support",
actual = ":tensorflow-lite-support",
)
java_library(
name = "tensorflow-lite-support-precondition",
srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"],
javacopts = JAVACOPTS,
deps = [
"@org_checkerframework_qual",
],
)
android_library(
name = "tensorflow-lite-support-precondition-lib-android",
srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"],
javacopts = JAVACOPTS,
manifest = "AndroidManifest.xml",
deps = [
"@org_checkerframework_qual",
],
)

View File

@ -1,17 +0,0 @@
# TensorFlow Lite Android Support Library
Mobile application developers typically interact with typed objects such as
bitmaps or primitives such as integers. However, the TensorFlow Lite Interpreter
that runs the on-device machine learning model uses tensors in the form of
ByteBuffer, which can be difficult to debug and manipulate. The TensorFlow Lite
Android Support Library is designed to help process the input and output of
TensorFlow Lite models, and make the TensorFlow Lite interpreter easier to use.
We welcome feedback from the community as we develop this support library,
especially around:
* Use-cases we should support including data types and operations
* Ease of use - does the APIs make sense to the community
See the [documentation](https://www.tensorflow.org/lite/guide/lite_support) for
instruction and examples.

View File

@ -1,184 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import org.checkerframework.checker.nullness.qual.NonNull;
/** File I/O utilities. */
public class FileUtil {
private FileUtil() {}
/**
* Loads labels from the label file into a list of strings.
*
* <p>A legal label file is the plain text file whose contents are split into lines, and each line
* is an individual value. The file should be in assets of the context.
*
* @param context The context holds assets.
* @param filePath The path of the label file, relative with assets directory.
* @return a list of labels.
* @throws IOException if error occurs to open or read the file.
*/
@NonNull
public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
throws IOException {
return loadLabels(context, filePath, Charset.defaultCharset());
}
/**
* Loads labels from the label file into a list of strings.
*
* <p>A legal label file is the plain text file whose contents are split into lines, and each line
* is an individual value. The empty lines will be ignored. The file should be in assets of the
* context.
*
* @param context The context holds assets.
* @param filePath The path of the label file, relative with assets directory.
* @param cs {@code Charset} to use when decoding content of label file.
* @return a list of labels.
* @throws IOException if error occurs to open or read the file.
*/
@NonNull
public static List<String> loadLabels(
@NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
SupportPreconditions.checkNotNull(context, "Context cannot be null.");
SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
try (InputStream inputStream = context.getAssets().open(filePath)) {
return loadLabels(inputStream, cs);
}
}
/**
* Loads labels from an input stream of an opened label file. See details for label files in
* {@link FileUtil#loadLabels(Context, String)}.
*
* @param inputStream the input stream of an opened label file.
* @return a list of labels.
* @throws IOException if error occurs to open or read the file.
*/
@NonNull
public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
return loadLabels(inputStream, Charset.defaultCharset());
}
/**
* Loads labels from an input stream of an opened label file. See details for label files in
* {@link FileUtil#loadLabels(Context, String)}.
*
* @param inputStream the input stream of an opened label file.
* @param cs {@code Charset} to use when decoding content of label file.
* @return a list of labels.
* @throws IOException if error occurs to open or read the file.
*/
@NonNull
public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
throws IOException {
List<String> labels = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.trim().length() > 0) {
labels.add(line);
}
}
return labels;
}
}
/**
* Loads a vocabulary file (a single-column text file) into a list of strings.
*
* <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
* and each line is an individual value. The file should be in assets of the context.
*
* @param context The context holds assets.
* @param filePath The path of the vocabulary file, relative with assets directory.
* @return a list of vocabulary words.
* @throws IOException if error occurs to open or read the file.
*/
@NonNull
public static List<String> loadSingleColumnTextFile(
@NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
return loadLabels(context, filePath, cs);
}
/**
* Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
* text file). See details for vocabulary files in {@link FileUtil#loadVocabularyFile(Context,
* String)}.
*
* @param inputStream the input stream of an opened vocabulary file.
* @return a list of vocabulary words.
* @throws IOException if error occurs to open or read the file.
*/
@NonNull
public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs)
throws IOException {
return loadLabels(inputStream, cs);
}
/**
* Loads a file from the asset folder through memory mapping.
*
* @param context Application context to access assets.
* @param filePath Asset path of the file.
* @return the loaded memory mapped file.
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
@NonNull
public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath)
throws IOException {
SupportPreconditions.checkNotNull(context, "Context should not be null.");
SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
/**
* Loads a binary file from the asset folder.
*
* @param context Application context to access assets.
* @param filePath Asset path of the file.
* @return the byte array for the binary file.
* @throws IOException if an I/O error occurs when loading file.
*/
@NonNull
public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath)
throws IOException {
ByteBuffer buffer = loadMappedFile(context, filePath);
byte[] byteArray = new byte[buffer.remaining()];
buffer.get(byteArray);
return byteArray;
}
}

View File

@ -1,31 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common;
/**
* The common interface for classes that carries an "apply" method, which converts T to another one.
* @param <T> The class which Operator handles.
*/
public interface Operator<T> {
/**
* Applies an operation on a T object, returning a T object.
*
* <p>Note: The returned object could probably be the same one with given input, and given input
* could probably be changed.
*/
T apply(T x);
}

View File

@ -1,23 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common;
/**
* Processes T object with prepared {@link Operator<T>}.
*/
public interface Processor<T> {
T process(T input);
}

View File

@ -1,82 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
/**
* A processor base class that chains a serial of {@link Operator<T>} and executes them.
*
* <p>Typically, users could use its subclasses, e.g. {@link
* org.tensorflow.lite.support.image.ImageProcessor} rather than directly use this one.
*
* @param <T> The type that the Operator is handling.
*/
public class SequentialProcessor<T> implements Processor<T> {
/** List of operators added to this {@link SequentialProcessor}. */
protected final List<Operator<T>> operatorList;
/**
* The {@link Map} between the operator name and the corresponding op indexes in {@code
* operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}.
*/
protected final Map<String, List<Integer>> operatorIndex;
protected SequentialProcessor(Builder<T> builder) {
operatorList = builder.operatorList;
operatorIndex = Collections.unmodifiableMap(builder.operatorIndex);
}
@Override
public T process(T x) {
for (Operator<T> op : operatorList) {
x = op.apply(x);
}
return x;
}
/** The inner builder class to build a Sequential Processor. */
protected static class Builder<T> {
private final List<Operator<T>> operatorList;
private final Map<String, List<Integer>> operatorIndex;
protected Builder() {
operatorList = new ArrayList<>();
operatorIndex = new HashMap<>();
}
public Builder<T> add(@NonNull Operator<T> op) {
SupportPreconditions.checkNotNull(op, "Adding null Op is illegal.");
operatorList.add(op);
String operatorName = op.getClass().getName();
if (!operatorIndex.containsKey(operatorName)) {
operatorIndex.put(operatorName, new ArrayList<Integer>());
}
operatorIndex.get(operatorName).add(operatorList.size() - 1);
return this;
}
public SequentialProcessor<T> build() {
return new SequentialProcessor<T>(this);
}
}
}

View File

@ -1,184 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common;
import org.checkerframework.checker.nullness.qual.Nullable;
/** Static error checking util methods. */
public final class SupportPreconditions {
/**
* Ensures that an object reference passed as a parameter to the calling method is not null.
*
* @param reference an object reference
* @return the non-null reference that was validated
* @throws NullPointerException if {@code reference} is null
*/
public static <T extends Object> T checkNotNull(T reference) {
if (reference == null) {
throw new NullPointerException("The object reference is null.");
}
return reference;
}
/**
* Ensures that an object reference passed as a parameter to the calling method is not null.
*
* @param reference an object reference
* @param errorMessage the exception message to use if the check fails; will be converted to a
* string using {@link String#valueOf(Object)}
* @return the non-null reference that was validated
* @throws NullPointerException if {@code reference} is null
*/
public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
if (reference == null) {
throw new NullPointerException(String.valueOf(errorMessage));
}
return reference;
}
/**
* Ensures that the given String is not empty and not null.
*
* @param string the String to test
* @return the non-null non-empty String that was validated
* @throws IllegalArgumentException if {@code string} is null or empty
*/
public static String checkNotEmpty(String string) {
if (string == null || string.length() == 0) {
throw new IllegalArgumentException("Given String is empty or null.");
}
return string;
}
/**
* Ensures that the given String is not empty and not null.
*
* @param string the String to test
* @param errorMessage the exception message to use if the check fails; will be converted to a
* string using {@link String#valueOf(Object)}
* @return the non-null non-empty String that was validated
* @throws IllegalArgumentException if {@code string} is null or empty
*/
public static String checkNotEmpty(String string, Object errorMessage) {
if (string == null || string.length() == 0) {
throw new IllegalArgumentException(String.valueOf(errorMessage));
}
return string;
}
/**
* Ensures the truth of an expression involving one or more parameters to the calling method.
*
* @param expression a boolean expression.
* @throws IllegalArgumentException if {@code expression} is false.
*/
public static void checkArgument(boolean expression) {
if (!expression) {
throw new IllegalArgumentException();
}
}
/**
* Ensures the truth of an expression involving one or more parameters to the calling method.
*
* @param expression a boolean expression.
* @param errorMessage the exception message to use if the check fails; will be converted to a
* string using {@link String#valueOf(Object)}.
* @throws IllegalArgumentException if {@code expression} is false.
*/
public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
if (!expression) {
throw new IllegalArgumentException(String.valueOf(errorMessage));
}
}
/**
* Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
* {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
*
* @param index a user-supplied index identifying an element of an array, list or string
* @param size the size of that array, list or string
* @return the value of {@code index}
* @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
* @throws IllegalArgumentException if {@code size} is negative
*/
public static int checkElementIndex(int index, int size) {
return checkElementIndex(index, size, "index");
}
/**
* Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
* {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
*
* @param index a user-supplied index identifying an element of an array, list or string
* @param size the size of that array, list or string
* @param desc the text to use to describe this index in an error message
* @return the value of {@code index}
* @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
* @throws IllegalArgumentException if {@code size} is negative
*/
public static int checkElementIndex(int index, int size, @Nullable String desc) {
// Carefully optimized for execution by hotspot (explanatory comment above)
if (index < 0 || index >= size) {
throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
}
return index;
}
/**
* Ensures the truth of an expression involving the state of the calling instance, but not
* involving any parameters to the calling method.
*
* @param expression a boolean expression
* @throws IllegalStateException if {@code expression} is false
* @see Verify#verify Verify.verify()
*/
public static void checkState(boolean expression) {
if (!expression) {
throw new IllegalStateException();
}
}
/**
* Ensures the truth of an expression involving the state of the calling instance, but not
* involving any parameters to the calling method.
*
* @param expression a boolean expression
* @param errorMessage the exception message to use if the check fails; will be converted to a
* string using {@link String#valueOf(Object)}
* @throws IllegalStateException if {@code expression} is false
* @see Verify#verify Verify.verify()
*/
public static void checkState(boolean expression, @Nullable Object errorMessage) {
if (!expression) {
throw new IllegalStateException(String.valueOf(errorMessage));
}
}
private static String badElementIndex(int index, int size, @Nullable String desc) {
if (index < 0) {
return String.format("%s (%s) must not be negative", desc, index);
} else if (size < 0) {
throw new IllegalArgumentException("negative size: " + size);
} else { // index >= size
return String.format("%s (%s) must be less than size (%s)", desc, index, size);
}
}
private SupportPreconditions() {
throw new AssertionError("SupportPreconditions is Uninstantiable.");
}
}

View File

@ -1,27 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* Applies some operation on TensorBuffers.
*/
public interface TensorOperator extends Operator<TensorBuffer> {
/** @see Operator#apply(Object) . */
@Override
TensorBuffer apply(TensorBuffer input);
}

View File

@ -1,68 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* TensorProcessor is a helper class for preprocessing and postprocessing tensors. It could
* transform a {@link TensorBuffer} to another by executing a chain of {@link TensorOperator}.
*
* <p>Example Usage:
*
* <pre>
* TensorProcessor processor = new TensorProcessor.Builder().add(new NormalizeOp(1, 2)).build();
* TensorBuffer anotherTensorBuffer = processor.process(tensorBuffer);
* </pre>
*
* @see TensorProcessor.Builder to build a {@link TensorProcessor} instance.
* @see TensorProcessor#process(TensorBuffer) to apply the processor on a {@link TensorBuffer}.
*/
public class TensorProcessor extends SequentialProcessor<TensorBuffer> {
private TensorProcessor(Builder builder) {
super(builder);
}
/** The Builder to create an {@link TensorProcessor}, which could be executed later. */
public static class Builder extends SequentialProcessor.Builder<TensorBuffer> {
/**
* Creates a Builder to build {@link TensorProcessor}.
*
* @see #add(TensorOperator) to add an Op.
* @see #build() to complete the building process and get a built Processor.
*/
public Builder() {
super();
}
/**
* Adds an {@link TensorOperator} into the Operator chain.
*
* @param op the Operator instance to be executed then.
*/
public TensorProcessor.Builder add(TensorOperator op) {
super.add(op);
return this;
}
/** Completes the building process and gets the {@link TensorProcessor} instance. */
@Override
public TensorProcessor build() {
return new TensorProcessor(this);
}
}
}

View File

@ -1,55 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common.ops;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.common.SupportPreconditions;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Casts a {@link TensorBuffer} to a specified data type. */
public class CastOp implements TensorOperator {
private final DataType destinationType;
/**
* Constructs a CastOp.
*
* <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than in
* a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}.
*
* <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code
* destinationType}, the original buffer will be directly returned.
*
* @param destinationType: The type of the casted {@link TensorBuffer}.
* @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8}
* nor {@link DataType#FLOAT32}.
*/
public CastOp(DataType destinationType) {
SupportPreconditions.checkArgument(
destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32,
"Destination type " + destinationType + " is not supported.");
this.destinationType = destinationType;
}
@Override
public TensorBuffer apply(TensorBuffer input) {
if (input.getDataType() == destinationType) {
return input;
}
return TensorBuffer.createFrom(input, destinationType);
}
}

View File

@ -1,40 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common.ops;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* Dequantizes a {@link TensorBuffer} with given {@code zeroPoint} and {@code scale}.
*
* <p>Note: The data type of output tensor is always {@code FLOAT32} except when the DequantizeOp is
* created effectively as an identity Op such as setting {@code zeroPoint} to 0 and {@code scale} to
* 1 (in this case, the output tensor is the same instance as input).
*
* <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link DequantizeOp} will be bypassed,
* which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful
* when passing in the quantization parameters that are extracted directly from the TFLite model
* flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read
* as 0.
*/
public class DequantizeOp extends NormalizeOp implements TensorOperator {
public DequantizeOp(float zeroPoint, float scale) {
// Quantization: f = (q - z) * s
super(zeroPoint, 1 / scale);
}
}

View File

@ -1,160 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common.ops;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.common.SupportPreconditions;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat;
/**
* Normalizes a {@link TensorBuffer} with given mean and stddev: output = (input - mean) / stddev.
*/
public class NormalizeOp implements TensorOperator {
// mean.length should always be equal to stddev.length and always >= 1.
private final float[] mean;
private final float[] stddev;
private final int numChannels;
private final boolean isIdentityOp;
/**
* Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
* satisfies:
*
* <pre>
* output = (input - mean) / stddev
* </pre>
*
* <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
* normalization. <br>
* 1. Both {@code mean} and {code stddev} are 0. <br>
* 2. {@code mean} is 0 and {stddev} is Infinity.
*
* <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
* happen, and original input will be directly returned in execution.
*
* <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
* present, except that the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 and
* {@code stddev} is set to 1.
*
* @param mean the mean value to be subtracted first.
* @param stddev the standard deviation value to divide then.
* @throws IllegalArgumentException if {@code stddev} is zero.
*/
public NormalizeOp(float mean, float stddev) {
// Make exceptions to the cases that
// 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters
// from a tensor which does not have the values populated in the metadata. The same situation
// may also happen to the quantization parameters.
// 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
// parameters from a tensor which does not have the values populated in the metadata, and then
// passing the parameters into the DequantizeOp.
// Bypass both of the two cases, by reseting stddev to 1.0f.
if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
stddev = 1.0f;
}
SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
boolean meansIsZeroAndDevsIs1 = false;
if (mean == 0.0f && stddev == 1.0f) {
meansIsZeroAndDevsIs1 = true;
}
this.isIdentityOp = meansIsZeroAndDevsIs1;
this.mean = new float[] {mean};
this.stddev = new float[] {stddev};
this.numChannels = 1;
}
/**
* Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
* satisfies:
*
* <pre>
* // Pseudo code. [...][i] means a certain element whose channel id is i.
* output[...][i] = (input[...][i] - mean[i]) / stddev[i]
* </pre>
*
* <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no
* computation will happen, and original input will be directly returned in execution.
*
* <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
* present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set to
* 0 and all {@code stddev} are set to 1.
*
* @param mean the mean values to be subtracted first for each channel.
* @param stddev the standard deviation values to divide then for each channel.
* @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different
* number of elements with {@code stddev}, or any of them is empty.
*/
public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) {
SupportPreconditions.checkNotNull(mean, "Mean cannot be null");
SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null");
SupportPreconditions.checkArgument(
mean.length == stddev.length,
"Per channel normalization requires same number of means and stddevs");
SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty.");
this.mean = mean.clone();
this.stddev = stddev.clone();
boolean allMeansAreZeroAndAllDevsAre1 = true;
this.numChannels = mean.length;
for (int i = 0; i < numChannels; i++) {
SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero.");
if (this.stddev[i] != 1 || this.mean[i] != 0) {
allMeansAreZeroAndAllDevsAre1 = false;
}
}
this.isIdentityOp = allMeansAreZeroAndAllDevsAre1;
}
/**
* Applies the defined normalization on given tensor and returns the result.
*
* <p>Note: {@code input} is possibly the same instance with the output.
*
* @param input input tensor. It may be the same instance with the output.
* @return output tensor.
*/
@Override
@NonNull
public TensorBuffer apply(@NonNull TensorBuffer input) {
if (isIdentityOp) {
return input;
}
int[] shape = input.getShape();
SupportPreconditions.checkArgument(
numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels),
"Number of means (stddevs) is not same with number of channels (size of last axis).");
// TODO(136750944): Eliminate the array copy here.
float[] values = input.getFloatArray();
int j = 0;
for (int i = 0; i < values.length; i++) {
values[i] = (values[i] - mean[j]) / stddev[j];
j = (j + 1) % numChannels;
}
TensorBuffer output;
if (input.isDynamic()) {
output = TensorBufferFloat.createDynamic(DataType.FLOAT32);
} else {
output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32);
}
output.loadArray(values, shape);
return output;
}
}

View File

@ -1,41 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.common.ops;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* Quantizes a {@link TensorBuffer} with given {@code zeroPoint} and {@code scale}.
*
* <p>Note: {@link QuantizeOp} does not cast output to UINT8, but only performs the quantization
* math on top of input. The data type of output tensor is always {@code FLOAT32} except that the Op
* is effectively an identity Op (in this case, the output tensor is the same instance as the
* input). To connect with quantized model, a {@link CastOp} is probably needed.
*
* <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link QuantizeOp} will be bypassed,
* which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful
* when passing in the quantization parameters that are extracted directly from the TFLite model
* flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read
* as 0.
*/
public class QuantizeOp extends NormalizeOp implements TensorOperator {
public QuantizeOp(float zeroPoint, float scale) {
// Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s
super(-zeroPoint * scale, scale);
}
}

View File

@ -1,202 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.image;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
import android.graphics.RectF;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* Helper class for converting values that represents bounding boxes into rectangles.
*
* <p>The class provides a static function to create bounding boxes as {@link RectF} from different
* types of configurations.
*
* <p>Generally, a bounding box could be represented by 4 float values, but the values could be
* interpreted in many ways. We now support 3 {@link Type} of configurations, and the order of
* elements in each type is configurable as well.
*/
public final class BoundingBoxUtil {
/** Denotes how a bounding box is represented. */
public enum Type {
/**
* Represents the bounding box by using the combination of boundaries, {left, top, right,
* bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated by an
* index array.
*/
BOUNDARIES,
/**
* Represents the bounding box by using the upper_left corner, width and height. The default
* order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an
* index array.
*/
UPPER_LEFT,
/**
* Represents the bounding box by using the center of the box, width and height. The default
* order is {center_x, center_y, width, height}. Other orders can be indicated by an index
* array.
*/
CENTER,
}
/** Denotes if the coordinates are actual pixels or relative ratios. */
public enum CoordinateType {
/** The coordinates are relative ratios in range [0, 1]. */
RATIO,
/** The coordinates are actual pixel values. */
PIXEL
}
/**
* Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes.
*
* @param tensor holds the data representing some boxes.
* @param valueIndex denotes the order of the elements defined in each bounding box type. An empty
* index array represent the default order of each bounding box type. For example, to denote
* the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1, 2,
* 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}.
* <p>The index array can be applied to all bounding box types to adjust the order of their
* corresponding underlying elements.
* @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The
* size of that dimension is required to be 4. Index here starts from 0. For example, if the
* tensor has shape 4x10, the axis for bounding boxes is likely to be 0. For shape 10x4, the
* axis is likely to be 1 (or -1, equivalently).
* @param type defines how values should be converted into boxes. See {@link Type}
* @param coordinateType defines how values are interpreted to coordinates. See {@link
* CoordinateType}
* @param height the height of the image which the boxes belong to. Only has effects when {@code
* coordinateType} is {@link CoordinateType#RATIO}
* @param width the width of the image which the boxes belong to. Only has effects when {@code
* coordinateType} is {@link CoordinateType#RATIO}
* @return A list of bounding boxes that the {@code tensor} represents. All dimensions except
* {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code
* tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a list
* of 20 bounding boxes.
* @throws IllegalArgumentException if size of bounding box dimension (set by {@code
* boundingBoxAxis}) is not 4.
* @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)} where
* {@code D} is the number of dimensions of the {@code tensor}.
* @throws IllegalArgumentException if {@code tensor} has data type other than {@link
* DataType#FLOAT32}.
*/
public static List<RectF> convert(
TensorBuffer tensor,
int[] valueIndex,
int boundingBoxAxis,
Type type,
CoordinateType coordinateType,
int height,
int width) {
int[] shape = tensor.getShape();
checkArgument(
boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length,
String.format(
"Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input"
+ " tensor (shape=%s)",
boundingBoxAxis, Arrays.toString(shape)));
if (boundingBoxAxis < 0) {
boundingBoxAxis = shape.length + boundingBoxAxis;
}
checkArgument(
shape[boundingBoxAxis] == 4,
String.format(
"Size of bounding box dimension %d is not 4. Got %d in shape %s",
boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape)));
checkArgument(
valueIndex.length == 4,
String.format(
"Bounding box index array length %d is not 4. Got index array %s",
valueIndex.length, Arrays.toString(valueIndex)));
checkArgument(
tensor.getDataType() == DataType.FLOAT32,
"Bounding Boxes only create from FLOAT32 buffers. Got: " + tensor.getDataType().name());
List<RectF> boundingBoxList = new ArrayList<>();
// Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and its
// four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by
// i * 4b + k * b + j.
int a = 1;
for (int i = 0; i < boundingBoxAxis; i++) {
a *= shape[i];
}
int b = 1;
for (int i = boundingBoxAxis + 1; i < shape.length; i++) {
b *= shape[i];
}
float[] values = new float[4];
ByteBuffer byteBuffer = tensor.getBuffer();
byteBuffer.rewind();
FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
for (int i = 0; i < a; i++) {
for (int j = 0; j < b; j++) {
for (int k = 0; k < 4; k++) {
values[k] = floatBuffer.get((i * 4 + k) * b + j);
}
boundingBoxList.add(
convertOneBoundingBox(values, valueIndex, type, coordinateType, height, width));
}
}
byteBuffer.rewind();
return boundingBoxList;
}
private static RectF convertOneBoundingBox(
float[] values,
int[] valueIndex,
Type type,
CoordinateType coordinateType,
int height,
int width) {
float[] orderedValues = new float[4];
for (int i = 0; i < 4; i++) {
orderedValues[i] = values[valueIndex[i]];
}
return convertOneBoundingBox(orderedValues, type, coordinateType, height, width);
}
private static RectF convertOneBoundingBox(
float[] values, Type type, CoordinateType coordinateType, int height, int width) {
switch (type) {
case BOUNDARIES:
return convertFromBoundaries(values, coordinateType, height, width);
case UPPER_LEFT:
case CENTER:
// TODO(b/150824448): convertFrom{UpperLeft, Center}
throw new IllegalArgumentException("BoundingBox.Type " + type + " is not yet supported.");
}
throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type);
}
private static RectF convertFromBoundaries(
float[] values, CoordinateType coordinateType, int height, int width) {
if (coordinateType == CoordinateType.RATIO) {
return new RectF(
values[0] * width, values[1] * height, values[2] * width, values[3] * height);
} else {
return new RectF(values[0], values[1], values[2], values[3]);
}
}
// Private constructor to prevent initialization.
private BoundingBoxUtil() {}
}

View File

@ -1,108 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.image;
import android.graphics.Bitmap;
import android.graphics.Color;
import java.util.Arrays;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* Implements some stateless image conversion methods.
*
* This class is an internal helper for {@link org.tensorflow.lite.support.image}.
*/
class ImageConversions {
/**
* Converts an Image in a TensorBuffer to a Bitmap, whose memory is already allocated.
*
* <p>Notice: We only support ARGB_8888 at this point.
*
* @param buffer The TensorBuffer object representing the image. It should be an UInt8 buffer with
* 3 dimensions: width, height, channel. Size of each dimension should be positive and the
* size of channels should be 3 (representing R, G, B). An optional 4th dimension "batch" is
* acceptable, and dimensions look like: batch, width, height, channel. In this case, size of
* batches should be 1.
* @param bitmap The destination of the conversion. Needs to be created in advance, needs to be
* mutable, and needs to have the same width and height with the buffer.
* @throws IllegalArgumentException 1) if the {@code buffer} is not uint8 (e.g. a float buffer),
* or has an invalid shape. 2) if the {@code bitmap} is not mutable. 3) if the {@code bitmap}
* has different height or width with the buffer.
*/
static void convertTensorBufferToBitmap(TensorBuffer buffer, Bitmap bitmap) {
if (buffer.getDataType() != DataType.UINT8) {
// We will add support to FLOAT format conversion in the future, as it may need other configs.
throw new UnsupportedOperationException(
String.format(
"Converting TensorBuffer of type %s to ARGB_8888 Bitmap is not supported yet.",
buffer.getDataType()));
}
int[] shape = buffer.getShape();
TensorImage.checkImageTensorShape(shape);
int h = shape[shape.length - 3];
int w = shape[shape.length - 2];
if (bitmap.getWidth() != w || bitmap.getHeight() != h) {
throw new IllegalArgumentException(String.format(
"Given bitmap has different width or height %s with the expected ones %s.",
Arrays.toString(new int[]{bitmap.getWidth(), bitmap.getHeight()}),
Arrays.toString(new int[]{w, h})));
}
if (!bitmap.isMutable()) {
throw new IllegalArgumentException("Given bitmap is not mutable");
}
// TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
int[] intValues = new int[w * h];
int[] rgbValues = buffer.getIntArray();
for (int i = 0, j = 0; i < intValues.length; i++) {
int r = rgbValues[j++];
int g = rgbValues[j++];
int b = rgbValues[j++];
intValues[i] = Color.rgb(r, g, b);
}
bitmap.setPixels(intValues, 0, w, 0, 0, w, h);
}
/**
* Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose memory
* is already allocated, or could be dynamically allocated.
*
* @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888
* config.
* @param buffer The destination of the conversion. Needs to be created in advance. If it's
* fixed-size, its flat size should be w*h*3.
* @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match.
*/
static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) {
int w = bitmap.getWidth();
int h = bitmap.getHeight();
int[] intValues = new int[w * h];
bitmap.getPixels(intValues, 0, w, 0, 0, w, h);
// TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
int[] rgbValues = new int[w * h * 3];
for (int i = 0, j = 0; i < intValues.length; i++) {
rgbValues[j++] = ((intValues[i] >> 16) & 0xFF);
rgbValues[j++] = ((intValues[i] >> 8) & 0xFF);
rgbValues[j++] = (intValues[i] & 0xFF);
}
int[] shape = new int[] {h, w, 3};
buffer.loadArray(rgbValues, shape);
}
// Hide the constructor as the class is static.
private ImageConversions() {}
}

View File

@ -1,43 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.image;
import android.graphics.PointF;
import org.tensorflow.lite.support.common.Operator;
/** Operates a TensorImage object. Used in ImageProcessor. */
public interface ImageOperator extends Operator<TensorImage> {
/** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */
@Override
TensorImage apply(TensorImage image);
/** Computes the width of the expected output image when input image size is given. */
int getOutputImageWidth(int inputImageHeight, int inputImageWidth);
/** Computes the height of the expected output image when input image size is given. */
int getOutputImageHeight(int inputImageHeight, int inputImageWidth);
/**
* Transforms a point from coordinates system of the result image back to the one of the input
* image.
*
* @param point the point from the result coordinates system.
* @param inputImageHeight the height of input image.
* @param inputImageWidth the width of input image.
* @return the point with the coordinates from the coordinates system of the input image.
*/
PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth);
}

View File

@ -1,198 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.image;
import android.graphics.PointF;
import android.graphics.RectF;
import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import org.tensorflow.lite.support.common.Operator;
import org.tensorflow.lite.support.common.SequentialProcessor;
import org.tensorflow.lite.support.common.SupportPreconditions;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.image.ops.Rot90Op;
import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper;
/**
* ImageProcessor is a helper class for preprocessing and postprocessing {@link TensorImage}. It
* could transform a {@link TensorImage} to another by executing a chain of {@link ImageOperator}.
*
* <p>Example Usage:
*
* <pre>
* ImageProcessor processor = new ImageProcessor.Builder()
* .add(new ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOR)
* .add(new Rot90Op())
* .add(new NormalizeOp(127.5f, 127.5f))
* .build();
* TensorImage anotherTensorImage = processor.process(tensorImage);
* </pre>
*
* <p><b>WARNING:</b> Instances of an {@code ImageProcessor} are <b>not</b> thread-safe with {@link
* #updateNumberOfRotations}. Updating the number of rotations and then processing images (using
* {@link #process}) must be protected from concurrent access. It is recommended to create separate
* {@code ImageProcessor} instances for each thread. If multiple threads access a {@code
* ImageProcessor} concurrently, it must be synchronized externally.
*
* @see ImageProcessor.Builder to build a {@link ImageProcessor} instance
* @see ImageProcessor#process(TensorImage) to apply the processor on a {@link TensorImage}
*/
public class ImageProcessor extends SequentialProcessor<TensorImage> {
private ImageProcessor(Builder builder) {
super(builder);
}
/**
* Transforms a point from coordinates system of the result image back to the one of the input
* image.
*
* @param point the point from the result coordinates system.
* @param inputImageHeight the height of input image.
* @param inputImageWidth the width of input image.
* @return the point with the coordinates from the coordinates system of the input image.
*/
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
List<Integer> widths = new ArrayList<>();
List<Integer> heights = new ArrayList<>();
int currentWidth = inputImageWidth;
int currentHeight = inputImageHeight;
for (Operator<TensorImage> op : operatorList) {
widths.add(currentWidth);
heights.add(currentHeight);
ImageOperator imageOperator = (ImageOperator) op;
int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth);
int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth);
currentHeight = newHeight;
currentWidth = newWidth;
}
ListIterator<Operator<TensorImage>> opIterator = operatorList.listIterator(operatorList.size());
ListIterator<Integer> widthIterator = widths.listIterator(widths.size());
ListIterator<Integer> heightIterator = heights.listIterator(heights.size());
while (opIterator.hasPrevious()) {
ImageOperator imageOperator = (ImageOperator) opIterator.previous();
int height = heightIterator.previous();
int width = widthIterator.previous();
point = imageOperator.inverseTransform(point, height, width);
}
return point;
}
/**
* Transforms a rectangle from coordinates system of the result image back to the one of the input
* image.
*
* @param rect the rectangle from the result coordinates system.
* @param inputImageHeight the height of input image.
* @param inputImageWidth the width of input image.
* @return the rectangle with the coordinates from the coordinates system of the input image.
*/
public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) {
// when rotation is involved, corner order may change - top left changes to bottom right, .etc
PointF p1 =
inverseTransform(new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth);
PointF p2 =
inverseTransform(new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth);
return new RectF(
Math.min(p1.x, p2.x), Math.min(p1.y, p2.y), Math.max(p1.x, p2.x), Math.max(p1.y, p2.y));
}
/**
* The Builder to create an ImageProcessor, which could be executed later.
*
* @see #add(TensorOperator) to add a general TensorOperator
* @see #add(ImageOperator) to add an ImageOperator
* @see #build() complete the building process and get a built Processor
*/
public static class Builder extends SequentialProcessor.Builder<TensorImage> {
public Builder() {
super();
}
/**
* Adds an {@link ImageOperator} into the Operator chain.
*
* @param op the Operator instance to be executed then
*/
public Builder add(ImageOperator op) {
super.add(op);
return this;
}
/**
* Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls
* {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by transforming
* the underlying {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
*
* @param op the Operator instance to be executed then
*/
public Builder add(TensorOperator op) {
return add(new TensorOperatorWrapper(op));
}
/** Completes the building process and gets the {@link ImageProcessor} instance. */
@Override
public ImageProcessor build() {
return new ImageProcessor(this);
}
}
/**
* Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}.
*
* <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
* then processing images (using {@link #process}) must be protected from concurrent access with
* additional synchronization.
*
* @param k the number of rotations
* @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
* ImageProcessor}
*/
public void updateNumberOfRotations(int k) {
updateNumberOfRotations(k, /*occurrence=*/ 0);
}
/**
* Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in this
* {@link ImageProcessor}.
*
* <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
* then processing images (using {@link #process}) must be protected from concurrent access with
* additional synchronization.
*
* @param k the number of rotations
* @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For
* example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be
* set to 1.
* @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the
* number of {@link Rot90Op} in this {@link ImageProcessor}
* @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
* ImageProcessor}
*/
public synchronized void updateNumberOfRotations(int k, int occurrence) {
SupportPreconditions.checkState(
operatorIndex.containsKey(Rot90Op.class.getName()),
"The Rot90Op has not been added to the ImageProcessor.");
List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName());
SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence");
// The index of the Rot90Op to be replaced in operatorList.
int index = indexes.get(occurrence);
Rot90Op newRot = new Rot90Op(k);
operatorList.set(index, newRot);
}
}

View File

@ -1,381 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.image;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import java.nio.ByteBuffer;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.common.SupportPreconditions;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* TensorImage is the wrapper class for Image object. When using image processing utils in
* TFLite.support library, it's common to convert image objects in variant types to TensorImage at
* first.
*
* <p>At present, only RGB images are supported, and the A channel is always ignored.
*
* <p>Details of data storage: a {@link TensorImage} object may have 2 potential sources of truth: a
* {@link Bitmap} or a {@link TensorBuffer}. {@link TensorImage} maintains the state and only
* convert one to the other when needed.
*
* <p>IMPORTANT: The container doesn't own its data. Callers should not modify data objects those
* are passed to {@link ImageContainer#set(Bitmap)} or {@link ImageContainer#set(TensorBuffer)}.
*
* <p>IMPORTANT: All methods are not proved thread-safe.
*
* @see ImageProcessor which is often used for transforming a {@link TensorImage}.
*/
// TODO(b/138906681): Support basic Image properties (ColorType, DataType)
// TODO(b/138907116): Support loading images from TensorBuffer with properties.
// TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary.
public class TensorImage {
private final ImageContainer container;
/**
* Initialize a TensorImage object.
*
* Note: The data type of this TensorImage is UINT8, which means it could naturally accept Bitmaps
* whose pixel value range is [0, 255]. However, any image with float value pixels will not be
* loaded correctly. In those cases, please use {@link TensorImage(DataType)}.
*/
public TensorImage() {
this(DataType.UINT8);
}
/**
* Initializes a TensorImage object with data type specified.
*
* <p>Note: The shape of a TensorImage is not fixed. It is determined when {@code load} methods
* called, and could be change later.
*
* @param dataType the expected internal data type of underlying tensor. The type is always fixed
* during the lifetime of the {@link TensorImage}. To convert the data type, use {@link
* TensorImage#createFrom(TensorImage, DataType)} to create a copy and convert data type at
* the same time.
* @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
* {@link DataType#FLOAT32}.
*/
public TensorImage(DataType dataType) {
SupportPreconditions.checkArgument(
dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
"Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
container = new ImageContainer(dataType);
}
/**
* Initializes a {@link TensorImage} object with a {@link Bitmap}.
*
* @see TensorImage#load(Bitmap) for reusing the object when it's expensive to create objects
* frequently, because every call of {@code fromBitmap} creates a new {@link TensorImage}.
*/
public static TensorImage fromBitmap(Bitmap bitmap) {
TensorImage image = new TensorImage();
image.load(bitmap);
return image;
}
/**
* Creates a deep-copy of a given {@link TensorImage} and converts internal tensor data type.
*
* <p>If the given {@code dataType} is different with {@code src.getDataType()}, an implicit data
* conversion will be applied. Converting data from {@link DataType#FLOAT32} to {@link
* DataType#UINT8} may involve default float->int conversion and value clamping, because {@link
* DataType#UINT8} stores value from 0 to 255 (inclusively).
*
* @param src the TensorImage to copy from.
* @param dataType the expected data type of newly created {@link TensorImage}.
* @return a TensorImage whose data is copied from {@code src} and data type is {@code dataType}.
*/
@NonNull
public static TensorImage createFrom(@NonNull TensorImage src, DataType dataType) {
TensorImage dst = new TensorImage(dataType);
if (src.container.isBufferUpdated) {
dst.container.set(TensorBuffer.createFrom(src.getTensorBuffer(), dataType));
} else if (src.container.isBitmapUpdated) {
Bitmap srcBitmap = src.getBitmap();
dst.container.set(srcBitmap.copy(srcBitmap.getConfig(), srcBitmap.isMutable()));
}
return dst;
}
/**
* Loads a Bitmap image object into TensorImage.
*
* Important: When loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The
* {@code TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well.
* In this method, we perform a zero-copy approach for that bitmap, by simply holding its
* reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
*
* Note: To get the best performance, please load images in the same shape to avoid memory
* re-allocation.
*
* @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888.
*/
public void load(@NonNull Bitmap bitmap) {
SupportPreconditions.checkNotNull(bitmap, "Cannot load null bitmap.");
SupportPreconditions.checkArgument(
bitmap.getConfig().equals(Config.ARGB_8888), "Only supports loading ARGB_8888 bitmaps.");
container.set(bitmap);
}
/**
* Loads a float array as RGB pixels into TensorImage, representing the pixels inside.
*
* <p>Note: If the TensorImage has data type {@link DataType#UINT8}, numeric casting and clamping
* will be applied.
*
* @param pixels The RGB pixels representing the image.
* @param shape The shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3).
*/
public void load(@NonNull float[] pixels, @NonNull int[] shape) {
checkImageTensorShape(shape);
TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
buffer.loadArray(pixels, shape);
load(buffer);
}
/**
* Loads an uint8 array as RGB pixels into TensorImage, representing the pixels inside.
*
* <p>Note: If the TensorImage has data type {@link DataType#UINT8}, all pixel values will clamp
* into [0, 255].
*
* @param pixels The RGB pixels representing the image.
* @param shape The shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3).
*/
public void load(@NonNull int[] pixels, @NonNull int[] shape) {
checkImageTensorShape(shape);
TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
buffer.loadArray(pixels, shape);
load(buffer);
}
/**
* Loads a TensorBuffer containing pixel values. The color layout should be RGB.
*
* @param buffer The TensorBuffer to load. Its shape should be either (h, w, 3) or (1, h, w, 3).
*/
public void load(TensorBuffer buffer) {
checkImageTensorShape(buffer.getShape());
container.set(buffer);
}
/**
* Returns a bitmap representation of this TensorImage.
*
* <p>Important: It's only a reference. DO NOT MODIFY. We don't create a copy here for performance
* concern, but if modification is necessary, please make a copy.
*
* @return a reference to a Bitmap in ARGB_8888 config. "A" channel is always opaque.
* @throws IllegalStateException if the TensorImage never loads data, or if the TensorImage is
* holding a float-value image in {@code TensorBuffer}.
*/
@NonNull
public Bitmap getBitmap() {
return container.getBitmap();
}
/**
* Returns a ByteBuffer representation of this TensorImage.
*
* <p>Important: It's only a reference. DO NOT MODIFY. We don't create a copy here for performance
* concern, but if modification is necessary, please make a copy.
*
* <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
*
* @return a reference to a ByteBuffer which holds the image data.
* @throws IllegalStateException if the TensorImage never loads data.
*/
@NonNull
public ByteBuffer getBuffer() {
return container.getTensorBuffer().getBuffer();
}
/**
* Returns a ByteBuffer representation of this TensorImage.
*
* <p>Important: It's only a reference. DO NOT MODIFY. We don't create a copy here for performance
* concern, but if modification is necessary, please make a copy.
*
* @return a reference to a TensorBuffer which holds the image data.
* @throws IllegalStateException if the TensorImage never loads data.
*/
@NonNull
public TensorBuffer getTensorBuffer() {
return container.getTensorBuffer();
}
/**
* Gets the current data type.
*
* @return a data type. Currently only UINT8 and FLOAT32 are possible.
*/
public DataType getDataType() {
return container.getDataType();
}
/**
* Gets the image width.
*
* @throws IllegalStateException if the TensorImage never loads data.
* @throws IllegalArgumentException if the container data is corrupted.
*/
public int getWidth() {
return container.getWidth();
}
/**
* Gets the image height.
*
* @throws IllegalStateException if the TensorImage never loads data.
* @throws IllegalArgumentException if the container data is corrupted.
*/
public int getHeight() {
return container.getHeight();
}
// Requires tensor shape [h, w, 3] or [1, h, w, 3].
static void checkImageTensorShape(int[] shape) {
SupportPreconditions.checkArgument(
(shape.length == 3 || (shape.length == 4 && shape[0] == 1))
&& shape[shape.length - 3] > 0
&& shape[shape.length - 2] > 0
&& shape[shape.length - 1] == 3,
"Only supports image shape in (h, w, c) or (1, h, w, c), and channels representing R, G, B"
+ " in order.");
}
// Handles RGB image data storage strategy of TensorBuffer.
private static class ImageContainer {
private TensorBuffer bufferImage;
private boolean isBufferUpdated;
private Bitmap bitmapImage;
private boolean isBitmapUpdated;
private final DataType dataType;
private static final int ARGB_8888_ELEMENT_BYTES = 4;
ImageContainer(DataType dataType) {
this.dataType = dataType;
}
// Internal method to set the image source-of-truth with a bitmap. The bitmap has to be
// ARGB_8888.
void set(Bitmap bitmap) {
bitmapImage = bitmap;
isBufferUpdated = false;
isBitmapUpdated = true;
}
// Internal method to set the image source-of-truth with a TensorBuffer.
void set(TensorBuffer buffer) {
bufferImage = buffer;
isBitmapUpdated = false;
isBufferUpdated = true;
}
int getWidth() {
SupportPreconditions.checkState(
isBitmapUpdated || isBufferUpdated,
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
if (isBitmapUpdated) {
return bitmapImage.getWidth();
}
return getBufferDimensionSize(-2);
}
int getHeight() {
SupportPreconditions.checkState(
isBitmapUpdated || isBufferUpdated,
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
if (isBitmapUpdated) {
return bitmapImage.getHeight();
}
return getBufferDimensionSize(-3);
}
// Internal helper method to get the size of one dimension in the shape of the `bufferImage`.
// Requires `isBufferUpdated` is true.
// Throws `IllegalArgumentException` if data is corrupted.
private int getBufferDimensionSize(int dim) {
int[] shape = bufferImage.getShape();
// The defensive check is needed because bufferImage might be invalidly changed by user
// (a.k.a internal data is corrupted)
TensorImage.checkImageTensorShape(shape);
dim = dim % shape.length;
if (dim < 0) {
dim += shape.length;
}
return shape[dim];
}
public DataType getDataType() {
return dataType;
}
// Internal method to update the internal Bitmap data by TensorBuffer data.
@NonNull
Bitmap getBitmap() {
if (isBitmapUpdated) {
return bitmapImage;
}
if (!isBufferUpdated) {
throw new IllegalStateException(
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
}
if (bufferImage.getDataType() != DataType.UINT8) {
throw new IllegalStateException(
"TensorImage is holding a float-value image which is not able to convert a Bitmap.");
}
int requiredAllocation = bufferImage.getFlatSize() * ARGB_8888_ELEMENT_BYTES;
// Create a new bitmap and reallocate memory for it.
if (bitmapImage == null || bitmapImage.getAllocationByteCount() < requiredAllocation) {
int[] shape = bufferImage.getShape();
int h = shape[shape.length - 3];
int w = shape[shape.length - 2];
bitmapImage = Bitmap.createBitmap(w, h, Config.ARGB_8888);
}
ImageConversions.convertTensorBufferToBitmap(bufferImage, bitmapImage);
isBitmapUpdated = true;
return bitmapImage;
}
// Internal method to update the internal TensorBuffer data by Bitmap data.
@NonNull
TensorBuffer getTensorBuffer() {
if (isBufferUpdated) {
return bufferImage;
}
SupportPreconditions.checkArgument(
isBitmapUpdated,
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
int requiredFlatSize = bitmapImage.getWidth() * bitmapImage.getHeight() * 3;
if (bufferImage == null
|| (!bufferImage.isDynamic() && bufferImage.getFlatSize() != requiredFlatSize)) {
bufferImage = TensorBuffer.createDynamic(dataType);
}
ImageConversions.convertBitmapToTensorBuffer(bitmapImage, bufferImage);
isBufferUpdated = true;
return bufferImage;
}
}
}

View File

@ -1,89 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.image.ops;
import android.graphics.Bitmap;
import android.graphics.PointF;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.image.ImageOperator;
import org.tensorflow.lite.support.image.TensorImage;
/**
* As a computation unit for processing images, it can resize an image to user-specified size.
*
* <p>It interpolates pixels when image is stretched, and discards pixels when image is compressed.
*
* @see ResizeWithCropOrPadOp for resizing without content distortion.
*/
public class ResizeOp implements ImageOperator {
/** Algorithms for resizing. */
public enum ResizeMethod {
BILINEAR,
NEAREST_NEIGHBOR
}
private final int targetHeight;
private final int targetWidth;
private final boolean useBilinear;
/**
* Creates a ResizeOp which can resize images to specified size in specified method.
*
* @param targetHeight: The expected height of resized image.
* @param targetWidth: The expected width of resized image.
* @param resizeMethod: The algorithm to use for resizing. Options: {@link ResizeMethod}
*/
public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) {
this.targetHeight = targetHeight;
this.targetWidth = targetWidth;
useBilinear = (resizeMethod == ResizeMethod.BILINEAR);
}
/**
* Applies the defined resizing on given image and returns the result.
*
* <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
* with the output.
*
* @param image input image.
* @return output image.
*/
@Override
@NonNull
public TensorImage apply(@NonNull TensorImage image) {
Bitmap scaled =
Bitmap.createScaledBitmap(image.getBitmap(), targetWidth, targetHeight, useBilinear);
image.load(scaled);
return image;
}
@Override
public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
return targetHeight;
}
@Override
public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
return targetWidth;
}
@Override
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
return new PointF(
point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight);
}
}

View File

@ -1,125 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.image.ops;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.PointF;
import android.graphics.Rect;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.image.ImageOperator;
import org.tensorflow.lite.support.image.TensorImage;
/**
* As a computation unit for processing images, it could resize image to predefined size.
*
* <p>It will not stretch or compress the content of image. However, to fit the new size, it crops
* or pads pixels. When it crops image, it performs a center-crop; when it pads pixels, it performs
* a zero-padding.
*
* @see ResizeOp for reszing images while stretching / compressing the content.
*/
public class ResizeWithCropOrPadOp implements ImageOperator {
private final int targetHeight;
private final int targetWidth;
private final Bitmap output;
/**
* Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts
* center-crop and zero-padding.
*
* @param targetHeight: The expected height of cropped/padded image.
* @param targetWidth: The expected width of cropped/padded image.
*/
public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) {
this.targetHeight = targetHeight;
this.targetWidth = targetWidth;
output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888);
}
/**
* Applies the defined resizing with cropping or/and padding on given image and returns the
* result.
*
* <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
* with the output.
*
* @param image input image.
* @return output image.
*/
@Override
@NonNull
public TensorImage apply(@NonNull TensorImage image) {
Bitmap input = image.getBitmap();
int srcL;
int srcR;
int srcT;
int srcB;
int dstL;
int dstR;
int dstT;
int dstB;
int w = input.getWidth();
int h = input.getHeight();
if (targetWidth > w) { // padding
srcL = 0;
srcR = w;
dstL = (targetWidth - w) / 2;
dstR = dstL + w;
} else { // cropping
dstL = 0;
dstR = targetWidth;
srcL = (w - targetWidth) / 2;
srcR = srcL + targetWidth;
}
if (targetHeight > h) { // padding
srcT = 0;
srcB = h;
dstT = (targetHeight - h) / 2;
dstB = dstT + h;
} else { // cropping
dstT = 0;
dstB = targetHeight;
srcT = (h - targetHeight) / 2;
srcB = srcT + targetHeight;
}
Rect src = new Rect(srcL, srcT, srcR, srcB);
Rect dst = new Rect(dstL, dstT, dstR, dstB);
new Canvas(output).drawBitmap(input, src, dst, null);
image.load(output);
return image;
}
@Override
public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
return targetHeight;
}
@Override
public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
return targetWidth;
}
@Override
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth);
}
private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) {
return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2);
}
}

View File

@ -1,103 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.image.ops;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.graphics.PointF;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.image.ImageOperator;
import org.tensorflow.lite.support.image.TensorImage;
/** Rotates image counter-clockwise. */
public class Rot90Op implements ImageOperator {
private final int numRotation;
/** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */
public Rot90Op() {
this(1);
}
/**
* Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times counter-clockwise.
*
* @param k: The number of times the image is rotated by 90 degrees. If it's positive, the image
* will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise.
*/
public Rot90Op(int k) {
numRotation = k % 4;
}
/**
* Applies the defined rotation on given image and returns the result.
*
* <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
* with the output.
*
* @param image input image.
* @return output image.
*/
@NonNull
@Override
public TensorImage apply(@NonNull TensorImage image) {
Bitmap input = image.getBitmap();
if (numRotation == 0) {
return image;
}
int w = input.getWidth();
int h = input.getHeight();
Matrix matrix = new Matrix();
matrix.postTranslate(w * 0.5f, h * 0.5f);
matrix.postRotate(-90 * numRotation);
int newW = (numRotation % 2 == 0) ? w : h;
int newH = (numRotation % 2 == 0) ? h : w;
matrix.postTranslate(newW * 0.5f, newH * 0.5f);
Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false);
image.load(output);
return image;
}
@Override
public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth;
}
@Override
public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight;
}
@Override
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
int inverseNumRotation = (4 - numRotation) % 4;
int height = getOutputImageHeight(inputImageHeight, inputImageWidth);
int width = getOutputImageWidth(inputImageHeight, inputImageWidth);
return transformImpl(point, height, width, inverseNumRotation);
}
private static PointF transformImpl(PointF point, int height, int width, int numRotation) {
if (numRotation == 0) {
return point;
} else if (numRotation == 1) {
return new PointF(point.y, width - point.x);
} else if (numRotation == 2) {
return new PointF(width - point.x, height - point.y);
} else { // numRotation == 3
return new PointF(height - point.y, point.x);
}
}
}

View File

@ -1,70 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.image.ops;
import android.graphics.PointF;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.common.SupportPreconditions;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.image.ImageOperator;
import org.tensorflow.lite.support.image.TensorImage;
/**
* The adapter that makes a TensorOperator able to run with TensorImage.
*
* @see org.tensorflow.lite.support.common.TensorOperator
* @see org.tensorflow.lite.support.image.TensorImage
*/
public class TensorOperatorWrapper implements ImageOperator {
private final TensorOperator tensorOp;
/**
* Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link
* TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link
* org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
*
* <p>Requirement: The {@code op} should not change coordinate system when applied on an image.
*
* @param op The created operator.
*/
public TensorOperatorWrapper(TensorOperator op) {
tensorOp = op;
}
@Override
@NonNull
public TensorImage apply(@NonNull TensorImage image) {
SupportPreconditions.checkNotNull(image, "Op cannot apply on null image.");
image.load(tensorOp.apply(image.getTensorBuffer()));
return image;
}
@Override
public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
return inputImageHeight;
}
@Override
public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
return inputImageWidth;
}
@Override
public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
return point;
}
}

View File

@ -1,62 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.label;
import java.util.Objects;
/**
* Category is a util class, contains a label and a float value. Typically it's used as result of
* classification tasks.
*/
public final class Category {
private final String label;
private final float score;
/** Constructs a Category. */
public Category(String label, float score) {
this.label = label;
this.score = score;
}
/** Gets the reference of category's label. */
public String getLabel() {
return label;
}
/** Gets the score of the category. */
public float getScore() {
return score;
}
@Override
public boolean equals(Object o) {
if (o instanceof Category) {
Category other = (Category) o;
return (other.getLabel().equals(this.label) && other.getScore() == this.score);
}
return false;
}
@Override
public int hashCode() {
return Objects.hash(label, score);
}
@Override
public String toString() {
return "<Category \"" + label + "\" (score=" + score + ")>";
}
}

View File

@ -1,64 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.label;
import android.util.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.common.SupportPreconditions;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/** Label operation utils. */
public class LabelUtil {
/**
* Maps an int value tensor to a list of string labels. It takes an array of strings as the
* dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background",
* "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"].
*
* @param tensorBuffer: A tensor with index values. The values should be non-negative integers,
* and each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is
* given as a float {@link TensorBuffer}, values will be cast to integers. All values that are
* out of bound will map to empty string.
* @param labels: A list of strings, used as a dictionary to look up. The index of the array
* element will be used as the key. To get better performance, use an object that implements
* RandomAccess, such as {@link ArrayList}.
* @param offset: The offset value when look up int values in the {@code labels}.
* @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}.
* @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null.
*/
public static List<String> mapValueToLabels(
@NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) {
SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null");
SupportPreconditions.checkNotNull(labels, "Given labels should not be null");
int[] values = tensorBuffer.getIntArray();
Log.d("values", Arrays.toString(values));
List<String> result = new ArrayList<>();
for (int v : values) {
int index = v + offset;
if (index < 0 || index >= labels.size()) {
result.add("");
} else {
result.add(labels.get(index));
}
}
return result;
}
// Private constructor to prevent initialization.
private LabelUtil() {}
}

View File

@ -1,224 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.label;
import android.content.Context;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.common.SupportPreconditions;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.
*
* <p>For example, an image classification model may have an output tensor with shape as {1, 10},
* where 1 is the batch size and 10 is the number of categories. In fact, on the 2nd axis, we could
* label each sub-tensor with the name or description of each corresponding category. {@link
* TensorLabel} could help converting the plain Tensor in {@link TensorBuffer} into a map from
* predefined labels to sub-tensors. In this case, if provided 10 labels for the 2nd axis, {@link
* TensorLabel} could convert the original {1, 10} Tensor to a 10 element map, each value of which
* is Tensor in shape {} (scalar). Usage example:
*
* <pre>
* TensorBuffer outputTensor = ...;
* {@literal List<String>} labels = FileUtil.loadLabels(context, labelFilePath);
* // labels the first axis with size greater than one
* TensorLabel labeled = new TensorLabel(labels, outputTensor);
* // If each sub-tensor has effectively size 1, we can directly get a float value
* {@literal Map<String, Float>} probabilities = labeled.getMapWithFloatValue();
* // Or get sub-tensors, when each sub-tensor has elements more than 1
* {@literal Map<String, TensorBuffer>} subTensors = labeled.getMapWithTensorBuffer();
* </pre>
*
* <p>Note: currently we only support tensor-to-map conversion for the first label with size greater
* than 1.
*
* @see org.tensorflow.lite.support.common.FileUtil#loadLabels(Context, String) to load labels from
* a label file (plain text file whose each line is a label) in assets simply.
*/
public class TensorLabel {
private final Map<Integer, List<String>> axisLabels;
private final TensorBuffer tensorBuffer;
private final int[] shape;
/**
* Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
*
* @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
* labels. Note: The size of labels should be same with the size of the tensor on that axis.
* @param tensorBuffer The TensorBuffer to be labeled.
* @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
* value in {@code axisLabels} is null.
* @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to
* the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code
* tensorBuffer} on the given dimension.
*/
public TensorLabel(
@NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
this.axisLabels = axisLabels;
this.tensorBuffer = tensorBuffer;
this.shape = tensorBuffer.getShape();
for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
int axis = entry.getKey();
SupportPreconditions.checkArgument(
axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis);
SupportPreconditions.checkArgument(
shape[axis] == entry.getValue().size(),
"Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis);
}
}
/**
* Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
*
* <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if
* the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from
* 0), and size of {@code axisLabels} should be 10 as well.
*
* @param axisLabels A list of labels, whose size should be same with the size of the tensor on
* the to-be-labeled axis.
* @param tensorBuffer The TensorBuffer to be labeled.
*/
public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
}
/**
* Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
* mapping on the first axis with size greater than 1 currently.
*/
@NonNull
public Map<String, TensorBuffer> getMapWithTensorBuffer() {
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
SupportPreconditions.checkArgument(
axisLabels.containsKey(labeledAxis),
"get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
List<String> labels = axisLabels.get(labeledAxis);
DataType dataType = tensorBuffer.getDataType();
int typeSize = tensorBuffer.getTypeSize();
int flatSize = tensorBuffer.getFlatSize();
// Gets the underlying bytes that could be used to generate the sub-array later.
ByteBuffer byteBuffer = tensorBuffer.getBuffer();
byteBuffer.rewind();
// Note: computation below is only correct when labeledAxis is the first axis with size greater
// than 1.
int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
int i = 0;
SupportPreconditions.checkNotNull(labels, "Label list should never be null");
for (String label : labels) {
// Gets the corresponding TensorBuffer.
byteBuffer.position(i * subArrayLength);
ByteBuffer subBuffer = byteBuffer.slice();
// ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
subBuffer.order(byteBuffer.order()).limit(subArrayLength);
TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
labelToTensorMap.put(label, labelBuffer);
i += 1;
}
return labelToTensorMap;
}
/**
* Gets a map that maps label to float. Only allow the mapping on the first axis with size greater
* than 1, and the axis should be effectively the last axis (which means every sub tensor
* specified by this axis should have a flat size of 1).
*
* <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
*
* @throws IllegalStateException if size of a sub tensor on each label is not 1.
*/
@NonNull
public Map<String, Float> getMapWithFloatValue() {
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
SupportPreconditions.checkState(
labeledAxis == shape.length - 1,
"get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
List<String> labels = axisLabels.get(labeledAxis);
float[] data = tensorBuffer.getFloatArray();
SupportPreconditions.checkState(labels.size() == data.length);
Map<String, Float> result = new LinkedHashMap<>();
int i = 0;
for (String label : labels) {
result.put(label, data[i]);
i += 1;
}
return result;
}
/**
* Gets a list of {@link Category} from the {@link TensorLabel} object.
*
* <p>The axis of label should be effectively the last axis (which means every sub tensor
* specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
* converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}}
* and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}.
*
* <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
* the result.
*
* @throws IllegalStateException if size of a sub tensor on each label is not 1.
*/
@NonNull
public List<Category> getCategoryList() {
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
SupportPreconditions.checkState(
labeledAxis == shape.length - 1,
"get a Category list is only valid when the only labeled axis is the last one.");
List<String> labels = axisLabels.get(labeledAxis);
float[] data = tensorBuffer.getFloatArray();
SupportPreconditions.checkState(labels.size() == data.length);
List<Category> result = new ArrayList<>();
int i = 0;
for (String label : labels) {
result.add(new Category(label, data[i]));
i += 1;
}
return result;
}
private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
int[] shape = tensorBuffer.getShape();
for (int i = 0; i < shape.length; i++) {
if (shape[i] > 1) {
return i;
}
}
throw new IllegalArgumentException(
"Cannot find an axis to label. A valid axis to label should have size larger than 1.");
}
// Helper function to wrap the List<String> to a one-entry map.
private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
Map<Integer, List<String>> map = new LinkedHashMap<>();
map.put(axis, labels);
return map;
}
}

View File

@ -1,74 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.label.ops;
import android.content.Context;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.SupportPreconditions;
import org.tensorflow.lite.support.label.TensorLabel;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
/**
* Labels TensorBuffer with axisLabels for outputs.
*
* <p>Apply on a {@code TensorBuffer} to get a {@code TensorLabel} that could output a Map, which is
* a pair of the label name and the corresponding TensorBuffer value.
*/
public class LabelAxisOp {
// Axis and its corresponding label names.
private final Map<Integer, List<String>> axisLabels;
protected LabelAxisOp(Builder builder) {
axisLabels = builder.axisLabels;
}
public TensorLabel apply(@NonNull TensorBuffer buffer) {
SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
return new TensorLabel(axisLabels, buffer);
}
/** The inner builder class to build a LabelTensor Operator. */
public static class Builder {
private final Map<Integer, List<String>> axisLabels;
protected Builder() {
axisLabels = new HashMap<>();
}
public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
throws IOException {
SupportPreconditions.checkNotNull(context, "Context cannot be null.");
SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
List<String> labels = FileUtil.loadLabels(context, filePath);
axisLabels.put(axis, labels);
return this;
}
public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
axisLabels.put(axis, labels);
return this;
}
public LabelAxisOp build() {
return new LabelAxisOp(this);
}
}
}

View File

@ -1,69 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.model;
import android.util.Log;
import java.io.Closeable;
import java.io.IOException;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.Delegate;
/**
* Helper class to create and call necessary methods of {@code GpuDelegate} which is not a strict
* dependency.
*/
class GpuDelegateProxy implements Delegate, Closeable {
private static final String TAG = "GpuDelegateProxy";
private final Delegate proxiedDelegate;
private final Closeable proxiedCloseable;
@Nullable
public static GpuDelegateProxy maybeNewInstance() {
try {
Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate");
Object instance = clazz.getDeclaredConstructor().newInstance();
return new GpuDelegateProxy(instance);
} catch (ReflectiveOperationException e) {
Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e);
return null;
}
}
/** Calls {@code close()} method of the delegate. */
@Override
public void close() {
try {
proxiedCloseable.close();
} catch (IOException e) {
// Should not trigger, because GpuDelegate#close never throws. The catch is required because
// of Closeable#close.
Log.e(TAG, "Failed to close the GpuDelegate.", e);
}
}
/** Calls {@code getNativeHandle()} method of the delegate. */
@Override
public long getNativeHandle() {
return proxiedDelegate.getNativeHandle();
}
private GpuDelegateProxy(Object instance) {
this.proxiedCloseable = (Closeable) instance;
this.proxiedDelegate = (Delegate) instance;
}
}

View File

@ -1,285 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.model;
import android.content.Context;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.Tensor;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.SupportPreconditions;
/**
* The wrapper class for a TFLite model and a TFLite interpreter.
*
* <p>Note: A {@link Model} can only holds 1 TFLite model at a time, and always holds a TFLite
* interpreter instance to run it.
*/
public class Model {
/** The runtime device type used for executing classification. */
public enum Device {
CPU,
NNAPI,
GPU
}
/**
* Options for running the model. Configurable parameters includes:
*
* <ul>
* <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the model.
* The default value is {@link Device#CPU}.
* <li>{@code numThreads} {@link Builder#setNumThreads(int)} specifies the number of threads
* used by TFLite inference. It's only effective when device is set to {@link Device#CPU}
* and default value is 1.
* </ul>
*/
public static class Options {
private final Device device;
private final int numThreads;
/** Builder of {@link Options}. See its doc for details. */
public static class Builder {
private Device device = Device.CPU;
private int numThreads = 1;
public Builder setDevice(Device device) {
this.device = device;
return this;
}
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
}
public Options build() {
return new Options(this);
}
}
private Options(Builder builder) {
device = builder.device;
numThreads = builder.numThreads;
}
}
/** An instance of the driver class to run model inference with Tensorflow Lite. */
private final Interpreter interpreter;
/** Path to tflite model file in asset folder. */
private final String modelPath;
/** The memory-mapped model data. */
private final MappedByteBuffer byteModel;
private final GpuDelegateProxy gpuDelegateProxy;
/**
* Builder for {@link Model}.
*
* @deprecated Please use {@link Model#createModel(Context, String, Options)}.
*/
@Deprecated
public static class Builder {
private Device device = Device.CPU;
private int numThreads = 1;
private final String modelPath;
private final MappedByteBuffer byteModel;
/**
* Creates a builder which loads tflite model from asset folder using memory-mapped files.
*
* @param context: Application context to access assets.
* @param modelPath: Asset path of the model (.tflite file).
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
@NonNull
public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException {
this.modelPath = modelPath;
byteModel = FileUtil.loadMappedFile(context, modelPath);
}
/** Sets running device. By default, TFLite will run on CPU. */
@NonNull
public Builder setDevice(Device device) {
this.device = device;
return this;
}
/** Sets number of threads. By default it's 1. */
@NonNull
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
}
// Note: The implementation is copied from `Model#createModel`. As the builder is going to be
// deprecated, this function is also to be removed.
@NonNull
public Model build() {
Options options = new Options.Builder().setNumThreads(numThreads).setDevice(device).build();
return createModel(byteModel, modelPath, options);
}
}
/**
* Loads a model from assets and initialize TFLite interpreter.
*
* <p>The default options are: (1) CPU device; (2) one thread.
*
* @param context The App Context.
* @param modelPath The path of the model file.
* @throws IOException if any exception occurs when open the model file.
*/
public static Model createModel(@NonNull Context context, @NonNull String modelPath)
throws IOException {
return createModel(context, modelPath, new Options.Builder().build());
}
/**
* Loads a model from assets and initialize TFLite interpreter with given options.
*
* @see Options for details.
* @param context The App Context.
* @param modelPath The path of the model file.
* @param options The options for running the model.
* @throws IOException if any exception occurs when open the model file.
*/
public static Model createModel(
@NonNull Context context, @NonNull String modelPath, @NonNull Options options)
throws IOException {
SupportPreconditions.checkNotEmpty(
modelPath, "Model path in the asset folder cannot be empty.");
MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath);
return createModel(byteModel, modelPath, options);
}
/**
* Creates a model with loaded {@link MappedByteBuffer}.
*
* @see Options for details.
* @param byteModel The loaded TFLite model.
* @param modelPath The original path of the model. It can be fetched later by {@link
* Model#getPath()}.
* @param options The options for running the model.
* @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but
* "tensorflow-lite-gpu" is not linked to the project.
*/
public static Model createModel(
@NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) {
Interpreter.Options interpreterOptions = new Interpreter.Options();
GpuDelegateProxy gpuDelegateProxy = null;
switch (options.device) {
case NNAPI:
interpreterOptions.setUseNNAPI(true);
break;
case GPU:
gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance();
SupportPreconditions.checkArgument(
gpuDelegateProxy != null,
"Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?");
interpreterOptions.addDelegate(gpuDelegateProxy);
break;
case CPU:
break;
}
interpreterOptions.setNumThreads(options.numThreads);
Interpreter interpreter = new Interpreter(byteModel, interpreterOptions);
return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy);
}
/** Returns the memory-mapped model data. */
@NonNull
public MappedByteBuffer getData() {
return byteModel;
}
/** Returns the path of the model file stored in Assets. */
@NonNull
public String getPath() {
return modelPath;
}
/**
* Gets the Tensor associated with the provdied input index.
*
* @throws IllegalStateException if the interpreter is closed.
*/
public Tensor getInputTensor(int inputIndex) {
return interpreter.getInputTensor(inputIndex);
}
/**
* Gets the Tensor associated with the provdied output index.
*
* @throws IllegalStateException if the interpreter is closed.
*/
public Tensor getOutputTensor(int outputIndex) {
return interpreter.getOutputTensor(outputIndex);
}
/**
* Returns the output shape. Useful if output shape is only determined when graph is created.
*
* @throws IllegalStateException if the interpreter is closed.
*/
public int[] getOutputTensorShape(int outputIndex) {
return interpreter.getOutputTensor(outputIndex).shape();
}
/**
* Runs model inference on multiple inputs, and returns multiple outputs.
*
* @param inputs an array of input data. The inputs should be in the same order as inputs of the
* model. Each input can be an array or multidimensional array, or a {@link
* java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link
* java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types
* require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer} is
* used, its content should remain unchanged until model inference is done.
* @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
* java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only
* needs to keep entries for the outputs to be used.
*/
public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
interpreter.runForMultipleInputsOutputs(inputs, outputs);
}
public void close() {
if (interpreter != null) {
interpreter.close();
}
if (gpuDelegateProxy != null) {
gpuDelegateProxy.close();
}
}
private Model(
@NonNull String modelPath,
@NonNull MappedByteBuffer byteModel,
@NonNull Interpreter interpreter,
@Nullable GpuDelegateProxy gpuDelegateProxy) {
this.modelPath = modelPath;
this.byteModel = byteModel;
this.interpreter = interpreter;
this.gpuDelegateProxy = gpuDelegateProxy;
}
}

View File

@ -1,412 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.tensorbuffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.common.SupportPreconditions;
/** Represents the data buffer for either a model's input or its output. */
public abstract class TensorBuffer {
/** Where the data is stored. */
protected ByteBuffer buffer;
/** Shape of the tensor stored in this buffer. */
protected int[] shape;
/** Number of elements in the buffer. It will be changed to a proper value in the constructor. */
protected int flatSize = -1;
/**
* Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
* pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
*/
protected final boolean isDynamic;
/**
* Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some
* examples:
*
* <pre>
* Creating a float TensorBuffer with shape {2, 3}:
* int[] shape = new int[] {2, 3};
* TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
* </pre>
*
* <pre>
* Creating an uint8 TensorBuffer of a scalar:
* int[] shape = new int[] {};
* TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
* </pre>
*
* <pre>
* Creating an empty uint8 TensorBuffer:
* int[] shape = new int[] {0};
* TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
* </pre>
*
* <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
*
* @param shape The shape of the {@link TensorBuffer} to be created.
* @param dataType The dataType of the {@link TensorBuffer} to be created.
* @throws NullPointerException if {@code shape} is null.
* @throws IllegalArgumentException if {@code shape} has non-positive elements.
*/
@NonNull
public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
switch (dataType) {
case FLOAT32:
return new TensorBufferFloat(shape);
case UINT8:
return new TensorBufferUint8(shape);
default:
throw new AssertionError("TensorBuffer does not support data type: " + dataType);
}
}
/**
* Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the
* created {@link TensorBuffer} is {0}.
*
* <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
* different buffer sizes.
*
* @param dataType The dataType of the {@link TensorBuffer} to be created.
*/
@NonNull
public static TensorBuffer createDynamic(DataType dataType) {
switch (dataType) {
case FLOAT32:
return new TensorBufferFloat();
case UINT8:
return new TensorBufferUint8();
default:
throw new AssertionError("TensorBuffer does not support data type: " + dataType);
}
}
/**
* Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}.
*
* @param buffer the source {@link TensorBuffer} to copy from.
* @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
* @throws NullPointerException if {@code buffer} is null.
*/
@NonNull
public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
SupportPreconditions.checkNotNull(buffer, "Cannot create a buffer from null");
TensorBuffer result;
if (buffer.isDynamic()) {
result = createDynamic(dataType);
} else {
result = createFixedSize(buffer.shape, dataType);
}
// The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
// intermediate container.
// The assumption is not true when we support other data types.
if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
float[] data = buffer.getFloatArray();
result.loadArray(data, buffer.shape);
} else {
int[] data = buffer.getIntArray();
result.loadArray(data, buffer.shape);
}
return result;
}
/** Returns the data buffer. */
@NonNull
public ByteBuffer getBuffer() {
return buffer;
}
/** Gets the {@link TensorBuffer#flatSize} of the buffer. */
public int getFlatSize() {
return flatSize;
}
/** Gets the current shape. (returning a copy here to avoid unexpected modification.) */
@NonNull
public int[] getShape() {
return Arrays.copyOf(shape, shape.length);
}
/** Returns the data type of this buffer. */
public abstract DataType getDataType();
/**
* Returns a float array of the values stored in this buffer. If the buffer is of different types
* than float, the values will be converted into float. For example, values in {@link
* TensorBufferUint8} will be converted from uint8 to float.
*/
@NonNull
public abstract float[] getFloatArray();
/**
* Returns a float value at a given index. If the buffer is of different types than float, the
* value will be converted into float. For example, when reading a value from {@link
* TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from
* uint8 to float.
*
* <pre>
* For example, a TensorBuffer with shape {2, 3} that represents the following array,
* [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
*
* The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by:
* float v = tensorBuffer.getFloatValue(3);
* </pre>
*
* @param absIndex The absolute index of the value to be read.
*/
public abstract float getFloatValue(int absIndex);
/**
* Returns an int array of the values stored in this buffer. If the buffer is of different type
* than int, the values will be converted into int, and loss of precision may apply. For example,
* getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output
* is {400, 23}.
*/
@NonNull
public abstract int[] getIntArray();
/**
* Returns an int value at a given index. If the buffer is of different types than int, the value
* will be converted into int. For example, when reading a value from {@link TensorBufferFloat},
* the value will be first read out as float, and then will be converted from float to int. Loss
* of precision may apply.
*
* <pre>
* For example, a TensorBuffer with shape {2, 3} that represents the following array,
* [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
*
* The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by:
* int v = tensorBuffer.getIntValue(3);
* Note that v is converted from 3.0f to 3 as a result of type conversion.
* </pre>
*
* @param absIndex The absolute index of the value to be read.
*/
public abstract int getIntValue(int absIndex);
/**
* Returns the number of bytes of a single element in the array. For example, a float buffer will
* return 4, and a byte buffer will return 1.
*/
public abstract int getTypeSize();
/** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
public boolean isDynamic() {
return isDynamic;
}
/**
* Loads an int array into this buffer with specific shape. If the buffer is of different types
* than int, the values will be converted into the buffer's type before being loaded into the
* buffer, and loss of precision may apply. For example, loading an int array with values {400,
* -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
* casted to uint8 by {255, 0}.
*
* @param src The source array to be loaded.
* @param shape Shape of the tensor that {@code src} represents.
* @throws NullPointerException if {@code src} is null.
* @throws NullPointerException if {@code shape} is null.
* @throws IllegalArgumentException if the size of the array to be loaded does not match the
* specified shape.
*/
public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
/**
* Loads an int array into this buffer. If the buffer is of different types than int, the values
* will be converted into the buffer's type before being loaded into the buffer, and loss of
* precision may apply. For example, loading an int array with values {400, -23} into a {@link
* TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
* {255, 0}.
*
* <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both
* fixed-size and dynamic {@link TensorBuffer}.
*
* @param src The source array to be loaded.
*/
public void loadArray(@NonNull int[] src) {
loadArray(src, shape);
}
/**
* Loads a float array into this buffer with specific shape. If the buffer is of different types
* than float, the values will be converted into the buffer's type before being loaded into the
* buffer, and loss of precision may apply. For example, loading a float array into a {@link
* TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
* then be casted to uint8 by {255, 0}.
*
* @param src The source array to be loaded.
* @param shape Shape of the tensor that {@code src} represents.
* @throws NullPointerException if {@code src} is null.
* @throws NullPointerException if {@code shape} is null.
* @throws IllegalArgumentException if the size of the array to be loaded does not match the
* specified shape.
*/
public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
/**
* Loads a float array into this buffer. If the buffer is of different types than float, the
* values will be converted into the buffer's type before being loaded into the buffer, and loss
* of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
* with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
* uint8 by {255, 0}.
*
* <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both
* fixed-size and dynamic {@link TensorBuffer}.
*
* @param src The source array to be loaded.
*/
public void loadArray(@NonNull float[] src) {
loadArray(src, shape);
}
/**
* Loads a byte buffer into this {@link TensorBuffer} with specific shape.
*
* <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
* performance concern, but if modification is necessary, please make a copy.
*
* @param buffer The byte buffer to load.
* @throws NullPointerException if {@code buffer} is null.
* @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
* match or the size of {@code buffer} and {@code flatSize} do not match.
*/
public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
SupportPreconditions.checkNotNull(buffer, "Byte buffer cannot be null.");
int flatSize = computeFlatSize(shape);
SupportPreconditions.checkArgument(
(buffer.limit() == getTypeSize() * flatSize),
"The size of byte buffer and the shape do not match.");
if (!isDynamic) {
SupportPreconditions.checkArgument(
flatSize == this.flatSize,
"The size of byte buffer and the size of the tensor buffer do not match.");
} else {
this.flatSize = flatSize;
}
this.shape = shape.clone();
buffer.rewind();
this.buffer = buffer;
}
/**
* Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
* this {@link TensorBuffer}.
*
* <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
* performance concern, but if modification is necessary, please make a copy.
*
* @param buffer The byte buffer to load.
*/
public void loadBuffer(@NonNull ByteBuffer buffer) {
loadBuffer(buffer, shape);
}
/**
* Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
*
* @throws NullPointerException if {@code shape} is null.
* @throws IllegalArgumentException if {@code shape} has non-positive elements.
*/
protected TensorBuffer(@NonNull int[] shape) {
isDynamic = false;
allocateMemory(shape);
}
/** Constructs a dynamic {@link TensorBuffer} which can be resized. */
protected TensorBuffer() {
isDynamic = true;
// Initialize the dynamic TensorBuffer with an empty ByteBuffer.
allocateMemory(new int[] {0});
}
/** Calculates number of elements in the buffer. */
protected static int computeFlatSize(@NonNull int[] shape) {
SupportPreconditions.checkNotNull(shape, "Shape cannot be null.");
int prod = 1;
for (int s : shape) {
prod = prod * s;
}
return prod;
}
/**
* For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
* shape} of src fits the buffer size.
*/
protected void resize(@NonNull int[] shape) {
if (isDynamic) {
allocateMemory(shape);
} else {
// Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
SupportPreconditions.checkArgument(Arrays.equals(shape, this.shape));
this.shape = shape.clone();
}
}
/**
* Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this
* {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
*
* @throws NullPointerException if {@code shape} is null.
* @throws IllegalArgumentException if {@code shape} has negative elements.
*/
private void allocateMemory(@NonNull int[] shape) {
SupportPreconditions.checkNotNull(shape, "TensorBuffer shape cannot be null.");
SupportPreconditions.checkArgument(
isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
// Check if the new shape is the same as current shape.
int newFlatSize = computeFlatSize(shape);
this.shape = shape.clone();
if (flatSize == newFlatSize) {
return;
}
// Update to the new shape.
flatSize = newFlatSize;
buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
buffer.order(ByteOrder.nativeOrder());
}
/**
* Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
* are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar.
*/
private static boolean isShapeValid(@NonNull int[] shape) {
if (shape.length == 0) {
// This shape refers to a scalar.
return true;
}
// This shape refers to a multidimensional array.
for (int s : shape) {
// All elements in shape should be non-negative.
if (s < 0) {
return false;
}
}
return true;
}
}

View File

@ -1,110 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.tensorbuffer;
import java.nio.FloatBuffer;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.common.SupportPreconditions;
/** Represents data buffer with float values. */
public final class TensorBufferFloat extends TensorBuffer {
private static final DataType DATA_TYPE = DataType.FLOAT32;
/**
* Creates a {@link TensorBufferFloat} with specified {@code shape}.
*
* @throws NullPointerException if {@code shape} is null.
* @throws IllegalArgumentException if {@code shape} has non-positive elements.
*/
TensorBufferFloat(@NonNull int[] shape) {
super(shape);
}
TensorBufferFloat() {
super();
}
@Override
public DataType getDataType() {
return DATA_TYPE;
}
@Override
@NonNull
public float[] getFloatArray() {
buffer.rewind();
float[] arr = new float[flatSize];
FloatBuffer floatBuffer = buffer.asFloatBuffer();
floatBuffer.get(arr);
return arr;
}
@Override
public float getFloatValue(int absIndex) {
return buffer.getFloat(absIndex << 2);
}
@Override
@NonNull
public int[] getIntArray() {
buffer.rewind();
int[] arr = new int[flatSize];
for (int i = 0; i < flatSize; i++) {
arr[i] = (int) buffer.getFloat();
}
return arr;
}
@Override
public int getIntValue(int absIndex) {
return (int) buffer.getFloat(absIndex << 2);
}
@Override
public int getTypeSize() {
return DATA_TYPE.byteSize();
}
@Override
public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
SupportPreconditions.checkArgument(
src.length == computeFlatSize(shape),
"The size of the array to be loaded does not match the specified shape.");
resize(shape);
buffer.rewind();
FloatBuffer floatBuffer = buffer.asFloatBuffer();
floatBuffer.put(src);
}
@Override
public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
SupportPreconditions.checkArgument(
src.length == computeFlatSize(shape),
"The size of the array to be loaded does not match the specified shape.");
resize(shape);
buffer.rewind();
for (int a : src) {
buffer.putFloat((float) a);
}
}
}

View File

@ -1,111 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.tensorbuffer;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.common.SupportPreconditions;
/** Represents data buffer with 8-bit unsigned integer values. */
public final class TensorBufferUint8 extends TensorBuffer {
private static final DataType DATA_TYPE = DataType.UINT8;
/**
* Creates a {@link TensorBufferUint8} with specified {@code shape}.
*
* @throws NullPointerException if {@code shape} is null.
* @throws IllegalArgumentException if {@code shape} has non-positive elements.
*/
TensorBufferUint8(@NonNull int[] shape) {
super(shape);
}
TensorBufferUint8() {
super();
}
@Override
public DataType getDataType() {
return DATA_TYPE;
}
@Override
@NonNull
public float[] getFloatArray() {
buffer.rewind();
float[] arr = new float[flatSize];
for (int i = 0; i < flatSize; i++) {
arr[i] = (float) (buffer.get() & 0xff);
}
return arr;
}
@Override
public float getFloatValue(int index) {
return (float) (buffer.get(index) & 0xff);
}
@Override
@NonNull
public int[] getIntArray() {
buffer.rewind();
int[] arr = new int[flatSize];
for (int i = 0; i < flatSize; i++) {
arr[i] = buffer.get() & 0xff;
}
return arr;
}
@Override
public int getIntValue(int index) {
return buffer.get(index) & 0xff;
}
@Override
public int getTypeSize() {
return DATA_TYPE.byteSize();
}
@Override
public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
SupportPreconditions.checkArgument(
src.length == computeFlatSize(shape),
"The size of the array to be loaded does not match the specified shape.");
resize(shape);
buffer.rewind();
for (float a : src) {
buffer.put((byte) Math.max(Math.min(a, 255.0), 0.0));
}
}
@Override
public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
SupportPreconditions.checkArgument(
src.length == computeFlatSize(shape),
"The size of the array to be loaded does not match the specified shape.");
resize(shape);
buffer.rewind();
for (int a : src) {
buffer.put((byte) Math.max(Math.min(a, 255), 0));
}
}
}

View File

@ -1,113 +0,0 @@
load("//tensorflow:tensorflow.bzl", "py_test")
load("@flatbuffers//:build_defs.bzl", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library", "flatbuffer_py_library")
load("//tensorflow/lite/experimental/support/metadata:build_defs.bzl", "stamp_metadata_parser_version")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["metadata_schema.fbs"])
flatbuffer_py_library(
name = "schema_py",
srcs = ["//tensorflow/lite/schema:schema.fbs"],
)
# Generic schema for inference on device.
flatbuffer_android_library(
name = "schema_fbs_android",
srcs = ["//tensorflow/lite/schema:schema.fbs"],
custom_package = "org.tensorflow.lite.schema",
)
flatbuffer_java_library(
name = "schema_fbs_java",
srcs = ["//tensorflow/lite/schema:schema.fbs"],
custom_package = "org.tensorflow.lite.schema",
)
# Generic schema for model metadata.
flatbuffer_cc_library(
name = "metadata_schema_cc",
srcs = ["metadata_schema.fbs"],
)
flatbuffer_py_library(
name = "metadata_schema_py",
srcs = ["metadata_schema.fbs"],
)
flatbuffer_java_library(
name = "metadata_schema_java",
srcs = ["metadata_schema.fbs"],
custom_package = "org.tensorflow.lite.support.metadata.schema",
)
flatbuffer_android_library(
name = "metadata_schema_fbs_android",
srcs = ["metadata_schema.fbs"],
custom_package = "org.tensorflow.lite.support.metadata.schema",
)
# TODO(b/157813075): move the metadata python library to metadata/python/ when migrating to the new repo.
stamp_metadata_parser_version(
name = "metadata_parser_py",
srcs = ["metadata_parser.py.template"],
outs = ["metadata_parser.py"],
)
py_library(
name = "metadata",
srcs = [
"metadata.py",
":metadata_parser_py",
],
data = [
"//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":metadata_schema_py",
":schema_py",
"//tensorflow/lite/experimental/support/metadata/cc/python:_pywrap_metadata_version",
"//tensorflow/lite/experimental/support/metadata/flatbuffers_lib:_pywrap_flatbuffers",
"//tensorflow/python:platform",
"@flatbuffers//:runtime_py",
],
)
py_test(
name = "metadata_test",
srcs = ["metadata_test.py"],
data = ["testdata/golden_json.json"],
python_version = "PY3",
srcs_version = "PY2AND3",
tags = [
"no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS.
],
deps = [
":metadata",
":metadata_schema_py",
":schema_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"@flatbuffers//:runtime_py",
"@six_archive//:six",
],
)
py_test(
name = "metadata_parser_test",
srcs = ["metadata_parser_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":metadata",
"//tensorflow/python:client_testlib",
],
)

View File

@ -1,15 +0,0 @@
# TensorFlow Lite Metadata and Android wrapper code generator
Note: Both TensorFlow Lite Metadata and the Android wrapper code generator are
in experimental (beta) phase.
TensorFlow Lite metadata provides a structured framework for storing metadata
to convey information for both the developer that will utilitised the model and
code generators which can create wrapper around the model. For information on
how to populate model metadata, please refer to the [TensorFlow Lite Metadata
documentation](https://www.tensorflow.org/lite/convert/metadata).
The first code generator which takes advantage of this metadata format is the
TensorFlow Lite Android Code Generator. For more information on how to use this
generator, please refer to the [TensorFlow Lite Android wrapper code generator
documentation](https://www.tensorflow.org/lite/guide/codegen).

View File

@ -1,43 +0,0 @@
"""Build rules to generate metadata schema versions."""
METADATA_SCHEMA_FILE = "//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs"
def stamp_metadata_parser_version(
name,
srcs,
outs):
"""Stamps the latest metadata parser version into the srcs files.
Replaces all the occurrences of "{LATEST_METADATA_PARSER_VERSION}" in the
srcs files with the metadata schema version extracted from
METADATA_SCHEMA_FILE and then outputs the generated file into outs,
respectively. The number of srcs files needs to match the number of outs
files.
Args:
name: Rule name. (required)
srcs: List of source files. (required)
outs: List of output files. (required)
"""
if len(srcs) != len(outs):
fail(("The number of srcs files (%d) does not match that of the outs" +
" files (%d).") %
(len(srcs), len(outs)))
for i in range(0, len(srcs)):
native.genrule(
name = "%s_file%d" % (name, i),
srcs = [srcs[i]],
outs = [outs[i]],
tools = [METADATA_SCHEMA_FILE],
# Gets the metadata schema version from the file, and stamps it
# into the srcs file.
cmd = "version=$$(sed -n -e '/Schema Semantic version/ s/.*\\: *//p' $(location %s));" %
METADATA_SCHEMA_FILE +
'sed "s/{LATEST_METADATA_PARSER_VERSION}/$$version/" $< > $@',
)
native.filegroup(
name = name,
srcs = outs,
)

View File

@ -1,29 +0,0 @@
load("//tensorflow/lite/experimental/support/metadata:build_defs.bzl", "stamp_metadata_parser_version")
package(
default_visibility = ["//tensorflow/lite/experimental/support:users"],
licenses = ["notice"], # Apache 2.0
)
stamp_metadata_parser_version(
name = "metadata_parser_h",
srcs = ["metadata_parser.h.template"],
outs = ["metadata_parser.h"],
)
cc_library(
name = "metadata_version",
srcs = ["metadata_version.cc"],
hdrs = [
"metadata_version.h",
":metadata_parser_h",
],
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/tools:logging",
"@com_google_absl//absl/strings",
"@flatbuffers",
],
)

View File

@ -1,28 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_PARSER_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_PARSER_H_
namespace tflite {
namespace metadata {
// The version of the metadata parser that this metadata versioning library is
// depending on.
inline constexpr char kMatadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}";
} // namespace metadata
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_PARSER_H_

View File

@ -1,214 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <ostream>
#include <string>
#include <vector>
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/tools/logging.h"
namespace tflite {
namespace metadata {
namespace {
// Members that are added to the metadata schema after the initial version
// of 1.0.0.
enum class SchemaMembers {
kAssociatedFileTypeVocabulary = 0,
};
// Helper class to compare semantic versions in terms of three integers, major,
// minor, and patch.
class Version {
public:
explicit Version(int major, int minor = 0, int patch = 0)
: version_({major, minor, patch}) {}
explicit Version(const std::string& version) {
const std::vector<std::string> vec = absl::StrSplit(version, '.');
// The version string should always be less than four numbers.
TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty());
version_[0] = std::stoi(vec[0]);
version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0;
version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0;
}
// Compares two semantic version numbers.
//
// Example results when comparing two versions strings:
// "1.9" precedes "1.14";
// "1.14" precedes "1.14.1";
// "1.14" and "1.14.0" are equal.
//
// Returns the value 0 if the two versions are equal; a value less than 0 if
// *this precedes v; a value greater than 0 if v precedes *this.
int Compare(const Version& v) {
for (int i = 0; i < kElementNumber; ++i) {
if (version_[i] != v.version_[i]) {
return version_[i] < v.version_[i] ? -1 : 1;
}
}
return 0;
}
// Converts version_ into a version string.
std::string ToString() { return absl::StrJoin(version_, "."); }
private:
static constexpr int kElementNumber = 3;
std::array<int, kElementNumber> version_;
};
Version GetMemberVersion(SchemaMembers member) {
switch (member) {
case SchemaMembers::kAssociatedFileTypeVocabulary:
return Version(1, 0, 1);
default:
TFLITE_LOG(FATAL) << "Unsupported schema member: "
<< static_cast<int>(member);
}
}
// Updates min_version if it precedes the new_version.
inline void UpdateMinimumVersion(const Version& new_version,
Version* min_version) {
if (min_version->Compare(new_version) < 0) {
*min_version = new_version;
}
}
void UpdateMinimumVersionForAssociatedFile(
const tflite::AssociatedFile* associated_file, Version* min_version) {
if (associated_file == nullptr) return;
if (associated_file->type() == AssociatedFileType_VOCABULARY) {
UpdateMinimumVersion(
GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
min_version);
}
}
void UpdateMinimumVersionForAssociatedFileArray(
const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
associated_files,
Version* min_version) {
if (associated_files == nullptr) return;
for (int i = 0; i < associated_files->size(); ++i) {
UpdateMinimumVersionForAssociatedFile(associated_files->Get(i),
min_version);
}
}
void UpdateMinimumVersionForTensorMetadata(
const tflite::TensorMetadata* tensor_metadata, Version* min_version) {
if (tensor_metadata == nullptr) return;
// Checks the associated_files field.
UpdateMinimumVersionForAssociatedFileArray(
tensor_metadata->associated_files(), min_version);
}
void UpdateMinimumVersionForTensorMetadataArray(
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
tensor_metadata_array,
Version* min_version) {
if (tensor_metadata_array == nullptr) return;
for (int i = 0; i < tensor_metadata_array->size(); ++i) {
UpdateMinimumVersionForTensorMetadata(tensor_metadata_array->Get(i),
min_version);
}
}
void UpdateMinimumVersionForSubGraphMetadata(
const tflite::SubGraphMetadata* subgraph_metadata, Version* min_version) {
if (subgraph_metadata == nullptr) return;
// Checks in the input/output metadata arrays.
UpdateMinimumVersionForTensorMetadataArray(
subgraph_metadata->input_tensor_metadata(), min_version);
UpdateMinimumVersionForTensorMetadataArray(
subgraph_metadata->output_tensor_metadata(), min_version);
// Checks the associated_files field.
UpdateMinimumVersionForAssociatedFileArray(
subgraph_metadata->associated_files(), min_version);
}
void UpdateMinimumVersionForModelMetadata(
const tflite::ModelMetadata& model_metadata, Version* min_version) {
// Checks the subgraph_metadata field.
if (model_metadata.subgraph_metadata() != nullptr) {
for (int i = 0; i < model_metadata.subgraph_metadata()->size(); ++i) {
UpdateMinimumVersionForSubGraphMetadata(
model_metadata.subgraph_metadata()->Get(i), min_version);
}
}
// Checks the associated_files field.
UpdateMinimumVersionForAssociatedFileArray(model_metadata.associated_files(),
min_version);
}
} // namespace
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
size_t buffer_size,
std::string* min_version_str) {
flatbuffers::Verifier verifier =
flatbuffers::Verifier(buffer_data, buffer_size);
if (!tflite::VerifyModelMetadataBuffer(verifier)) {
TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer.";
return kTfLiteError;
}
static constexpr char kDefaultVersion[] = "1.0.0";
Version min_version = Version(kDefaultVersion);
// Checks if any member declared after 1.0.0 (such as those in
// SchemaMembers) exists, and updates min_version accordingly. The minimum
// metadata parser version will be the largest version number of all fields
// that has been added to a metadata flatbuffer
const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);
// All tables in the metadata schema should have their dedicated
// UpdateMinimumVersionFor**() methods, respectively. We'll gradually add
// these methods when new fields show up in later schema versions.
//
// UpdateMinimumVersionFor<Foo>() takes a const pointer of Foo. The pointer
// can be a nullptr if Foo is not populated into the corresponding table of
// the Flatbuffer object. In this case, UpdateMinimumVersionFor<Foo>() will be
// skipped. An exception is UpdateMinimumVersionForModelMetadata(), where
// ModelMetadata is the root table, and it won't be null.
UpdateMinimumVersionForModelMetadata(*model_metadata, &min_version);
*min_version_str = min_version.ToString();
return kTfLiteOk;
}
} // namespace metadata
} // namespace tflite

View File

@ -1,38 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_
#include <stddef.h>
#include <stdint.h>
#include <string>
#include "tensorflow/lite/c/common.h"
namespace tflite {
namespace metadata {
// Gets the minimum metadata parser version that can fully understand all fields
// in a given metadata flatbuffer. TFLite Metadata follows Semantic Versioning
// 2.0. Each release version has the form MAJOR.MINOR.PATCH.
TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
size_t buffer_size,
std::string* min_version);
} // namespace metadata
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_

View File

@ -1,22 +0,0 @@
load("//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = [
"//tensorflow/lite/experimental/support/metadata:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
pybind_extension(
name = "_pywrap_metadata_version",
srcs = [
"metadata_version.cc",
],
features = ["-use_header_modules"],
module_name = "_pywrap_metadata_version",
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
"@pybind11",
],
)

View File

@ -1,55 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
#include "pybind11/pybind11.h"
#include "tensorflow/lite/c/common.h"
namespace tflite {
namespace metadata {
PYBIND11_MODULE(_pywrap_metadata_version, m) {
m.doc() = R"pbdoc(
_pywrap_metadata_version
A module that returns the minimum metadata parser version of a given
metadata flatbuffer.
)pbdoc";
// Using pybind11 type conversions to convert between Python and native
// C++ types. There are other options to provide access to native Python types
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details.
// Type converstions is recommended by pybind11, though the main downside
// is that a copy of the data must be made on every Python to C++ transition:
// this is needed since the C++ and Python versions of the same type generally
// wont have the same memory layout.
//
// [1]: https://pybind11.readthedocs.io/en/stable/advanced/cast/index.html
m.def("GetMinimumMetadataParserVersion",
[](const std::string& buffer_data) -> std::string {
std::string min_version;
if (GetMinimumMetadataParserVersion(
reinterpret_cast<const uint8_t*>(buffer_data.c_str()),
buffer_data.length(), &min_version) != kTfLiteOk) {
pybind11::value_error(
"Error occurred when getting the minimum metadata parser "
"version of the metadata flatbuffer.");
}
return min_version;
});
}
} // namespace metadata
} // namespace tflite

View File

@ -1,24 +0,0 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
cc_test(
name = "metadata_version_test",
srcs = ["metadata_version_test.cc"],
deps = [
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
],
)
cc_test(
name = "metadata_parser_test",
srcs = ["metadata_parser_test.cc"],
deps = [
"//tensorflow/lite/experimental/support/metadata/cc:metadata_version",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -1,33 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_parser.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace tflite {
namespace metadata {
namespace {
using ::testing::MatchesRegex;
TEST(MetadataParserTest, MatadataParserVersionIsWellFormed) {
// Validates that the version is well-formed (x.y.z).
EXPECT_THAT(kMatadataParserVersion, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
}
} // namespace
} // namespace metadata
} // namespace tflite

View File

@ -1,187 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h"
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace metadata {
namespace {
using ::testing::MatchesRegex;
using ::testing::StrEq;
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionSucceedsWithValidMetadata) {
// Creates a dummy metadata flatbuffer for test.
flatbuffers::FlatBufferBuilder builder(1024);
auto name = builder.CreateString("Foo");
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_name(name);
auto metadata = metadata_builder.Finish();
FinishModelMetadataBuffer(builder, metadata);
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is well-formed (x.y.z).
EXPECT_THAT(min_version, MatchesRegex("[0-9]+\\.[0-9]+\\.[0-9]+"));
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionFailsWithInvalidIdentifier) {
// Creates a dummy metadata flatbuffer without identifier.
flatbuffers::FlatBufferBuilder builder(1024);
ModelMetadataBuilder metadata_builder(builder);
auto metadata = metadata_builder.Finish();
builder.Finish(metadata);
// Gets the mimimum metadata parser version and triggers error.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteError);
EXPECT_TRUE(min_version.empty());
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionForModelMetadataVocabAssociatedFiles) {
// Creates a metadata flatbuffer with the field,
// ModelMetadata.associated_fiels, populated with the vocabulary file type.
flatbuffers::FlatBufferBuilder builder(1024);
AssociatedFileBuilder associated_file_builder(builder);
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
auto associated_files =
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
associated_file_builder.Finish()});
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is exactly 1.0.1.
EXPECT_THAT(min_version, StrEq("1.0.1"));
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionForSubGraphMetadataVocabAssociatedFiles) {
// Creates a metadata flatbuffer with the field,
// SubGraphMetadata.associated_fiels, populated with the vocabulary file type.
flatbuffers::FlatBufferBuilder builder(1024);
AssociatedFileBuilder associated_file_builder(builder);
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
auto associated_files =
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
associated_file_builder.Finish()});
SubGraphMetadataBuilder subgraph_builder(builder);
subgraph_builder.add_associated_files(associated_files);
auto subgraphs =
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
subgraph_builder.Finish()});
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is exactly 1.0.1.
EXPECT_THAT(min_version, StrEq("1.0.1"));
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionForInputMetadataVocabAssociatedFiles) {
// Creates a metadata flatbuffer with the field,
// SubGraphMetadata.input_tensor_metadata.associated_fiels, populated with the
// vocabulary file type.
flatbuffers::FlatBufferBuilder builder(1024);
AssociatedFileBuilder associated_file_builder(builder);
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
auto associated_files =
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
associated_file_builder.Finish()});
TensorMetadataBuilder tensor_builder(builder);
tensor_builder.add_associated_files(associated_files);
auto tensors =
builder.CreateVector(std::vector<flatbuffers::Offset<TensorMetadata>>{
tensor_builder.Finish()});
SubGraphMetadataBuilder subgraph_builder(builder);
subgraph_builder.add_input_tensor_metadata(tensors);
auto subgraphs =
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
subgraph_builder.Finish()});
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is exactly 1.0.1.
EXPECT_THAT(min_version, StrEq("1.0.1"));
}
TEST(MetadataVersionTest,
GetMinimumMetadataParserVersionForOutputMetadataVocabAssociatedFiles) {
// Creates a metadata flatbuffer with the field,
// SubGraphMetadata.output_tensor_metadata.associated_fiels, populated with
// the vocabulary file type.
flatbuffers::FlatBufferBuilder builder(1024);
AssociatedFileBuilder associated_file_builder(builder);
associated_file_builder.add_type(tflite::AssociatedFileType_VOCABULARY);
auto associated_files =
builder.CreateVector(std::vector<flatbuffers::Offset<AssociatedFile>>{
associated_file_builder.Finish()});
TensorMetadataBuilder tensor_builder(builder);
tensor_builder.add_associated_files(associated_files);
auto tensors =
builder.CreateVector(std::vector<flatbuffers::Offset<TensorMetadata>>{
tensor_builder.Finish()});
SubGraphMetadataBuilder subgraph_builder(builder);
subgraph_builder.add_output_tensor_metadata(tensors);
auto subgraphs =
builder.CreateVector(std::vector<flatbuffers::Offset<SubGraphMetadata>>{
subgraph_builder.Finish()});
ModelMetadataBuilder metadata_builder(builder);
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
kTfLiteOk);
// Validates that the version is exactly 1.0.1.
EXPECT_EQ(min_version, "1.0.1");
}
} // namespace
} // namespace metadata
} // namespace tflite

View File

@ -1,23 +0,0 @@
load("//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
pybind_extension(
name = "_pywrap_flatbuffers",
srcs = [
"flatbuffers_lib.cc",
],
features = ["-use_header_modules"],
module_name = "_pywrap_flatbuffers",
deps = [
"//tensorflow/python:pybind11_lib",
"//third_party/python_runtime:headers",
"@flatbuffers",
"@pybind11",
],
)

View File

@ -1,59 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "flatbuffers/idl.h" // from @flatbuffers
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
namespace tflite {
namespace support {
PYBIND11_MODULE(_pywrap_flatbuffers, m) {
pybind11::class_<flatbuffers::IDLOptions>(m, "IDLOptions")
.def(pybind11::init<>())
.def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json);
pybind11::class_<flatbuffers::Parser>(m, "Parser")
.def(pybind11::init<const flatbuffers::IDLOptions&>())
.def("parse",
[](flatbuffers::Parser* self, const std::string& source) {
return self->Parse(source.c_str());
})
.def_readonly("builder", &flatbuffers::Parser::builder_)
.def_readonly("error", &flatbuffers::Parser::error_);
pybind11::class_<flatbuffers::FlatBufferBuilder>(m, "FlatBufferBuilder")
.def("clear", &flatbuffers::FlatBufferBuilder::Clear)
.def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self,
const std::string& contents) {
self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
contents.length());
});
m.def("generate_text_file", &flatbuffers::GenerateTextFile);
m.def(
"generate_text",
[](const flatbuffers::Parser& parser,
const std::string& buffer) -> std::string {
std::string text;
if (!flatbuffers::GenerateText(
parser, reinterpret_cast<const void*>(buffer.c_str()), &text)) {
return "";
}
return text;
});
}
} // namespace support
} // namespace tflite

View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.support">
<uses-sdk android:minSdkVersion="19" />
</manifest>

View File

@ -1,40 +0,0 @@
# Description:
# TensorFlow Lite Support API in Java for metadata.
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
METADATA_SRCS = glob(
["src/java/org/tensorflow/lite/support/metadata/**/*.java"],
)
android_library(
name = "tensorflow-lite-support-metadata",
srcs = METADATA_SRCS,
manifest = "AndroidManifest.xml",
deps = [
"//tensorflow/lite/experimental/support/metadata:metadata_schema_fbs_android",
"//tensorflow/lite/experimental/support/metadata:schema_fbs_android",
"@org_checkerframework_qual",
],
)
java_library(
name = "tensorflow-lite-support-metadata-lib",
srcs = METADATA_SRCS,
javacopts = JAVACOPTS,
resource_jars = [
"//tensorflow/lite/experimental/support/metadata:libmetadata_schema_java.jar",
"//tensorflow/lite/experimental/support/metadata:libschema_fbs_java.jar",
],
deps = [
"//tensorflow/lite/experimental/support/metadata:metadata_schema_java",
"//tensorflow/lite/experimental/support/metadata:schema_fbs_java",
"@org_checkerframework_qual",
],
)

View File

@ -1,116 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
import static org.tensorflow.lite.support.metadata.Preconditions.checkElementIndex;
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
/**
* An {@link InputStream} that wraps a section of a {@link SeekableByteChannelCompat}.
*
* <p><b>WARNING:</b> Similar as {@link InputStream}, instances of an {@link BoundedInputStream} are
* <b>not</b> thread-safe. If multiple threads concurrently reading from the same {@link
* BoundedInputStream}, it must be synchronized externally. Also, if multiple instances of {@link
* BoundedInputStream} are created on the same {@link SeekableByteChannelCompat}, it must be
* synchronized as well.
*/
final class BoundedInputStream extends InputStream {
private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
private final long end; // The valid data for the stream is between [start, end).
private long position;
private final SeekableByteChannelCompat channel;
/**
* Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
*
* @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
* BoundedInputStream}
* @param start the starting position of this {@link BoundedInputStream} in the given {@link
* SeekableByteChannelCompat}
* @param remaining the length of this {@link BoundedInputStream}
* @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
*/
BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
checkArgument(
remaining >= 0 && start >= 0,
String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
end = start + remaining;
this.channel = channel;
position = start;
}
@Override
public int available() throws IOException {
return (int) (Math.min(end, channel.size()) - position);
}
@Override
public int read() throws IOException {
if (position >= end) {
return -1;
}
singleByteBuffer.rewind();
int count = read(position, singleByteBuffer);
if (count < 0) {
return count;
}
position++;
return singleByteBuffer.get() & 0xff;
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
checkNotNull(b);
checkElementIndex(off, b.length, "The start offset");
checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
if (len == 0) {
return 0;
}
if (len > end - position) {
if (position >= end) {
return -1;
}
len = (int) (end - position);
}
ByteBuffer buf = ByteBuffer.wrap(b, off, len);
int count = read(position, buf);
if (count > 0) {
position += count;
}
return count;
}
private int read(long position, ByteBuffer buf) throws IOException {
int count;
synchronized (channel) {
channel.position(position);
count = channel.read(buf);
}
buf.flip();
return count;
}
}

View File

@ -1,130 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static java.lang.Math.min;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
import java.nio.ByteBuffer;
import java.nio.channels.NonWritableChannelException;
/** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */
final class ByteBufferChannel implements SeekableByteChannelCompat {
/** The ByteBuffer that holds the data. */
private final ByteBuffer buffer;
/**
* Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
*
* @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
* @throws NullPointerException if {@code buffer} is null
*/
public ByteBufferChannel(ByteBuffer buffer) {
checkNotNull(buffer, "The ByteBuffer cannot be null.");
this.buffer = buffer;
}
@Override
public void close() {}
@Override
public boolean isOpen() {
return true;
}
@Override
public long position() {
return buffer.position();
}
/**
* Sets this channel's position.
*
* @param newPosition the new position, a non-negative integer counting the number of bytes from
* the beginning of the entity
* @return this channel
* @throws IllegalArgumentException if the new position is negative, or greater than the size of
* the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
*/
@Override
public synchronized ByteBufferChannel position(long newPosition) {
checkArgument(
(newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
"The new position should be non-negative and be less than Integer.MAX_VALUE.");
buffer.position((int) newPosition);
return this;
}
/**
* {@inheritDoc}
*
* <p>Bytes are read starting at this channel's current position, and then the position is updated
* with the number of bytes actually read. Otherwise this method behaves exactly as specified in
* the {@link ReadableByteChannel} interface.
*/
@Override
public synchronized int read(ByteBuffer dst) {
if (buffer.remaining() == 0) {
return -1;
}
int count = min(dst.remaining(), buffer.remaining());
if (count > 0) {
ByteBuffer tempBuffer = buffer.slice();
tempBuffer.order(buffer.order()).limit(count);
dst.put(tempBuffer);
buffer.position(buffer.position() + count);
}
return count;
}
@Override
public long size() {
return buffer.limit();
}
@Override
public synchronized ByteBufferChannel truncate(long size) {
checkArgument(
(size >= 0 && size <= Integer.MAX_VALUE),
"The new size should be non-negative and be less than Integer.MAX_VALUE.");
if (size < buffer.limit()) {
buffer.limit((int) size);
if (buffer.position() > size) {
buffer.position((int) size);
}
}
return this;
}
@Override
public synchronized int write(ByteBuffer src) {
if (buffer.isReadOnly()) {
throw new NonWritableChannelException();
}
int count = min(src.remaining(), buffer.remaining());
if (count > 0) {
ByteBuffer tempBuffer = src.slice();
tempBuffer.order(buffer.order()).limit(count);
buffer.put(tempBuffer);
}
return count;
}
}

View File

@ -1,368 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.zip.ZipException;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.schema.Tensor;
import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
/**
* Loads metadata from TFLite Model FlatBuffer.
*
* <p>TFLite Model FlatBuffer can be generated using the <a
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
* Model schema file.</a>
*
* <p>Some models contain a TFLite Metadata Flatbuffer, which records more information about what
* the model does and how to interprete the model. TFLite Metadata Flatbuffer can be generated using
* the <a
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/metadata_schema.fbs">TFLite
* Metadata schema file.</a>
*
* <p>It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking methods
* that read from TFLite metadata will cause runtime errors.
*
* <p>Similarly, it is allowed to pass in a model FlatBuffer without associated files. However,
* invoking methods that read the associated files will cause runtime errors.
*
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports a
* single subgraph so far. See the <a
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
* of how to specify subgraph during convertion for more information.</a> Therefore, {@link
* MetadataExtractor} omits subgraph index as an input in its methods.
*/
public class MetadataExtractor {
/** The helper class to load metadata from TFLite model FlatBuffer. */
private final ModelInfo modelInfo;
/** The helper class to load metadata from TFLite metadata FlatBuffer. */
@Nullable private final ModelMetadataInfo metadataInfo;
/** The handler to load associated files through zip. */
@Nullable private final ZipFile zipFile;
/**
* Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
*
* @param buffer the TFLite model FlatBuffer
* @throws IllegalArgumentException if the number of input or output tensors in the model does not
* match that in the metadata
* @throws IOException if an error occurs while reading the model as a Zip file
*/
public MetadataExtractor(ByteBuffer buffer) throws IOException {
modelInfo = new ModelInfo(buffer);
ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
if (metadataBuffer != null) {
metadataInfo = new ModelMetadataInfo(metadataBuffer);
// Prints warning message if the minimum parser version is not satisfied.
if (!isMinimumParserVersionSatisfied()) {
System.err.printf(
"<Warning> Some fields in the metadata belong to a future schema. The minimum parser"
+ " version required is %s, but the version of the current metadata parser is %s",
metadataInfo.getMininumParserVersion(), MetadataParser.VERSION);
}
checkArgument(
modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
String.format(
"The number of input tensors in the model is %d. The number of input tensors that"
+ " recorded in the metadata is %d. These two values does not match.",
modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
checkArgument(
modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
String.format(
"The number of output tensors in the model is %d. The number of output tensors that"
+ " recorded in the metadata is %d. These two values does not match.",
modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
} else {
// It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking
// methods that read from TFLite metadata will cause runtime errors.
metadataInfo = null;
}
zipFile = createZipFile(buffer);
}
/**
* Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
* <a
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
* Model schema file.</a>
*
* <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and
* {@code zero_point} are both single values instead of arrays.
*
* <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
*
* <p>Given a quantized value q, the corresponding float value f should be: <br>
* f = scale * (q - zero_point) <br>
*/
public static class QuantizationParams {
/** The scale value used in quantization. */
private final float scale;
/** The zero point value used in quantization. */
private final int zeroPoint;
/**
* Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
*
* @param scale The scale value used in quantization.
* @param zeroPoint The zero point value used in quantization.
*/
public QuantizationParams(final float scale, final int zeroPoint) {
this.scale = scale;
this.zeroPoint = zeroPoint;
}
/** Returns the scale value. */
public float getScale() {
return scale;
}
/** Returns the zero point value. */
public int getZeroPoint() {
return zeroPoint;
}
}
/** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
public boolean hasMetadata() {
return metadataInfo != null;
}
/**
* Gets the packed associated file with the specified {@code fileName}.
*
* @param fileName the name of the associated file
* @return the raw input stream containing specified file
* @throws IllegalStateException if the model is not a zip file
* @throws IllegalArgumentException if the specified file does not exist in the model
*/
public InputStream getAssociatedFile(String fileName) {
assertZipFile();
return zipFile.getRawInputStream(fileName);
}
/** Gets the count of input tensors in the model. */
public int getInputTensorCount() {
return modelInfo.getInputTensorCount();
}
/**
* Gets the metadata for the input tensor specified by {@code inputIndex}.
*
* @param inputIndex the index of the desired input tensor
* @throws IllegalStateException if this model does not contain model metadata
*/
@Nullable
public TensorMetadata getInputTensorMetadata(int inputIndex) {
assertMetadataInfo();
return metadataInfo.getInputTensorMetadata(inputIndex);
}
/**
* Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
*
* @param inputIndex the index of the desired input tensor
*/
public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
Tensor tensor = modelInfo.getInputTensor(inputIndex);
return modelInfo.getQuantizationParams(tensor);
}
/**
* Gets the shape of the input tensor with {@code inputIndex}.
*
* @param inputIndex the index of the desired input tensor
*/
public int[] getInputTensorShape(int inputIndex) {
return modelInfo.getInputTensorShape(inputIndex);
}
/**
* Gets the {@link TensorType} of the input tensor with {@code inputIndex}.
*
* @param inputIndex the index of the desired input tensor
*/
public byte getInputTensorType(int inputIndex) {
return modelInfo.getInputTensorType(inputIndex);
}
/**
* Gets the root handler for the model metadata.
*
* @throws IllegalStateException if this model does not contain model metadata
*/
public ModelMetadata getModelMetadata() {
assertMetadataInfo();
return metadataInfo.getModelMetadata();
}
/** Gets the count of output tensors in the model. */
public int getOutputTensorCount() {
return modelInfo.getOutputTensorCount();
}
/**
* Gets the metadata for the output tensor specified by {@code outputIndex}.
*
* @param outputIndex the index of the desired output tensor
* @throws IllegalStateException if this model does not contain model metadata
*/
@Nullable
public TensorMetadata getOutputTensorMetadata(int outputIndex) {
assertMetadataInfo();
return metadataInfo.getOutputTensorMetadata(outputIndex);
}
/**
* Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
*
* @param outputIndex the index of the desired output tensor
*/
public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
Tensor tensor = modelInfo.getOutputTensor(outputIndex);
return modelInfo.getQuantizationParams(tensor);
}
/**
* Gets the shape of the output tensor with {@code outputIndex}.
*
* @param outputIndex the index of the desired output tensor
*/
public int[] getOutputTensorShape(int outputIndex) {
return modelInfo.getOutputTensorShape(outputIndex);
}
/**
* Gets the {@link TensorType} of the output tensor with {@code outputIndex}.
*
* @param outputIndex the index of the desired output tensor
*/
public byte getOutputTensorType(int outputIndex) {
return modelInfo.getOutputTensorType(outputIndex);
}
/**
* Returns {@code true} if the minimum parser version required by the given metadata flatbuffer
* precedes or equals to the version of the metadata parser that this MetadataExtractor library is
* relying on. All fields in the metadata can be parsed correctly with this metadata extractor
* library in this case. Otherwise, it returns {@code false}.
*
* <p>For example, assume the underlying metadata parser version is {@code 1.14.1},
*
* <ul>
* <li>it returns {@code true}, if the required minimum parser version is the same or older,
* such as {@code 1.14.1} or {@code 1.14.0}. Null version precedes all numeric versions,
* because some metadata flatbuffers are generated before the first versioned release; <br>
* <li>it returns {@code false}, if the required minimum parser version is newer, such as {@code
* 1.14.2}.
* </ul>
*/
public final boolean isMinimumParserVersionSatisfied() {
String minVersion = metadataInfo.getMininumParserVersion();
if (minVersion == null) {
return true;
}
return compareVersions(minVersion, MetadataParser.VERSION) <= 0;
}
/**
* Asserts if {@link #metadataInfo} is not initialized. Some models may not have metadata and this
* is allowed. However, invoking methods that reads the metadata is not allowed.
*
* @throws IllegalStateException if this model does not contain model metadata
*/
private void assertMetadataInfo() {
if (metadataInfo == null) {
throw new IllegalStateException("This model does not contain model metadata.");
}
}
/**
* Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus
* are not Zip files. This is allowed. However, invoking methods that reads those associated files
* is not allowed.
*
* @throws IllegalStateException if this model is not a Zip file
*/
private void assertZipFile() {
if (zipFile == null) {
throw new IllegalStateException(
"This model does not contain associated files, and is not a Zip file.");
}
}
/**
* Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
* it does not have associated files, return a null handler.
*
* @param buffer the TFLite model FlatBuffer
* @throws IOException if an error occurs while reading the model as a Zip file
*/
@Nullable
private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
try {
// Creates the handler to hold the associated files through the Zip.
ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
return ZipFile.createFrom(byteBufferChannel);
} catch (ZipException e) {
// Some models may not have associate files. Therefore, Those models are not zip files.
// However, invoking methods that read associated files later will lead into errors.
return null;
}
}
/**
* Compares two semantic version numbers.
*
* <p>Examples of comparing two versions: <br>
* {@code 1.9} precedes {@code 1.14}; <br>
* {@code 1.14} precedes {@code 1.14.1}; <br>
* {@code 1.14} and {@code 1.14.0} are euqal;
*
* @return the value {@code 0} if the two versions are equal; a value less than {@code 0} if
* {@code version1} precedes {@code version2}; a value greater than {@code 0} if {@code
* version2} precedes {@code version1}.
*/
private static int compareVersions(String version1, String version2) {
// Using String.split instead of the recommanded Guava Splitter because we've been avoiding
// depending on other third party libraries in this project.
String[] levels1 = version1.split("\\.", 0);
String[] levels2 = version2.split("\\.", 0);
int length = Math.max(levels1.length, levels2.length);
for (int i = 0; i < length; i++) {
Integer v1 = i < levels1.length ? Integer.parseInt(levels1[i]) : 0;
Integer v2 = i < levels2.length ? Integer.parseInt(levels2[i]) : 0;
int compare = v1.compareTo(v2);
if (compare != 0) {
return compare;
}
}
return 0;
}
}

View File

@ -1,27 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
/** Information about the metadata parser that this metadata extractor library is depending on. */
public final class MetadataParser {
/**
* The version of the metadata parser that this metadata extractor library is depending on. The
* value should match the value of "Schema Semantic version" in metadata_schema.fbs.
*/
public static final String VERSION = "1.0.1";
private MetadataParser() {}
}

View File

@ -1,266 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.schema.Buffer;
import org.tensorflow.lite.schema.Metadata;
import org.tensorflow.lite.schema.Model;
import org.tensorflow.lite.schema.QuantizationParameters;
import org.tensorflow.lite.schema.SubGraph;
import org.tensorflow.lite.schema.Tensor;
import org.tensorflow.lite.schema.TensorType;
import org.tensorflow.lite.support.metadata.MetadataExtractor.QuantizationParams;
/** Extracts model information out of TFLite model FLatBuffer. */
final class ModelInfo {
/** The model that is loaded from TFLite model FlatBuffer. */
private final Model model;
/** A list of input tensors. */
private final List</* @Nullable */ Tensor> inputTensors;
/** A list of output tensors. */
private final List</* @Nullable */ Tensor> outputTensors;
/** Identifier of the TFLite model metadata in the Metadata array. */
static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
/**
* Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
*
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports
* single subgraph so far. See the <a
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
* of how to specify subgraph during convertion for more information.</a> Therefore, all methods
* in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
*
* @param buffer the TFLite model FlatBuffer
* @throws NullPointerException if {@code buffer} is null
* @throws IllegalArgumentException if the model does not contain any subgraph, or the model does
* not contain the expected identifier
*/
ModelInfo(ByteBuffer buffer) {
assertTFLiteModel(buffer);
model = Model.getRootAsModel(buffer);
checkArgument(model.subgraphsLength() > 0, "The model does not contain any subgraph.");
inputTensors = getInputTensors(model);
outputTensors = getOutputTensors(model);
}
/**
* Gets the input tensor with {@code inputIndex}.
*
* @param inputIndex The index of the desired input tensor.
* @throws IllegalArgumentException if the inputIndex specified is invalid.
*/
@Nullable
Tensor getInputTensor(int inputIndex) {
checkArgument(
inputIndex >= 0 && inputIndex < inputTensors.size(),
"The inputIndex specified is invalid.");
return inputTensors.get(inputIndex);
}
int getInputTensorCount() {
return inputTensors.size();
}
/**
* Gets shape of the input tensor with {@code inputIndex}.
*
* @param inputIndex The index of the desired intput tensor.
*/
int[] getInputTensorShape(int inputIndex) {
Tensor tensor = getInputTensor(inputIndex);
return getShape(tensor);
}
/**
* Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}.
*
* @param inputIndex The index of the desired intput tensor.
*/
byte getInputTensorType(int inputIndex) {
return getInputTensor(inputIndex).type();
}
/** Gets the metadata FlatBuffer from the model FlatBuffer. */
@Nullable
ByteBuffer getMetadataBuffer() {
// Some models may not have metadata, and this is allowed.
if (model.metadataLength() == 0) {
return null;
}
for (int i = 0; i < model.metadataLength(); i++) {
Metadata meta = model.metadata(i);
if (METADATA_FIELD_NAME.equals(meta.name())) {
long bufferIndex = meta.buffer();
Buffer metadataBuf = model.buffers((int) bufferIndex);
return metadataBuf.dataAsByteBuffer();
}
}
return null;
}
/**
* Gets the output tensor with {@code outputIndex}.
*
* @param outputIndex The index of the desired outtput tensor.
* @throws IllegalArgumentException if the outputIndex specified is invalid.
*/
@Nullable
Tensor getOutputTensor(int outputIndex) {
checkArgument(
outputIndex >= 0 && outputIndex < outputTensors.size(),
"The outputIndex specified is invalid.");
return outputTensors.get(outputIndex);
}
int getOutputTensorCount() {
return outputTensors.size();
}
/**
* Gets shape of the output tensor with {@code outputIndex}.
*
* @param outputIndex The index of the desired outtput tensor.
*/
int[] getOutputTensorShape(int outputIndex) {
Tensor tensor = getOutputTensor(outputIndex);
return getShape(tensor);
}
/**
* Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}.
*
* @param outputIndex The index of the desired outtput tensor.
*/
byte getOutputTensorType(int outputIndex) {
return getOutputTensor(outputIndex).type();
}
/**
* Gets the quantization parameters of a tensor.
*
* <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
* quantized, the values of scale and zero_point are both 0.
*
* @param tensor The tensor whoes quantization parameters is desired.
* @throws NullPointerException if the tensor is null.
* @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link
* QuantizationParameters} are not single values.
*/
QuantizationParams getQuantizationParams(Tensor tensor) {
checkNotNull(tensor, "Tensor cannot be null.");
float scale;
int zeroPoint;
QuantizationParameters quantization = tensor.quantization();
// Tensors that are not quantized do not have quantization parameters, which can be null when
// being extracted from the flatbuffer.
if (quantization == null) {
scale = 0.0f;
zeroPoint = 0;
return new QuantizationParams(scale, zeroPoint);
}
// Tensors that are not quantized do not have quantization parameters.
// quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
checkArgument(
quantization.scaleLength() <= 1,
"Input and output tensors do not support per-channel quantization.");
checkArgument(
quantization.zeroPointLength() <= 1,
"Input and output tensors do not support per-channel quantization.");
// For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will
// both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++
// runtime.
scale = quantization.scale(0);
// zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it
// consistent with the C++ runtime.
zeroPoint = (int) quantization.zeroPoint(0);
return new QuantizationParams(scale, zeroPoint);
}
/**
* Verifies if the buffer is a valid TFLite model.
*
* @param buffer the TFLite model flatbuffer
* @throws NullPointerException if {@code buffer} is null.
* @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
*/
private static void assertTFLiteModel(ByteBuffer buffer) {
checkNotNull(buffer, "Model flatbuffer cannot be null.");
checkArgument(
Model.ModelBufferHasIdentifier(buffer),
"The identifier of the model is invalid. The buffer may not be a valid TFLite model"
+ " flatbuffer.");
}
/**
* Gets the shape of a tensor.
*
* @param tensor The tensor whoes shape is desired.
* @throws NullPointerException if the tensor is null.
*/
private static int[] getShape(Tensor tensor) {
checkNotNull(tensor, "Tensor cannot be null.");
int shapeDim = tensor.shapeLength();
int[] tensorShape = new int[shapeDim];
for (int i = 0; i < shapeDim; i++) {
tensorShape[i] = tensor.shape(i);
}
return tensorShape;
}
/** Gets input tensors from a model. */
private static List<Tensor> getInputTensors(Model model) {
// TFLite only support one subgraph currently.
SubGraph subgraph = model.subgraphs(0);
int tensorNum = subgraph.inputsLength();
ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
for (int i = 0; i < tensorNum; i++) {
inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
}
return Collections.unmodifiableList(inputTensors);
}
/** Gets output tensors from a model. */
private static List<Tensor> getOutputTensors(Model model) {
// TFLite only support one subgraph currently.
SubGraph subgraph = model.subgraphs(0);
int tensorNum = subgraph.outputsLength();
ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
for (int i = 0; i < tensorNum; i++) {
outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
}
return Collections.unmodifiableList(outputTensors);
}
}

View File

@ -1,153 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
/** Extracts model metadata information out of TFLite metadata FlatBuffer. */
final class ModelMetadataInfo {
/** The root handler for the model metadata. */
private final ModelMetadata modelMetadata;
/** Metadata array of input tensors. */
private final List</* @Nullable */ TensorMetadata> inputsMetadata;
/** Metadata array of output tensors. */
private final List</* @Nullable */ TensorMetadata> outputsMetadata;
/** The minimum parser version required to fully understand the metadata flatbuffer. */
private final String /* @Nullable */ minVersion;
/**
* Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
*
* @param buffer the TFLite metadata FlatBuffer
* @throws NullPointerException if {@code buffer} is null
* @throws IllegalArgumentException if {@code buffer} does not contain any subgraph metadata, or
* it does not contain the expected identifier
*/
ModelMetadataInfo(ByteBuffer buffer) {
assertTFLiteMetadata(buffer);
modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
checkArgument(
modelMetadata.subgraphMetadataLength() > 0,
"The metadata flatbuffer does not contain any subgraph metadata.");
inputsMetadata = getInputsMetadata(modelMetadata);
outputsMetadata = getOutputsMetadata(modelMetadata);
minVersion = modelMetadata.minParserVersion();
}
/** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
int getInputTensorCount() {
return inputsMetadata.size();
}
/**
* Gets the metadata for the input tensor specified by {@code inputIndex}.
*
* @param inputIndex The index of the desired intput tensor.
* @throws IllegalArgumentException if the inputIndex specified is invalid.
*/
@Nullable
TensorMetadata getInputTensorMetadata(int inputIndex) {
checkArgument(
inputIndex >= 0 && inputIndex < inputsMetadata.size(),
"The inputIndex specified is invalid.");
return inputsMetadata.get(inputIndex);
}
/**
* Gets the minimum parser version of the metadata. It can be {@code null} if the version is not
* populated.
*/
@Nullable
String getMininumParserVersion() {
return minVersion;
}
/** Gets the root handler for the model metadata. */
ModelMetadata getModelMetadata() {
return modelMetadata;
}
/** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
int getOutputTensorCount() {
return outputsMetadata.size();
}
/**
* Gets the metadata for the output tensor specified by {@code outputIndex}.
*
* @param outputIndex The index of the desired output tensor.
* @throws IllegalArgumentException if the outputIndex specified is invalid.
*/
@Nullable
TensorMetadata getOutputTensorMetadata(int outputIndex) {
checkArgument(
outputIndex >= 0 && outputIndex < outputsMetadata.size(),
"The outputIndex specified is invalid.");
return outputsMetadata.get(outputIndex);
}
/**
* Verifies if the buffer is a valid TFLite metadata flatbuffer.
*
* @param buffer the TFLite metadata flatbuffer
* @throws NullPointerException if {@code buffer} is null.
* @throws IllegalArgumentException if {@code buffer} does not contain the expected identifier
*/
private static void assertTFLiteMetadata(ByteBuffer buffer) {
checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
checkArgument(
ModelMetadata.ModelMetadataBufferHasIdentifier(buffer),
"The identifier of the metadata is invalid. The buffer may not be a valid TFLite metadata"
+ " flatbuffer.");
}
/** Gets metadata for all input tensors. */
private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
int tensorNum = subgraphMetadata.inputTensorMetadataLength();
ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
for (int i = 0; i < tensorNum; i++) {
inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
}
return Collections.unmodifiableList(inputsMetadata);
}
/** Gets metadata for all output tensors. */
private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
int tensorNum = subgraphMetadata.outputTensorMetadataLength();
ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
for (int i = 0; i < tensorNum; i++) {
outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
}
return Collections.unmodifiableList(outputsMetadata);
}
}

View File

@ -1,184 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import org.checkerframework.checker.nullness.qual.Nullable;
/** Static error checking util methods. */
final class Preconditions {
/**
* Ensures that an object reference passed as a parameter to the calling method is not null.
*
* @param reference an object reference
* @return the non-null reference that was validated
* @throws NullPointerException if {@code reference} is null
*/
public static <T extends Object> T checkNotNull(T reference) {
if (reference == null) {
throw new NullPointerException("The object reference is null.");
}
return reference;
}
/**
* Ensures that an object reference passed as a parameter to the calling method is not null.
*
* @param reference an object reference
* @param errorMessage the exception message to use if the check fails; will be converted to a
* string using {@link String#valueOf(Object)}
* @return the non-null reference that was validated
* @throws NullPointerException if {@code reference} is null
*/
public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
if (reference == null) {
throw new NullPointerException(String.valueOf(errorMessage));
}
return reference;
}
/**
* Ensures that the given String is not empty and not null.
*
* @param string the String to test
* @return the non-null non-empty String that was validated
* @throws IllegalArgumentException if {@code string} is null or empty
*/
public static String checkNotEmpty(String string) {
if (string == null || string.length() == 0) {
throw new IllegalArgumentException("Given String is empty or null.");
}
return string;
}
/**
* Ensures that the given String is not empty and not null.
*
* @param string the String to test
* @param errorMessage the exception message to use if the check fails; will be converted to a
* string using {@link String#valueOf(Object)}
* @return the non-null non-empty String that was validated
* @throws IllegalArgumentException if {@code string} is null or empty
*/
public static String checkNotEmpty(String string, Object errorMessage) {
if (string == null || string.length() == 0) {
throw new IllegalArgumentException(String.valueOf(errorMessage));
}
return string;
}
/**
* Ensures the truth of an expression involving one or more parameters to the calling method.
*
* @param expression a boolean expression.
* @throws IllegalArgumentException if {@code expression} is false.
*/
public static void checkArgument(boolean expression) {
if (!expression) {
throw new IllegalArgumentException();
}
}
/**
* Ensures the truth of an expression involving one or more parameters to the calling method.
*
* @param expression a boolean expression.
* @param errorMessage the exception message to use if the check fails; will be converted to a
* string using {@link String#valueOf(Object)}.
* @throws IllegalArgumentException if {@code expression} is false.
*/
public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
if (!expression) {
throw new IllegalArgumentException(String.valueOf(errorMessage));
}
}
/**
* Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
* {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
*
* @param index a user-supplied index identifying an element of an array, list or string
* @param size the size of that array, list or string
* @return the value of {@code index}
* @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
* @throws IllegalArgumentException if {@code size} is negative
*/
public static int checkElementIndex(int index, int size) {
return checkElementIndex(index, size, "index");
}
/**
* Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
* {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
*
* @param index a user-supplied index identifying an element of an array, list or string
* @param size the size of that array, list or string
* @param desc the text to use to describe this index in an error message
* @return the value of {@code index}
* @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
* @throws IllegalArgumentException if {@code size} is negative
*/
public static int checkElementIndex(int index, int size, @Nullable String desc) {
// Carefully optimized for execution by hotspot (explanatory comment above)
if (index < 0 || index >= size) {
throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
}
return index;
}
/**
* Ensures the truth of an expression involving the state of the calling instance, but not
* involving any parameters to the calling method.
*
* @param expression a boolean expression
* @throws IllegalStateException if {@code expression} is false
* @see Verify#verify Verify.verify()
*/
public static void checkState(boolean expression) {
if (!expression) {
throw new IllegalStateException();
}
}
/**
* Ensures the truth of an expression involving the state of the calling instance, but not
* involving any parameters to the calling method.
*
* @param expression a boolean expression
* @param errorMessage the exception message to use if the check fails; will be converted to a
* string using {@link String#valueOf(Object)}
* @throws IllegalStateException if {@code expression} is false
* @see Verify#verify Verify.verify()
*/
public static void checkState(boolean expression, @Nullable Object errorMessage) {
if (!expression) {
throw new IllegalStateException(String.valueOf(errorMessage));
}
}
private static String badElementIndex(int index, int size, @Nullable String desc) {
if (index < 0) {
return String.format("%s (%s) must not be negative", desc, index);
} else if (size < 0) {
throw new IllegalArgumentException("negative size: " + size);
} else { // index >= size
return String.format("%s (%s) must be less than size (%s)", desc, index, size);
}
}
private Preconditions() {
throw new AssertionError("Preconditions is Uninstantiable.");
}
}

View File

@ -1,107 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channel;
/**
* A byte channel that maintains a current <i>position</i> and allows the position to be changed.
* {@link SeekableByteChannelCompat} is compatible with {@link
* java.nio.channels.SeekableByteChannel}.
*
* <p>{@link java.nio.channels.SeekableByteChannel} is not available in Android API 23 and under.
* Therefore, {@link SeekableByteChannelCompat} is introduced here to make the interfaces used in
* the MetadtaExtractor library consistent with the common used Java libraries.
*/
interface SeekableByteChannelCompat extends Channel {
/**
* Reads a sequence of bytes from this channel into the given buffer.
*
* @param dst The buffer into which bytes are to be transferred
* @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
* end-of-stream
* @throws NonReadableChannelException If this channel was not opened for reading
* @throws ClosedChannelException If this channel is closed
* @throws AsynchronousCloseException If another thread closes this channel while the read
* operation is in progress
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
* read operation is in progress, thereby closing the channel and setting the current thread's
* interrupt status
* @throws IOException If some other I/O error occurs
*/
int read(ByteBuffer dst) throws IOException;
/**
* Writes a sequence of bytes to this channel from the given buffer.
*
* @param src The buffer from which bytes are to be retrieved
* @return The number of bytes written, possibly zero
* @throws NonWritableChannelException If this channel was not opened for writing
* @throws ClosedChannelException If this channel is closed
* @throws AsynchronousCloseException If another thread closes this channel while the write
* operation is in progress
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
* write operation is in progress, thereby closing the channel and setting the current
* thread's interrupt status
* @throws IOException If some other I/O error occurs
*/
int write(ByteBuffer src) throws IOException;
/**
* Returns this channel's position.
*
* @return This channel's position, a non-negative integer counting the number of bytes from the
* beginning of the entity to the current position
* @throws ClosedChannelException If this channel is closed
* @throws IOException If some other I/O error occurs
*/
long position() throws IOException;
/**
* Sets this channel's position.
*
* @param newPosition The new position, a non-negative integer counting the number of bytes from
* the beginning of the entity
* @return This channel
* @throws ClosedChannelException If this channel is closed
* @throws IllegalArgumentException If the new position is negative
* @throws IOException If some other I/O error occurs
*/
SeekableByteChannelCompat position(long newPosition) throws IOException;
/**
* Returns the current size of entity to which this channel is connected.
*
* @return The current size, measured in bytes
* @throws ClosedChannelException If this channel is closed
* @throws IOException If some other I/O error occurs
*/
long size() throws IOException;
/**
* Truncates the entity, to which this channel is connected, to the given size.
*
* @param size The new size, a non-negative byte count
* @return This channel
* @throws NonWritableChannelException If this channel was not opened for writing
* @throws ClosedChannelException If this channel is closed
* @throws IllegalArgumentException If the new size is negative
* @throws IOException If some other I/O error occurs
*/
SeekableByteChannelCompat truncate(long size) throws IOException;
}

View File

@ -1,427 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.metadata.Preconditions.checkArgument;
import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.ZipException;
/**
* Reads uncompressed files from the TFLite model, a zip file.
*
* <p>TODO(b/150237111): add a link to the webpage of MetadataPopulator once it's available.
*
* <p>A TFLite model file becomes a zip file when it contains associated files. The associated files
* can be packed to a TFLite model file using the MetadataPopulator. The associated files are not
* compressed when being added to the model file.
*
* <p>{@link ZipFile} does not support Zip64 format, because TFLite models are much smaller than the
* size limit for Zip64, which is 4GB.
*/
final class ZipFile implements Closeable {
/** Maps String to list of ZipEntrys, name -> actual entries. */
private final Map<String, List<ZipEntry>> nameMap;
/** The actual data source. */
private final ByteBufferChannel archive;
/**
* Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
* ZipFile} does not synchronized over the buffer that is passed into it.
*
* @param channel the archive
* @throws IOException if an error occurs while creating this {@link ZipFile}
* @throws ZipException if the channel is not a zip archive
* @throws NullPointerException if the archive is null
*/
public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
checkNotNull(channel);
ZipParser zipParser = new ZipParser(channel);
Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
return new ZipFile(channel, nameMap);
}
@Override
public void close() {
archive.close();
}
/**
* Exposes the raw stream of the archive entry.
*
* <p>Since the associated files will not be compressed when being packed to the zip file, the raw
* stream represents the non-compressed files.
*
* <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
* threads concurrently reading from the returned {@link InputStream}, it must be synchronized
* externally.
*
* @param name name of the entry to get the stream for
* @return the raw input stream containing data
* @throws IllegalArgumentException if the specified file does not exist in the zip file
*/
public InputStream getRawInputStream(String name) {
checkArgument(
nameMap.containsKey(name),
String.format("The file, %s, does not exist in the zip file.", name));
List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
ZipEntry entry = entriesWithTheSameName.get(0);
long start = entry.getDataOffset();
long remaining = entry.getSize();
return new BoundedInputStream(archive, start, remaining);
}
private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
archive = channel;
this.nameMap = nameMap;
}
/* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
private static class ZipParser {
private final ByteBufferChannel archive;
// Cached buffers that will only be used locally in the class to reduce garbage collection.
private final ByteBuffer longBuffer =
ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer intBuffer =
ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer shortBuffer =
ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
private ZipParser(ByteBufferChannel archive) {
this.archive = archive;
}
/**
* Parses the underlying {@code archive} and returns the information as a list of {@link
* ZipEntry}.
*/
private Map<String, List<ZipEntry>> parseEntries() throws IOException {
List<ZipEntry> entries = parseCentralDirectory();
return parseLocalFileHeaderData(entries);
}
/**
* Checks if the current position contains a central file header signature, {@link
* ZipConstants#CENSIG}.
*/
private boolean foundCentralFileheaderSignature() {
long signature = (long) getInt();
return signature == ZipConstants.CENSIG;
}
/**
* Gets the value as a Java int from two bytes starting at the current position of the archive.
*/
private int getShort() {
shortBuffer.rewind();
archive.read(shortBuffer);
shortBuffer.flip();
return (int) shortBuffer.getShort();
}
/**
* Gets the value as a Java long from four bytes starting at the current position of the
* archive.
*/
private int getInt() {
intBuffer.rewind();
archive.read(intBuffer);
intBuffer.flip();
return intBuffer.getInt();
}
/**
* Gets the value as a Java long from four bytes starting at the current position of the
* archive.
*/
private long getLong() {
longBuffer.rewind();
archive.read(longBuffer);
longBuffer.flip();
return longBuffer.getLong();
}
/**
* Positions the archive at the start of the central directory.
*
* <p>First, it searches for the signature of the "end of central directory record", {@link
* ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
* record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG}
* should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file.
*
* <p>Then, parse the "end of central dir record" and position the archive at the start of the
* central directory.
*/
private void locateCentralDirectory() throws IOException {
if (archive.size() < ZipConstants.ENDHDR) {
throw new ZipException("The archive is not a ZIP archive.");
}
// Positions the archive at the start of the "end of central directory record".
long offsetRecord = archive.size() - ZipConstants.ENDHDR;
archive.position(offsetRecord);
// Checks for the signature, {@link ZipConstants#ENDSIG}.
long endSig = getLong();
if (endSig != ZipConstants.ENDSIG) {
throw new ZipException("The archive is not a ZIP archive.");
}
// Positions the archive at the offset of central directory.
skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
// Gets the offset to central directory
long offsetDirectory = getInt();
// Goes to the central directory.
archive.position(offsetDirectory);
}
/**
* Reads the central directory of the given archive and populates the internal tables with
* {@link ZipEntry} instances.
*/
private List<ZipEntry> parseCentralDirectory() throws IOException {
/** List of entries in the order they appear inside the central directory. */
List<ZipEntry> entries = new ArrayList<>();
locateCentralDirectory();
while (foundCentralFileheaderSignature()) {
ZipEntry entry = parseCentralDirectoryEntry();
entries.add(entry);
}
return entries;
}
/**
* Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to
* the global maps.
*/
private ZipEntry parseCentralDirectoryEntry() throws IOException {
// Positions the archive at the "compressed size" and read the value.
skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
long compressSize = getInt();
// Positions the archive at the "filename length" and read the value.
skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
int fileNameLen = getShort();
// Reads the extra field length and the comment length.
int extraLen = getShort();
int commentLen = getShort();
// Positions the archive at the "local file header offset" and read the value.
skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
long localHeaderOffset = getInt();
// Reads the file name.
byte[] fileNameBuf = new byte[fileNameLen];
archive.read(ByteBuffer.wrap(fileNameBuf));
String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
// Skips the extra field and the comment.
skipBytes(extraLen + commentLen);
ZipEntry entry = new ZipEntry();
entry.setSize(compressSize);
entry.setLocalHeaderOffset(localHeaderOffset);
entry.setName(fileName);
return entry;
}
/** Walks through all recorded entries and records the offsets for the entry data. */
private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
/** Maps String to list of ZipEntrys, name -> actual entries. */
Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
for (ZipEntry entry : entries) {
long offset = entry.getLocalHeaderOffset();
archive.position(offset + ZipConstants.LOCNAM);
// Gets the data offset of this entry.
int fileNameLen = getShort();
int extraFieldLen = getShort();
long dataOffset =
offset
+ ZipConstants.LOCEXT
+ ZipConstants.SHORT_BYTE_SIZE
+ fileNameLen
+ extraFieldLen;
entry.setDataOffset(dataOffset);
// Puts the entry into the nameMap.
String name = entry.getName();
List<ZipEntry> entriesWithTheSameName;
if (nameMap.containsKey(name)) {
entriesWithTheSameName = nameMap.get(name);
} else {
entriesWithTheSameName = new ArrayList<>();
nameMap.put(name, entriesWithTheSameName);
}
entriesWithTheSameName.add(entry);
}
return nameMap;
}
/** Skips the given number of bytes or throws an EOFException if skipping failed. */
private void skipBytes(int count) throws IOException {
long currentPosition = archive.position();
long newPosition = currentPosition + count;
if (newPosition > archive.size()) {
throw new EOFException();
}
archive.position(newPosition);
}
}
/** Stores the data offset and the size of an entry in the archive. */
private static class ZipEntry {
private String name;
private long dataOffset = -1;
private long size = -1;
private long localHeaderOffset = -1;
public long getSize() {
return size;
}
public long getDataOffset() {
return dataOffset;
}
public String getName() {
return name;
}
public long getLocalHeaderOffset() {
return localHeaderOffset;
}
public void setSize(long size) {
this.size = size;
}
public void setDataOffset(long dataOffset) {
this.dataOffset = dataOffset;
}
public void setName(String name) {
this.name = name;
}
public void setLocalHeaderOffset(long localHeaderOffset) {
this.localHeaderOffset = localHeaderOffset;
}
}
/**
* Various constants for this {@link ZipFile}.
*
* <p>Referenced from {@link java.util.zip.ZipConstants}.
*/
private static class ZipConstants {
/** length of Java short in bytes. */
static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
/** length of Java int in bytes. */
static final int INT_BYTE_SIZE = Integer.SIZE / 8;
/** length of Java long in bytes. */
static final int LONG_BYTE_SIZE = Long.SIZE / 8;
/*
* Header signatures
*/
static final long LOCSIG = 0x04034b50L; // "PK\003\004"
static final long EXTSIG = 0x08074b50L; // "PK\007\008"
static final long CENSIG = 0x02014b50L; // "PK\001\002"
static final long ENDSIG = 0x06054b50L; // "PK\005\006"
/*
* Header sizes in bytes (including signatures)
*/
static final int LOCHDR = 30; // LOC header size
static final int EXTHDR = 16; // EXT header size
static final int CENHDR = 46; // CEN header size
static final int ENDHDR = 22; // END header size
/*
* Local file (LOC) header field offsets
*/
static final int LOCVER = 4; // version needed to extract
static final int LOCFLG = 6; // general purpose bit flag
static final int LOCHOW = 8; // compression method
static final int LOCTIM = 10; // modification time
static final int LOCCRC = 14; // uncompressed file crc-32 value
static final int LOCSIZ = 18; // compressed size
static final int LOCLEN = 22; // uncompressed size
static final int LOCNAM = 26; // filename length
static final int LOCEXT = 28; // extra field length
/*
* Extra local (EXT) header field offsets
*/
static final int EXTCRC = 4; // uncompressed file crc-32 value
static final int EXTSIZ = 8; // compressed size
static final int EXTLEN = 12; // uncompressed size
/*
* Central directory (CEN) header field offsets
*/
static final int CENVEM = 4; // version made by
static final int CENVER = 6; // version needed to extract
static final int CENFLG = 8; // encrypt, decrypt flags
static final int CENHOW = 10; // compression method
static final int CENTIM = 12; // modification time
static final int CENCRC = 16; // uncompressed file crc-32 value
static final int CENSIZ = 20; // compressed size
static final int CENLEN = 24; // uncompressed size
static final int CENNAM = 28; // filename length
static final int CENEXT = 30; // extra field length
static final int CENCOM = 32; // comment length
static final int CENDSK = 34; // disk number start
static final int CENATT = 36; // internal file attributes
static final int CENATX = 38; // external file attributes
static final int CENOFF = 42; // LOC header offset
/*
* End of central directory (END) header field offsets
*/
static final int ENDSUB = 8; // number of entries on this disk
static final int ENDTOT = 10; // total number of entries
static final int ENDSIZ = 12; // central directory size in bytes
static final int ENDOFF = 16; // offset of first CEN header
static final int ENDCOM = 20; // zip file comment length
private ZipConstants() {}
}
}

View File

@ -1,615 +0,0 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite metadata tools."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
import shutil
import tempfile
import warnings
import zipfile
from flatbuffers.python import flatbuffers
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
from tensorflow.lite.experimental.support.metadata.cc.python import _pywrap_metadata_version
from tensorflow.lite.experimental.support.metadata.flatbuffers_lib import _pywrap_flatbuffers
from tensorflow.python.platform import resource_loader
_FLATC_TFLITE_METADATA_SCHEMA_FILE = resource_loader.get_path_to_datafile(
"metadata_schema.fbs")
# TODO(b/141467403): add delete method for associated files.
class MetadataPopulator(object):
"""Packs metadata and associated files into TensorFlow Lite model file.
MetadataPopulator can be used to populate metadata and model associated files
into a model file or a model buffer (in bytearray). It can also help to
inspect list of files that have been packed into the model or are supposed to
be packed into the model.
The metadata file (or buffer) should be generated based on the metadata
schema:
third_party/tensorflow/lite/schema/metadata_schema.fbs
Example usage:
Populate matadata and label file into an image classifier model.
First, based on metadata_schema.fbs, generate the metadata for this image
classifer model using Flatbuffers API. Attach the label file onto the ouput
tensor (the tensor of probabilities) in the metadata.
Then, pack the metadata and label file into the model as follows.
```python
# Populating a metadata file (or a metadta buffer) and associated files to
a model file:
populator = MetadataPopulator.with_model_file(model_file)
# For metadata buffer (bytearray read from the metadata file), use:
# populator.load_metadata_buffer(metadata_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
populator.populate()
# Populating a metadata file (or a metadta buffer) and associated files to
a model buffer:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
populator.populate()
# Writing the updated model buffer into a file.
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
```
Note that existing metadata buffer (if applied) will be overridden by the new
metadata buffer.
"""
# As Zip API is used to concatenate associated files after tflite model file,
# the populating operation is developed based on a model file. For in-memory
# model buffer, we create a tempfile to serve the populating operation.
# Creating the deleting such a tempfile is handled by the class,
# _MetadataPopulatorWithBuffer.
METADATA_FIELD_NAME = "TFLITE_METADATA"
TFLITE_FILE_IDENTIFIER = b"TFL3"
METADATA_FILE_IDENTIFIER = b"M001"
def __init__(self, model_file):
"""Constructor for MetadataPopulator.
Args:
model_file: valid path to a TensorFlow Lite model file.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
_assert_model_file_identifier(model_file)
self._model_file = model_file
self._metadata_buf = None
self._associated_files = set()
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataPopulator object that populates data to a model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataPopulator object.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
return cls(model_file)
# TODO(b/141468993): investigate if type check can be applied to model_buf for
# FB.
@classmethod
def with_model_buffer(cls, model_buf):
"""Creates a MetadataPopulator object that populates data to a model buffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Returns:
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
Raises:
ValueError: the model does not have the expected flatbuffer identifer.
"""
return _MetadataPopulatorWithBuffer(model_buf)
def get_model_buffer(self):
"""Gets the buffer of the model with packed metadata and associated files.
Returns:
Model buffer (in bytearray).
"""
with open(self._model_file, "rb") as f:
return f.read()
def get_packed_associated_file_list(self):
"""Gets a list of associated files packed to the model file.
Returns:
List of packed associated files.
"""
if not zipfile.is_zipfile(self._model_file):
return []
with zipfile.ZipFile(self._model_file, "r") as zf:
return zf.namelist()
def get_recorded_associated_file_list(self):
"""Gets a list of associated files recorded in metadata of the model file.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Returns:
List of recorded associated files.
"""
recorded_files = []
if not self._metadata_buf:
return recorded_files
metadata = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
self._metadata_buf, 0)
# Add associated files attached to ModelMetadata
self._get_associated_files_from_metadata_struct(metadata, recorded_files)
# Add associated files attached to each SubgraphMetadata
for j in range(metadata.SubgraphMetadataLength()):
subgraph = metadata.SubgraphMetadata(j)
self._get_associated_files_from_metadata_struct(subgraph, recorded_files)
# Add associated files attached to each input tensor
for k in range(subgraph.InputTensorMetadataLength()):
tensor = subgraph.InputTensorMetadata(k)
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
# Add associated files attached to each output tensor
for k in range(subgraph.OutputTensorMetadataLength()):
tensor = subgraph.OutputTensorMetadata(k)
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
return recorded_files
def load_associated_files(self, associated_files):
"""Loads associated files that to be concatenated after the model file.
Args:
associated_files: list of file paths.
Raises:
IOError:
File not found.
"""
for af in associated_files:
_assert_exist(af)
self._associated_files.add(af)
def load_metadata_buffer(self, metadata_buf):
"""Loads the metadata buffer (in bytearray) to be populated.
Args:
metadata_buf: metadata buffer (in bytearray) to be populated.
Raises:
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Error occurs when getting the minimum metadata parser version.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")
_assert_metadata_buffer_identifier(metadata_buf)
# Gets the minimum metadata parser version of the metadata_buf.
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
bytes(metadata_buf))
# Inserts in the minimum metadata parser version into the metadata_buf.
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
metadata.minParserVersion = min_version
b = flatbuffers.Builder(0)
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
metadata_buf_with_version = b.Output()
self._metadata_buf = metadata_buf_with_version
def load_metadata_file(self, metadata_file):
"""Loads the metadata file to be populated.
Args:
metadata_file: path to the metadata file to be populated.
Raises:
IOError: File not found.
ValueError: The metadata does not have the expected flatbuffer identifer.
"""
_assert_exist(metadata_file)
with open(metadata_file, "rb") as f:
metadata_buf = f.read()
self.load_metadata_buffer(bytearray(metadata_buf))
def populate(self):
"""Populates loaded metadata and associated files into the model file."""
self._assert_validate()
self._populate_metadata_buffer()
self._populate_associated_files()
def _assert_validate(self):
"""Validates the metadata and associated files to be populated.
Raises:
ValueError:
File is recorded in the metadata, but is not going to be populated.
File has already been packed.
"""
# Gets files that are recorded in metadata.
recorded_files = self.get_recorded_associated_file_list()
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
# Gets the file name of those associated files to be populated.
to_be_populated_files = []
for af in self._associated_files:
to_be_populated_files.append(os.path.basename(af))
# Checks all files recorded in the metadata will be populated.
for rf in recorded_files:
if rf not in to_be_populated_files and rf not in packed_files:
raise ValueError("File, '{0}', is recorded in the metadata, but has "
"not been loaded into the populator.".format(rf))
for f in to_be_populated_files:
if f in packed_files:
raise ValueError("File, '{0}', has already been packed.".format(f))
if f not in recorded_files:
warnings.warn(
"File, '{0}', does not exsit in the metadata. But packing it to "
"tflite model is still allowed.".format(f))
def _copy_archived_files(self, src_zip, dst_zip, file_list):
"""Copy archieved files in file_list from src_zip ro dst_zip."""
if not zipfile.is_zipfile(src_zip):
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
with zipfile.ZipFile(src_zip,
"r") as src_zf, zipfile.ZipFile(dst_zip,
"a") as dst_zf:
src_list = src_zf.namelist()
for f in file_list:
if f not in src_list:
raise ValueError(
"File, '{0}', does not exist in the zipfile, {1}.".format(
f, src_zip))
file_buffer = src_zf.read(f)
dst_zf.writestr(f, file_buffer)
def _get_associated_files_from_metadata_struct(self, file_holder, file_list):
for j in range(file_holder.AssociatedFilesLength()):
file_list.append(file_holder.AssociatedFiles(j).Name().decode("utf-8"))
def _populate_associated_files(self):
"""Concatenates associated files after TensorFlow Lite model file.
If the MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
"""
# Opens up the model file in "appending" mode.
# If self._model_file already has pack files, zipfile will concatenate
# addition files after self._model_file. For example, suppose we have
# self._model_file = old_tflite_file | label1.txt | label2.txt
# Then after trigger populate() to add label3.txt, self._model_file becomes
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
with zipfile.ZipFile(self._model_file, "a") as zf:
for af in self._associated_files:
filename = os.path.basename(af)
zf.write(af, filename)
def _populate_metadata_buffer(self):
"""Populates the metadata buffer (in bytearray) into the model file.
Inserts metadata_buf into the metadata field of schema.Model. If the
MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
Existing metadata buffer (if applied) will be overridden by the new metadata
buffer.
"""
with open(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.ModelT.InitFromObj(
_schema_fb.Model.GetRootAsModel(model_buf, 0))
buffer_field = _schema_fb.BufferT()
buffer_field.data = self._metadata_buf
is_populated = False
if not model.metadata:
model.metadata = []
else:
# Check if metadata has already been populated.
for meta in model.metadata:
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
is_populated = True
model.buffers[meta.buffer] = buffer_field
if not is_populated:
if not model.buffers:
model.buffers = []
model.buffers.append(buffer_field)
# Creates a new metadata field.
metadata_field = _schema_fb.MetadataT()
metadata_field.name = self.METADATA_FIELD_NAME
metadata_field.buffer = len(model.buffers) - 1
model.metadata.append(metadata_field)
# Packs model back to a flatbuffer binaray file.
b = flatbuffers.Builder(0)
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
model_buf = b.Output()
# Saves the updated model buffer to model file.
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
if packed_files:
# Writes the updated model buffer and associated files into a new model
# file. Then overwrites the original model file.
with tempfile.NamedTemporaryFile() as temp:
new_file = temp.name
with open(new_file, "wb") as f:
f.write(model_buf)
self._copy_archived_files(self._model_file, new_file, packed_files)
shutil.copy(new_file, self._model_file)
os.remove(new_file)
else:
with open(self._model_file, "wb") as f:
f.write(model_buf)
class _MetadataPopulatorWithBuffer(MetadataPopulator):
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
This class is used to populate metadata into a in-memory model buffer. As we
use Zip API to concatenate associated files after tflite model file, the
populating operation is developed based on a model file. For in-memory model
buffer, we create a tempfile to serve the populating operation. This class is
then used to generate this tempfile, and delete the file when the
MetadataPopulator object is deleted.
"""
def __init__(self, model_buf):
"""Constructor for _MetadataPopulatorWithBuffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Raises:
ValueError: model_buf is empty.
ValueError: model_buf does not have the expected flatbuffer identifer.
"""
if not model_buf:
raise ValueError("model_buf cannot be empty.")
with tempfile.NamedTemporaryFile() as temp:
model_file = temp.name
with open(model_file, "wb") as f:
f.write(model_buf)
MetadataPopulator.__init__(self, model_file)
def __del__(self):
"""Destructor of _MetadataPopulatorWithBuffer.
Deletes the tempfile.
"""
if os.path.exists(self._model_file):
os.remove(self._model_file)
class MetadataDisplayer(object):
"""Displays metadata and associated file info in human-readable format."""
def __init__(self, model_file, metadata_file, associated_file_list):
"""Constructor for MetadataDisplayer.
Args:
model_file: valid path to the model file.
metadata_file: valid path to the metadata file.
associated_file_list: list of associate files in the model file.
"""
_assert_model_file_identifier(model_file)
_assert_metadata_file_identifier(metadata_file)
self._model_file = model_file
self._metadata_file = metadata_file
self._associated_file_list = associated_file_list
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataDisplayer object for the model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataDisplayer object.
Raises:
IOError: File not found.
ValueError: The model does not have metadata.
"""
_assert_exist(model_file)
metadata_file = cls._save_temporary_metadata_file(model_file)
associated_file_list = cls._parse_packed_associted_file_list(model_file)
return cls(model_file, metadata_file, associated_file_list)
@classmethod
def with_model_buffer(cls, model_buffer):
"""Creates a MetadataDisplayer object for a file buffer.
Args:
model_buffer: TensorFlow Lite model buffer in bytearray.
Returns:
MetadataDisplayer object.
"""
if not model_buffer:
raise ValueError("model_buffer cannot be empty.")
with tempfile.NamedTemporaryFile() as temp:
model_file = temp.name
with open(model_file, "wb") as f:
f.write(model_buffer)
return cls.with_model_file(model_file)
def get_metadata_json(self):
"""Converts the metadata into a json string."""
opt = _pywrap_flatbuffers.IDLOptions()
opt.strict_json = True
parser = _pywrap_flatbuffers.Parser(opt)
with open(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f:
metadata_schema_content = f.read()
with open(self._metadata_file, "rb") as f:
metadata_file_content = f.read()
if not parser.parse(metadata_schema_content):
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
with open(self._metadata_file, "rb") as f:
metadata_file_content = f.read()
return _pywrap_flatbuffers.generate_text(parser, metadata_file_content)
def get_packed_associated_file_list(self):
"""Returns a list of associated files that are packed in the model.
Returns:
A name list of associated files.
"""
return copy.deepcopy(self._associated_file_list)
@staticmethod
def _save_temporary_metadata_file(model_file):
"""Saves the metadata in the model file to a temporary file.
Args:
model_file: valid path to the model file.
Returns:
Path to the metadata temporary file.
Raises:
ValueError: The model does not have metadata.
"""
with open(model_file, "rb") as f:
model_buf = f.read()
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
# Gets metadata from the model file.
for i in range(tflite_model.MetadataLength()):
meta = tflite_model.Metadata(i)
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
buffer_index = meta.Buffer()
metadata = tflite_model.Buffers(buffer_index)
metadata_buf = metadata.DataAsNumpy().tobytes()
# Creates a temporary file to store the metadata.
with tempfile.NamedTemporaryFile() as temp:
metadata_file = temp.name
# Saves the metadata into the temporary file.
with open(metadata_file, "wb") as f:
f.write(metadata_buf)
return metadata_file
raise ValueError("The model does not have metadata.")
@staticmethod
def _parse_packed_associted_file_list(model_file):
"""Gets a list of associated files packed to the model file.
Args:
model_file: valid path to the model file.
Returns:
List of packed associated files.
"""
if not zipfile.is_zipfile(model_file):
return []
with zipfile.ZipFile(model_file, "r") as zf:
return zf.namelist()
def __del__(self):
"""Destructor of MetadataDisplayer.
Deletes the tempfile.
"""
if os.path.exists(self._metadata_file):
os.remove(self._metadata_file)
def _assert_exist(filename):
"""Checks if a file exists."""
if not os.path.exists(filename):
raise IOError("File, '{0}', does not exist.".format(filename))
def _assert_model_file_identifier(model_file):
"""Checks if a model file has the expected TFLite schema identifier."""
_assert_exist(model_file)
with open(model_file, "rb") as f:
model_buf = f.read()
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
raise ValueError(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.")
def _assert_metadata_file_identifier(metadata_file):
"""Checks if a metadata file has the expected Metadata schema identifier."""
_assert_exist(metadata_file)
with open(metadata_file, "rb") as f:
metadata_buf = f.read()
_assert_metadata_buffer_identifier(metadata_buf)
def _assert_metadata_buffer_identifier(metadata_buf):
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
metadata_buf, 0):
raise ValueError(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.")

View File

@ -1,26 +0,0 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Information about the metadata parser that this python library depends on."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class MetadataParser(object):
"""Information about the metadata parser."""
# The version of the metadata parser.
VERSION = "{LATEST_METADATA_PARSER_VERSION}"

View File

@ -1,38 +0,0 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.lite.experimental.support.metadata.metadata_parser."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.lite.experimental.support.metadata import metadata_parser
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
class MetadataParserTest(test_util.TensorFlowTestCase):
def test_version_wellFormedSemanticVersion(self):
# Validates that the version is well-formed (x.y.z).
self.assertTrue(
re.match('[0-9]+\\.[0-9]+\\.[0-9]+',
metadata_parser.MetadataParser.VERSION))
if __name__ == '__main__':
test.main()

View File

@ -1,570 +0,0 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace tflite;
// TFLite metadata contains both human readable and machine readable information
// about what the model does and how to use the model. It can be used as a
// README file, which elaborates the details of the model, each input/ouput
// tensor, and each associated file.
//
// An important use case of TFLite metadata is the TFLite codegen tool, which
// automatically generates the model interface based on the properties of the
// model and the tensors. The model interface provides high-level APIs to
// interact with the model, such as preprocessing the input data and running
// inferences.
//
// Entries marked with "<Codegen usage>" are used in TFLite codegen tool to
// generate the model interface. It is recommended to fill in at least those
// enties to boost the codegen performance.
// The Metadata schema is versioned by the Semantic versioning number, such as
// MAJOR.MINOR.PATCH. It tracks the schema changes according to the rules below:
// * Bump up the MAJOR number when making potentially backwards incompatible
// changes. It must be incremented if the new changes break the backwards
// compatibility. It may also include minor and patch level changes as
// needed. The true backwards compatibility is indicated by the file
// identifier.
// * Bump up the MINOR number when making backwards compatible updates for
// major features, such as supporting new content types or adding new
// processing units.
// * Bump up the PATCH number when making small backwards compatible changes,
// such as adding a new fields or deprecating certain fields (not deleting
// them).
//
// ModelMetadata.min_parser_version indicates the minimum necessary metadata
// parser version to fully understand all fields in a given metadata flatbuffer.
//
// New fields and types will have associated comments with the schema version
// for which they were added.
//
// LINT.IfChange
// Schema Semantic version: 1.0.1
// LINT.ThenChange(//tensorflow/lite/experimental/\
//. support/metadata/java/src/java/org/tensorflow/lite/support/metadata/\
//. MetadataParser.java)
// This indicates the flatbuffer compatibility. The number will bump up when a
// break change is applied to the schema, such as removing fields or adding new
// fields to the middle of a table.
file_identifier "M001";
// History:
// 1.0.1 - Added VOCABULARY type to AssociatedFileType.
// File extension of any written files.
file_extension "tflitemeta";
// LINT.IfChange
enum AssociatedFileType : byte {
UNKNOWN = 0,
// Files such as readme.txt.
DESCRIPTIONS = 1,
// Contains labels that annotate certain axis of the tensor. For example,
// the label file in image classification. Those labels annotate the
// the output tensor, such that each value in the output tensor is the
// probability of that corresponding category specified by the label.
//
// <Codegen usage>:
// If an output tensor has an associated file as TENSOR_AXIS_LABELS, return
// the output as a mapping between the labels and probability in the model
// interface.
// If multiple files of the same type are present, the first one is used by
// default; additional ones are to be distinguished from one another by their
// specified locale.
TENSOR_AXIS_LABELS = 2,
// Contains labels that tensor values correspond to. For example, in
// the object detection model, one of the output tensors is the detected
// classes. And each value in the tensor refers to the index of label in the
// category label file.
//
// <Codegen usage>:
// If an output tensor has an associated file as TENSOR_VALUE_LABELS, convert
// the tensor values into labels, and return a list of string as the output.
// If multiple files of the same type are present, the first one is used by
// default; additional ones are to be distinguished from one another by their
// specified locale.
TENSOR_VALUE_LABELS = 3,
// Contains sigmoid-based score calibration parameters, formatted as CSV.
// Lines contain for each index of an output tensor the scale, slope, offset
// and (optional) min_score parameters to be used for sigmoid fitting (in this
// order and in `strtof`-compatible [1] format).
// A line may be left empty to default calibrated scores for this index to
// default_score.
// In summary, each line should thus contain 0, 3 or 4 comma-separated values.
//
// See documentation for ScoreCalibrationOptions for details.
//
// [1]: https://en.cppreference.com/w/c/string/byte/strtof
TENSOR_AXIS_SCORE_CALIBRATION = 4,
// Contains a list of unique words (characters separated by "\n" or in lines)
// that help to convert natural language words to embedding vectors.
// Added in: 1.0.1
VOCABULARY = 5,
}
table AssociatedFile {
// Name of this file. Need to be exact the same as the name of the actual file
// packed into the TFLite model as a zip file.
//
// <Codegen usage>:
// Locates to the actual file in the TFLite model.
name:string;
// A description of what the file is.
description:string;
// Type of the associated file. There may be special pre/post processing for
// some types. For example in image classification, a label file of the output
// will be used to convert object index into string.
//
// <Codegen usage>:
// Determines how to process the corresponding tensor.
type:AssociatedFileType;
// An optional locale for this associated file (if applicable). It is
// recommended to use an ISO 639-1 letter code (e.g. "en" for English),
// optionally completed by a two letter region code (e.g. "en-US" for US
// English and "en-CA" for Canadian English).
// Leverage this in order to specify e.g multiple label files translated in
// different languages.
locale:string;
}
// The basic content type for all tensors.
//
// <Codegen usage>:
// Input feature tensors:
// 1. Generates the method to load data from a TensorBuffer.
// 2. Creates the preprocessing logic. The default processing pipeline is:
// [NormalizeOp, QuantizeOp].
// Output feature tensors:
// 1. Generates the method to return the output data to a TensorBuffer.
// 2. Creates the post-processing logic. The default processing pipeline is:
// [DeQuantizeOp].
table FeatureProperties {
}
// The type of color space of an image.
enum ColorSpaceType : byte {
UNKNOWN = 0,
RGB = 1,
GRAYSCALE = 2,
}
table ImageSize {
width:uint;
height:uint;
}
// The properties for image tensors.
//
// <Codegen usage>:
// Input image tensors:
// 1. Generates the method to load an image from a TensorImage.
// 2. Creates the preprocessing logic. The default processing pipeline is:
// [ResizeOp, NormalizeOp, QuantizeOp].
// Output image tensors:
// 1. Generates the method to return the output data to a TensorImage.
// 2. Creates the post-processing logic. The default processing pipeline is:
// [DeQuantizeOp].
table ImageProperties {
// The color space of the image.
//
// <Codegen usage>:
// Determines how to convert the color space of a given image from users.
color_space:ColorSpaceType;
// Indicates the default value of image width and height if the tensor shape
// is dynamic. For fixed-size tensor, this size will be consistent with the
// expected size.
default_size:ImageSize;
}
// The properties for tensors representing bounding boxes.
//
// <Codegen usage>:
// Input image tensors: NA.
// Output image tensors: parses the values into a data stucture that represents
// bounding boxes. For example, in the generated wrapper for Android, it returns
// the output as android.graphics.Rect objects.
enum BoundingBoxType : byte {
UNKNOWN = 0,
// Represents the bounding box by using the combination of boundaries,
// {left, top, right, bottom}.
// The default order is {left, top, right, bottom}. Other orders can be
// indicated by BoundingBoxProperties.index.
BOUNDARIES = 1,
// Represents the bounding box by using the upper_left corner, width and
// height.
// The default order is {upper_left_x, upper_left_y, width, height}. Other
// orders can be indicated by BoundingBoxProperties.index.
UPPER_LEFT = 2,
// Represents the bounding box by using the center of the box, width and
// height. The default order is {center_x, center_y, width, height}. Other
// orders can be indicated by BoundingBoxProperties.index.
CENTER = 3,
}
enum CoordinateType : byte {
// The coordinates are float values from 0 to 1.
RATIO = 0,
// The coordinates are integers.
PIXEL = 1,
}
table BoundingBoxProperties {
// Denotes the order of the elements defined in each bounding box type. An
// empty index array represent the default order of each bounding box type.
// For example, to denote the default order of BOUNDARIES, {left, top, right,
// bottom}, the index should be {0, 1, 2, 3}. To denote the order {left,
// right, top, bottom}, the order should be {0, 2, 1, 3}.
//
// The index array can be applied to all bounding box types to adjust the
// order of their corresponding underlying elements.
//
// <Codegen usage>:
// Indicates how to parse the bounding box values.
index:[uint];
// <Codegen usage>:
// Indicates how to parse the bounding box values.
type:BoundingBoxType;
// <Codegen usage>:
// Indicates how to convert the bounding box back to the original image in
// pixels.
coordinate_type:CoordinateType;
}
union ContentProperties {
FeatureProperties,
ImageProperties,
BoundingBoxProperties,
}
table ValueRange {
min:int;
max:int;
}
table Content {
// The properties that the content may have, indicating the type of the
// Content.
//
// <Codegen usage>:
// Indicates how to process the tensor.
content_properties:ContentProperties;
// The range of dimensions that the content corresponds to. A NULL
// "range" indicates that the content uses up all dimensions,
// except the batch axis if applied.
//
// Here are all the possible situations of how a tensor is composed.
// Case 1: The tensor is a single object, such as an image.
// For example, the input of an image classifier
// (https://www.tensorflow.org/lite/models/image_classification/overview),
// a tensor of shape [1, 224, 224, 3]. Dimensions 1 to 3 correspond to the
// image. Since dimension 0 is a batch axis, which can be ignored,
// "range" can be left as NULL.
//
// Case 2: The tensor contains multiple instances of the same object.
// For example, the output tensor of detected bounding boxes of an object
// detection model
// (https://www.tensorflow.org/lite/models/object_detection/overview).
// The tensor shape is [1, 10, 4]. Here is the what the three dimensions
// represent for:
// dimension 0: the batch axis.
// dimension 1: the 10 objects detected with the highest confidence.
// dimension 2: the bounding boxes of the 10 detected objects.
// The tensor is essentially 10 bounding boxes. In this case,
// "range" should be {min=2; max=2;}.
// Another example is the pose estimation model
// (https://www.tensorflow.org/lite/models/pose_estimation/overview).
// The output tensor of heatmaps is in the shape of [1, 9, 9, 17].
// Here is the what the four dimensions represent for:
// dimension 0: the batch axis.
// dimension 1/2: the heatmap image.
// dimension 3: 17 body parts of a person.
// Even though the last axis is body part, the real content of this tensor is
// the heatmap. "range" should be [min=1; max=2].
//
// Case 3: The tensor contains multiple different objects. (Not supported by
// Content at this point).
// Sometimes a tensor may contain multiple different objects, thus different
// contents. It is very common for regression models. For example, a model
// to predict the fuel efficiency
// (https://www.tensorflow.org/tutorials/keras/regression).
// The input tensor has shape [1, 9], consisting of 9 features, such as
// "Cylinders", "Displacement", "Weight", etc. In this case, dimension 1
// contains 9 different contents. However, since these sub-dimension objects
// barely need to be specifically processed, their contents are not recorded
// in the metadata. Through, the name of each dimension can be set through
// TensorMetadata.dimension_names.
//
// Note that if it is not case 3, a tensor can only have one content type.
//
// <Codegen usage>:
// Case 1: return a processed single object of certain content type.
// Case 2: return a list of processed objects of certain content type. The
// generated model interface have API to random access those objects from
// the output.
range:ValueRange;
}
// Parameters that are used when normalizing the tensor.
table NormalizationOptions{
// mean and std are normalization parameters. Tensor values are normalized
// on a per-channel basis, by the formula
// (x - mean) / std.
// If there is only one value in mean or std, we'll propogate the value to
// all channels.
//
// Quantized models share the same normalization parameters as their
// corresponding float models. For example, an image input tensor may have
// the normalization parameter of
// mean = 127.5f and std = 127.5f.
// The image value will be normalized from [0, 255] to [-1, 1].
// Then, for quantized models, the image data should be further quantized
// according to the quantization parameters. In the case of uint8, the image
// data will be scaled back to [0, 255], while for int8, the image data will
// be scaled to [-128, 127].
//
// Both the normalization parameters and quantization parameters can be
// retrieved through the metadata extractor library.
// TODO(b/156644598): add link for the metadata extractor library.
// Per-channel mean of the possible values used in normalization.
//
// <Codegen usage>:
// Apply normalization to input tensors accordingly.
mean:[float];
// Per-channel standard dev. of the possible values used in normalization.
//
// <Codegen usage>:
// Apply normalization to input tensors accordingly.
std:[float];
}
// The different possible score transforms to apply to uncalibrated scores
// before applying score calibration.
enum ScoreTransformationType : byte {
// Identity function: g(x) = x.
IDENTITY = 0,
// Log function: g(x) = log(x).
LOG = 1,
// Inverse logistic function: g(x) = log(x) - log(1-x).
INVERSE_LOGISTIC = 2,
}
// Options to perform score calibration on an output tensor through sigmoid
// functions. One of the main purposes of score calibration is to make scores
// across classes comparable, so that a common threshold can be used for all
// output classes. This is meant for models producing class predictions as
// output, e.g. image classification or detection models.
//
// For each index in the output tensor, this applies:
// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score` or if no
// `min_score` has been specified,
// * `f(x) = default_score` otherwise or if no scale, slope and offset have been
// specified.
// Where:
// * scale, slope, offset and (optional) min_score are index-specific parameters
// * g(x) is an index-independent transform among those defined in
// ScoreTransformationType
// * default_score is an index-independent parameter.
// An AssociatedFile with type TANSOR_AXIS_SCORE_CALIBRATION specifying the
// index-specific parameters must be associated with the corresponding
// TensorMetadata for score calibration be applied.
table ScoreCalibrationOptions {
// The function to use for transforming the uncalibrated score before
// applying score calibration.
score_transformation:ScoreTransformationType;
// The default calibrated score to apply if the uncalibrated score is
// below min_score or if no parameters were specified for a given index.
default_score:float;
}
// Performs thresholding on output tensor values, in order to filter out
// low-confidence results.
table ScoreThresholdingOptions {
// The recommended global threshold below which results are considered
// low-confidence and should be filtered out.
global_score_threshold:float;
}
// Options that are used when processing the tensor.
union ProcessUnitOptions {
NormalizationOptions,
ScoreCalibrationOptions,
ScoreThresholdingOptions,
}
// A process unit that is used to process the tensor out-of-graph.
table ProcessUnit {
options:ProcessUnitOptions;
}
// Statistics to describe a tensor.
table Stats {
// Max and min are not currently used in tflite.support codegen. They mainly
// serve as references for users to better understand the model. They can also
// be used to validate model pre/post processing results.
// If there is only one value in max or min, we'll propogate the value to
// all channels.
// Per-channel maximum value of the tensor.
max:[float];
// Per-channel minimum value of the tensor.
min:[float];
}
// Detailed information of an input or output tensor.
table TensorMetadata {
// Name of the tensor.
//
// <Codegen usage>:
// The name of this tensor in the generated model interface.
name:string;
// A description of the tensor.
description:string;
// A list of names of the dimensions in this tensor. The length of
// dimension_names need to match the number of dimensions in this tensor.
//
// <Codegen usage>:
// The name of each dimension in the generated model interface. See "Case 2"
// in the comments of Content.range.
dimension_names:[string];
// The content that represents this tensor.
//
// <Codegen usage>:
// Determines how to process this tensor. See each item in ContentProperties
// for the default process units that will be applied to the tensor.
content:Content;
// The process units that are used to process the tensor out-of-graph.
//
// <Codegen usage>:
// Contains the parameters of the default processing pipeline for each content
// type, such as the normalization parameters in all content types. See the
// items under ContentProperties for the details of the default processing
// pipeline.
process_units:[ProcessUnit];
// The statistics of the tensor values.
stats:Stats;
// A list of associated files of this tensor.
//
// <Codegen usage>:
// Contains processing parameters of this tensor, such as normalization.
associated_files:[AssociatedFile];
}
table SubGraphMetadata {
// Name of the subgraph.
//
// Note that, since TFLite only support one subgraph at this moment, the
// Codegen tool will use the name in ModelMetadata in the generated model
// interface.
name:string;
// A description explains details about what the subgraph does.
description:string;
// Metadata of all input tensors used in this subgraph. It matches extactly
// the input tensors specified by `SubGraph.inputs` in the TFLite
// schema.fbs file[2]. The number of `TensorMetadata` in the array should
// equal to the number of indices in `SubGraph.inputs`.
//
// [2]: tensorflow/lite/schema/schema.fbs
// <Codegen usage>:
// Determines how to process the inputs.
input_tensor_metadata:[TensorMetadata];
// Metadata of all output tensors used in this subgraph. It matches extactly
// the output tensors specified by `SubGraph.outputs` in the TFLite
// schema.fbs file[2]. The number of `TensorMetadata` in the array should
// equal to the number of indices in `SubGraph.outputs`.
//
// <Codegen usage>:
// Determines how to process the outputs.
output_tensor_metadata:[TensorMetadata];
// A list of associated files of this subgraph.
associated_files:[AssociatedFile];
}
table ModelMetadata {
// Name of the model.
//
// <Codegen usage>:
// The name of the model in the generated model interface.
name:string;
// Model description in schema.
description:string;
// Version of the model that specified by model creators.
version:string;
// Noted that, the minimum required TFLite runtime version that the model is
// compatible with, has already been added as a metadata entry in tflite
// schema. We'll decide later if we want to move it here, and keep it with
// other metadata entries.
// Metadata of all the subgraphs of the model. The 0th is assumed to be the
// main subgraph.
//
// <Codegen usage>:
// Determines how to process the inputs and outputs.
subgraph_metadata:[SubGraphMetadata];
// The person who creates this model.
author:string;
// Licenses that may apply to this model.
license:string;
// A list of associated files of this model.
associated_files:[AssociatedFile];
// The minimum metadata parser version that can fully understand the fields in
// the metadata flatbuffer. The version is effectively the largest version
// number among the versions of all the fields populated and the smallest
// compatible version indicated by the file identifier.
//
// This field is automaticaly populated by the MetadataPopulator when
// the metadata is populated into a TFLite model.
min_parser_version:string;
}
// LINT.ThenChange(//tensorflow/lite/experimental/\
// support/metadata/cc/metadata_version.cc)
root_type ModelMetadata;

View File

@ -1,484 +0,0 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.lite.experimental.support.metadata.metadata."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import six
from flatbuffers.python import flatbuffers
from tensorflow.lite.experimental.support.metadata import metadata as _metadata
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
class MetadataTest(test_util.TensorFlowTestCase):
def setUp(self):
super(MetadataTest, self).setUp()
self._invalid_model_buf = None
self._invalid_file = "not_existed_file"
self._empty_model_buf = self._create_empty_model_buf()
self._empty_model_file = self.create_tempfile().full_path
with open(self._empty_model_file, "wb") as f:
f.write(self._empty_model_buf)
self._model_file = self._create_model_file_with_metadata_and_buf_fields()
self._metadata_file = self._create_metadata_file()
self._metadata_file_with_version = self._create_metadata_file_with_version(
self._metadata_file, "1.0.0")
self._file1 = self.create_tempfile("file1").full_path
self._file2 = self.create_tempfile("file2").full_path
self._file3 = self.create_tempfile("file3").full_path
def _create_empty_model_buf(self):
model = _schema_fb.ModelT()
model_builder = flatbuffers.Builder(0)
model_builder.Finish(
model.Pack(model_builder),
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
return model_builder.Output()
def _create_model_file_with_metadata_and_buf_fields(self):
metadata_field = _schema_fb.MetadataT()
metadata_field.name = "meta"
buffer_field = _schema_fb.BufferT()
model = _schema_fb.ModelT()
model.metadata = [metadata_field, metadata_field]
model.buffers = [buffer_field, buffer_field, buffer_field]
model_builder = flatbuffers.Builder(0)
model_builder.Finish(
model.Pack(model_builder),
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
mnodel_file = self.create_tempfile().full_path
with open(mnodel_file, "wb") as f:
f.write(model_builder.Output())
return mnodel_file
def _create_metadata_file(self):
associated_file1 = _metadata_fb.AssociatedFileT()
associated_file1.name = b"file1"
associated_file2 = _metadata_fb.AssociatedFileT()
associated_file2.name = b"file2"
self.expected_recorded_files = [
six.ensure_str(associated_file1.name),
six.ensure_str(associated_file2.name)
]
output_meta = _metadata_fb.TensorMetadataT()
output_meta.associatedFiles = [associated_file2]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.outputTensorMetadata = [output_meta]
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Mobilenet_quantized"
model_meta.associatedFiles = [associated_file1]
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_file = self.create_tempfile().full_path
with open(metadata_file, "wb") as f:
f.write(b.Output())
return metadata_file
def _create_model_buffer_with_wrong_identifier(self):
wrong_identifier = b"widn"
model = _schema_fb.ModelT()
model_builder = flatbuffers.Builder(0)
model_builder.Finish(model.Pack(model_builder), wrong_identifier)
return model_builder.Output()
def _create_metadata_buffer_with_wrong_identifier(self):
# Creates a metadata with wrong identifier
wrong_identifier = b"widn"
metadata = _metadata_fb.ModelMetadataT()
metadata_builder = flatbuffers.Builder(0)
metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier)
return metadata_builder.Output()
def _populate_metadata_with_identifier(self, model_buf, metadata_buf,
identifier):
# For testing purposes only. MetadataPopulator cannot populate metadata with
# wrong identifiers.
model = _schema_fb.ModelT.InitFromObj(
_schema_fb.Model.GetRootAsModel(model_buf, 0))
buffer_field = _schema_fb.BufferT()
buffer_field.data = metadata_buf
model.buffers = [buffer_field]
# Creates a new metadata field.
metadata_field = _schema_fb.MetadataT()
metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME
metadata_field.buffer = len(model.buffers) - 1
model.metadata = [metadata_field]
b = flatbuffers.Builder(0)
b.Finish(model.Pack(b), identifier)
return b.Output()
def _create_metadata_file_with_version(self, metadata_file, min_version):
# Creates a new metadata file with the specified min_version for testing
# purposes.
with open(metadata_file, "rb") as f:
metadata_buf = bytearray(f.read())
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
metadata.minParserVersion = min_version
b = flatbuffers.Builder(0)
b.Finish(
metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_file_with_version = self.create_tempfile().full_path
with open(metadata_file_with_version, "wb") as f:
f.write(b.Output())
return metadata_file_with_version
class MetadataPopulatorTest(MetadataTest):
def testToValidModelFile(self):
populator = _metadata.MetadataPopulator.with_model_file(
self._empty_model_file)
self.assertIsInstance(populator, _metadata.MetadataPopulator)
def testToInvalidModelFile(self):
with self.assertRaises(IOError) as error:
_metadata.MetadataPopulator.with_model_file(self._invalid_file)
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testToValidModelBuffer(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
self.assertIsInstance(populator, _metadata.MetadataPopulator)
def testToInvalidModelBuffer(self):
with self.assertRaises(ValueError) as error:
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
self.assertEqual("model_buf cannot be empty.", str(error.exception))
def testToModelBufferWithWrongIdentifier(self):
model_buf = self._create_model_buffer_with_wrong_identifier()
with self.assertRaises(ValueError) as error:
_metadata.MetadataPopulator.with_model_buffer(model_buf)
self.assertEqual(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.", str(error.exception))
def testSinglePopulateAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
populator.load_associated_files([self._file1])
populator.populate()
packed_files = populator.get_packed_associated_file_list()
expected_packed_files = [os.path.basename(self._file1)]
self.assertEqual(set(packed_files), set(expected_packed_files))
def testRepeatedPopulateAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_file(
self._empty_model_file)
populator.load_associated_files([self._file1, self._file2])
# Loads file2 multiple times.
populator.load_associated_files([self._file2])
populator.populate()
packed_files = populator.get_packed_associated_file_list()
expected_packed_files = [
os.path.basename(self._file1),
os.path.basename(self._file2)
]
self.assertEqual(len(packed_files), 2)
self.assertEqual(set(packed_files), set(expected_packed_files))
# Check if the model buffer read from file is the same as that read from
# get_model_buffer().
with open(self._empty_model_file, "rb") as f:
model_buf_from_file = f.read()
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testPopulateInvalidAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
with self.assertRaises(IOError) as error:
populator.load_associated_files([self._invalid_file])
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testPopulatePackedAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
populator.load_associated_files([self._file1])
populator.populate()
with self.assertRaises(ValueError) as error:
populator.load_associated_files([self._file1])
populator.populate()
self.assertEqual(
"File, '{0}', has already been packed.".format(
os.path.basename(self._file1)), str(error.exception))
def testGetPackedAssociatedFileList(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
packed_files = populator.get_packed_associated_file_list()
self.assertEqual(packed_files, [])
def testPopulateMetadataFileToEmptyModelFile(self):
populator = _metadata.MetadataPopulator.with_model_file(
self._empty_model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1, self._file2])
populator.populate()
with open(self._empty_model_file, "rb") as f:
model_buf_from_file = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
metadata_field = model.Metadata(0)
self.assertEqual(
six.ensure_str(metadata_field.Name()),
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
buffer_index = metadata_field.Buffer()
buffer_data = model.Buffers(buffer_index)
metadata_buf_np = buffer_data.DataAsNumpy()
metadata_buf = metadata_buf_np.tobytes()
with open(self._metadata_file_with_version, "rb") as f:
expected_metadata_buf = bytearray(f.read())
self.assertEqual(metadata_buf, expected_metadata_buf)
recorded_files = populator.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
# Up to now, we've proved the correctness of the model buffer that read from
# file. Then we'll test if get_model_buffer() gives the same model buffer.
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testPopulateMetadataFileWithoutAssociatedFiles(self):
populator = _metadata.MetadataPopulator.with_model_file(
self._empty_model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1])
# Suppose to populate self._file2, because it is recorded in the metadta.
with self.assertRaises(ValueError) as error:
populator.populate()
self.assertEqual(("File, '{0}', is recorded in the metadata, but has "
"not been loaded into the populator.").format(
os.path.basename(self._file2)), str(error.exception))
def testPopulateMetadataBufferWithWrongIdentifier(self):
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
with self.assertRaises(ValueError) as error:
populator.load_metadata_buffer(metadata_buf)
self.assertEqual(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.", str(error.exception))
def _assert_golden_metadata(self, model_file):
with open(model_file, "rb") as f:
model_buf_from_file = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
# There are two elements in model.Metadata array before the population.
# Metadata should be packed to the third element in the array.
metadata_field = model.Metadata(2)
self.assertEqual(
six.ensure_str(metadata_field.Name()),
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
buffer_index = metadata_field.Buffer()
buffer_data = model.Buffers(buffer_index)
metadata_buf_np = buffer_data.DataAsNumpy()
metadata_buf = metadata_buf_np.tobytes()
with open(self._metadata_file_with_version, "rb") as f:
expected_metadata_buf = bytearray(f.read())
self.assertEqual(metadata_buf, expected_metadata_buf)
def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self):
# First, creates a dummy metadata. Populates it and the associated files
# into the model.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Mobilenet_quantized"
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator1.load_metadata_buffer(metadata_buf)
populator1.load_associated_files([self._file1, self._file2])
populator1.populate()
# Then, populates the metadata again.
populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator2.load_metadata_file(self._metadata_file)
populator2.populate()
# Tests if the metadata is populated correctly.
self._assert_golden_metadata(self._model_file)
def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self):
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1, self._file2])
populator.populate()
# Tests if the metadata is populated correctly.
self._assert_golden_metadata(self._model_file)
recorded_files = populator.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
# Up to now, we've proved the correctness of the model buffer that read from
# file. Then we'll test if get_model_buffer() gives the same model buffer.
with open(self._model_file, "rb") as f:
model_buf_from_file = f.read()
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testPopulateInvalidMetadataFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
with self.assertRaises(IOError) as error:
populator.load_metadata_file(self._invalid_file)
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testPopulateInvalidMetadataBuffer(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
with self.assertRaises(ValueError) as error:
populator.load_metadata_buffer([])
self.assertEqual("The metadata to be populated is empty.",
str(error.exception))
def testGetModelBufferBeforePopulatingData(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
model_buf = populator.get_model_buffer()
expected_model_buf = self._empty_model_buf
self.assertEqual(model_buf, expected_model_buf)
class MetadataDisplayerTest(MetadataTest):
def setUp(self):
super(MetadataDisplayerTest, self).setUp()
self._model_file = self._create_model_with_metadata_and_associated_files()
def _create_model_with_metadata_and_associated_files(self):
model_buf = self._create_empty_model_buf()
model_file = self.create_tempfile().full_path
with open(model_file, "wb") as f:
f.write(model_buf)
populator = _metadata.MetadataPopulator.with_model_file(model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1, self._file2])
populator.populate()
return model_file
def test_load_model_buffer_metadataBufferWithWrongIdentifier_throwsException(
self):
model_buf = self._create_model_buffer_with_wrong_identifier()
metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
model_buf = self._populate_metadata_with_identifier(
model_buf, metadata_buf,
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
self.assertEqual(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.", str(error.exception))
def test_load_model_buffer_modelBufferWithWrongIdentifier_throwsException(
self):
model_buf = self._create_model_buffer_with_wrong_identifier()
metadata_file = self._create_metadata_file()
wrong_identifier = b"widn"
with open(metadata_file, "rb") as f:
metadata_buf = bytearray(f.read())
model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf,
wrong_identifier)
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_buffer(model_buf)
self.assertEqual(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.", str(error.exception))
def test_load_model_file_invalidModelFile_throwsException(self):
with self.assertRaises(IOError) as error:
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def test_load_model_file_modelWithoutMetadata_throwsException(self):
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_file(self._empty_model_file)
self.assertEqual("The model does not have metadata.", str(error.exception))
def test_load_model_file_modelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
def test_load_model_buffer_modelWithOutMetadata_throwsException(self):
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_buffer(
self._create_empty_model_buf())
self.assertEqual("The model does not have metadata.", str(error.exception))
def test_load_model_buffer_modelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_buffer(
open(self._model_file, "rb").read())
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
def test_get_metadata_json_modelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
actual = displayer.get_metadata_json()
# Verifies the generated json file.
golden_json_file_path = resource_loader.get_path_to_datafile(
"testdata/golden_json.json")
with open(golden_json_file_path, "r") as f:
expected = f.read()
self.assertEqual(actual, expected)
def test_get_packed_associated_file_list_modelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
packed_files = displayer.get_packed_associated_file_list()
expected_packed_files = [
os.path.basename(self._file1),
os.path.basename(self._file2)
]
self.assertEqual(len(packed_files), 2)
self.assertEqual(set(packed_files), set(expected_packed_files))
if __name__ == "__main__":
test.main()

View File

@ -1,22 +0,0 @@
{
"name": "Mobilenet_quantized",
"subgraph_metadata": [
{
"output_tensor_metadata": [
{
"associated_files": [
{
"name": "file2"
}
]
}
]
}
],
"associated_files": [
{
"name": "file1"
}
],
"min_parser_version": "1.0.0"
}