Opensource TFLite Support codegen.
PiperOrigin-RevId: 302011153 Change-Id: Idb2f649dc48fdc449fac2d6e9009719d29afb2ad
This commit is contained in:
parent
abf182a882
commit
b57d910db5
87
tensorflow/lite/experimental/support/codegen/BUILD
Normal file
87
tensorflow/lite/experimental/support/codegen/BUILD
Normal file
@ -0,0 +1,87 @@
|
||||
# The tools for generating wrapper classes for a TFLite model with metadata.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
srcs = [
|
||||
"utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils.h",
|
||||
],
|
||||
deps = [
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "code_generator",
|
||||
srcs = [
|
||||
"code_generator.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"code_generator.h",
|
||||
],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "metadata_helper",
|
||||
srcs = [
|
||||
"metadata_helper.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"metadata_helper.h",
|
||||
],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "android_java_generator",
|
||||
srcs = [
|
||||
"android_java_generator.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"android_java_generator.h",
|
||||
],
|
||||
deps = [
|
||||
":code_generator",
|
||||
":metadata_helper",
|
||||
":utils",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "code_generator_test",
|
||||
size = "small",
|
||||
srcs = ["code_generator_test.cc"],
|
||||
data = ["//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs"],
|
||||
deps = [
|
||||
":code_generator",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "utils_test",
|
||||
srcs = ["utils_test.cc"],
|
||||
deps = [
|
||||
":utils",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
13
tensorflow/lite/experimental/support/codegen/README.md
Normal file
13
tensorflow/lite/experimental/support/codegen/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
# TensorFlow Lite Android Wrapper Code Generator
|
||||
|
||||
For TensorFlow Lite model enhanced with [metadata](https://www.tensorflow.org/lite/convert/metadata.md),
|
||||
developers can use the TensorFlow Lite Android wrapper code generator to create
|
||||
platform specific wrapper code. The wrapper code removes the need to interact
|
||||
directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow
|
||||
Lite model with typed objects such as `Bitmap` and `Rect`.
|
||||
|
||||
The usefulness of the code generator depend on the completeness of the
|
||||
TensorFlow Lite model's metadata entry. Refer to the `<Codegen usage>` section
|
||||
under relevant fields in
|
||||
[metadata_schema.fbs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs),
|
||||
to see how the codegen tool parses each field.
|
@ -0,0 +1,978 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
|
||||
|
||||
#include <ctype.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||
#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h"
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
namespace {
|
||||
|
||||
using details_android_java::ModelInfo;
|
||||
using details_android_java::TensorInfo;
|
||||
|
||||
// Helper class to organize the C++ code block as a generated code block.
|
||||
// Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
|
||||
class AsBlock {
|
||||
public:
|
||||
AsBlock(CodeWriter* code_writer, const std::string& before,
|
||||
bool trailing_blank_line = false)
|
||||
: code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
|
||||
code_writer_->AppendNoNewLine(before);
|
||||
code_writer_->Append(" {");
|
||||
code_writer_->Indent();
|
||||
}
|
||||
~AsBlock() {
|
||||
code_writer_->Outdent();
|
||||
code_writer_->Append("}");
|
||||
if (trailing_blank_line_) {
|
||||
code_writer_->NewLine();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
CodeWriter* code_writer_;
|
||||
bool trailing_blank_line_;
|
||||
};
|
||||
|
||||
// Declare the functions first, so that the functions can follow a logical
|
||||
// order.
|
||||
bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||
bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||
bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||
bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||
bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||
bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||
|
||||
std::string GetModelVersionedName(const ModelMetadata* metadata) {
|
||||
std::string model_name = "MyModel";
|
||||
if (metadata->name() != nullptr && !(metadata->name()->str().empty())) {
|
||||
model_name = metadata->name()->str();
|
||||
}
|
||||
std::string model_version = "unknown";
|
||||
if (metadata->version() != nullptr && !(metadata->version()->str().empty())) {
|
||||
model_version = metadata->version()->str();
|
||||
}
|
||||
return model_name + " (Version: " + model_version + ")";
|
||||
}
|
||||
|
||||
TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
|
||||
const std::string& name, bool is_input, int index,
|
||||
ErrorReporter* err) {
|
||||
TensorInfo tensor_info;
|
||||
std::string tensor_identifier = is_input ? "input" : "output";
|
||||
tensor_identifier += " " + std::to_string(index);
|
||||
tensor_info.associated_axis_label_index = FindAssociatedFile(
|
||||
metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err);
|
||||
tensor_info.associated_value_label_index = FindAssociatedFile(
|
||||
metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err);
|
||||
if (is_input && (tensor_info.associated_axis_label_index >= 0 ||
|
||||
tensor_info.associated_value_label_index >= 0)) {
|
||||
err->Warning(
|
||||
"Found label file on input tensor (%s). Label file for input "
|
||||
"tensor is not supported yet. The "
|
||||
"file will be ignored.",
|
||||
tensor_identifier.c_str());
|
||||
}
|
||||
if (tensor_info.associated_axis_label_index >= 0 &&
|
||||
tensor_info.associated_value_label_index >= 0) {
|
||||
err->Warning(
|
||||
"Found both axis label file and value label file for tensor (%s), "
|
||||
"which is not supported. Only the axis label file will be used.",
|
||||
tensor_identifier.c_str());
|
||||
}
|
||||
tensor_info.is_input = is_input;
|
||||
tensor_info.name = SnakeCaseToCamelCase(name);
|
||||
tensor_info.upper_camel_name = tensor_info.name;
|
||||
tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]);
|
||||
tensor_info.normalization_unit =
|
||||
FindNormalizationUnit(metadata, tensor_identifier, err);
|
||||
if (metadata->content()->content_properties_type() ==
|
||||
ContentProperties_ImageProperties) {
|
||||
if (metadata->content()
|
||||
->content_properties_as_ImageProperties()
|
||||
->color_space() == ColorSpaceType_RGB) {
|
||||
tensor_info.content_type = "image";
|
||||
tensor_info.wrapper_type = "TensorImage";
|
||||
tensor_info.processor_type = "ImageProcessor";
|
||||
return tensor_info;
|
||||
} else {
|
||||
err->Warning(
|
||||
"Found Non-RGB image on tensor (%s). Codegen currently does not "
|
||||
"support it, and regard it as a plain numeric tensor.",
|
||||
tensor_identifier.c_str());
|
||||
}
|
||||
}
|
||||
tensor_info.content_type = "tensor";
|
||||
tensor_info.wrapper_type = "TensorBuffer";
|
||||
tensor_info.processor_type = "TensorProcessor";
|
||||
return tensor_info;
|
||||
}
|
||||
|
||||
ModelInfo CreateModelInfo(const ModelMetadata* metadata,
|
||||
const std::string& package_name,
|
||||
const std::string& model_class_name,
|
||||
const std::string& model_asset_path,
|
||||
ErrorReporter* err) {
|
||||
ModelInfo model_info;
|
||||
if (!CodeGenerator::VerifyMetadata(metadata, err)) {
|
||||
// TODO(b/150116380): Create dummy model info.
|
||||
err->Error("Validating metadata failed.");
|
||||
return model_info;
|
||||
}
|
||||
model_info.package_name = package_name;
|
||||
model_info.model_class_name = model_class_name;
|
||||
model_info.model_asset_path = model_asset_path;
|
||||
model_info.model_versioned_name = GetModelVersionedName(metadata);
|
||||
const auto* graph = metadata->subgraph_metadata()->Get(0);
|
||||
auto names = CodeGenerator::NameInputsAndOutputs(
|
||||
graph->input_tensor_metadata(), graph->output_tensor_metadata());
|
||||
std::vector<std::string> input_tensor_names = std::move(names.first);
|
||||
std::vector<std::string> output_tensor_names = std::move(names.second);
|
||||
for (int i = 0; i < graph->input_tensor_metadata()->size(); i++) {
|
||||
model_info.inputs.push_back(
|
||||
CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
|
||||
input_tensor_names[i], true, i, err));
|
||||
}
|
||||
for (int i = 0; i < graph->output_tensor_metadata()->size(); i++) {
|
||||
model_info.outputs.push_back(
|
||||
CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
|
||||
output_tensor_names[i], false, i, err));
|
||||
}
|
||||
return model_info;
|
||||
}
|
||||
|
||||
void SetCodeWriterWithTensorInfo(CodeWriter* code_writer,
|
||||
const TensorInfo& tensor_info) {
|
||||
code_writer->SetTokenValue("NAME", tensor_info.name);
|
||||
code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name);
|
||||
code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type);
|
||||
code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type);
|
||||
std::string wrapper_name = tensor_info.wrapper_type;
|
||||
wrapper_name[0] = tolower(wrapper_name[0]);
|
||||
code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name);
|
||||
code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type);
|
||||
code_writer->SetTokenValue("NORMALIZATION_UNIT",
|
||||
std::to_string(tensor_info.normalization_unit));
|
||||
code_writer->SetTokenValue(
|
||||
"ASSOCIATED_AXIS_LABEL_INDEX",
|
||||
std::to_string(tensor_info.associated_axis_label_index));
|
||||
code_writer->SetTokenValue(
|
||||
"ASSOCIATED_VALUE_LABEL_INDEX",
|
||||
std::to_string(tensor_info.associated_value_label_index));
|
||||
}
|
||||
|
||||
void SetCodeWriterWithModelInfo(CodeWriter* code_writer,
|
||||
const ModelInfo& model_info) {
|
||||
code_writer->SetTokenValue("PACKAGE", model_info.package_name);
|
||||
code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path);
|
||||
code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name);
|
||||
}
|
||||
|
||||
constexpr char JAVA_DEFAULT_PACKAGE[] = "default";
|
||||
|
||||
std::string ConvertPackageToPath(const std::string& package) {
|
||||
if (package == JAVA_DEFAULT_PACKAGE) {
|
||||
return "";
|
||||
}
|
||||
std::string path = package;
|
||||
std::replace(path.begin(), path.end(), '.', '/');
|
||||
return path;
|
||||
}
|
||||
|
||||
bool IsImageUsed(const ModelInfo& model) {
|
||||
for (const auto& input : model.inputs) {
|
||||
if (input.content_type == "image") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto& output : model.outputs) {
|
||||
if (output.content_type == "image") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
code_writer->Append("// Generated by TFLite Support.");
|
||||
code_writer->Append("package {{PACKAGE}};");
|
||||
code_writer->NewLine();
|
||||
|
||||
if (!GenerateWrapperImports(code_writer, model, err)) {
|
||||
err->Error("Fail to generate imports for wrapper class.");
|
||||
return false;
|
||||
}
|
||||
if (!GenerateWrapperClass(code_writer, model, err)) {
|
||||
err->Error("Fail to generate wrapper class.");
|
||||
return false;
|
||||
}
|
||||
code_writer->NewLine();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
const std::string support_pkg = "org.tensorflow.lite.support.";
|
||||
std::vector<std::string> imports{
|
||||
"android.content.Context",
|
||||
"java.io.IOException",
|
||||
"java.nio.ByteBuffer",
|
||||
"java.nio.FloatBuffer",
|
||||
"java.util.Arrays",
|
||||
"java.util.HashMap",
|
||||
"java.util.List",
|
||||
"java.util.Map",
|
||||
"org.checkerframework.checker.nullness.qual.Nullable",
|
||||
"org.tensorflow.lite.DataType",
|
||||
"org.tensorflow.lite.Tensor.QuantizationParams",
|
||||
support_pkg + "common.FileUtil",
|
||||
support_pkg + "common.TensorProcessor",
|
||||
support_pkg + "common.ops.CastOp",
|
||||
support_pkg + "common.ops.DequantizeOp",
|
||||
support_pkg + "common.ops.NormalizeOp",
|
||||
support_pkg + "common.ops.QuantizeOp",
|
||||
support_pkg + "label.TensorLabel",
|
||||
support_pkg + "metadata.MetadataExtractor",
|
||||
support_pkg + "metadata.schema.NormalizationOptions",
|
||||
support_pkg + "model.Model",
|
||||
support_pkg + "model.Model.Device",
|
||||
support_pkg + "tensorbuffer.TensorBuffer",
|
||||
};
|
||||
if (IsImageUsed(model)) {
|
||||
for (const auto& target :
|
||||
{"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp",
|
||||
"image.ops.ResizeOp.ResizeMethod"}) {
|
||||
imports.push_back(support_pkg + target);
|
||||
}
|
||||
imports.push_back("android.graphics.Bitmap");
|
||||
}
|
||||
|
||||
std::sort(imports.begin(), imports.end());
|
||||
for (const auto target : imports) {
|
||||
code_writer->SetTokenValue("TARGET", target);
|
||||
code_writer->Append("import {{TARGET}};");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
|
||||
model.model_versioned_name);
|
||||
code_writer->Append(
|
||||
R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)");
|
||||
const auto code_block =
|
||||
AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}");
|
||||
code_writer->Append(R"(private final Metadata metadata;
|
||||
private final Model model;
|
||||
private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
|
||||
for (const auto& tensor : model.outputs) {
|
||||
if (tensor.associated_axis_label_index >= 0) {
|
||||
code_writer->SetTokenValue("NAME", tensor.name);
|
||||
code_writer->Append("private final List<String> {{NAME}}Labels;");
|
||||
}
|
||||
}
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(
|
||||
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(
|
||||
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
if (!GenerateWrapperInputs(code_writer, model, err)) {
|
||||
err->Error("Failed to generate input classes");
|
||||
return false;
|
||||
}
|
||||
code_writer->NewLine();
|
||||
if (!GenerateWrapperOutputs(code_writer, model, err)) {
|
||||
err->Error("Failed to generate output classes");
|
||||
return false;
|
||||
}
|
||||
code_writer->NewLine();
|
||||
if (!GenerateWrapperMetadata(code_writer, model, err)) {
|
||||
err->Error("Failed to generate the metadata class");
|
||||
return false;
|
||||
}
|
||||
code_writer->NewLine();
|
||||
if (!GenerateWrapperAPI(code_writer, model, err)) {
|
||||
err->Error("Failed to generate the common APIs");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateWrapperInputs(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
code_writer->Append("/** Input wrapper of {@link {{MODEL_CLASS_NAME}}} */");
|
||||
auto class_block = AsBlock(code_writer, "public class Inputs");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append("private {{WRAPPER_TYPE}} {{NAME}};");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
// Ctor
|
||||
{
|
||||
auto ctor_block = AsBlock(code_writer, "public Inputs()");
|
||||
code_writer->Append(
|
||||
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
if (tensor.content_type == "image") {
|
||||
code_writer->Append(
|
||||
"{{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());");
|
||||
} else {
|
||||
code_writer->Append(
|
||||
"{{NAME}} = "
|
||||
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
|
||||
"metadata.get{{NAME_U}}Type());");
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto& tensor : model.inputs) {
|
||||
code_writer->NewLine();
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
// Loaders
|
||||
if (tensor.content_type == "image") {
|
||||
{
|
||||
auto bitmap_loader_block =
|
||||
AsBlock(code_writer, "public void load{{NAME_U}}(Bitmap bitmap)");
|
||||
code_writer->Append(R"({{NAME}}.load(bitmap);
|
||||
{{NAME}} = preprocess{{NAME_U}}({{NAME}});)");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
{
|
||||
auto tensor_image_loader_block = AsBlock(
|
||||
code_writer, "public void load{{NAME_U}}(TensorImage tensorImage)");
|
||||
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorImage);");
|
||||
}
|
||||
} else { // content_type == "FEATURE" or "UNKNOWN"
|
||||
auto tensorbuffer_loader_block = AsBlock(
|
||||
code_writer, "public void load{{NAME_U}}(TensorBuffer tensorBuffer)");
|
||||
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorBuffer);");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
// Processor
|
||||
code_writer->Append(
|
||||
R"(private {{WRAPPER_TYPE}} preprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}}) {
|
||||
if ({{NAME}}Preprocessor == null) {
|
||||
return {{WRAPPER_NAME}};
|
||||
}
|
||||
return {{NAME}}Preprocessor.process({{WRAPPER_NAME}});
|
||||
}
|
||||
)");
|
||||
}
|
||||
{
|
||||
const auto get_buffer_block = AsBlock(code_writer, "Object[] getBuffer()");
|
||||
code_writer->AppendNoNewLine("return new Object[] {");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->AppendNoNewLine("{{NAME}}.getBuffer(), ");
|
||||
}
|
||||
code_writer->Backspace(2);
|
||||
code_writer->Append("};");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
|
||||
auto class_block = AsBlock(code_writer, "public class Outputs");
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
{
|
||||
const auto ctor_block = AsBlock(code_writer, "public Outputs()");
|
||||
code_writer->Append(
|
||||
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
if (tensor.content_type == "image") {
|
||||
code_writer->Append(
|
||||
R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());
|
||||
{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)");
|
||||
} else { // FEATURE, UNKNOWN
|
||||
code_writer->Append(
|
||||
"{{NAME}} = "
|
||||
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
|
||||
"metadata.get{{NAME_U}}Type());");
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->NewLine();
|
||||
if (tensor.associated_axis_label_index >= 0) {
|
||||
if (tensor.content_type == "image") {
|
||||
err->Warning(
|
||||
"Axis label for images is not supported. The labels will "
|
||||
"be ignored.");
|
||||
} else {
|
||||
code_writer->Append(R"(public Map<String, Float> get{{NAME_U}}() {
|
||||
return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getMapWithFloatValue();
|
||||
})");
|
||||
}
|
||||
} else {
|
||||
code_writer->Append(R"(public {{WRAPPER_TYPE}} get{{NAME_U}}() {
|
||||
return postprocess{{NAME_U}}({{NAME}});
|
||||
})");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
{
|
||||
auto processor_block =
|
||||
AsBlock(code_writer,
|
||||
"private {{WRAPPER_TYPE}} "
|
||||
"postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
|
||||
code_writer->Append(R"(if ({{NAME}}Postprocessor == null) {
|
||||
return {{WRAPPER_NAME}};
|
||||
}
|
||||
return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});)");
|
||||
}
|
||||
}
|
||||
code_writer->NewLine();
|
||||
{
|
||||
const auto get_buffer_block =
|
||||
AsBlock(code_writer, "Map<Integer, Object> getBuffer()");
|
||||
code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();");
|
||||
for (int i = 0; i < model.outputs.size(); i++) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
|
||||
code_writer->SetTokenValue("ID", std::to_string(i));
|
||||
code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());");
|
||||
}
|
||||
code_writer->Append("return outputs;");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
code_writer->Append(
|
||||
"/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
|
||||
const auto class_block = AsBlock(code_writer, "public static class Metadata");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(private final int[] {{NAME}}Shape;
|
||||
private final DataType {{NAME}}DataType;
|
||||
private final QuantizationParams {{NAME}}QuantizationParams;)");
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->Append(R"(private final float[] {{NAME}}Mean;
|
||||
private final float[] {{NAME}}Stddev;)");
|
||||
}
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(private final int[] {{NAME}}Shape;
|
||||
private final DataType {{NAME}}DataType;
|
||||
private final QuantizationParams {{NAME}}QuantizationParams;)");
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->Append(R"(private final float[] {{NAME}}Mean;
|
||||
private final float[] {{NAME}}Stddev;)");
|
||||
}
|
||||
if (tensor.associated_axis_label_index >= 0 ||
|
||||
tensor.associated_value_label_index >= 0) {
|
||||
code_writer->Append("private final List<String> {{NAME}}Labels;");
|
||||
}
|
||||
}
|
||||
code_writer->NewLine();
|
||||
{
|
||||
const auto ctor_block = AsBlock(
|
||||
code_writer,
|
||||
"public Metadata(ByteBuffer buffer, Model model) throws IOException");
|
||||
code_writer->Append(
|
||||
"MetadataExtractor extractor = new MetadataExtractor(buffer);");
|
||||
for (int i = 0; i < model.inputs.size(); i++) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
|
||||
code_writer->SetTokenValue("ID", std::to_string(i));
|
||||
code_writer->Append(
|
||||
R"({{NAME}}Shape = extractor.getInputTensorShape({{ID}});
|
||||
{{NAME}}DataType = extractor.getInputTensorType({{ID}});
|
||||
{{NAME}}QuantizationParams = extractor.getInputTensorQuantizationParams({{ID}});)");
|
||||
if (model.inputs[i].normalization_unit >= 0) {
|
||||
code_writer->Append(
|
||||
R"(NormalizationOptions {{NAME}}NormalizationOptions =
|
||||
(NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
|
||||
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
|
||||
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
|
||||
{{NAME}}MeanBuffer.get({{NAME}}Mean);
|
||||
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
|
||||
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
|
||||
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < model.outputs.size(); i++) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
|
||||
code_writer->SetTokenValue("ID", std::to_string(i));
|
||||
code_writer->Append(
|
||||
R"({{NAME}}Shape = model.getOutputTensorShape({{ID}});
|
||||
{{NAME}}DataType = extractor.getOutputTensorType({{ID}});
|
||||
{{NAME}}QuantizationParams = extractor.getOutputTensorQuantizationParams({{ID}});)");
|
||||
if (model.outputs[i].normalization_unit >= 0) {
|
||||
code_writer->Append(
|
||||
R"(NormalizationOptions {{NAME}}NormalizationOptions =
|
||||
(NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
|
||||
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
|
||||
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
|
||||
{{NAME}}MeanBuffer.get({{NAME}}Mean);
|
||||
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
|
||||
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
|
||||
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
|
||||
}
|
||||
if (model.outputs[i].associated_axis_label_index >= 0) {
|
||||
code_writer->Append(R"(String {{NAME}}LabelsFileName =
|
||||
extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name();
|
||||
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
|
||||
} else if (model.outputs[i].associated_value_label_index >= 0) {
|
||||
code_writer->Append(R"(String {{NAME}}LabelsFileName =
|
||||
extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name();
|
||||
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(
|
||||
public int[] get{{NAME_U}}Shape() {
|
||||
return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
|
||||
}
|
||||
|
||||
public DataType get{{NAME_U}}Type() {
|
||||
return {{NAME}}DataType;
|
||||
}
|
||||
|
||||
public QuantizationParams get{{NAME_U}}QuantizationParams() {
|
||||
return {{NAME}}QuantizationParams;
|
||||
})");
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->Append(R"(
|
||||
public float[] get{{NAME_U}}Mean() {
|
||||
return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
|
||||
}
|
||||
|
||||
public float[] get{{NAME_U}}Stddev() {
|
||||
return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
|
||||
})");
|
||||
}
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(
|
||||
public int[] get{{NAME_U}}Shape() {
|
||||
return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
|
||||
}
|
||||
|
||||
public DataType get{{NAME_U}}Type() {
|
||||
return {{NAME}}DataType;
|
||||
}
|
||||
|
||||
public QuantizationParams get{{NAME_U}}QuantizationParams() {
|
||||
return {{NAME}}QuantizationParams;
|
||||
})");
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->Append(R"(
|
||||
public float[] get{{NAME_U}}Mean() {
|
||||
return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
|
||||
}
|
||||
|
||||
public float[] get{{NAME_U}}Stddev() {
|
||||
return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
|
||||
})");
|
||||
}
|
||||
if (tensor.associated_axis_label_index >= 0 ||
|
||||
tensor.associated_value_label_index >= 0) {
|
||||
code_writer->Append(R"(
|
||||
public List<String> get{{NAME_U}}Labels() {
|
||||
return {{NAME}}Labels;
|
||||
})");
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
code_writer->Append(R"(public Metadata getMetadata() {
|
||||
return metadata;
|
||||
}
|
||||
)");
|
||||
code_writer->Append(R"(/**
|
||||
* Creates interpreter and loads associated files if needed.
|
||||
*
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
public {{MODEL_CLASS_NAME}}(Context context) throws IOException {
|
||||
this(context, MODEL_NAME, Device.CPU, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates interpreter and loads associated files if needed, but loading another model in the same
|
||||
* input / output structure with the original one.
|
||||
*
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
public {{MODEL_CLASS_NAME}}(Context context, String modelPath) throws IOException {
|
||||
this(context, modelPath, Device.CPU, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates interpreter and loads associated files if needed, with device and number of threads
|
||||
* configured.
|
||||
*
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
public {{MODEL_CLASS_NAME}}(Context context, Device device, int numThreads) throws IOException {
|
||||
this(context, MODEL_NAME, device, numThreads);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates interpreter for a user-specified model.
|
||||
*
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
public {{MODEL_CLASS_NAME}}(Context context, String modelPath, Device device, int numThreads) throws IOException {
|
||||
model = new Model.Builder(context, modelPath).setDevice(device).setNumThreads(numThreads).build();
|
||||
metadata = new Metadata(model.getData(), model);)");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(
|
||||
{{PROCESSOR_TYPE}}.Builder {{NAME}}PreprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder())");
|
||||
if (tensor.content_type == "image") {
|
||||
code_writer->Append(R"( .add(new ResizeOp(
|
||||
metadata.get{{NAME_U}}Shape()[1],
|
||||
metadata.get{{NAME_U}}Shape()[2],
|
||||
ResizeMethod.NEAREST_NEIGHBOR)))");
|
||||
}
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->Append(
|
||||
R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
|
||||
}
|
||||
code_writer->Append(
|
||||
R"( .add(new QuantizeOp(
|
||||
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
|
||||
metadata.get{{NAME_U}}QuantizationParams().getScale()))
|
||||
.add(new CastOp(metadata.get{{NAME_U}}Type()));
|
||||
{{NAME}}Preprocessor = {{NAME}}PreprocessorBuilder.build();)");
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->AppendNoNewLine(R"(
|
||||
{{PROCESSOR_TYPE}}.Builder {{NAME}}PostprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder()
|
||||
.add(new DequantizeOp(
|
||||
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
|
||||
metadata.get{{NAME_U}}QuantizationParams().getScale())))");
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->AppendNoNewLine(R"(
|
||||
.add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
|
||||
}
|
||||
code_writer->Append(R"(;
|
||||
{{NAME}}Postprocessor = {{NAME}}PostprocessorBuilder.build();)");
|
||||
if (tensor.associated_axis_label_index >= 0) {
|
||||
code_writer->Append(R"(
|
||||
{{NAME}}Labels = metadata.get{{NAME_U}}Labels();)");
|
||||
}
|
||||
}
|
||||
code_writer->Append("}");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(
|
||||
public void reset{{NAME_U}}Preprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
|
||||
{{NAME}}Preprocessor = processor;
|
||||
})");
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(
|
||||
public void reset{{NAME_U}}Postprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
|
||||
{{NAME}}Postprocessor = processor;
|
||||
})");
|
||||
}
|
||||
code_writer->Append(R"(
|
||||
/** Creates inputs */
|
||||
public Inputs createInputs() {
|
||||
return new Inputs();
|
||||
}
|
||||
|
||||
/** Triggers the model. */
|
||||
public Outputs run(Inputs inputs) {
|
||||
Outputs outputs = new Outputs();
|
||||
model.run(inputs.getBuffer(), outputs.getBuffer());
|
||||
return outputs;
|
||||
}
|
||||
|
||||
/** Closes the model. */
|
||||
public void close() {
|
||||
model.close();
|
||||
})");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateBuildGradleContent(CodeWriter* code_writer,
|
||||
const ModelInfo& model_info) {
|
||||
code_writer->Append(R"(buildscript {
|
||||
repositories {
|
||||
google()
|
||||
jcenter()
|
||||
}
|
||||
dependencies {
|
||||
classpath 'com.android.tools.build:gradle:3.2.1'
|
||||
}
|
||||
}
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
google()
|
||||
jcenter()
|
||||
flatDir {
|
||||
dirs 'libs'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
apply plugin: 'com.android.library'
|
||||
|
||||
android {
|
||||
compileSdkVersion 29
|
||||
defaultConfig {
|
||||
targetSdkVersion 29
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
}
|
||||
aaptOptions {
|
||||
noCompress "tflite"
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility = '1.8'
|
||||
targetCompatibility = '1.8'
|
||||
}
|
||||
lintOptions {
|
||||
abortOnError false
|
||||
}
|
||||
}
|
||||
|
||||
configurations {
|
||||
libMetadata
|
||||
}
|
||||
|
||||
dependencies {
|
||||
libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic'
|
||||
}
|
||||
|
||||
task downloadLibs(type: Sync) {
|
||||
from configurations.libMetadata
|
||||
into "$buildDir/libs"
|
||||
rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar"
|
||||
}
|
||||
|
||||
preBuild.dependsOn downloadLibs
|
||||
|
||||
dependencies {
|
||||
compileOnly 'org.checkerframework:checker-qual:2.5.8'
|
||||
api 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
|
||||
api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
|
||||
api files("$buildDir/libs/tensorflow-lite-support-metadata.jar")
|
||||
implementation 'org.apache.commons:commons-compress:1.19'
|
||||
})");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateAndroidManifestContent(CodeWriter* code_writer,
|
||||
const ModelInfo& model_info) {
|
||||
code_writer->Append(R"(<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="{{PACKAGE}}">
|
||||
</manifest>)");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) {
|
||||
code_writer->Append("# {{MODEL_CLASS_NAME}} Usage");
|
||||
code_writer->AppendNoNewLine(R"(
|
||||
```
|
||||
import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
|
||||
|
||||
// 1. Initialize the Model
|
||||
{{MODEL_CLASS_NAME}} model = null;
|
||||
|
||||
try {
|
||||
model = new {{MODEL_CLASS_NAME}}(context); // android.content.Context
|
||||
// Create the input container.
|
||||
{{MODEL_CLASS_NAME}}.Inputs inputs = model.createInputs();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
||||
if (model != null) {
|
||||
|
||||
// 2. Set the inputs)");
|
||||
for (const auto& t : model_info.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, t);
|
||||
if (t.content_type == "image") {
|
||||
code_writer->Append(R"(
|
||||
// Load input tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
|
||||
Bitmap bitmap = ...;
|
||||
inputs.load{{NAME_U}}(bitmap);
|
||||
// Alternatively, load the input tensor "{{NAME}}" from a TensorImage.
|
||||
// Check out TensorImage documentation to load other image data structures.
|
||||
// TensorImage tensorImage = ...;
|
||||
// inputs.load{{NAME_U}}(tensorImage);)");
|
||||
} else {
|
||||
code_writer->Append(R"(
|
||||
// Load input tensor "{{NAME}}" from a TensorBuffer.
|
||||
// Check out TensorBuffer documentation to load other data structures.
|
||||
TensorBuffer tensorBuffer = ...;
|
||||
inputs.load{{NAME_U}}(tensorBuffer);)");
|
||||
}
|
||||
}
|
||||
code_writer->Append(R"(
|
||||
// 3. Run the model
|
||||
{{MODEL_CLASS_NAME}}.Outputs outputs = model.run(inputs);)");
|
||||
code_writer->Append(R"(
|
||||
// 4. Retrieve the results)");
|
||||
for (const auto& t : model_info.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, t);
|
||||
if (t.associated_axis_label_index >= 0) {
|
||||
code_writer->SetTokenValue("WRAPPER_TYPE", "Map<String, Float>");
|
||||
}
|
||||
code_writer->Append(
|
||||
R"( {{WRAPPER_TYPE}} {{NAME}} = outputs.get{{NAME_U}}();)");
|
||||
}
|
||||
code_writer->Append(R"(}
|
||||
```)");
|
||||
return true;
|
||||
}
|
||||
|
||||
GenerationResult::File GenerateWrapperFile(const std::string& module_root,
|
||||
const ModelInfo& model_info,
|
||||
ErrorReporter* err) {
|
||||
const auto java_path = JoinPath(module_root, "src/main/java");
|
||||
const auto package_path =
|
||||
JoinPath(java_path, ConvertPackageToPath(model_info.package_name));
|
||||
const auto file_path =
|
||||
JoinPath(package_path, model_info.model_class_name + JAVA_EXT);
|
||||
|
||||
CodeWriter code_writer(err);
|
||||
code_writer.SetIndentString(" ");
|
||||
SetCodeWriterWithModelInfo(&code_writer, model_info);
|
||||
|
||||
if (!GenerateWrapperFileContent(&code_writer, model_info, err)) {
|
||||
err->Error("Generating Java wrapper content failed.");
|
||||
}
|
||||
|
||||
const auto java_file = code_writer.ToString();
|
||||
return GenerationResult::File{file_path, java_file};
|
||||
}
|
||||
|
||||
GenerationResult::File GenerateBuildGradle(const std::string& module_root,
|
||||
const ModelInfo& model_info,
|
||||
ErrorReporter* err) {
|
||||
const auto file_path = JoinPath(module_root, "build.gradle");
|
||||
CodeWriter code_writer(err);
|
||||
SetCodeWriterWithModelInfo(&code_writer, model_info);
|
||||
if (!GenerateBuildGradleContent(&code_writer, model_info)) {
|
||||
err->Error("Generating build.gradle failed.");
|
||||
}
|
||||
const auto content = code_writer.ToString();
|
||||
return GenerationResult::File{file_path, content};
|
||||
}
|
||||
|
||||
GenerationResult::File GenerateAndroidManifest(const std::string& module_root,
|
||||
const ModelInfo& model_info,
|
||||
ErrorReporter* err) {
|
||||
const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml");
|
||||
CodeWriter code_writer(err);
|
||||
SetCodeWriterWithModelInfo(&code_writer, model_info);
|
||||
if (!GenerateAndroidManifestContent(&code_writer, model_info)) {
|
||||
err->Error("Generating AndroidManifest.xml failed.");
|
||||
}
|
||||
return GenerationResult::File{file_path, code_writer.ToString()};
|
||||
}
|
||||
|
||||
GenerationResult::File GenerateDoc(const std::string& module_root,
|
||||
const ModelInfo& model_info,
|
||||
ErrorReporter* err) {
|
||||
std::string lower = model_info.model_class_name;
|
||||
for (int i = 0; i < lower.length(); i++) {
|
||||
lower[i] = std::tolower(lower[i]);
|
||||
}
|
||||
const auto file_path = JoinPath(module_root, lower + ".md");
|
||||
CodeWriter code_writer(err);
|
||||
SetCodeWriterWithModelInfo(&code_writer, model_info);
|
||||
if (!GenerateDocContent(&code_writer, model_info)) {
|
||||
err->Error("Generating doc failed.");
|
||||
}
|
||||
return GenerationResult::File{file_path, code_writer.ToString()};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
|
||||
: CodeGenerator(), module_root_(module_root) {}
|
||||
|
||||
GenerationResult AndroidJavaGenerator::Generate(
|
||||
const Model* model, const std::string& package_name,
|
||||
const std::string& model_class_name, const std::string& model_asset_path) {
|
||||
GenerationResult result;
|
||||
const ModelMetadata* metadata = GetMetadataFromModel(model);
|
||||
if (metadata == nullptr) {
|
||||
err_.Error(
|
||||
"Cannot find TFLite Metadata in the model. Codegen will generate "
|
||||
"nothing.");
|
||||
return result;
|
||||
}
|
||||
details_android_java::ModelInfo model_info = CreateModelInfo(
|
||||
metadata, package_name, model_class_name, model_asset_path, &err_);
|
||||
result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_));
|
||||
result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_));
|
||||
result.files.push_back(
|
||||
GenerateAndroidManifest(module_root_, model_info, &err_));
|
||||
result.files.push_back(GenerateDoc(module_root_, model_info, &err_));
|
||||
return result;
|
||||
}
|
||||
|
||||
GenerationResult AndroidJavaGenerator::Generate(
|
||||
const char* model_storage, const std::string& package_name,
|
||||
const std::string& model_class_name, const std::string& model_asset_path) {
|
||||
const Model* model = GetModel(model_storage);
|
||||
return Generate(model, package_name, model_class_name, model_asset_path);
|
||||
}
|
||||
|
||||
std::string AndroidJavaGenerator::GetErrorMessage() {
|
||||
return err_.GetMessage();
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -0,0 +1,107 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
namespace details_android_java {
|
||||
|
||||
/// The intermediate data structure for generating code from TensorMetadata.
|
||||
/// Should only be used as const reference when created.
|
||||
struct TensorInfo {
|
||||
std::string name;
|
||||
std::string upper_camel_name;
|
||||
std::string content_type;
|
||||
std::string wrapper_type;
|
||||
std::string processor_type;
|
||||
bool is_input;
|
||||
/// Optional. Set to -1 if not applicable.
|
||||
int normalization_unit;
|
||||
/// Optional. Set to -1 if associated_axis_label is empty.
|
||||
int associated_axis_label_index;
|
||||
/// Optional. Set to -1 if associated_value_label is empty.
|
||||
int associated_value_label_index;
|
||||
};
|
||||
|
||||
/// The intermediate data structure for generating code from ModelMetadata.
|
||||
/// Should only be used as const reference when created.
|
||||
struct ModelInfo {
|
||||
std::string package_name;
|
||||
std::string model_asset_path;
|
||||
std::string model_class_name;
|
||||
std::string model_versioned_name;
|
||||
std::vector<TensorInfo> inputs;
|
||||
std::vector<TensorInfo> outputs;
|
||||
};
|
||||
|
||||
} // namespace details_android_java
|
||||
|
||||
constexpr char JAVA_EXT[] = ".java";
|
||||
|
||||
/// Generates Android supporting codes and modules (in Java) based on TFLite
|
||||
/// metadata.
|
||||
class AndroidJavaGenerator : public CodeGenerator {
|
||||
public:
|
||||
/// Creates an AndroidJavaGenerator.
|
||||
/// Args:
|
||||
/// - module_root: The root of destination Java module.
|
||||
explicit AndroidJavaGenerator(const std::string& module_root);
|
||||
|
||||
/// Generates files. Returns the file paths and contents.
|
||||
/// Args:
|
||||
/// - model: The TFLite model with Metadata filled.
|
||||
/// - package_name: The name of the Java package which generated classes
|
||||
/// belong to.
|
||||
/// - model_class_name: A readable name of the generated wrapper class, such
|
||||
/// as "ImageClassifier", "MobileNetV2" or "MyModel".
|
||||
/// - model_asset_path: The relevant path to the model file in the asset.
|
||||
// TODO(b/141225157): Automatically generate model_class_name.
|
||||
GenerationResult Generate(const Model* model, const std::string& package_name,
|
||||
const std::string& model_class_name,
|
||||
const std::string& model_asset_path);
|
||||
|
||||
/// Generates files and returns the file paths and contents.
|
||||
/// It's mostly identical with the previous one, but the model here is
|
||||
/// provided as binary flatbuffer content without parsing.
|
||||
GenerationResult Generate(const char* model_storage,
|
||||
const std::string& package_name,
|
||||
const std::string& model_class_name,
|
||||
const std::string& model_asset_path);
|
||||
|
||||
std::string GetErrorMessage();
|
||||
|
||||
private:
|
||||
const std::string module_root_;
|
||||
ErrorReporter err_;
|
||||
};
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
|
179
tensorflow/lite/experimental/support/codegen/code_generator.cc
Normal file
179
tensorflow/lite/experimental/support/codegen/code_generator.cc
Normal file
@ -0,0 +1,179 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||
|
||||
#include <cctype>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
namespace {
|
||||
|
||||
void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) {
|
||||
auto& names = *names_ptr;
|
||||
std::unordered_map<std::string, int> indexes;
|
||||
std::unordered_map<std::string, int> first_appearance;
|
||||
for (int i = 0; i < names.size(); i++) {
|
||||
if (indexes.find(names[i]) == indexes.end()) {
|
||||
indexes[names[i]] = 1;
|
||||
first_appearance[names[i]] = i;
|
||||
} else {
|
||||
indexes[names[i]] += 1;
|
||||
names[i].append(std::to_string(indexes[names[i]]));
|
||||
}
|
||||
}
|
||||
for (const auto it : first_appearance) {
|
||||
const auto& name = it.first;
|
||||
const auto i = it.second;
|
||||
if (indexes[name] > 1) {
|
||||
names[i].append("1");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CodeGenerator::CodeGenerator() {}
|
||||
|
||||
bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata,
|
||||
ErrorReporter* err) {
|
||||
if (metadata == nullptr) {
|
||||
err->Error("Loading nullptr is not allowed");
|
||||
return false;
|
||||
}
|
||||
if (metadata->subgraph_metadata()->size() != 1) {
|
||||
err->Error("Only exact 1 subgraph is supported");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<std::vector<std::string>, std::vector<std::string>>
|
||||
CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs,
|
||||
const TensorMetadataList* outputs) {
|
||||
std::vector<std::string> input_names;
|
||||
std::vector<std::string> output_names;
|
||||
if (inputs != nullptr) {
|
||||
input_names.reserve(inputs->size());
|
||||
for (const auto* tensor : *inputs) {
|
||||
input_names.push_back(NameTensor(*tensor, "input"));
|
||||
}
|
||||
}
|
||||
if (outputs != nullptr) {
|
||||
output_names.reserve(outputs->size());
|
||||
for (const auto* tensor : *outputs) {
|
||||
output_names.push_back(NameTensor(*tensor, "output"));
|
||||
}
|
||||
}
|
||||
// Solve conflict
|
||||
ResolveConflictedInputAndOutputNames(&input_names, &output_names);
|
||||
return std::make_pair(input_names, output_names);
|
||||
}
|
||||
|
||||
std::string CodeGenerator::ConvertToValidName(const std::string& name) {
|
||||
// lowercase all
|
||||
std::string result = name;
|
||||
for (int i = 0; i < result.size(); i++) {
|
||||
result[i] = std::tolower(result[i]);
|
||||
}
|
||||
// replace all non-alpha or non-numeric with underscores, except underscore
|
||||
// itself
|
||||
for (int i = 0; i < result.size(); i++) {
|
||||
if (result[i] != '_' && !std::isalnum(result[i])) {
|
||||
result[i] = '_';
|
||||
}
|
||||
}
|
||||
// remove leading underscores
|
||||
int leading_underscores = 0;
|
||||
while (leading_underscores < result.size() &&
|
||||
result[leading_underscores] == '_') {
|
||||
leading_underscores++;
|
||||
}
|
||||
result.erase(0, leading_underscores);
|
||||
if (result.empty()) {
|
||||
return "";
|
||||
}
|
||||
// first char should be alpha
|
||||
if (std::isalpha(result[0])) {
|
||||
return result;
|
||||
}
|
||||
return "tensor_" + result;
|
||||
}
|
||||
|
||||
std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
|
||||
const std::string& default_name) {
|
||||
if (tensor.name() != nullptr && tensor.name()->size() > 0) {
|
||||
// TODO(b/141225157) Validate tensor name. It should be in lower case.
|
||||
auto suggested_name = ConvertToValidName(tensor.name()->str());
|
||||
if (!suggested_name.empty()) {
|
||||
return suggested_name;
|
||||
}
|
||||
}
|
||||
auto* content = tensor.content();
|
||||
if (content == nullptr || content->content_properties() == nullptr) {
|
||||
return default_name;
|
||||
}
|
||||
switch (content->content_properties_type()) {
|
||||
case ContentProperties_ImageProperties:
|
||||
return "image";
|
||||
case ContentProperties_FeatureProperties:
|
||||
return "feature";
|
||||
default:
|
||||
return default_name;
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenerator::ResolveConflictedInputAndOutputNames(
|
||||
std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
|
||||
std::unordered_set<std::string> io_conflict;
|
||||
auto& input_names = *inputs;
|
||||
auto& output_names = *outputs;
|
||||
for (const auto input : input_names) {
|
||||
if (io_conflict.find(input) != io_conflict.end()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto output : output_names) {
|
||||
if (input == output) {
|
||||
io_conflict.insert(input);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < input_names.size(); i++) {
|
||||
if (io_conflict.find(input_names[i]) != io_conflict.end()) {
|
||||
input_names[i] = "input_" + input_names[i];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < output_names.size(); i++) {
|
||||
if (io_conflict.find(output_names[i]) != io_conflict.end()) {
|
||||
output_names[i] = "output_" + output_names[i];
|
||||
}
|
||||
}
|
||||
// 2. Second, add index if input[i] == input[j]
|
||||
ResolveConflictedNamesByAddingIndex(&input_names);
|
||||
ResolveConflictedNamesByAddingIndex(&output_names);
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -0,0 +1,80 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
struct GenerationResult {
|
||||
struct File {
|
||||
std::string path;
|
||||
std::string content;
|
||||
};
|
||||
std::vector<File> files;
|
||||
};
|
||||
|
||||
/// Defines language-independent codegen strategies, like class naming, .etc.
|
||||
/// Should not be used directly.
|
||||
class CodeGenerator {
|
||||
public:
|
||||
CodeGenerator();
|
||||
|
||||
using TensorMetadataList =
|
||||
typename flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>;
|
||||
|
||||
virtual ~CodeGenerator() {}
|
||||
|
||||
// Strategies.
|
||||
/// Names all the IO tensors. It's useful when they don't have names, or the
|
||||
/// names have conflicts. We have to name every tensor for code generation.
|
||||
// TODO(b/141225157): Add reserved keywords check.
|
||||
static std::pair<std::vector<std::string>, std::vector<std::string>>
|
||||
NameInputsAndOutputs(const TensorMetadataList* inputs,
|
||||
const TensorMetadataList* outputs);
|
||||
|
||||
/// Loads a metadata for code generation.
|
||||
/// Returns false if the metadata is not good for generation.
|
||||
static bool VerifyMetadata(const ModelMetadata* metadata, ErrorReporter* err);
|
||||
|
||||
protected:
|
||||
/// Converts a name into a valid form. Rules:
|
||||
/// - lower all letters.
|
||||
/// - replace all non alphabet nor numeric characters with underscores.
|
||||
/// - remove prefix underscores.
|
||||
/// - add prefix if the leading character is a number.
|
||||
/// Returns empty string if not possible.
|
||||
static std::string ConvertToValidName(const std::string& name);
|
||||
static std::string NameTensor(const TensorMetadata& tensor,
|
||||
const std::string& default_name);
|
||||
static void ResolveConflictedInputAndOutputNames(
|
||||
std::vector<std::string>* input, std::vector<std::string>* output);
|
||||
};
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
|
@ -0,0 +1,126 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class CodeGeneratorTest : public ::testing::Test {
|
||||
public:
|
||||
class TestingCodeGenerator : public CodeGenerator {
|
||||
public:
|
||||
explicit TestingCodeGenerator() : CodeGenerator() {}
|
||||
|
||||
// Make tested method public.
|
||||
static std::string ConvertToValidName(const std::string& name) {
|
||||
return CodeGenerator::ConvertToValidName(name);
|
||||
}
|
||||
static void ResolveConflictedInputAndOutputNames(
|
||||
std::vector<std::string>* input, std::vector<std::string>* output) {
|
||||
CodeGenerator::ResolveConflictedInputAndOutputNames(input, output);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
TEST_F(CodeGeneratorTest, UpperCasesShouldLower) {
|
||||
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("AlphaBetCOOL"),
|
||||
"alphabetcool");
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, NonAlphaNumShouldReplace) {
|
||||
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("A+=B C\t"), "a__b_c_");
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, NoLeadingUnderscore) {
|
||||
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("+KAI Z"), "kai_z");
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, NoLeadingNumbers) {
|
||||
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("3000 Cool Tensors"),
|
||||
"tensor_3000_cool_tensors");
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestSimpleIONames) {
|
||||
std::vector<std::string> inputs = {"image"};
|
||||
std::vector<std::string> outputs = {"output"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs, ElementsAreArray({"image"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestIOConflict) {
|
||||
std::vector<std::string> inputs = {"image"};
|
||||
std::vector<std::string> outputs = {"image"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs, ElementsAreArray({"input_image"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestInternalConflict) {
|
||||
std::vector<std::string> inputs = {"image", "image"};
|
||||
std::vector<std::string> outputs = {"output"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs, ElementsAreArray({"image1", "image2"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestAllConflictNTo1) {
|
||||
std::vector<std::string> inputs = {"image", "image"};
|
||||
std::vector<std::string> outputs = {"image"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs, ElementsAreArray({"input_image1", "input_image2"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestAllConflict) {
|
||||
std::vector<std::string> inputs = {"image", "audio", "image", "audio",
|
||||
"audio"};
|
||||
std::vector<std::string> outputs = {"image", "image", "audio", "feature",
|
||||
"feature"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs,
|
||||
ElementsAreArray({"input_image1", "input_audio1", "input_image2",
|
||||
"input_audio2", "input_audio3"}));
|
||||
EXPECT_THAT(outputs,
|
||||
ElementsAreArray({"output_image1", "output_image2",
|
||||
"output_audio", "feature1", "feature2"}));
|
||||
}
|
||||
|
||||
TEST_F(CodeGeneratorTest, TestAllConflictReversed) {
|
||||
std::vector<std::string> inputs = {"image", "image", "audio", "feature",
|
||||
"feature"};
|
||||
std::vector<std::string> outputs = {"image", "audio", "image", "audio",
|
||||
"audio"};
|
||||
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||
EXPECT_THAT(inputs,
|
||||
ElementsAreArray({"input_image1", "input_image2", "input_audio",
|
||||
"feature1", "feature2"}));
|
||||
EXPECT_THAT(outputs, ElementsAreArray({"output_image1", "output_audio1",
|
||||
"output_image2", "output_audio2",
|
||||
"output_audio3"}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -0,0 +1,92 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h"
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
constexpr char BUFFER_KEY[] = "TFLITE_METADATA";
|
||||
const ModelMetadata* GetMetadataFromModel(const Model* model) {
|
||||
if (model->metadata() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
for (auto i = 0; i < model->metadata()->size(); i++) {
|
||||
if (model->metadata()->Get(i)->name()->str() == BUFFER_KEY) {
|
||||
const auto buffer_index = model->metadata()->Get(i)->buffer();
|
||||
const auto* buffer = model->buffers()->Get(buffer_index)->data()->data();
|
||||
return GetModelMetadata(buffer);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int FindAssociatedFile(const TensorMetadata* metadata,
|
||||
const AssociatedFileType file_type,
|
||||
const std::string& tensor_identifier,
|
||||
ErrorReporter* err) {
|
||||
int result = -1;
|
||||
if (metadata->associated_files() == nullptr ||
|
||||
metadata->associated_files()->size() == 0) {
|
||||
return result;
|
||||
}
|
||||
for (int i = 0; i < metadata->associated_files()->size(); i++) {
|
||||
const auto* file_metadata = metadata->associated_files()->Get(i);
|
||||
if (file_metadata->type() == file_type) {
|
||||
if (result >= 0) {
|
||||
err->Warning(
|
||||
"Multiple associated file of type %d found on tensor %s. Only the "
|
||||
"first one will be used.",
|
||||
file_type, tensor_identifier.c_str());
|
||||
continue;
|
||||
}
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int FindNormalizationUnit(const TensorMetadata* metadata,
|
||||
const std::string& tensor_identifier,
|
||||
ErrorReporter* err) {
|
||||
int result = -1;
|
||||
if (metadata->process_units() == nullptr ||
|
||||
metadata->process_units()->size() == 0) {
|
||||
return result;
|
||||
}
|
||||
for (int i = 0; i < metadata->process_units()->size(); i++) {
|
||||
const auto* process_uint = metadata->process_units()->Get(i);
|
||||
if (process_uint->options_type() ==
|
||||
ProcessUnitOptions_NormalizationOptions) {
|
||||
if (result >= 0) {
|
||||
err->Warning(
|
||||
"Multiple normalization unit found in tensor %s. Only the first "
|
||||
"one will be effective.",
|
||||
tensor_identifier.c_str());
|
||||
continue;
|
||||
}
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
@ -0,0 +1,51 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
/// Parses a ModelMetadata out from a Model. The returned ModelMetadata's
|
||||
/// lifetime is scoped by the model.
|
||||
/// Returns nullptr if we cannot find any metadata.
|
||||
const ModelMetadata* GetMetadataFromModel(const Model* model);
|
||||
|
||||
/// Finds an associated file from a TensorMetadata of certain type. If there're
|
||||
/// multiple files meet the criteria, only the first one is used. If there's no
|
||||
/// file meets the criteria, -1 will be returned.
|
||||
int FindAssociatedFile(const TensorMetadata* metadata,
|
||||
const AssociatedFileType file_type,
|
||||
const std::string& tensor_identifier,
|
||||
ErrorReporter* err);
|
||||
|
||||
/// Find the first normalization unit. If none, return -1.
|
||||
int FindNormalizationUnit(const TensorMetadata* metadata,
|
||||
const std::string& tensor_identifier,
|
||||
ErrorReporter* err);
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
|
38
tensorflow/lite/experimental/support/codegen/python/BUILD
Normal file
38
tensorflow/lite/experimental/support/codegen/python/BUILD
Normal file
@ -0,0 +1,38 @@
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_pywrap_codegen",
|
||||
srcs = [
|
||||
"codegen_lib.cc",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_pywrap_codegen",
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/support/codegen:android_java_generator",
|
||||
"//tensorflow/lite/experimental/support/codegen:code_generator",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "codegen",
|
||||
srcs = [
|
||||
"codegen.py",
|
||||
],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":_pywrap_codegen",
|
||||
"@absl_py//absl:app",
|
||||
"@absl_py//absl/flags",
|
||||
"@absl_py//absl/logging",
|
||||
],
|
||||
)
|
@ -0,0 +1,96 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Generates Android Java sources from a TFLite model with metadata."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.lite.experimental.support.codegen.python import _pywrap_codegen
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.')
|
||||
flags.DEFINE_string('destination', None, 'Path of destination of generation.')
|
||||
flags.DEFINE_string('package_name', 'org.tensorflow.lite.support',
|
||||
'Name of generated java package to put the wrapper class.')
|
||||
flags.DEFINE_string(
|
||||
'model_class_name', 'MyModel',
|
||||
'Name of generated wrapper class (should not contain package name).')
|
||||
flags.DEFINE_string(
|
||||
'model_asset_path', '',
|
||||
'(Optional) Path to the model in generated assets/ dir. If not set, '
|
||||
'generator will use base name of input model.'
|
||||
)
|
||||
|
||||
|
||||
def get_model_buffer(path):
|
||||
if not os.path.isfile(path):
|
||||
logging.error('Cannot find model at path %s.', path)
|
||||
with open(path, 'rb') as f:
|
||||
buf = f.read()
|
||||
return buf
|
||||
|
||||
|
||||
def prepare_directory_for_file(file_path):
|
||||
target_dir = os.path.dirname(file_path)
|
||||
if not os.path.exists(target_dir):
|
||||
os.makedirs(target_dir)
|
||||
return
|
||||
if not os.path.isdir(target_dir):
|
||||
logging.error('Cannot write to %s', target_dir)
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) > 1:
|
||||
logging.error('None flag arguments found: [%s]', ', '.join(argv[1:]))
|
||||
|
||||
codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination)
|
||||
model_buffer = get_model_buffer(FLAGS.model)
|
||||
model_asset_path = FLAGS.model_asset_path
|
||||
if not model_asset_path:
|
||||
model_asset_path = os.path.basename(FLAGS.model)
|
||||
result = codegen.generate(model_buffer, FLAGS.package_name,
|
||||
FLAGS.model_class_name, model_asset_path)
|
||||
error_message = codegen.get_error_message().strip()
|
||||
if error_message:
|
||||
logging.error(error_message)
|
||||
if not result.files:
|
||||
logging.error('Generation failed!')
|
||||
return
|
||||
|
||||
for each in result.files:
|
||||
prepare_directory_for_file(each.path)
|
||||
with open(each.path, 'w') as f:
|
||||
f.write(each.content)
|
||||
|
||||
logging.info('Generation succeeded!')
|
||||
model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets',
|
||||
model_asset_path)
|
||||
prepare_directory_for_file(model_asset_path)
|
||||
shutil.copy(FLAGS.model, model_asset_path)
|
||||
logging.info('Model copied into assets!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('model')
|
||||
flags.mark_flag_as_required('destination')
|
||||
app.run(main)
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "include/pybind11/detail/common.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/pytypes.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
|
||||
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
template <typename... Args>
|
||||
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_codegen, m) {
|
||||
pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator")
|
||||
.def(pybind11::init<const std::string &>())
|
||||
.def("generate",
|
||||
overload_cast_<const char *, const std::string &,
|
||||
const std::string &, const std::string &>()(
|
||||
&AndroidJavaGenerator::Generate))
|
||||
.def("get_error_message", &AndroidJavaGenerator::GetErrorMessage);
|
||||
pybind11::class_<GenerationResult>(m, "GenerationResult")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("files", &GenerationResult::files);
|
||||
pybind11::class_<GenerationResult::File>(m, "GenerationResultFile")
|
||||
.def(pybind11::init<>())
|
||||
.def_readwrite("path", &GenerationResult::File::path)
|
||||
.def_readwrite("content", &GenerationResult::File::content);
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
194
tensorflow/lite/experimental/support/codegen/utils.cc
Normal file
194
tensorflow/lite/experimental/support/codegen/utils.cc
Normal file
@ -0,0 +1,194 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
|
||||
#include <cstdarg>
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
int ErrorReporter::Warning(const char* format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
return Report("[WARN] ", format, args);
|
||||
}
|
||||
|
||||
int ErrorReporter::Error(const char* format, ...) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
return Report("[ERROR] ", format, args);
|
||||
}
|
||||
|
||||
int ErrorReporter::Report(const char* prefix, const char* format,
|
||||
va_list args) {
|
||||
char buf[1024];
|
||||
int formatted = vsnprintf(buf, sizeof(buf), format, args);
|
||||
buffer_ << prefix << buf << std::endl;
|
||||
return formatted;
|
||||
}
|
||||
|
||||
std::string ErrorReporter::GetMessage() {
|
||||
std::string value = buffer_.str();
|
||||
buffer_.str("");
|
||||
return value;
|
||||
}
|
||||
|
||||
CodeWriter::CodeWriter(ErrorReporter* err) : indent_(0), err_(err) {}
|
||||
|
||||
void CodeWriter::SetTokenValue(const std::string& token,
|
||||
const std::string& value) {
|
||||
value_map_[token] = value;
|
||||
}
|
||||
|
||||
const std::string CodeWriter::GetTokenValue(const std::string& token) const {
|
||||
auto iter = value_map_.find(token);
|
||||
if (iter == value_map_.end()) {
|
||||
// Typically only Code Generator's call this function (or `Append`). It's
|
||||
// their duty to make sure the token is valid, and requesting for an invalid
|
||||
// token implicits flaws in the code generation logic.
|
||||
err_->Error("Internal: Cannot find value with token '%s'", token.c_str());
|
||||
return "";
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void CodeWriter::SetIndentString(const std::string& indent_str) {
|
||||
indent_str_ = indent_str;
|
||||
}
|
||||
|
||||
void CodeWriter::Indent() { indent_++; }
|
||||
|
||||
void CodeWriter::Outdent() { indent_--; }
|
||||
|
||||
std::string CodeWriter::GenerateIndent() const {
|
||||
std::string res;
|
||||
res.reserve(indent_str_.size() * indent_);
|
||||
for (int i = 0; i < indent_; i++) {
|
||||
res.append(indent_str_);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); }
|
||||
|
||||
void CodeWriter::AppendNoNewLine(const std::string& text) {
|
||||
AppendInternal(text, false);
|
||||
}
|
||||
|
||||
void CodeWriter::AppendInternal(const std::string& text, bool newline) {
|
||||
// Prefix indent
|
||||
if ((buffer_.empty() // nothing in the buffer
|
||||
|| buffer_.back() == '\n') // is on new line
|
||||
// is writing on current line
|
||||
&& (!text.empty() && text[0] != '\n' && text[0] != '\r')) {
|
||||
buffer_.append(GenerateIndent());
|
||||
}
|
||||
// State machine variables
|
||||
bool in_token = false;
|
||||
int i = 0;
|
||||
// Rough memory reserve
|
||||
buffer_.reserve(buffer_.size() + text.size());
|
||||
std::string token_buffer;
|
||||
// A simple LL1 analysis
|
||||
while (i < text.size()) {
|
||||
char cur = text[i];
|
||||
char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian
|
||||
if (in_token == false) {
|
||||
if (cur == '{' && cur_next == '{') { // Enter token
|
||||
in_token = true;
|
||||
i += 2;
|
||||
} else if (cur == '\n') { // We need to apply global indent here
|
||||
buffer_.push_back(cur);
|
||||
if (cur_next != '\0' && cur_next != '\n' && cur_next != '\r') {
|
||||
buffer_.append(GenerateIndent());
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
buffer_.push_back(cur);
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
if (cur == '}' && cur_next == '}') { // Close token
|
||||
in_token = false;
|
||||
const auto value = GetTokenValue(token_buffer);
|
||||
buffer_.append(value);
|
||||
token_buffer.clear();
|
||||
i += 2;
|
||||
} else {
|
||||
token_buffer.push_back(cur);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!token_buffer.empty()) {
|
||||
// Typically only Code Generator's call this function. It's
|
||||
// their duty to make sure the code (or template) has valid syntax, and
|
||||
// unclosed "{{...}}" implicits severe error in the template.
|
||||
err_->Error("Internal: Invalid template: {{token}} is not closed.");
|
||||
}
|
||||
if (newline) {
|
||||
buffer_.push_back('\n');
|
||||
}
|
||||
}
|
||||
|
||||
void CodeWriter::NewLine() { Append(""); }
|
||||
|
||||
void CodeWriter::Backspace(int n) {
|
||||
buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0);
|
||||
}
|
||||
|
||||
std::string CodeWriter::ToString() const { return buffer_; }
|
||||
|
||||
bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); }
|
||||
|
||||
void CodeWriter::Clear() {
|
||||
buffer_.clear();
|
||||
value_map_.clear();
|
||||
indent_ = 0;
|
||||
}
|
||||
|
||||
std::string SnakeCaseToCamelCase(const std::string& s) {
|
||||
std::string t;
|
||||
t.reserve(s.length());
|
||||
size_t i = 0;
|
||||
// Note: Use simple string += for simplicity.
|
||||
bool cap = false;
|
||||
while (i < s.size()) {
|
||||
const char c = s[i++];
|
||||
if (c == '_') {
|
||||
cap = true;
|
||||
} else if (cap) {
|
||||
t += toupper(c);
|
||||
cap = false;
|
||||
} else {
|
||||
t += c;
|
||||
}
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
std::string JoinPath(const std::string& a, const std::string& b) {
|
||||
if (a.empty()) return b;
|
||||
std::string a_fixed = a;
|
||||
if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back();
|
||||
std::string b_fixed = b;
|
||||
if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1);
|
||||
return a_fixed + "/" + b_fixed;
|
||||
}
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
127
tensorflow/lite/experimental/support/codegen/utils.h
Normal file
127
tensorflow/lite/experimental/support/codegen/utils.h
Normal file
@ -0,0 +1,127 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
|
||||
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
|
||||
/// Collects runtime error logs which could be showed later.
|
||||
// TODO(b/150538286): Consider a better mechanism to simplify callsite code.
|
||||
class ErrorReporter {
|
||||
public:
|
||||
int Warning(const char* format, ...);
|
||||
int Error(const char* format, ...);
|
||||
std::string GetMessage();
|
||||
|
||||
private:
|
||||
int Report(const char* prefix, const char* format, va_list args);
|
||||
std::stringstream buffer_;
|
||||
};
|
||||
|
||||
/// Implements basic code generating with text templates.
|
||||
///
|
||||
/// It could accept code templates and concatenate them into complete codes. A
|
||||
/// template could contain named values.
|
||||
///
|
||||
/// Example code:
|
||||
/// CodeWriter code;
|
||||
/// code.SetValue("NAME", "Foo");
|
||||
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
|
||||
/// code.SetValue("NAME", "Bar");
|
||||
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
|
||||
///
|
||||
/// Output:
|
||||
/// void Foo() { printf("%s", "Foo"); }
|
||||
/// void Bar() { printf("%s", "Bar"); }
|
||||
class CodeWriter {
|
||||
public:
|
||||
explicit CodeWriter(ErrorReporter* err);
|
||||
/// Sets value to a token. When generating code with template, a string in a
|
||||
/// pair of {{ and }} will be regarded as a token and replaced with the
|
||||
/// corresponding value in code generation.
|
||||
/// It rewrites if the token already has a value.
|
||||
void SetTokenValue(const std::string& token, const std::string& value);
|
||||
|
||||
/// Gets the current value set on the given token.
|
||||
const std::string GetTokenValue(const std::string& token) const;
|
||||
|
||||
/// Sets the unit indent string. For example, in Java it should be " ".
|
||||
void SetIndentString(const std::string& indent);
|
||||
|
||||
/// Increases the indent by a unit (the string set in SetIndentString).
|
||||
void Indent();
|
||||
|
||||
/// Decreases the indent by a unit (the string set in SetIndentString).
|
||||
void Outdent();
|
||||
|
||||
/// Generates the indentation string.
|
||||
std::string GenerateIndent() const;
|
||||
|
||||
/// Appends a piece of template codes to the stream. Every named value will be
|
||||
/// replaced via the real value. A new line will always be appended at the
|
||||
/// end.
|
||||
void Append(const std::string& text);
|
||||
|
||||
/// Appends a piece of template codes to the stream. Same with `Append`, but a
|
||||
/// new line will not be appended at the end.
|
||||
void AppendNoNewLine(const std::string& text);
|
||||
|
||||
/// Appends a new line to the stream.
|
||||
void NewLine();
|
||||
|
||||
/// Deletes the last N charaters in the stream. If the stream has less than N
|
||||
/// characters, deletes all.
|
||||
void Backspace(int n);
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
/// Checks if the internal string stream is empty. Note: This method has
|
||||
// overhead.
|
||||
bool IsStreamEmpty() const;
|
||||
|
||||
/// Clears all the internal string stream and value map.
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
void AppendInternal(const std::string& text, bool newline);
|
||||
|
||||
std::string indent_str_;
|
||||
int indent_;
|
||||
|
||||
std::map<std::string, std::string> value_map_;
|
||||
std::string buffer_;
|
||||
|
||||
ErrorReporter* err_;
|
||||
};
|
||||
|
||||
/// Converts foo_bar_name to fooBarName. It's callers duty to make sure given
|
||||
/// string "s" is already in snake case; or unexpected behavior may occur.
|
||||
std::string SnakeCaseToCamelCase(const std::string& s);
|
||||
|
||||
/// Joins 2 parts of file path into one, connected by unix path seperator '/'.
|
||||
/// It's callers duty to ensure the two parts are valid.
|
||||
std::string JoinPath(const std::string& a, const std::string& b);
|
||||
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
|
97
tensorflow/lite/experimental/support/codegen/utils_test.cc
Normal file
97
tensorflow/lite/experimental/support/codegen/utils_test.cc
Normal file
@ -0,0 +1,97 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
namespace support {
|
||||
namespace codegen {
|
||||
namespace {
|
||||
|
||||
TEST(ErrorReporterTest, TestReportError) {
|
||||
ErrorReporter err;
|
||||
err.Error("some text");
|
||||
EXPECT_EQ(err.GetMessage(), "[ERROR] some text\n");
|
||||
EXPECT_EQ(err.GetMessage(), "");
|
||||
}
|
||||
|
||||
TEST(CodeGeneratorTest, TestExample) {
|
||||
ErrorReporter err;
|
||||
CodeWriter writer(&err);
|
||||
writer.SetTokenValue("NAME", "Foo");
|
||||
const std::string text = R"(void {{NAME}}() { printf("%s", "{{NAME}}"); })";
|
||||
writer.Append(text);
|
||||
writer.SetTokenValue("NAME", "Bar");
|
||||
writer.Append(text);
|
||||
EXPECT_EQ(
|
||||
"void Foo() { printf(\"%s\", \"Foo\"); }\n"
|
||||
"void Bar() { printf(\"%s\", \"Bar\"); }\n",
|
||||
writer.ToString());
|
||||
}
|
||||
|
||||
TEST(CodeGeneratorTest, TestInexistentToken) {
|
||||
ErrorReporter err;
|
||||
CodeWriter writer(&err);
|
||||
writer.SetTokenValue("NAME", "Foo");
|
||||
const std::string text = R"(void {{name}}() {})";
|
||||
writer.Append(text);
|
||||
EXPECT_EQ(err.GetMessage(),
|
||||
"[ERROR] Internal: Cannot find value with token 'name'\n");
|
||||
}
|
||||
|
||||
TEST(CodeGeneratorTest, TestUnclosedToken) {
|
||||
ErrorReporter err;
|
||||
CodeWriter writer(&err);
|
||||
writer.SetTokenValue("NAME", "Foo");
|
||||
const std::string text = R"(void {{NAME}() {})";
|
||||
writer.Append(text);
|
||||
EXPECT_EQ(err.GetMessage(),
|
||||
"[ERROR] Internal: Invalid template: {{token}} is not closed.\n");
|
||||
}
|
||||
|
||||
TEST(CodeGeneratorTest, TestIndentControl) {
|
||||
ErrorReporter err;
|
||||
CodeWriter writer(&err);
|
||||
writer.SetIndentString(" ");
|
||||
writer.Indent();
|
||||
writer.AppendNoNewLine("abcde"); // Will indent
|
||||
EXPECT_EQ(" abcde", writer.ToString());
|
||||
writer.Clear();
|
||||
writer.Indent();
|
||||
writer.AppendNoNewLine("abc\n\nde");
|
||||
// The blank line will not indent
|
||||
EXPECT_EQ(" abc\n\n de", writer.ToString());
|
||||
writer.Clear();
|
||||
writer.Indent();
|
||||
writer.Append("abc");
|
||||
writer.Outdent();
|
||||
writer.AppendNoNewLine("def");
|
||||
EXPECT_EQ(" abc\ndef", writer.ToString());
|
||||
}
|
||||
|
||||
TEST(CaseConversionTest, TestSnakeToCamel) {
|
||||
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel"));
|
||||
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel_"));
|
||||
EXPECT_EQ("ImACamel", SnakeCaseToCamelCase("_im_a_camel"));
|
||||
EXPECT_EQ("", SnakeCaseToCamelCase("_"));
|
||||
EXPECT_EQ("camel", SnakeCaseToCamelCase("camel"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace codegen
|
||||
} // namespace support
|
||||
} // namespace tflite
|
Loading…
x
Reference in New Issue
Block a user