Opensource TFLite Support codegen.
PiperOrigin-RevId: 302011153 Change-Id: Idb2f649dc48fdc449fac2d6e9009719d29afb2ad
This commit is contained in:
parent
abf182a882
commit
b57d910db5
87
tensorflow/lite/experimental/support/codegen/BUILD
Normal file
87
tensorflow/lite/experimental/support/codegen/BUILD
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# The tools for generating wrapper classes for a TFLite model with metadata.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "utils",
|
||||||
|
srcs = [
|
||||||
|
"utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"utils.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "code_generator",
|
||||||
|
srcs = [
|
||||||
|
"code_generator.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"code_generator.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "metadata_helper",
|
||||||
|
srcs = [
|
||||||
|
"metadata_helper.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"metadata_helper.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "android_java_generator",
|
||||||
|
srcs = [
|
||||||
|
"android_java_generator.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"android_java_generator.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":code_generator",
|
||||||
|
":metadata_helper",
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
|
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
|
||||||
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "code_generator_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["code_generator_test.cc"],
|
||||||
|
data = ["//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs"],
|
||||||
|
deps = [
|
||||||
|
":code_generator",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "utils_test",
|
||||||
|
srcs = ["utils_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":utils",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
13
tensorflow/lite/experimental/support/codegen/README.md
Normal file
13
tensorflow/lite/experimental/support/codegen/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# TensorFlow Lite Android Wrapper Code Generator
|
||||||
|
|
||||||
|
For TensorFlow Lite model enhanced with [metadata](https://www.tensorflow.org/lite/convert/metadata.md),
|
||||||
|
developers can use the TensorFlow Lite Android wrapper code generator to create
|
||||||
|
platform specific wrapper code. The wrapper code removes the need to interact
|
||||||
|
directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow
|
||||||
|
Lite model with typed objects such as `Bitmap` and `Rect`.
|
||||||
|
|
||||||
|
The usefulness of the code generator depend on the completeness of the
|
||||||
|
TensorFlow Lite model's metadata entry. Refer to the `<Codegen usage>` section
|
||||||
|
under relevant fields in
|
||||||
|
[metadata_schema.fbs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs),
|
||||||
|
to see how the codegen tool parses each field.
|
@ -0,0 +1,978 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
|
||||||
|
|
||||||
|
#include <ctype.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using details_android_java::ModelInfo;
|
||||||
|
using details_android_java::TensorInfo;
|
||||||
|
|
||||||
|
// Helper class to organize the C++ code block as a generated code block.
|
||||||
|
// Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
|
||||||
|
class AsBlock {
|
||||||
|
public:
|
||||||
|
AsBlock(CodeWriter* code_writer, const std::string& before,
|
||||||
|
bool trailing_blank_line = false)
|
||||||
|
: code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
|
||||||
|
code_writer_->AppendNoNewLine(before);
|
||||||
|
code_writer_->Append(" {");
|
||||||
|
code_writer_->Indent();
|
||||||
|
}
|
||||||
|
~AsBlock() {
|
||||||
|
code_writer_->Outdent();
|
||||||
|
code_writer_->Append("}");
|
||||||
|
if (trailing_blank_line_) {
|
||||||
|
code_writer_->NewLine();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CodeWriter* code_writer_;
|
||||||
|
bool trailing_blank_line_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Declare the functions first, so that the functions can follow a logical
|
||||||
|
// order.
|
||||||
|
bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||||
|
bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||||
|
bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||||
|
bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||||
|
bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||||
|
bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*);
|
||||||
|
|
||||||
|
std::string GetModelVersionedName(const ModelMetadata* metadata) {
|
||||||
|
std::string model_name = "MyModel";
|
||||||
|
if (metadata->name() != nullptr && !(metadata->name()->str().empty())) {
|
||||||
|
model_name = metadata->name()->str();
|
||||||
|
}
|
||||||
|
std::string model_version = "unknown";
|
||||||
|
if (metadata->version() != nullptr && !(metadata->version()->str().empty())) {
|
||||||
|
model_version = metadata->version()->str();
|
||||||
|
}
|
||||||
|
return model_name + " (Version: " + model_version + ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
|
||||||
|
const std::string& name, bool is_input, int index,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
TensorInfo tensor_info;
|
||||||
|
std::string tensor_identifier = is_input ? "input" : "output";
|
||||||
|
tensor_identifier += " " + std::to_string(index);
|
||||||
|
tensor_info.associated_axis_label_index = FindAssociatedFile(
|
||||||
|
metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err);
|
||||||
|
tensor_info.associated_value_label_index = FindAssociatedFile(
|
||||||
|
metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err);
|
||||||
|
if (is_input && (tensor_info.associated_axis_label_index >= 0 ||
|
||||||
|
tensor_info.associated_value_label_index >= 0)) {
|
||||||
|
err->Warning(
|
||||||
|
"Found label file on input tensor (%s). Label file for input "
|
||||||
|
"tensor is not supported yet. The "
|
||||||
|
"file will be ignored.",
|
||||||
|
tensor_identifier.c_str());
|
||||||
|
}
|
||||||
|
if (tensor_info.associated_axis_label_index >= 0 &&
|
||||||
|
tensor_info.associated_value_label_index >= 0) {
|
||||||
|
err->Warning(
|
||||||
|
"Found both axis label file and value label file for tensor (%s), "
|
||||||
|
"which is not supported. Only the axis label file will be used.",
|
||||||
|
tensor_identifier.c_str());
|
||||||
|
}
|
||||||
|
tensor_info.is_input = is_input;
|
||||||
|
tensor_info.name = SnakeCaseToCamelCase(name);
|
||||||
|
tensor_info.upper_camel_name = tensor_info.name;
|
||||||
|
tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]);
|
||||||
|
tensor_info.normalization_unit =
|
||||||
|
FindNormalizationUnit(metadata, tensor_identifier, err);
|
||||||
|
if (metadata->content()->content_properties_type() ==
|
||||||
|
ContentProperties_ImageProperties) {
|
||||||
|
if (metadata->content()
|
||||||
|
->content_properties_as_ImageProperties()
|
||||||
|
->color_space() == ColorSpaceType_RGB) {
|
||||||
|
tensor_info.content_type = "image";
|
||||||
|
tensor_info.wrapper_type = "TensorImage";
|
||||||
|
tensor_info.processor_type = "ImageProcessor";
|
||||||
|
return tensor_info;
|
||||||
|
} else {
|
||||||
|
err->Warning(
|
||||||
|
"Found Non-RGB image on tensor (%s). Codegen currently does not "
|
||||||
|
"support it, and regard it as a plain numeric tensor.",
|
||||||
|
tensor_identifier.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_info.content_type = "tensor";
|
||||||
|
tensor_info.wrapper_type = "TensorBuffer";
|
||||||
|
tensor_info.processor_type = "TensorProcessor";
|
||||||
|
return tensor_info;
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelInfo CreateModelInfo(const ModelMetadata* metadata,
|
||||||
|
const std::string& package_name,
|
||||||
|
const std::string& model_class_name,
|
||||||
|
const std::string& model_asset_path,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
ModelInfo model_info;
|
||||||
|
if (!CodeGenerator::VerifyMetadata(metadata, err)) {
|
||||||
|
// TODO(b/150116380): Create dummy model info.
|
||||||
|
err->Error("Validating metadata failed.");
|
||||||
|
return model_info;
|
||||||
|
}
|
||||||
|
model_info.package_name = package_name;
|
||||||
|
model_info.model_class_name = model_class_name;
|
||||||
|
model_info.model_asset_path = model_asset_path;
|
||||||
|
model_info.model_versioned_name = GetModelVersionedName(metadata);
|
||||||
|
const auto* graph = metadata->subgraph_metadata()->Get(0);
|
||||||
|
auto names = CodeGenerator::NameInputsAndOutputs(
|
||||||
|
graph->input_tensor_metadata(), graph->output_tensor_metadata());
|
||||||
|
std::vector<std::string> input_tensor_names = std::move(names.first);
|
||||||
|
std::vector<std::string> output_tensor_names = std::move(names.second);
|
||||||
|
for (int i = 0; i < graph->input_tensor_metadata()->size(); i++) {
|
||||||
|
model_info.inputs.push_back(
|
||||||
|
CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
|
||||||
|
input_tensor_names[i], true, i, err));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < graph->output_tensor_metadata()->size(); i++) {
|
||||||
|
model_info.outputs.push_back(
|
||||||
|
CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
|
||||||
|
output_tensor_names[i], false, i, err));
|
||||||
|
}
|
||||||
|
return model_info;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCodeWriterWithTensorInfo(CodeWriter* code_writer,
|
||||||
|
const TensorInfo& tensor_info) {
|
||||||
|
code_writer->SetTokenValue("NAME", tensor_info.name);
|
||||||
|
code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name);
|
||||||
|
code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type);
|
||||||
|
code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type);
|
||||||
|
std::string wrapper_name = tensor_info.wrapper_type;
|
||||||
|
wrapper_name[0] = tolower(wrapper_name[0]);
|
||||||
|
code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name);
|
||||||
|
code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type);
|
||||||
|
code_writer->SetTokenValue("NORMALIZATION_UNIT",
|
||||||
|
std::to_string(tensor_info.normalization_unit));
|
||||||
|
code_writer->SetTokenValue(
|
||||||
|
"ASSOCIATED_AXIS_LABEL_INDEX",
|
||||||
|
std::to_string(tensor_info.associated_axis_label_index));
|
||||||
|
code_writer->SetTokenValue(
|
||||||
|
"ASSOCIATED_VALUE_LABEL_INDEX",
|
||||||
|
std::to_string(tensor_info.associated_value_label_index));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetCodeWriterWithModelInfo(CodeWriter* code_writer,
|
||||||
|
const ModelInfo& model_info) {
|
||||||
|
code_writer->SetTokenValue("PACKAGE", model_info.package_name);
|
||||||
|
code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path);
|
||||||
|
code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr char JAVA_DEFAULT_PACKAGE[] = "default";
|
||||||
|
|
||||||
|
std::string ConvertPackageToPath(const std::string& package) {
|
||||||
|
if (package == JAVA_DEFAULT_PACKAGE) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
std::string path = package;
|
||||||
|
std::replace(path.begin(), path.end(), '.', '/');
|
||||||
|
return path;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsImageUsed(const ModelInfo& model) {
|
||||||
|
for (const auto& input : model.inputs) {
|
||||||
|
if (input.content_type == "image") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& output : model.outputs) {
|
||||||
|
if (output.content_type == "image") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
code_writer->Append("// Generated by TFLite Support.");
|
||||||
|
code_writer->Append("package {{PACKAGE}};");
|
||||||
|
code_writer->NewLine();
|
||||||
|
|
||||||
|
if (!GenerateWrapperImports(code_writer, model, err)) {
|
||||||
|
err->Error("Fail to generate imports for wrapper class.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!GenerateWrapperClass(code_writer, model, err)) {
|
||||||
|
err->Error("Fail to generate wrapper class.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
const std::string support_pkg = "org.tensorflow.lite.support.";
|
||||||
|
std::vector<std::string> imports{
|
||||||
|
"android.content.Context",
|
||||||
|
"java.io.IOException",
|
||||||
|
"java.nio.ByteBuffer",
|
||||||
|
"java.nio.FloatBuffer",
|
||||||
|
"java.util.Arrays",
|
||||||
|
"java.util.HashMap",
|
||||||
|
"java.util.List",
|
||||||
|
"java.util.Map",
|
||||||
|
"org.checkerframework.checker.nullness.qual.Nullable",
|
||||||
|
"org.tensorflow.lite.DataType",
|
||||||
|
"org.tensorflow.lite.Tensor.QuantizationParams",
|
||||||
|
support_pkg + "common.FileUtil",
|
||||||
|
support_pkg + "common.TensorProcessor",
|
||||||
|
support_pkg + "common.ops.CastOp",
|
||||||
|
support_pkg + "common.ops.DequantizeOp",
|
||||||
|
support_pkg + "common.ops.NormalizeOp",
|
||||||
|
support_pkg + "common.ops.QuantizeOp",
|
||||||
|
support_pkg + "label.TensorLabel",
|
||||||
|
support_pkg + "metadata.MetadataExtractor",
|
||||||
|
support_pkg + "metadata.schema.NormalizationOptions",
|
||||||
|
support_pkg + "model.Model",
|
||||||
|
support_pkg + "model.Model.Device",
|
||||||
|
support_pkg + "tensorbuffer.TensorBuffer",
|
||||||
|
};
|
||||||
|
if (IsImageUsed(model)) {
|
||||||
|
for (const auto& target :
|
||||||
|
{"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp",
|
||||||
|
"image.ops.ResizeOp.ResizeMethod"}) {
|
||||||
|
imports.push_back(support_pkg + target);
|
||||||
|
}
|
||||||
|
imports.push_back("android.graphics.Bitmap");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(imports.begin(), imports.end());
|
||||||
|
for (const auto target : imports) {
|
||||||
|
code_writer->SetTokenValue("TARGET", target);
|
||||||
|
code_writer->Append("import {{TARGET}};");
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
|
||||||
|
model.model_versioned_name);
|
||||||
|
code_writer->Append(
|
||||||
|
R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)");
|
||||||
|
const auto code_block =
|
||||||
|
AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}");
|
||||||
|
code_writer->Append(R"(private final Metadata metadata;
|
||||||
|
private final Model model;
|
||||||
|
private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
|
||||||
|
for (const auto& tensor : model.outputs) {
|
||||||
|
if (tensor.associated_axis_label_index >= 0) {
|
||||||
|
code_writer->SetTokenValue("NAME", tensor.name);
|
||||||
|
code_writer->Append("private final List<String> {{NAME}}Labels;");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& tensor : model.inputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append(
|
||||||
|
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
|
||||||
|
}
|
||||||
|
for (const auto& tensor : model.outputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append(
|
||||||
|
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
if (!GenerateWrapperInputs(code_writer, model, err)) {
|
||||||
|
err->Error("Failed to generate input classes");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
if (!GenerateWrapperOutputs(code_writer, model, err)) {
|
||||||
|
err->Error("Failed to generate output classes");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
if (!GenerateWrapperMetadata(code_writer, model, err)) {
|
||||||
|
err->Error("Failed to generate the metadata class");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
if (!GenerateWrapperAPI(code_writer, model, err)) {
|
||||||
|
err->Error("Failed to generate the common APIs");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateWrapperInputs(CodeWriter* code_writer, const ModelInfo& model,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
code_writer->Append("/** Input wrapper of {@link {{MODEL_CLASS_NAME}}} */");
|
||||||
|
auto class_block = AsBlock(code_writer, "public class Inputs");
|
||||||
|
for (const auto& tensor : model.inputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append("private {{WRAPPER_TYPE}} {{NAME}};");
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
// Ctor
|
||||||
|
{
|
||||||
|
auto ctor_block = AsBlock(code_writer, "public Inputs()");
|
||||||
|
code_writer->Append(
|
||||||
|
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
|
||||||
|
for (const auto& tensor : model.inputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
if (tensor.content_type == "image") {
|
||||||
|
code_writer->Append(
|
||||||
|
"{{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());");
|
||||||
|
} else {
|
||||||
|
code_writer->Append(
|
||||||
|
"{{NAME}} = "
|
||||||
|
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
|
||||||
|
"metadata.get{{NAME_U}}Type());");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& tensor : model.inputs) {
|
||||||
|
code_writer->NewLine();
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
// Loaders
|
||||||
|
if (tensor.content_type == "image") {
|
||||||
|
{
|
||||||
|
auto bitmap_loader_block =
|
||||||
|
AsBlock(code_writer, "public void load{{NAME_U}}(Bitmap bitmap)");
|
||||||
|
code_writer->Append(R"({{NAME}}.load(bitmap);
|
||||||
|
{{NAME}} = preprocess{{NAME_U}}({{NAME}});)");
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
{
|
||||||
|
auto tensor_image_loader_block = AsBlock(
|
||||||
|
code_writer, "public void load{{NAME_U}}(TensorImage tensorImage)");
|
||||||
|
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorImage);");
|
||||||
|
}
|
||||||
|
} else { // content_type == "FEATURE" or "UNKNOWN"
|
||||||
|
auto tensorbuffer_loader_block = AsBlock(
|
||||||
|
code_writer, "public void load{{NAME_U}}(TensorBuffer tensorBuffer)");
|
||||||
|
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorBuffer);");
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
// Processor
|
||||||
|
code_writer->Append(
|
||||||
|
R"(private {{WRAPPER_TYPE}} preprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}}) {
|
||||||
|
if ({{NAME}}Preprocessor == null) {
|
||||||
|
return {{WRAPPER_NAME}};
|
||||||
|
}
|
||||||
|
return {{NAME}}Preprocessor.process({{WRAPPER_NAME}});
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const auto get_buffer_block = AsBlock(code_writer, "Object[] getBuffer()");
|
||||||
|
code_writer->AppendNoNewLine("return new Object[] {");
|
||||||
|
for (const auto& tensor : model.inputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->AppendNoNewLine("{{NAME}}.getBuffer(), ");
|
||||||
|
}
|
||||||
|
code_writer->Backspace(2);
|
||||||
|
code_writer->Append("};");
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
|
||||||
|
auto class_block = AsBlock(code_writer, "public class Outputs");
|
||||||
|
for (const auto& tensor : model.outputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
{
|
||||||
|
const auto ctor_block = AsBlock(code_writer, "public Outputs()");
|
||||||
|
code_writer->Append(
|
||||||
|
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
|
||||||
|
for (const auto& tensor : model.outputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
if (tensor.content_type == "image") {
|
||||||
|
code_writer->Append(
|
||||||
|
R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());
|
||||||
|
{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)");
|
||||||
|
} else { // FEATURE, UNKNOWN
|
||||||
|
code_writer->Append(
|
||||||
|
"{{NAME}} = "
|
||||||
|
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
|
||||||
|
"metadata.get{{NAME_U}}Type());");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& tensor : model.outputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->NewLine();
|
||||||
|
if (tensor.associated_axis_label_index >= 0) {
|
||||||
|
if (tensor.content_type == "image") {
|
||||||
|
err->Warning(
|
||||||
|
"Axis label for images is not supported. The labels will "
|
||||||
|
"be ignored.");
|
||||||
|
} else {
|
||||||
|
code_writer->Append(R"(public Map<String, Float> get{{NAME_U}}() {
|
||||||
|
return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getMapWithFloatValue();
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
code_writer->Append(R"(public {{WRAPPER_TYPE}} get{{NAME_U}}() {
|
||||||
|
return postprocess{{NAME_U}}({{NAME}});
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
{
|
||||||
|
auto processor_block =
|
||||||
|
AsBlock(code_writer,
|
||||||
|
"private {{WRAPPER_TYPE}} "
|
||||||
|
"postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
|
||||||
|
code_writer->Append(R"(if ({{NAME}}Postprocessor == null) {
|
||||||
|
return {{WRAPPER_NAME}};
|
||||||
|
}
|
||||||
|
return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
{
|
||||||
|
const auto get_buffer_block =
|
||||||
|
AsBlock(code_writer, "Map<Integer, Object> getBuffer()");
|
||||||
|
code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();");
|
||||||
|
for (int i = 0; i < model.outputs.size(); i++) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
|
||||||
|
code_writer->SetTokenValue("ID", std::to_string(i));
|
||||||
|
code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());");
|
||||||
|
}
|
||||||
|
code_writer->Append("return outputs;");
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
code_writer->Append(
|
||||||
|
"/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
|
||||||
|
const auto class_block = AsBlock(code_writer, "public static class Metadata");
|
||||||
|
for (const auto& tensor : model.inputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append(R"(private final int[] {{NAME}}Shape;
|
||||||
|
private final DataType {{NAME}}DataType;
|
||||||
|
private final QuantizationParams {{NAME}}QuantizationParams;)");
|
||||||
|
if (tensor.normalization_unit >= 0) {
|
||||||
|
code_writer->Append(R"(private final float[] {{NAME}}Mean;
|
||||||
|
private final float[] {{NAME}}Stddev;)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& tensor : model.outputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append(R"(private final int[] {{NAME}}Shape;
|
||||||
|
private final DataType {{NAME}}DataType;
|
||||||
|
private final QuantizationParams {{NAME}}QuantizationParams;)");
|
||||||
|
if (tensor.normalization_unit >= 0) {
|
||||||
|
code_writer->Append(R"(private final float[] {{NAME}}Mean;
|
||||||
|
private final float[] {{NAME}}Stddev;)");
|
||||||
|
}
|
||||||
|
if (tensor.associated_axis_label_index >= 0 ||
|
||||||
|
tensor.associated_value_label_index >= 0) {
|
||||||
|
code_writer->Append("private final List<String> {{NAME}}Labels;");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
code_writer->NewLine();
|
||||||
|
{
|
||||||
|
const auto ctor_block = AsBlock(
|
||||||
|
code_writer,
|
||||||
|
"public Metadata(ByteBuffer buffer, Model model) throws IOException");
|
||||||
|
code_writer->Append(
|
||||||
|
"MetadataExtractor extractor = new MetadataExtractor(buffer);");
|
||||||
|
for (int i = 0; i < model.inputs.size(); i++) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
|
||||||
|
code_writer->SetTokenValue("ID", std::to_string(i));
|
||||||
|
code_writer->Append(
|
||||||
|
R"({{NAME}}Shape = extractor.getInputTensorShape({{ID}});
|
||||||
|
{{NAME}}DataType = extractor.getInputTensorType({{ID}});
|
||||||
|
{{NAME}}QuantizationParams = extractor.getInputTensorQuantizationParams({{ID}});)");
|
||||||
|
if (model.inputs[i].normalization_unit >= 0) {
|
||||||
|
code_writer->Append(
|
||||||
|
R"(NormalizationOptions {{NAME}}NormalizationOptions =
|
||||||
|
(NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
|
||||||
|
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
|
||||||
|
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
|
||||||
|
{{NAME}}MeanBuffer.get({{NAME}}Mean);
|
||||||
|
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
|
||||||
|
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
|
||||||
|
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < model.outputs.size(); i++) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
|
||||||
|
code_writer->SetTokenValue("ID", std::to_string(i));
|
||||||
|
code_writer->Append(
|
||||||
|
R"({{NAME}}Shape = model.getOutputTensorShape({{ID}});
|
||||||
|
{{NAME}}DataType = extractor.getOutputTensorType({{ID}});
|
||||||
|
{{NAME}}QuantizationParams = extractor.getOutputTensorQuantizationParams({{ID}});)");
|
||||||
|
if (model.outputs[i].normalization_unit >= 0) {
|
||||||
|
code_writer->Append(
|
||||||
|
R"(NormalizationOptions {{NAME}}NormalizationOptions =
|
||||||
|
(NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
|
||||||
|
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
|
||||||
|
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
|
||||||
|
{{NAME}}MeanBuffer.get({{NAME}}Mean);
|
||||||
|
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
|
||||||
|
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
|
||||||
|
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
|
||||||
|
}
|
||||||
|
if (model.outputs[i].associated_axis_label_index >= 0) {
|
||||||
|
code_writer->Append(R"(String {{NAME}}LabelsFileName =
|
||||||
|
extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name();
|
||||||
|
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
|
||||||
|
} else if (model.outputs[i].associated_value_label_index >= 0) {
|
||||||
|
code_writer->Append(R"(String {{NAME}}LabelsFileName =
|
||||||
|
extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name();
|
||||||
|
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& tensor : model.inputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append(R"(
|
||||||
|
public int[] get{{NAME_U}}Shape() {
|
||||||
|
return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DataType get{{NAME_U}}Type() {
|
||||||
|
return {{NAME}}DataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public QuantizationParams get{{NAME_U}}QuantizationParams() {
|
||||||
|
return {{NAME}}QuantizationParams;
|
||||||
|
})");
|
||||||
|
if (tensor.normalization_unit >= 0) {
|
||||||
|
code_writer->Append(R"(
|
||||||
|
public float[] get{{NAME_U}}Mean() {
|
||||||
|
return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] get{{NAME_U}}Stddev() {
|
||||||
|
return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& tensor : model.outputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append(R"(
|
||||||
|
public int[] get{{NAME_U}}Shape() {
|
||||||
|
return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DataType get{{NAME_U}}Type() {
|
||||||
|
return {{NAME}}DataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public QuantizationParams get{{NAME_U}}QuantizationParams() {
|
||||||
|
return {{NAME}}QuantizationParams;
|
||||||
|
})");
|
||||||
|
if (tensor.normalization_unit >= 0) {
|
||||||
|
code_writer->Append(R"(
|
||||||
|
public float[] get{{NAME_U}}Mean() {
|
||||||
|
return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] get{{NAME_U}}Stddev() {
|
||||||
|
return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
if (tensor.associated_axis_label_index >= 0 ||
|
||||||
|
tensor.associated_value_label_index >= 0) {
|
||||||
|
code_writer->Append(R"(
|
||||||
|
public List<String> get{{NAME_U}}Labels() {
|
||||||
|
return {{NAME}}Labels;
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
code_writer->Append(R"(public Metadata getMetadata() {
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
code_writer->Append(R"(/**
|
||||||
|
* Creates interpreter and loads associated files if needed.
|
||||||
|
*
|
||||||
|
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||||
|
*/
|
||||||
|
public {{MODEL_CLASS_NAME}}(Context context) throws IOException {
|
||||||
|
this(context, MODEL_NAME, Device.CPU, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates interpreter and loads associated files if needed, but loading another model in the same
|
||||||
|
* input / output structure with the original one.
|
||||||
|
*
|
||||||
|
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||||
|
*/
|
||||||
|
public {{MODEL_CLASS_NAME}}(Context context, String modelPath) throws IOException {
|
||||||
|
this(context, modelPath, Device.CPU, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates interpreter and loads associated files if needed, with device and number of threads
|
||||||
|
* configured.
|
||||||
|
*
|
||||||
|
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||||
|
*/
|
||||||
|
public {{MODEL_CLASS_NAME}}(Context context, Device device, int numThreads) throws IOException {
|
||||||
|
this(context, MODEL_NAME, device, numThreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates interpreter for a user-specified model.
|
||||||
|
*
|
||||||
|
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||||
|
*/
|
||||||
|
public {{MODEL_CLASS_NAME}}(Context context, String modelPath, Device device, int numThreads) throws IOException {
|
||||||
|
model = new Model.Builder(context, modelPath).setDevice(device).setNumThreads(numThreads).build();
|
||||||
|
metadata = new Metadata(model.getData(), model);)");
|
||||||
|
for (const auto& tensor : model.inputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append(R"(
|
||||||
|
{{PROCESSOR_TYPE}}.Builder {{NAME}}PreprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder())");
|
||||||
|
if (tensor.content_type == "image") {
|
||||||
|
code_writer->Append(R"( .add(new ResizeOp(
|
||||||
|
metadata.get{{NAME_U}}Shape()[1],
|
||||||
|
metadata.get{{NAME_U}}Shape()[2],
|
||||||
|
ResizeMethod.NEAREST_NEIGHBOR)))");
|
||||||
|
}
|
||||||
|
if (tensor.normalization_unit >= 0) {
|
||||||
|
code_writer->Append(
|
||||||
|
R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
|
||||||
|
}
|
||||||
|
code_writer->Append(
|
||||||
|
R"( .add(new QuantizeOp(
|
||||||
|
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
|
||||||
|
metadata.get{{NAME_U}}QuantizationParams().getScale()))
|
||||||
|
.add(new CastOp(metadata.get{{NAME_U}}Type()));
|
||||||
|
{{NAME}}Preprocessor = {{NAME}}PreprocessorBuilder.build();)");
|
||||||
|
}
|
||||||
|
for (const auto& tensor : model.outputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->AppendNoNewLine(R"(
|
||||||
|
{{PROCESSOR_TYPE}}.Builder {{NAME}}PostprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder()
|
||||||
|
.add(new DequantizeOp(
|
||||||
|
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
|
||||||
|
metadata.get{{NAME_U}}QuantizationParams().getScale())))");
|
||||||
|
if (tensor.normalization_unit >= 0) {
|
||||||
|
code_writer->AppendNoNewLine(R"(
|
||||||
|
.add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
|
||||||
|
}
|
||||||
|
code_writer->Append(R"(;
|
||||||
|
{{NAME}}Postprocessor = {{NAME}}PostprocessorBuilder.build();)");
|
||||||
|
if (tensor.associated_axis_label_index >= 0) {
|
||||||
|
code_writer->Append(R"(
|
||||||
|
{{NAME}}Labels = metadata.get{{NAME_U}}Labels();)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
code_writer->Append("}");
|
||||||
|
for (const auto& tensor : model.inputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append(R"(
|
||||||
|
public void reset{{NAME_U}}Preprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
|
||||||
|
{{NAME}}Preprocessor = processor;
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
for (const auto& tensor : model.outputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||||
|
code_writer->Append(R"(
|
||||||
|
public void reset{{NAME_U}}Postprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
|
||||||
|
{{NAME}}Postprocessor = processor;
|
||||||
|
})");
|
||||||
|
}
|
||||||
|
code_writer->Append(R"(
|
||||||
|
/** Creates inputs */
|
||||||
|
public Inputs createInputs() {
|
||||||
|
return new Inputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Triggers the model. */
|
||||||
|
public Outputs run(Inputs inputs) {
|
||||||
|
Outputs outputs = new Outputs();
|
||||||
|
model.run(inputs.getBuffer(), outputs.getBuffer());
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Closes the model. */
|
||||||
|
public void close() {
|
||||||
|
model.close();
|
||||||
|
})");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateBuildGradleContent(CodeWriter* code_writer,
|
||||||
|
const ModelInfo& model_info) {
|
||||||
|
code_writer->Append(R"(buildscript {
|
||||||
|
repositories {
|
||||||
|
google()
|
||||||
|
jcenter()
|
||||||
|
}
|
||||||
|
dependencies {
|
||||||
|
classpath 'com.android.tools.build:gradle:3.2.1'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allprojects {
|
||||||
|
repositories {
|
||||||
|
google()
|
||||||
|
jcenter()
|
||||||
|
flatDir {
|
||||||
|
dirs 'libs'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
apply plugin: 'com.android.library'
|
||||||
|
|
||||||
|
android {
|
||||||
|
compileSdkVersion 29
|
||||||
|
defaultConfig {
|
||||||
|
targetSdkVersion 29
|
||||||
|
versionCode 1
|
||||||
|
versionName "1.0"
|
||||||
|
}
|
||||||
|
aaptOptions {
|
||||||
|
noCompress "tflite"
|
||||||
|
}
|
||||||
|
compileOptions {
|
||||||
|
sourceCompatibility = '1.8'
|
||||||
|
targetCompatibility = '1.8'
|
||||||
|
}
|
||||||
|
lintOptions {
|
||||||
|
abortOnError false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
configurations {
|
||||||
|
libMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic'
|
||||||
|
}
|
||||||
|
|
||||||
|
task downloadLibs(type: Sync) {
|
||||||
|
from configurations.libMetadata
|
||||||
|
into "$buildDir/libs"
|
||||||
|
rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar"
|
||||||
|
}
|
||||||
|
|
||||||
|
preBuild.dependsOn downloadLibs
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
compileOnly 'org.checkerframework:checker-qual:2.5.8'
|
||||||
|
api 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
|
||||||
|
api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
|
||||||
|
api files("$buildDir/libs/tensorflow-lite-support-metadata.jar")
|
||||||
|
implementation 'org.apache.commons:commons-compress:1.19'
|
||||||
|
})");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateAndroidManifestContent(CodeWriter* code_writer,
|
||||||
|
const ModelInfo& model_info) {
|
||||||
|
code_writer->Append(R"(<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="{{PACKAGE}}">
|
||||||
|
</manifest>)");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) {
|
||||||
|
code_writer->Append("# {{MODEL_CLASS_NAME}} Usage");
|
||||||
|
code_writer->AppendNoNewLine(R"(
|
||||||
|
```
|
||||||
|
import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
|
||||||
|
|
||||||
|
// 1. Initialize the Model
|
||||||
|
{{MODEL_CLASS_NAME}} model = null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
model = new {{MODEL_CLASS_NAME}}(context); // android.content.Context
|
||||||
|
// Create the input container.
|
||||||
|
{{MODEL_CLASS_NAME}}.Inputs inputs = model.createInputs();
|
||||||
|
} catch (IOException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model != null) {
|
||||||
|
|
||||||
|
// 2. Set the inputs)");
|
||||||
|
for (const auto& t : model_info.inputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, t);
|
||||||
|
if (t.content_type == "image") {
|
||||||
|
code_writer->Append(R"(
|
||||||
|
// Load input tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
|
||||||
|
Bitmap bitmap = ...;
|
||||||
|
inputs.load{{NAME_U}}(bitmap);
|
||||||
|
// Alternatively, load the input tensor "{{NAME}}" from a TensorImage.
|
||||||
|
// Check out TensorImage documentation to load other image data structures.
|
||||||
|
// TensorImage tensorImage = ...;
|
||||||
|
// inputs.load{{NAME_U}}(tensorImage);)");
|
||||||
|
} else {
|
||||||
|
code_writer->Append(R"(
|
||||||
|
// Load input tensor "{{NAME}}" from a TensorBuffer.
|
||||||
|
// Check out TensorBuffer documentation to load other data structures.
|
||||||
|
TensorBuffer tensorBuffer = ...;
|
||||||
|
inputs.load{{NAME_U}}(tensorBuffer);)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
code_writer->Append(R"(
|
||||||
|
// 3. Run the model
|
||||||
|
{{MODEL_CLASS_NAME}}.Outputs outputs = model.run(inputs);)");
|
||||||
|
code_writer->Append(R"(
|
||||||
|
// 4. Retrieve the results)");
|
||||||
|
for (const auto& t : model_info.outputs) {
|
||||||
|
SetCodeWriterWithTensorInfo(code_writer, t);
|
||||||
|
if (t.associated_axis_label_index >= 0) {
|
||||||
|
code_writer->SetTokenValue("WRAPPER_TYPE", "Map<String, Float>");
|
||||||
|
}
|
||||||
|
code_writer->Append(
|
||||||
|
R"( {{WRAPPER_TYPE}} {{NAME}} = outputs.get{{NAME_U}}();)");
|
||||||
|
}
|
||||||
|
code_writer->Append(R"(}
|
||||||
|
```)");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
GenerationResult::File GenerateWrapperFile(const std::string& module_root,
|
||||||
|
const ModelInfo& model_info,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
const auto java_path = JoinPath(module_root, "src/main/java");
|
||||||
|
const auto package_path =
|
||||||
|
JoinPath(java_path, ConvertPackageToPath(model_info.package_name));
|
||||||
|
const auto file_path =
|
||||||
|
JoinPath(package_path, model_info.model_class_name + JAVA_EXT);
|
||||||
|
|
||||||
|
CodeWriter code_writer(err);
|
||||||
|
code_writer.SetIndentString(" ");
|
||||||
|
SetCodeWriterWithModelInfo(&code_writer, model_info);
|
||||||
|
|
||||||
|
if (!GenerateWrapperFileContent(&code_writer, model_info, err)) {
|
||||||
|
err->Error("Generating Java wrapper content failed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto java_file = code_writer.ToString();
|
||||||
|
return GenerationResult::File{file_path, java_file};
|
||||||
|
}
|
||||||
|
|
||||||
|
GenerationResult::File GenerateBuildGradle(const std::string& module_root,
|
||||||
|
const ModelInfo& model_info,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
const auto file_path = JoinPath(module_root, "build.gradle");
|
||||||
|
CodeWriter code_writer(err);
|
||||||
|
SetCodeWriterWithModelInfo(&code_writer, model_info);
|
||||||
|
if (!GenerateBuildGradleContent(&code_writer, model_info)) {
|
||||||
|
err->Error("Generating build.gradle failed.");
|
||||||
|
}
|
||||||
|
const auto content = code_writer.ToString();
|
||||||
|
return GenerationResult::File{file_path, content};
|
||||||
|
}
|
||||||
|
|
||||||
|
GenerationResult::File GenerateAndroidManifest(const std::string& module_root,
|
||||||
|
const ModelInfo& model_info,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml");
|
||||||
|
CodeWriter code_writer(err);
|
||||||
|
SetCodeWriterWithModelInfo(&code_writer, model_info);
|
||||||
|
if (!GenerateAndroidManifestContent(&code_writer, model_info)) {
|
||||||
|
err->Error("Generating AndroidManifest.xml failed.");
|
||||||
|
}
|
||||||
|
return GenerationResult::File{file_path, code_writer.ToString()};
|
||||||
|
}
|
||||||
|
|
||||||
|
GenerationResult::File GenerateDoc(const std::string& module_root,
|
||||||
|
const ModelInfo& model_info,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
std::string lower = model_info.model_class_name;
|
||||||
|
for (int i = 0; i < lower.length(); i++) {
|
||||||
|
lower[i] = std::tolower(lower[i]);
|
||||||
|
}
|
||||||
|
const auto file_path = JoinPath(module_root, lower + ".md");
|
||||||
|
CodeWriter code_writer(err);
|
||||||
|
SetCodeWriterWithModelInfo(&code_writer, model_info);
|
||||||
|
if (!GenerateDocContent(&code_writer, model_info)) {
|
||||||
|
err->Error("Generating doc failed.");
|
||||||
|
}
|
||||||
|
return GenerationResult::File{file_path, code_writer.ToString()};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
|
||||||
|
: CodeGenerator(), module_root_(module_root) {}
|
||||||
|
|
||||||
|
GenerationResult AndroidJavaGenerator::Generate(
|
||||||
|
const Model* model, const std::string& package_name,
|
||||||
|
const std::string& model_class_name, const std::string& model_asset_path) {
|
||||||
|
GenerationResult result;
|
||||||
|
const ModelMetadata* metadata = GetMetadataFromModel(model);
|
||||||
|
if (metadata == nullptr) {
|
||||||
|
err_.Error(
|
||||||
|
"Cannot find TFLite Metadata in the model. Codegen will generate "
|
||||||
|
"nothing.");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
details_android_java::ModelInfo model_info = CreateModelInfo(
|
||||||
|
metadata, package_name, model_class_name, model_asset_path, &err_);
|
||||||
|
result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_));
|
||||||
|
result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_));
|
||||||
|
result.files.push_back(
|
||||||
|
GenerateAndroidManifest(module_root_, model_info, &err_));
|
||||||
|
result.files.push_back(GenerateDoc(module_root_, model_info, &err_));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
GenerationResult AndroidJavaGenerator::Generate(
|
||||||
|
const char* model_storage, const std::string& package_name,
|
||||||
|
const std::string& model_class_name, const std::string& model_asset_path) {
|
||||||
|
const Model* model = GetModel(model_storage);
|
||||||
|
return Generate(model, package_name, model_class_name, model_asset_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AndroidJavaGenerator::GetErrorMessage() {
|
||||||
|
return err_.GetMessage();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,107 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
|
||||||
|
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
namespace details_android_java {
|
||||||
|
|
||||||
|
/// The intermediate data structure for generating code from TensorMetadata.
|
||||||
|
/// Should only be used as const reference when created.
|
||||||
|
struct TensorInfo {
|
||||||
|
std::string name;
|
||||||
|
std::string upper_camel_name;
|
||||||
|
std::string content_type;
|
||||||
|
std::string wrapper_type;
|
||||||
|
std::string processor_type;
|
||||||
|
bool is_input;
|
||||||
|
/// Optional. Set to -1 if not applicable.
|
||||||
|
int normalization_unit;
|
||||||
|
/// Optional. Set to -1 if associated_axis_label is empty.
|
||||||
|
int associated_axis_label_index;
|
||||||
|
/// Optional. Set to -1 if associated_value_label is empty.
|
||||||
|
int associated_value_label_index;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// The intermediate data structure for generating code from ModelMetadata.
|
||||||
|
/// Should only be used as const reference when created.
|
||||||
|
struct ModelInfo {
|
||||||
|
std::string package_name;
|
||||||
|
std::string model_asset_path;
|
||||||
|
std::string model_class_name;
|
||||||
|
std::string model_versioned_name;
|
||||||
|
std::vector<TensorInfo> inputs;
|
||||||
|
std::vector<TensorInfo> outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace details_android_java
|
||||||
|
|
||||||
|
constexpr char JAVA_EXT[] = ".java";
|
||||||
|
|
||||||
|
/// Generates Android supporting codes and modules (in Java) based on TFLite
|
||||||
|
/// metadata.
|
||||||
|
class AndroidJavaGenerator : public CodeGenerator {
|
||||||
|
public:
|
||||||
|
/// Creates an AndroidJavaGenerator.
|
||||||
|
/// Args:
|
||||||
|
/// - module_root: The root of destination Java module.
|
||||||
|
explicit AndroidJavaGenerator(const std::string& module_root);
|
||||||
|
|
||||||
|
/// Generates files. Returns the file paths and contents.
|
||||||
|
/// Args:
|
||||||
|
/// - model: The TFLite model with Metadata filled.
|
||||||
|
/// - package_name: The name of the Java package which generated classes
|
||||||
|
/// belong to.
|
||||||
|
/// - model_class_name: A readable name of the generated wrapper class, such
|
||||||
|
/// as "ImageClassifier", "MobileNetV2" or "MyModel".
|
||||||
|
/// - model_asset_path: The relevant path to the model file in the asset.
|
||||||
|
// TODO(b/141225157): Automatically generate model_class_name.
|
||||||
|
GenerationResult Generate(const Model* model, const std::string& package_name,
|
||||||
|
const std::string& model_class_name,
|
||||||
|
const std::string& model_asset_path);
|
||||||
|
|
||||||
|
/// Generates files and returns the file paths and contents.
|
||||||
|
/// It's mostly identical with the previous one, but the model here is
|
||||||
|
/// provided as binary flatbuffer content without parsing.
|
||||||
|
GenerationResult Generate(const char* model_storage,
|
||||||
|
const std::string& package_name,
|
||||||
|
const std::string& model_class_name,
|
||||||
|
const std::string& model_asset_path);
|
||||||
|
|
||||||
|
std::string GetErrorMessage();
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string module_root_;
|
||||||
|
ErrorReporter err_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
|
179
tensorflow/lite/experimental/support/codegen/code_generator.cc
Normal file
179
tensorflow/lite/experimental/support/codegen/code_generator.cc
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||||
|
|
||||||
|
#include <cctype>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) {
|
||||||
|
auto& names = *names_ptr;
|
||||||
|
std::unordered_map<std::string, int> indexes;
|
||||||
|
std::unordered_map<std::string, int> first_appearance;
|
||||||
|
for (int i = 0; i < names.size(); i++) {
|
||||||
|
if (indexes.find(names[i]) == indexes.end()) {
|
||||||
|
indexes[names[i]] = 1;
|
||||||
|
first_appearance[names[i]] = i;
|
||||||
|
} else {
|
||||||
|
indexes[names[i]] += 1;
|
||||||
|
names[i].append(std::to_string(indexes[names[i]]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto it : first_appearance) {
|
||||||
|
const auto& name = it.first;
|
||||||
|
const auto i = it.second;
|
||||||
|
if (indexes[name] > 1) {
|
||||||
|
names[i].append("1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
CodeGenerator::CodeGenerator() {}
|
||||||
|
|
||||||
|
bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
if (metadata == nullptr) {
|
||||||
|
err->Error("Loading nullptr is not allowed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (metadata->subgraph_metadata()->size() != 1) {
|
||||||
|
err->Error("Only exact 1 subgraph is supported");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<std::string>, std::vector<std::string>>
|
||||||
|
CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs,
|
||||||
|
const TensorMetadataList* outputs) {
|
||||||
|
std::vector<std::string> input_names;
|
||||||
|
std::vector<std::string> output_names;
|
||||||
|
if (inputs != nullptr) {
|
||||||
|
input_names.reserve(inputs->size());
|
||||||
|
for (const auto* tensor : *inputs) {
|
||||||
|
input_names.push_back(NameTensor(*tensor, "input"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (outputs != nullptr) {
|
||||||
|
output_names.reserve(outputs->size());
|
||||||
|
for (const auto* tensor : *outputs) {
|
||||||
|
output_names.push_back(NameTensor(*tensor, "output"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Solve conflict
|
||||||
|
ResolveConflictedInputAndOutputNames(&input_names, &output_names);
|
||||||
|
return std::make_pair(input_names, output_names);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CodeGenerator::ConvertToValidName(const std::string& name) {
|
||||||
|
// lowercase all
|
||||||
|
std::string result = name;
|
||||||
|
for (int i = 0; i < result.size(); i++) {
|
||||||
|
result[i] = std::tolower(result[i]);
|
||||||
|
}
|
||||||
|
// replace all non-alpha or non-numeric with underscores, except underscore
|
||||||
|
// itself
|
||||||
|
for (int i = 0; i < result.size(); i++) {
|
||||||
|
if (result[i] != '_' && !std::isalnum(result[i])) {
|
||||||
|
result[i] = '_';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// remove leading underscores
|
||||||
|
int leading_underscores = 0;
|
||||||
|
while (leading_underscores < result.size() &&
|
||||||
|
result[leading_underscores] == '_') {
|
||||||
|
leading_underscores++;
|
||||||
|
}
|
||||||
|
result.erase(0, leading_underscores);
|
||||||
|
if (result.empty()) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
// first char should be alpha
|
||||||
|
if (std::isalpha(result[0])) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
return "tensor_" + result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
|
||||||
|
const std::string& default_name) {
|
||||||
|
if (tensor.name() != nullptr && tensor.name()->size() > 0) {
|
||||||
|
// TODO(b/141225157) Validate tensor name. It should be in lower case.
|
||||||
|
auto suggested_name = ConvertToValidName(tensor.name()->str());
|
||||||
|
if (!suggested_name.empty()) {
|
||||||
|
return suggested_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto* content = tensor.content();
|
||||||
|
if (content == nullptr || content->content_properties() == nullptr) {
|
||||||
|
return default_name;
|
||||||
|
}
|
||||||
|
switch (content->content_properties_type()) {
|
||||||
|
case ContentProperties_ImageProperties:
|
||||||
|
return "image";
|
||||||
|
case ContentProperties_FeatureProperties:
|
||||||
|
return "feature";
|
||||||
|
default:
|
||||||
|
return default_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeGenerator::ResolveConflictedInputAndOutputNames(
|
||||||
|
std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
|
||||||
|
std::unordered_set<std::string> io_conflict;
|
||||||
|
auto& input_names = *inputs;
|
||||||
|
auto& output_names = *outputs;
|
||||||
|
for (const auto input : input_names) {
|
||||||
|
if (io_conflict.find(input) != io_conflict.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (const auto output : output_names) {
|
||||||
|
if (input == output) {
|
||||||
|
io_conflict.insert(input);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < input_names.size(); i++) {
|
||||||
|
if (io_conflict.find(input_names[i]) != io_conflict.end()) {
|
||||||
|
input_names[i] = "input_" + input_names[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < output_names.size(); i++) {
|
||||||
|
if (io_conflict.find(output_names[i]) != io_conflict.end()) {
|
||||||
|
output_names[i] = "output_" + output_names[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2. Second, add index if input[i] == input[j]
|
||||||
|
ResolveConflictedNamesByAddingIndex(&input_names);
|
||||||
|
ResolveConflictedNamesByAddingIndex(&output_names);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,80 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
|
||||||
|
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
struct GenerationResult {
|
||||||
|
struct File {
|
||||||
|
std::string path;
|
||||||
|
std::string content;
|
||||||
|
};
|
||||||
|
std::vector<File> files;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Defines language-independent codegen strategies, like class naming, .etc.
|
||||||
|
/// Should not be used directly.
|
||||||
|
class CodeGenerator {
|
||||||
|
public:
|
||||||
|
CodeGenerator();
|
||||||
|
|
||||||
|
using TensorMetadataList =
|
||||||
|
typename flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>;
|
||||||
|
|
||||||
|
virtual ~CodeGenerator() {}
|
||||||
|
|
||||||
|
// Strategies.
|
||||||
|
/// Names all the IO tensors. It's useful when they don't have names, or the
|
||||||
|
/// names have conflicts. We have to name every tensor for code generation.
|
||||||
|
// TODO(b/141225157): Add reserved keywords check.
|
||||||
|
static std::pair<std::vector<std::string>, std::vector<std::string>>
|
||||||
|
NameInputsAndOutputs(const TensorMetadataList* inputs,
|
||||||
|
const TensorMetadataList* outputs);
|
||||||
|
|
||||||
|
/// Loads a metadata for code generation.
|
||||||
|
/// Returns false if the metadata is not good for generation.
|
||||||
|
static bool VerifyMetadata(const ModelMetadata* metadata, ErrorReporter* err);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// Converts a name into a valid form. Rules:
|
||||||
|
/// - lower all letters.
|
||||||
|
/// - replace all non alphabet nor numeric characters with underscores.
|
||||||
|
/// - remove prefix underscores.
|
||||||
|
/// - add prefix if the leading character is a number.
|
||||||
|
/// Returns empty string if not possible.
|
||||||
|
static std::string ConvertToValidName(const std::string& name);
|
||||||
|
static std::string NameTensor(const TensorMetadata& tensor,
|
||||||
|
const std::string& default_name);
|
||||||
|
static void ResolveConflictedInputAndOutputNames(
|
||||||
|
std::vector<std::string>* input, std::vector<std::string>* output);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
|
@ -0,0 +1,126 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
class CodeGeneratorTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
class TestingCodeGenerator : public CodeGenerator {
|
||||||
|
public:
|
||||||
|
explicit TestingCodeGenerator() : CodeGenerator() {}
|
||||||
|
|
||||||
|
// Make tested method public.
|
||||||
|
static std::string ConvertToValidName(const std::string& name) {
|
||||||
|
return CodeGenerator::ConvertToValidName(name);
|
||||||
|
}
|
||||||
|
static void ResolveConflictedInputAndOutputNames(
|
||||||
|
std::vector<std::string>* input, std::vector<std::string>* output) {
|
||||||
|
CodeGenerator::ResolveConflictedInputAndOutputNames(input, output);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, UpperCasesShouldLower) {
|
||||||
|
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("AlphaBetCOOL"),
|
||||||
|
"alphabetcool");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, NonAlphaNumShouldReplace) {
|
||||||
|
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("A+=B C\t"), "a__b_c_");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, NoLeadingUnderscore) {
|
||||||
|
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("+KAI Z"), "kai_z");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, NoLeadingNumbers) {
|
||||||
|
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("3000 Cool Tensors"),
|
||||||
|
"tensor_3000_cool_tensors");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, TestSimpleIONames) {
|
||||||
|
std::vector<std::string> inputs = {"image"};
|
||||||
|
std::vector<std::string> outputs = {"output"};
|
||||||
|
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||||
|
EXPECT_THAT(inputs, ElementsAreArray({"image"}));
|
||||||
|
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, TestIOConflict) {
|
||||||
|
std::vector<std::string> inputs = {"image"};
|
||||||
|
std::vector<std::string> outputs = {"image"};
|
||||||
|
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||||
|
EXPECT_THAT(inputs, ElementsAreArray({"input_image"}));
|
||||||
|
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, TestInternalConflict) {
|
||||||
|
std::vector<std::string> inputs = {"image", "image"};
|
||||||
|
std::vector<std::string> outputs = {"output"};
|
||||||
|
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||||
|
EXPECT_THAT(inputs, ElementsAreArray({"image1", "image2"}));
|
||||||
|
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, TestAllConflictNTo1) {
|
||||||
|
std::vector<std::string> inputs = {"image", "image"};
|
||||||
|
std::vector<std::string> outputs = {"image"};
|
||||||
|
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||||
|
EXPECT_THAT(inputs, ElementsAreArray({"input_image1", "input_image2"}));
|
||||||
|
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, TestAllConflict) {
|
||||||
|
std::vector<std::string> inputs = {"image", "audio", "image", "audio",
|
||||||
|
"audio"};
|
||||||
|
std::vector<std::string> outputs = {"image", "image", "audio", "feature",
|
||||||
|
"feature"};
|
||||||
|
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||||
|
EXPECT_THAT(inputs,
|
||||||
|
ElementsAreArray({"input_image1", "input_audio1", "input_image2",
|
||||||
|
"input_audio2", "input_audio3"}));
|
||||||
|
EXPECT_THAT(outputs,
|
||||||
|
ElementsAreArray({"output_image1", "output_image2",
|
||||||
|
"output_audio", "feature1", "feature2"}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CodeGeneratorTest, TestAllConflictReversed) {
|
||||||
|
std::vector<std::string> inputs = {"image", "image", "audio", "feature",
|
||||||
|
"feature"};
|
||||||
|
std::vector<std::string> outputs = {"image", "audio", "image", "audio",
|
||||||
|
"audio"};
|
||||||
|
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
|
||||||
|
EXPECT_THAT(inputs,
|
||||||
|
ElementsAreArray({"input_image1", "input_image2", "input_audio",
|
||||||
|
"feature1", "feature2"}));
|
||||||
|
EXPECT_THAT(outputs, ElementsAreArray({"output_image1", "output_audio1",
|
||||||
|
"output_image2", "output_audio2",
|
||||||
|
"output_audio3"}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,92 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
constexpr char BUFFER_KEY[] = "TFLITE_METADATA";
|
||||||
|
const ModelMetadata* GetMetadataFromModel(const Model* model) {
|
||||||
|
if (model->metadata() == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
for (auto i = 0; i < model->metadata()->size(); i++) {
|
||||||
|
if (model->metadata()->Get(i)->name()->str() == BUFFER_KEY) {
|
||||||
|
const auto buffer_index = model->metadata()->Get(i)->buffer();
|
||||||
|
const auto* buffer = model->buffers()->Get(buffer_index)->data()->data();
|
||||||
|
return GetModelMetadata(buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int FindAssociatedFile(const TensorMetadata* metadata,
|
||||||
|
const AssociatedFileType file_type,
|
||||||
|
const std::string& tensor_identifier,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
int result = -1;
|
||||||
|
if (metadata->associated_files() == nullptr ||
|
||||||
|
metadata->associated_files()->size() == 0) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < metadata->associated_files()->size(); i++) {
|
||||||
|
const auto* file_metadata = metadata->associated_files()->Get(i);
|
||||||
|
if (file_metadata->type() == file_type) {
|
||||||
|
if (result >= 0) {
|
||||||
|
err->Warning(
|
||||||
|
"Multiple associated file of type %d found on tensor %s. Only the "
|
||||||
|
"first one will be used.",
|
||||||
|
file_type, tensor_identifier.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
result = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
int FindNormalizationUnit(const TensorMetadata* metadata,
|
||||||
|
const std::string& tensor_identifier,
|
||||||
|
ErrorReporter* err) {
|
||||||
|
int result = -1;
|
||||||
|
if (metadata->process_units() == nullptr ||
|
||||||
|
metadata->process_units()->size() == 0) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < metadata->process_units()->size(); i++) {
|
||||||
|
const auto* process_uint = metadata->process_units()->Get(i);
|
||||||
|
if (process_uint->options_type() ==
|
||||||
|
ProcessUnitOptions_NormalizationOptions) {
|
||||||
|
if (result >= 0) {
|
||||||
|
err->Warning(
|
||||||
|
"Multiple normalization unit found in tensor %s. Only the first "
|
||||||
|
"one will be effective.",
|
||||||
|
tensor_identifier.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
result = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
|
||||||
|
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
/// Parses a ModelMetadata out from a Model. The returned ModelMetadata's
|
||||||
|
/// lifetime is scoped by the model.
|
||||||
|
/// Returns nullptr if we cannot find any metadata.
|
||||||
|
const ModelMetadata* GetMetadataFromModel(const Model* model);
|
||||||
|
|
||||||
|
/// Finds an associated file from a TensorMetadata of certain type. If there're
|
||||||
|
/// multiple files meet the criteria, only the first one is used. If there's no
|
||||||
|
/// file meets the criteria, -1 will be returned.
|
||||||
|
int FindAssociatedFile(const TensorMetadata* metadata,
|
||||||
|
const AssociatedFileType file_type,
|
||||||
|
const std::string& tensor_identifier,
|
||||||
|
ErrorReporter* err);
|
||||||
|
|
||||||
|
/// Find the first normalization unit. If none, return -1.
|
||||||
|
int FindNormalizationUnit(const TensorMetadata* metadata,
|
||||||
|
const std::string& tensor_identifier,
|
||||||
|
ErrorReporter* err);
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
|
38
tensorflow/lite/experimental/support/codegen/python/BUILD
Normal file
38
tensorflow/lite/experimental/support/codegen/python/BUILD
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
pybind_extension(
|
||||||
|
name = "_pywrap_codegen",
|
||||||
|
srcs = [
|
||||||
|
"codegen_lib.cc",
|
||||||
|
],
|
||||||
|
features = ["-use_header_modules"],
|
||||||
|
module_name = "_pywrap_codegen",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/experimental/support/codegen:android_java_generator",
|
||||||
|
"//tensorflow/lite/experimental/support/codegen:code_generator",
|
||||||
|
"//tensorflow/python:pybind11_lib",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "codegen",
|
||||||
|
srcs = [
|
||||||
|
"codegen.py",
|
||||||
|
],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":_pywrap_codegen",
|
||||||
|
"@absl_py//absl:app",
|
||||||
|
"@absl_py//absl/flags",
|
||||||
|
"@absl_py//absl/logging",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,96 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Generates Android Java sources from a TFLite model with metadata."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
|
||||||
|
from tensorflow.lite.experimental.support.codegen.python import _pywrap_codegen
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.')
|
||||||
|
flags.DEFINE_string('destination', None, 'Path of destination of generation.')
|
||||||
|
flags.DEFINE_string('package_name', 'org.tensorflow.lite.support',
|
||||||
|
'Name of generated java package to put the wrapper class.')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'model_class_name', 'MyModel',
|
||||||
|
'Name of generated wrapper class (should not contain package name).')
|
||||||
|
flags.DEFINE_string(
|
||||||
|
'model_asset_path', '',
|
||||||
|
'(Optional) Path to the model in generated assets/ dir. If not set, '
|
||||||
|
'generator will use base name of input model.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_buffer(path):
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
logging.error('Cannot find model at path %s.', path)
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
buf = f.read()
|
||||||
|
return buf
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_directory_for_file(file_path):
|
||||||
|
target_dir = os.path.dirname(file_path)
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
os.makedirs(target_dir)
|
||||||
|
return
|
||||||
|
if not os.path.isdir(target_dir):
|
||||||
|
logging.error('Cannot write to %s', target_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
if len(argv) > 1:
|
||||||
|
logging.error('None flag arguments found: [%s]', ', '.join(argv[1:]))
|
||||||
|
|
||||||
|
codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination)
|
||||||
|
model_buffer = get_model_buffer(FLAGS.model)
|
||||||
|
model_asset_path = FLAGS.model_asset_path
|
||||||
|
if not model_asset_path:
|
||||||
|
model_asset_path = os.path.basename(FLAGS.model)
|
||||||
|
result = codegen.generate(model_buffer, FLAGS.package_name,
|
||||||
|
FLAGS.model_class_name, model_asset_path)
|
||||||
|
error_message = codegen.get_error_message().strip()
|
||||||
|
if error_message:
|
||||||
|
logging.error(error_message)
|
||||||
|
if not result.files:
|
||||||
|
logging.error('Generation failed!')
|
||||||
|
return
|
||||||
|
|
||||||
|
for each in result.files:
|
||||||
|
prepare_directory_for_file(each.path)
|
||||||
|
with open(each.path, 'w') as f:
|
||||||
|
f.write(each.content)
|
||||||
|
|
||||||
|
logging.info('Generation succeeded!')
|
||||||
|
model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets',
|
||||||
|
model_asset_path)
|
||||||
|
prepare_directory_for_file(model_asset_path)
|
||||||
|
shutil.copy(FLAGS.model, model_asset_path)
|
||||||
|
logging.info('Model copied into assets!')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
flags.mark_flag_as_required('model')
|
||||||
|
flags.mark_flag_as_required('destination')
|
||||||
|
app.run(main)
|
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "include/pybind11/detail/common.h"
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "include/pybind11/pytypes.h"
|
||||||
|
#include "include/pybind11/stl.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_codegen, m) {
|
||||||
|
pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator")
|
||||||
|
.def(pybind11::init<const std::string &>())
|
||||||
|
.def("generate",
|
||||||
|
overload_cast_<const char *, const std::string &,
|
||||||
|
const std::string &, const std::string &>()(
|
||||||
|
&AndroidJavaGenerator::Generate))
|
||||||
|
.def("get_error_message", &AndroidJavaGenerator::GetErrorMessage);
|
||||||
|
pybind11::class_<GenerationResult>(m, "GenerationResult")
|
||||||
|
.def(pybind11::init<>())
|
||||||
|
.def_readwrite("files", &GenerationResult::files);
|
||||||
|
pybind11::class_<GenerationResult::File>(m, "GenerationResultFile")
|
||||||
|
.def(pybind11::init<>())
|
||||||
|
.def_readwrite("path", &GenerationResult::File::path)
|
||||||
|
.def_readwrite("content", &GenerationResult::File::content);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
194
tensorflow/lite/experimental/support/codegen/utils.cc
Normal file
194
tensorflow/lite/experimental/support/codegen/utils.cc
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||||
|
|
||||||
|
#include <cstdarg>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
int ErrorReporter::Warning(const char* format, ...) {
|
||||||
|
va_list args;
|
||||||
|
va_start(args, format);
|
||||||
|
return Report("[WARN] ", format, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
int ErrorReporter::Error(const char* format, ...) {
|
||||||
|
va_list args;
|
||||||
|
va_start(args, format);
|
||||||
|
return Report("[ERROR] ", format, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
int ErrorReporter::Report(const char* prefix, const char* format,
|
||||||
|
va_list args) {
|
||||||
|
char buf[1024];
|
||||||
|
int formatted = vsnprintf(buf, sizeof(buf), format, args);
|
||||||
|
buffer_ << prefix << buf << std::endl;
|
||||||
|
return formatted;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ErrorReporter::GetMessage() {
|
||||||
|
std::string value = buffer_.str();
|
||||||
|
buffer_.str("");
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
CodeWriter::CodeWriter(ErrorReporter* err) : indent_(0), err_(err) {}
|
||||||
|
|
||||||
|
void CodeWriter::SetTokenValue(const std::string& token,
|
||||||
|
const std::string& value) {
|
||||||
|
value_map_[token] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string CodeWriter::GetTokenValue(const std::string& token) const {
|
||||||
|
auto iter = value_map_.find(token);
|
||||||
|
if (iter == value_map_.end()) {
|
||||||
|
// Typically only Code Generator's call this function (or `Append`). It's
|
||||||
|
// their duty to make sure the token is valid, and requesting for an invalid
|
||||||
|
// token implicits flaws in the code generation logic.
|
||||||
|
err_->Error("Internal: Cannot find value with token '%s'", token.c_str());
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeWriter::SetIndentString(const std::string& indent_str) {
|
||||||
|
indent_str_ = indent_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeWriter::Indent() { indent_++; }
|
||||||
|
|
||||||
|
void CodeWriter::Outdent() { indent_--; }
|
||||||
|
|
||||||
|
std::string CodeWriter::GenerateIndent() const {
|
||||||
|
std::string res;
|
||||||
|
res.reserve(indent_str_.size() * indent_);
|
||||||
|
for (int i = 0; i < indent_; i++) {
|
||||||
|
res.append(indent_str_);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); }
|
||||||
|
|
||||||
|
void CodeWriter::AppendNoNewLine(const std::string& text) {
|
||||||
|
AppendInternal(text, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeWriter::AppendInternal(const std::string& text, bool newline) {
|
||||||
|
// Prefix indent
|
||||||
|
if ((buffer_.empty() // nothing in the buffer
|
||||||
|
|| buffer_.back() == '\n') // is on new line
|
||||||
|
// is writing on current line
|
||||||
|
&& (!text.empty() && text[0] != '\n' && text[0] != '\r')) {
|
||||||
|
buffer_.append(GenerateIndent());
|
||||||
|
}
|
||||||
|
// State machine variables
|
||||||
|
bool in_token = false;
|
||||||
|
int i = 0;
|
||||||
|
// Rough memory reserve
|
||||||
|
buffer_.reserve(buffer_.size() + text.size());
|
||||||
|
std::string token_buffer;
|
||||||
|
// A simple LL1 analysis
|
||||||
|
while (i < text.size()) {
|
||||||
|
char cur = text[i];
|
||||||
|
char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian
|
||||||
|
if (in_token == false) {
|
||||||
|
if (cur == '{' && cur_next == '{') { // Enter token
|
||||||
|
in_token = true;
|
||||||
|
i += 2;
|
||||||
|
} else if (cur == '\n') { // We need to apply global indent here
|
||||||
|
buffer_.push_back(cur);
|
||||||
|
if (cur_next != '\0' && cur_next != '\n' && cur_next != '\r') {
|
||||||
|
buffer_.append(GenerateIndent());
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
} else {
|
||||||
|
buffer_.push_back(cur);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (cur == '}' && cur_next == '}') { // Close token
|
||||||
|
in_token = false;
|
||||||
|
const auto value = GetTokenValue(token_buffer);
|
||||||
|
buffer_.append(value);
|
||||||
|
token_buffer.clear();
|
||||||
|
i += 2;
|
||||||
|
} else {
|
||||||
|
token_buffer.push_back(cur);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!token_buffer.empty()) {
|
||||||
|
// Typically only Code Generator's call this function. It's
|
||||||
|
// their duty to make sure the code (or template) has valid syntax, and
|
||||||
|
// unclosed "{{...}}" implicits severe error in the template.
|
||||||
|
err_->Error("Internal: Invalid template: {{token}} is not closed.");
|
||||||
|
}
|
||||||
|
if (newline) {
|
||||||
|
buffer_.push_back('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CodeWriter::NewLine() { Append(""); }
|
||||||
|
|
||||||
|
void CodeWriter::Backspace(int n) {
|
||||||
|
buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CodeWriter::ToString() const { return buffer_; }
|
||||||
|
|
||||||
|
bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); }
|
||||||
|
|
||||||
|
void CodeWriter::Clear() {
|
||||||
|
buffer_.clear();
|
||||||
|
value_map_.clear();
|
||||||
|
indent_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SnakeCaseToCamelCase(const std::string& s) {
|
||||||
|
std::string t;
|
||||||
|
t.reserve(s.length());
|
||||||
|
size_t i = 0;
|
||||||
|
// Note: Use simple string += for simplicity.
|
||||||
|
bool cap = false;
|
||||||
|
while (i < s.size()) {
|
||||||
|
const char c = s[i++];
|
||||||
|
if (c == '_') {
|
||||||
|
cap = true;
|
||||||
|
} else if (cap) {
|
||||||
|
t += toupper(c);
|
||||||
|
cap = false;
|
||||||
|
} else {
|
||||||
|
t += c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string JoinPath(const std::string& a, const std::string& b) {
|
||||||
|
if (a.empty()) return b;
|
||||||
|
std::string a_fixed = a;
|
||||||
|
if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back();
|
||||||
|
std::string b_fixed = b;
|
||||||
|
if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1);
|
||||||
|
return a_fixed + "/" + b_fixed;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
127
tensorflow/lite/experimental/support/codegen/utils.h
Normal file
127
tensorflow/lite/experimental/support/codegen/utils.h
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
|
||||||
|
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
|
||||||
|
/// Collects runtime error logs which could be showed later.
|
||||||
|
// TODO(b/150538286): Consider a better mechanism to simplify callsite code.
|
||||||
|
class ErrorReporter {
|
||||||
|
public:
|
||||||
|
int Warning(const char* format, ...);
|
||||||
|
int Error(const char* format, ...);
|
||||||
|
std::string GetMessage();
|
||||||
|
|
||||||
|
private:
|
||||||
|
int Report(const char* prefix, const char* format, va_list args);
|
||||||
|
std::stringstream buffer_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Implements basic code generating with text templates.
|
||||||
|
///
|
||||||
|
/// It could accept code templates and concatenate them into complete codes. A
|
||||||
|
/// template could contain named values.
|
||||||
|
///
|
||||||
|
/// Example code:
|
||||||
|
/// CodeWriter code;
|
||||||
|
/// code.SetValue("NAME", "Foo");
|
||||||
|
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
|
||||||
|
/// code.SetValue("NAME", "Bar");
|
||||||
|
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
|
||||||
|
///
|
||||||
|
/// Output:
|
||||||
|
/// void Foo() { printf("%s", "Foo"); }
|
||||||
|
/// void Bar() { printf("%s", "Bar"); }
|
||||||
|
class CodeWriter {
|
||||||
|
public:
|
||||||
|
explicit CodeWriter(ErrorReporter* err);
|
||||||
|
/// Sets value to a token. When generating code with template, a string in a
|
||||||
|
/// pair of {{ and }} will be regarded as a token and replaced with the
|
||||||
|
/// corresponding value in code generation.
|
||||||
|
/// It rewrites if the token already has a value.
|
||||||
|
void SetTokenValue(const std::string& token, const std::string& value);
|
||||||
|
|
||||||
|
/// Gets the current value set on the given token.
|
||||||
|
const std::string GetTokenValue(const std::string& token) const;
|
||||||
|
|
||||||
|
/// Sets the unit indent string. For example, in Java it should be " ".
|
||||||
|
void SetIndentString(const std::string& indent);
|
||||||
|
|
||||||
|
/// Increases the indent by a unit (the string set in SetIndentString).
|
||||||
|
void Indent();
|
||||||
|
|
||||||
|
/// Decreases the indent by a unit (the string set in SetIndentString).
|
||||||
|
void Outdent();
|
||||||
|
|
||||||
|
/// Generates the indentation string.
|
||||||
|
std::string GenerateIndent() const;
|
||||||
|
|
||||||
|
/// Appends a piece of template codes to the stream. Every named value will be
|
||||||
|
/// replaced via the real value. A new line will always be appended at the
|
||||||
|
/// end.
|
||||||
|
void Append(const std::string& text);
|
||||||
|
|
||||||
|
/// Appends a piece of template codes to the stream. Same with `Append`, but a
|
||||||
|
/// new line will not be appended at the end.
|
||||||
|
void AppendNoNewLine(const std::string& text);
|
||||||
|
|
||||||
|
/// Appends a new line to the stream.
|
||||||
|
void NewLine();
|
||||||
|
|
||||||
|
/// Deletes the last N charaters in the stream. If the stream has less than N
|
||||||
|
/// characters, deletes all.
|
||||||
|
void Backspace(int n);
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
|
||||||
|
/// Checks if the internal string stream is empty. Note: This method has
|
||||||
|
// overhead.
|
||||||
|
bool IsStreamEmpty() const;
|
||||||
|
|
||||||
|
/// Clears all the internal string stream and value map.
|
||||||
|
void Clear();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void AppendInternal(const std::string& text, bool newline);
|
||||||
|
|
||||||
|
std::string indent_str_;
|
||||||
|
int indent_;
|
||||||
|
|
||||||
|
std::map<std::string, std::string> value_map_;
|
||||||
|
std::string buffer_;
|
||||||
|
|
||||||
|
ErrorReporter* err_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Converts foo_bar_name to fooBarName. It's callers duty to make sure given
|
||||||
|
/// string "s" is already in snake case; or unexpected behavior may occur.
|
||||||
|
std::string SnakeCaseToCamelCase(const std::string& s);
|
||||||
|
|
||||||
|
/// Joins 2 parts of file path into one, connected by unix path seperator '/'.
|
||||||
|
/// It's callers duty to ensure the two parts are valid.
|
||||||
|
std::string JoinPath(const std::string& a, const std::string& b);
|
||||||
|
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
|
97
tensorflow/lite/experimental/support/codegen/utils_test.cc
Normal file
97
tensorflow/lite/experimental/support/codegen/utils_test.cc
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/support/codegen/utils.h"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace support {
|
||||||
|
namespace codegen {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(ErrorReporterTest, TestReportError) {
|
||||||
|
ErrorReporter err;
|
||||||
|
err.Error("some text");
|
||||||
|
EXPECT_EQ(err.GetMessage(), "[ERROR] some text\n");
|
||||||
|
EXPECT_EQ(err.GetMessage(), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CodeGeneratorTest, TestExample) {
|
||||||
|
ErrorReporter err;
|
||||||
|
CodeWriter writer(&err);
|
||||||
|
writer.SetTokenValue("NAME", "Foo");
|
||||||
|
const std::string text = R"(void {{NAME}}() { printf("%s", "{{NAME}}"); })";
|
||||||
|
writer.Append(text);
|
||||||
|
writer.SetTokenValue("NAME", "Bar");
|
||||||
|
writer.Append(text);
|
||||||
|
EXPECT_EQ(
|
||||||
|
"void Foo() { printf(\"%s\", \"Foo\"); }\n"
|
||||||
|
"void Bar() { printf(\"%s\", \"Bar\"); }\n",
|
||||||
|
writer.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CodeGeneratorTest, TestInexistentToken) {
|
||||||
|
ErrorReporter err;
|
||||||
|
CodeWriter writer(&err);
|
||||||
|
writer.SetTokenValue("NAME", "Foo");
|
||||||
|
const std::string text = R"(void {{name}}() {})";
|
||||||
|
writer.Append(text);
|
||||||
|
EXPECT_EQ(err.GetMessage(),
|
||||||
|
"[ERROR] Internal: Cannot find value with token 'name'\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CodeGeneratorTest, TestUnclosedToken) {
|
||||||
|
ErrorReporter err;
|
||||||
|
CodeWriter writer(&err);
|
||||||
|
writer.SetTokenValue("NAME", "Foo");
|
||||||
|
const std::string text = R"(void {{NAME}() {})";
|
||||||
|
writer.Append(text);
|
||||||
|
EXPECT_EQ(err.GetMessage(),
|
||||||
|
"[ERROR] Internal: Invalid template: {{token}} is not closed.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CodeGeneratorTest, TestIndentControl) {
|
||||||
|
ErrorReporter err;
|
||||||
|
CodeWriter writer(&err);
|
||||||
|
writer.SetIndentString(" ");
|
||||||
|
writer.Indent();
|
||||||
|
writer.AppendNoNewLine("abcde"); // Will indent
|
||||||
|
EXPECT_EQ(" abcde", writer.ToString());
|
||||||
|
writer.Clear();
|
||||||
|
writer.Indent();
|
||||||
|
writer.AppendNoNewLine("abc\n\nde");
|
||||||
|
// The blank line will not indent
|
||||||
|
EXPECT_EQ(" abc\n\n de", writer.ToString());
|
||||||
|
writer.Clear();
|
||||||
|
writer.Indent();
|
||||||
|
writer.Append("abc");
|
||||||
|
writer.Outdent();
|
||||||
|
writer.AppendNoNewLine("def");
|
||||||
|
EXPECT_EQ(" abc\ndef", writer.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CaseConversionTest, TestSnakeToCamel) {
|
||||||
|
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel"));
|
||||||
|
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel_"));
|
||||||
|
EXPECT_EQ("ImACamel", SnakeCaseToCamelCase("_im_a_camel"));
|
||||||
|
EXPECT_EQ("", SnakeCaseToCamelCase("_"));
|
||||||
|
EXPECT_EQ("camel", SnakeCaseToCamelCase("camel"));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace codegen
|
||||||
|
} // namespace support
|
||||||
|
} // namespace tflite
|
Loading…
x
Reference in New Issue
Block a user