Opensource TFLite Support codegen.

PiperOrigin-RevId: 302011153
Change-Id: Idb2f649dc48fdc449fac2d6e9009719d29afb2ad
This commit is contained in:
Xunkai Zhang 2020-03-20 05:52:10 -07:00 committed by TensorFlower Gardener
parent abf182a882
commit b57d910db5
15 changed files with 2314 additions and 0 deletions

View File

@ -0,0 +1,87 @@
# The tools for generating wrapper classes for a TFLite model with metadata.
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "utils",
srcs = [
"utils.cc",
],
hdrs = [
"utils.h",
],
deps = [
],
)
cc_library(
name = "code_generator",
srcs = [
"code_generator.cc",
],
hdrs = [
"code_generator.h",
],
deps = [
":utils",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
],
)
cc_library(
name = "metadata_helper",
srcs = [
"metadata_helper.cc",
],
hdrs = [
"metadata_helper.h",
],
deps = [
":utils",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
"//tensorflow/lite/schema:schema_fbs",
],
)
cc_library(
name = "android_java_generator",
srcs = [
"android_java_generator.cc",
],
hdrs = [
"android_java_generator.h",
],
deps = [
":code_generator",
":metadata_helper",
":utils",
"//tensorflow/core/platform:logging",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_cc",
"//tensorflow/lite/schema:schema_fbs",
],
)
cc_test(
name = "code_generator_test",
size = "small",
srcs = ["code_generator_test.cc"],
data = ["//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs"],
deps = [
":code_generator",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "utils_test",
srcs = ["utils_test.cc"],
deps = [
":utils",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -0,0 +1,13 @@
# TensorFlow Lite Android Wrapper Code Generator
For TensorFlow Lite model enhanced with [metadata](https://www.tensorflow.org/lite/convert/metadata.md),
developers can use the TensorFlow Lite Android wrapper code generator to create
platform specific wrapper code. The wrapper code removes the need to interact
directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow
Lite model with typed objects such as `Bitmap` and `Rect`.
The usefulness of the code generator depend on the completeness of the
TensorFlow Lite model's metadata entry. Refer to the `<Codegen usage>` section
under relevant fields in
[metadata_schema.fbs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs),
to see how the codegen tool parses each field.

View File

@ -0,0 +1,978 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
#include <ctype.h>
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h"
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
namespace {
using details_android_java::ModelInfo;
using details_android_java::TensorInfo;
// Helper class to organize the C++ code block as a generated code block.
// Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
class AsBlock {
public:
AsBlock(CodeWriter* code_writer, const std::string& before,
bool trailing_blank_line = false)
: code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
code_writer_->AppendNoNewLine(before);
code_writer_->Append(" {");
code_writer_->Indent();
}
~AsBlock() {
code_writer_->Outdent();
code_writer_->Append("}");
if (trailing_blank_line_) {
code_writer_->NewLine();
}
}
private:
CodeWriter* code_writer_;
bool trailing_blank_line_;
};
// Declare the functions first, so that the functions can follow a logical
// order.
bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*);
bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*);
std::string GetModelVersionedName(const ModelMetadata* metadata) {
std::string model_name = "MyModel";
if (metadata->name() != nullptr && !(metadata->name()->str().empty())) {
model_name = metadata->name()->str();
}
std::string model_version = "unknown";
if (metadata->version() != nullptr && !(metadata->version()->str().empty())) {
model_version = metadata->version()->str();
}
return model_name + " (Version: " + model_version + ")";
}
TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
const std::string& name, bool is_input, int index,
ErrorReporter* err) {
TensorInfo tensor_info;
std::string tensor_identifier = is_input ? "input" : "output";
tensor_identifier += " " + std::to_string(index);
tensor_info.associated_axis_label_index = FindAssociatedFile(
metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err);
tensor_info.associated_value_label_index = FindAssociatedFile(
metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err);
if (is_input && (tensor_info.associated_axis_label_index >= 0 ||
tensor_info.associated_value_label_index >= 0)) {
err->Warning(
"Found label file on input tensor (%s). Label file for input "
"tensor is not supported yet. The "
"file will be ignored.",
tensor_identifier.c_str());
}
if (tensor_info.associated_axis_label_index >= 0 &&
tensor_info.associated_value_label_index >= 0) {
err->Warning(
"Found both axis label file and value label file for tensor (%s), "
"which is not supported. Only the axis label file will be used.",
tensor_identifier.c_str());
}
tensor_info.is_input = is_input;
tensor_info.name = SnakeCaseToCamelCase(name);
tensor_info.upper_camel_name = tensor_info.name;
tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]);
tensor_info.normalization_unit =
FindNormalizationUnit(metadata, tensor_identifier, err);
if (metadata->content()->content_properties_type() ==
ContentProperties_ImageProperties) {
if (metadata->content()
->content_properties_as_ImageProperties()
->color_space() == ColorSpaceType_RGB) {
tensor_info.content_type = "image";
tensor_info.wrapper_type = "TensorImage";
tensor_info.processor_type = "ImageProcessor";
return tensor_info;
} else {
err->Warning(
"Found Non-RGB image on tensor (%s). Codegen currently does not "
"support it, and regard it as a plain numeric tensor.",
tensor_identifier.c_str());
}
}
tensor_info.content_type = "tensor";
tensor_info.wrapper_type = "TensorBuffer";
tensor_info.processor_type = "TensorProcessor";
return tensor_info;
}
ModelInfo CreateModelInfo(const ModelMetadata* metadata,
const std::string& package_name,
const std::string& model_class_name,
const std::string& model_asset_path,
ErrorReporter* err) {
ModelInfo model_info;
if (!CodeGenerator::VerifyMetadata(metadata, err)) {
// TODO(b/150116380): Create dummy model info.
err->Error("Validating metadata failed.");
return model_info;
}
model_info.package_name = package_name;
model_info.model_class_name = model_class_name;
model_info.model_asset_path = model_asset_path;
model_info.model_versioned_name = GetModelVersionedName(metadata);
const auto* graph = metadata->subgraph_metadata()->Get(0);
auto names = CodeGenerator::NameInputsAndOutputs(
graph->input_tensor_metadata(), graph->output_tensor_metadata());
std::vector<std::string> input_tensor_names = std::move(names.first);
std::vector<std::string> output_tensor_names = std::move(names.second);
for (int i = 0; i < graph->input_tensor_metadata()->size(); i++) {
model_info.inputs.push_back(
CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
input_tensor_names[i], true, i, err));
}
for (int i = 0; i < graph->output_tensor_metadata()->size(); i++) {
model_info.outputs.push_back(
CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
output_tensor_names[i], false, i, err));
}
return model_info;
}
void SetCodeWriterWithTensorInfo(CodeWriter* code_writer,
const TensorInfo& tensor_info) {
code_writer->SetTokenValue("NAME", tensor_info.name);
code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name);
code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type);
code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type);
std::string wrapper_name = tensor_info.wrapper_type;
wrapper_name[0] = tolower(wrapper_name[0]);
code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name);
code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type);
code_writer->SetTokenValue("NORMALIZATION_UNIT",
std::to_string(tensor_info.normalization_unit));
code_writer->SetTokenValue(
"ASSOCIATED_AXIS_LABEL_INDEX",
std::to_string(tensor_info.associated_axis_label_index));
code_writer->SetTokenValue(
"ASSOCIATED_VALUE_LABEL_INDEX",
std::to_string(tensor_info.associated_value_label_index));
}
void SetCodeWriterWithModelInfo(CodeWriter* code_writer,
const ModelInfo& model_info) {
code_writer->SetTokenValue("PACKAGE", model_info.package_name);
code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path);
code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name);
}
constexpr char JAVA_DEFAULT_PACKAGE[] = "default";
std::string ConvertPackageToPath(const std::string& package) {
if (package == JAVA_DEFAULT_PACKAGE) {
return "";
}
std::string path = package;
std::replace(path.begin(), path.end(), '.', '/');
return path;
}
bool IsImageUsed(const ModelInfo& model) {
for (const auto& input : model.inputs) {
if (input.content_type == "image") {
return true;
}
}
for (const auto& output : model.outputs) {
if (output.content_type == "image") {
return true;
}
}
return false;
}
bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("// Generated by TFLite Support.");
code_writer->Append("package {{PACKAGE}};");
code_writer->NewLine();
if (!GenerateWrapperImports(code_writer, model, err)) {
err->Error("Fail to generate imports for wrapper class.");
return false;
}
if (!GenerateWrapperClass(code_writer, model, err)) {
err->Error("Fail to generate wrapper class.");
return false;
}
code_writer->NewLine();
return true;
}
bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
const std::string support_pkg = "org.tensorflow.lite.support.";
std::vector<std::string> imports{
"android.content.Context",
"java.io.IOException",
"java.nio.ByteBuffer",
"java.nio.FloatBuffer",
"java.util.Arrays",
"java.util.HashMap",
"java.util.List",
"java.util.Map",
"org.checkerframework.checker.nullness.qual.Nullable",
"org.tensorflow.lite.DataType",
"org.tensorflow.lite.Tensor.QuantizationParams",
support_pkg + "common.FileUtil",
support_pkg + "common.TensorProcessor",
support_pkg + "common.ops.CastOp",
support_pkg + "common.ops.DequantizeOp",
support_pkg + "common.ops.NormalizeOp",
support_pkg + "common.ops.QuantizeOp",
support_pkg + "label.TensorLabel",
support_pkg + "metadata.MetadataExtractor",
support_pkg + "metadata.schema.NormalizationOptions",
support_pkg + "model.Model",
support_pkg + "model.Model.Device",
support_pkg + "tensorbuffer.TensorBuffer",
};
if (IsImageUsed(model)) {
for (const auto& target :
{"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp",
"image.ops.ResizeOp.ResizeMethod"}) {
imports.push_back(support_pkg + target);
}
imports.push_back("android.graphics.Bitmap");
}
std::sort(imports.begin(), imports.end());
for (const auto target : imports) {
code_writer->SetTokenValue("TARGET", target);
code_writer->Append("import {{TARGET}};");
}
code_writer->NewLine();
return true;
}
bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
model.model_versioned_name);
code_writer->Append(
R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)");
const auto code_block =
AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}");
code_writer->Append(R"(private final Metadata metadata;
private final Model model;
private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
for (const auto& tensor : model.outputs) {
if (tensor.associated_axis_label_index >= 0) {
code_writer->SetTokenValue("NAME", tensor.name);
code_writer->Append("private final List<String> {{NAME}}Labels;");
}
}
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
}
code_writer->NewLine();
if (!GenerateWrapperInputs(code_writer, model, err)) {
err->Error("Failed to generate input classes");
return false;
}
code_writer->NewLine();
if (!GenerateWrapperOutputs(code_writer, model, err)) {
err->Error("Failed to generate output classes");
return false;
}
code_writer->NewLine();
if (!GenerateWrapperMetadata(code_writer, model, err)) {
err->Error("Failed to generate the metadata class");
return false;
}
code_writer->NewLine();
if (!GenerateWrapperAPI(code_writer, model, err)) {
err->Error("Failed to generate the common APIs");
return false;
}
return true;
}
bool GenerateWrapperInputs(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("/** Input wrapper of {@link {{MODEL_CLASS_NAME}}} */");
auto class_block = AsBlock(code_writer, "public class Inputs");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("private {{WRAPPER_TYPE}} {{NAME}};");
}
code_writer->NewLine();
// Ctor
{
auto ctor_block = AsBlock(code_writer, "public Inputs()");
code_writer->Append(
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
if (tensor.content_type == "image") {
code_writer->Append(
"{{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());");
} else {
code_writer->Append(
"{{NAME}} = "
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
"metadata.get{{NAME_U}}Type());");
}
}
}
for (const auto& tensor : model.inputs) {
code_writer->NewLine();
SetCodeWriterWithTensorInfo(code_writer, tensor);
// Loaders
if (tensor.content_type == "image") {
{
auto bitmap_loader_block =
AsBlock(code_writer, "public void load{{NAME_U}}(Bitmap bitmap)");
code_writer->Append(R"({{NAME}}.load(bitmap);
{{NAME}} = preprocess{{NAME_U}}({{NAME}});)");
}
code_writer->NewLine();
{
auto tensor_image_loader_block = AsBlock(
code_writer, "public void load{{NAME_U}}(TensorImage tensorImage)");
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorImage);");
}
} else { // content_type == "FEATURE" or "UNKNOWN"
auto tensorbuffer_loader_block = AsBlock(
code_writer, "public void load{{NAME_U}}(TensorBuffer tensorBuffer)");
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorBuffer);");
}
code_writer->NewLine();
// Processor
code_writer->Append(
R"(private {{WRAPPER_TYPE}} preprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}}) {
if ({{NAME}}Preprocessor == null) {
return {{WRAPPER_NAME}};
}
return {{NAME}}Preprocessor.process({{WRAPPER_NAME}});
}
)");
}
{
const auto get_buffer_block = AsBlock(code_writer, "Object[] getBuffer()");
code_writer->AppendNoNewLine("return new Object[] {");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->AppendNoNewLine("{{NAME}}.getBuffer(), ");
}
code_writer->Backspace(2);
code_writer->Append("};");
}
return true;
}
bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
auto class_block = AsBlock(code_writer, "public class Outputs");
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
}
code_writer->NewLine();
{
const auto ctor_block = AsBlock(code_writer, "public Outputs()");
code_writer->Append(
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
if (tensor.content_type == "image") {
code_writer->Append(
R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());
{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)");
} else { // FEATURE, UNKNOWN
code_writer->Append(
"{{NAME}} = "
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
"metadata.get{{NAME_U}}Type());");
}
}
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->NewLine();
if (tensor.associated_axis_label_index >= 0) {
if (tensor.content_type == "image") {
err->Warning(
"Axis label for images is not supported. The labels will "
"be ignored.");
} else {
code_writer->Append(R"(public Map<String, Float> get{{NAME_U}}() {
return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getMapWithFloatValue();
})");
}
} else {
code_writer->Append(R"(public {{WRAPPER_TYPE}} get{{NAME_U}}() {
return postprocess{{NAME_U}}({{NAME}});
})");
}
code_writer->NewLine();
{
auto processor_block =
AsBlock(code_writer,
"private {{WRAPPER_TYPE}} "
"postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
code_writer->Append(R"(if ({{NAME}}Postprocessor == null) {
return {{WRAPPER_NAME}};
}
return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});)");
}
}
code_writer->NewLine();
{
const auto get_buffer_block =
AsBlock(code_writer, "Map<Integer, Object> getBuffer()");
code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();");
for (int i = 0; i < model.outputs.size(); i++) {
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
code_writer->SetTokenValue("ID", std::to_string(i));
code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());");
}
code_writer->Append("return outputs;");
}
return true;
}
bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append(
"/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
const auto class_block = AsBlock(code_writer, "public static class Metadata");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(private final int[] {{NAME}}Shape;
private final DataType {{NAME}}DataType;
private final QuantizationParams {{NAME}}QuantizationParams;)");
if (tensor.normalization_unit >= 0) {
code_writer->Append(R"(private final float[] {{NAME}}Mean;
private final float[] {{NAME}}Stddev;)");
}
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(private final int[] {{NAME}}Shape;
private final DataType {{NAME}}DataType;
private final QuantizationParams {{NAME}}QuantizationParams;)");
if (tensor.normalization_unit >= 0) {
code_writer->Append(R"(private final float[] {{NAME}}Mean;
private final float[] {{NAME}}Stddev;)");
}
if (tensor.associated_axis_label_index >= 0 ||
tensor.associated_value_label_index >= 0) {
code_writer->Append("private final List<String> {{NAME}}Labels;");
}
}
code_writer->NewLine();
{
const auto ctor_block = AsBlock(
code_writer,
"public Metadata(ByteBuffer buffer, Model model) throws IOException");
code_writer->Append(
"MetadataExtractor extractor = new MetadataExtractor(buffer);");
for (int i = 0; i < model.inputs.size(); i++) {
SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
code_writer->SetTokenValue("ID", std::to_string(i));
code_writer->Append(
R"({{NAME}}Shape = extractor.getInputTensorShape({{ID}});
{{NAME}}DataType = extractor.getInputTensorType({{ID}});
{{NAME}}QuantizationParams = extractor.getInputTensorQuantizationParams({{ID}});)");
if (model.inputs[i].normalization_unit >= 0) {
code_writer->Append(
R"(NormalizationOptions {{NAME}}NormalizationOptions =
(NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
{{NAME}}MeanBuffer.get({{NAME}}Mean);
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
}
}
for (int i = 0; i < model.outputs.size(); i++) {
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
code_writer->SetTokenValue("ID", std::to_string(i));
code_writer->Append(
R"({{NAME}}Shape = model.getOutputTensorShape({{ID}});
{{NAME}}DataType = extractor.getOutputTensorType({{ID}});
{{NAME}}QuantizationParams = extractor.getOutputTensorQuantizationParams({{ID}});)");
if (model.outputs[i].normalization_unit >= 0) {
code_writer->Append(
R"(NormalizationOptions {{NAME}}NormalizationOptions =
(NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
{{NAME}}MeanBuffer.get({{NAME}}Mean);
FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
}
if (model.outputs[i].associated_axis_label_index >= 0) {
code_writer->Append(R"(String {{NAME}}LabelsFileName =
extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name();
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
} else if (model.outputs[i].associated_value_label_index >= 0) {
code_writer->Append(R"(String {{NAME}}LabelsFileName =
extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name();
{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
}
}
}
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public int[] get{{NAME_U}}Shape() {
return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
}
public DataType get{{NAME_U}}Type() {
return {{NAME}}DataType;
}
public QuantizationParams get{{NAME_U}}QuantizationParams() {
return {{NAME}}QuantizationParams;
})");
if (tensor.normalization_unit >= 0) {
code_writer->Append(R"(
public float[] get{{NAME_U}}Mean() {
return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
}
public float[] get{{NAME_U}}Stddev() {
return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
})");
}
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public int[] get{{NAME_U}}Shape() {
return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
}
public DataType get{{NAME_U}}Type() {
return {{NAME}}DataType;
}
public QuantizationParams get{{NAME_U}}QuantizationParams() {
return {{NAME}}QuantizationParams;
})");
if (tensor.normalization_unit >= 0) {
code_writer->Append(R"(
public float[] get{{NAME_U}}Mean() {
return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
}
public float[] get{{NAME_U}}Stddev() {
return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
})");
}
if (tensor.associated_axis_label_index >= 0 ||
tensor.associated_value_label_index >= 0) {
code_writer->Append(R"(
public List<String> get{{NAME_U}}Labels() {
return {{NAME}}Labels;
})");
}
}
return true;
}
bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append(R"(public Metadata getMetadata() {
return metadata;
}
)");
code_writer->Append(R"(/**
* Creates interpreter and loads associated files if needed.
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public {{MODEL_CLASS_NAME}}(Context context) throws IOException {
this(context, MODEL_NAME, Device.CPU, 1);
}
/**
* Creates interpreter and loads associated files if needed, but loading another model in the same
* input / output structure with the original one.
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public {{MODEL_CLASS_NAME}}(Context context, String modelPath) throws IOException {
this(context, modelPath, Device.CPU, 1);
}
/**
* Creates interpreter and loads associated files if needed, with device and number of threads
* configured.
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public {{MODEL_CLASS_NAME}}(Context context, Device device, int numThreads) throws IOException {
this(context, MODEL_NAME, device, numThreads);
}
/**
* Creates interpreter for a user-specified model.
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public {{MODEL_CLASS_NAME}}(Context context, String modelPath, Device device, int numThreads) throws IOException {
model = new Model.Builder(context, modelPath).setDevice(device).setNumThreads(numThreads).build();
metadata = new Metadata(model.getData(), model);)");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
{{PROCESSOR_TYPE}}.Builder {{NAME}}PreprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder())");
if (tensor.content_type == "image") {
code_writer->Append(R"( .add(new ResizeOp(
metadata.get{{NAME_U}}Shape()[1],
metadata.get{{NAME_U}}Shape()[2],
ResizeMethod.NEAREST_NEIGHBOR)))");
}
if (tensor.normalization_unit >= 0) {
code_writer->Append(
R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
}
code_writer->Append(
R"( .add(new QuantizeOp(
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
metadata.get{{NAME_U}}QuantizationParams().getScale()))
.add(new CastOp(metadata.get{{NAME_U}}Type()));
{{NAME}}Preprocessor = {{NAME}}PreprocessorBuilder.build();)");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->AppendNoNewLine(R"(
{{PROCESSOR_TYPE}}.Builder {{NAME}}PostprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder()
.add(new DequantizeOp(
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
metadata.get{{NAME_U}}QuantizationParams().getScale())))");
if (tensor.normalization_unit >= 0) {
code_writer->AppendNoNewLine(R"(
.add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
}
code_writer->Append(R"(;
{{NAME}}Postprocessor = {{NAME}}PostprocessorBuilder.build();)");
if (tensor.associated_axis_label_index >= 0) {
code_writer->Append(R"(
{{NAME}}Labels = metadata.get{{NAME_U}}Labels();)");
}
}
code_writer->Append("}");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public void reset{{NAME_U}}Preprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
{{NAME}}Preprocessor = processor;
})");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public void reset{{NAME_U}}Postprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
{{NAME}}Postprocessor = processor;
})");
}
code_writer->Append(R"(
/** Creates inputs */
public Inputs createInputs() {
return new Inputs();
}
/** Triggers the model. */
public Outputs run(Inputs inputs) {
Outputs outputs = new Outputs();
model.run(inputs.getBuffer(), outputs.getBuffer());
return outputs;
}
/** Closes the model. */
public void close() {
model.close();
})");
return true;
}
bool GenerateBuildGradleContent(CodeWriter* code_writer,
const ModelInfo& model_info) {
code_writer->Append(R"(buildscript {
repositories {
google()
jcenter()
}
dependencies {
classpath 'com.android.tools.build:gradle:3.2.1'
}
}
allprojects {
repositories {
google()
jcenter()
flatDir {
dirs 'libs'
}
}
}
apply plugin: 'com.android.library'
android {
compileSdkVersion 29
defaultConfig {
targetSdkVersion 29
versionCode 1
versionName "1.0"
}
aaptOptions {
noCompress "tflite"
}
compileOptions {
sourceCompatibility = '1.8'
targetCompatibility = '1.8'
}
lintOptions {
abortOnError false
}
}
configurations {
libMetadata
}
dependencies {
libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic'
}
task downloadLibs(type: Sync) {
from configurations.libMetadata
into "$buildDir/libs"
rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar"
}
preBuild.dependsOn downloadLibs
dependencies {
compileOnly 'org.checkerframework:checker-qual:2.5.8'
api 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
api files("$buildDir/libs/tensorflow-lite-support-metadata.jar")
implementation 'org.apache.commons:commons-compress:1.19'
})");
return true;
}
bool GenerateAndroidManifestContent(CodeWriter* code_writer,
const ModelInfo& model_info) {
code_writer->Append(R"(<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="{{PACKAGE}}">
</manifest>)");
return true;
}
bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) {
code_writer->Append("# {{MODEL_CLASS_NAME}} Usage");
code_writer->AppendNoNewLine(R"(
```
import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
// 1. Initialize the Model
{{MODEL_CLASS_NAME}} model = null;
try {
model = new {{MODEL_CLASS_NAME}}(context); // android.content.Context
// Create the input container.
{{MODEL_CLASS_NAME}}.Inputs inputs = model.createInputs();
} catch (IOException e) {
e.printStackTrace();
}
if (model != null) {
// 2. Set the inputs)");
for (const auto& t : model_info.inputs) {
SetCodeWriterWithTensorInfo(code_writer, t);
if (t.content_type == "image") {
code_writer->Append(R"(
// Load input tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
Bitmap bitmap = ...;
inputs.load{{NAME_U}}(bitmap);
// Alternatively, load the input tensor "{{NAME}}" from a TensorImage.
// Check out TensorImage documentation to load other image data structures.
// TensorImage tensorImage = ...;
// inputs.load{{NAME_U}}(tensorImage);)");
} else {
code_writer->Append(R"(
// Load input tensor "{{NAME}}" from a TensorBuffer.
// Check out TensorBuffer documentation to load other data structures.
TensorBuffer tensorBuffer = ...;
inputs.load{{NAME_U}}(tensorBuffer);)");
}
}
code_writer->Append(R"(
// 3. Run the model
{{MODEL_CLASS_NAME}}.Outputs outputs = model.run(inputs);)");
code_writer->Append(R"(
// 4. Retrieve the results)");
for (const auto& t : model_info.outputs) {
SetCodeWriterWithTensorInfo(code_writer, t);
if (t.associated_axis_label_index >= 0) {
code_writer->SetTokenValue("WRAPPER_TYPE", "Map<String, Float>");
}
code_writer->Append(
R"( {{WRAPPER_TYPE}} {{NAME}} = outputs.get{{NAME_U}}();)");
}
code_writer->Append(R"(}
```)");
return true;
}
GenerationResult::File GenerateWrapperFile(const std::string& module_root,
const ModelInfo& model_info,
ErrorReporter* err) {
const auto java_path = JoinPath(module_root, "src/main/java");
const auto package_path =
JoinPath(java_path, ConvertPackageToPath(model_info.package_name));
const auto file_path =
JoinPath(package_path, model_info.model_class_name + JAVA_EXT);
CodeWriter code_writer(err);
code_writer.SetIndentString(" ");
SetCodeWriterWithModelInfo(&code_writer, model_info);
if (!GenerateWrapperFileContent(&code_writer, model_info, err)) {
err->Error("Generating Java wrapper content failed.");
}
const auto java_file = code_writer.ToString();
return GenerationResult::File{file_path, java_file};
}
GenerationResult::File GenerateBuildGradle(const std::string& module_root,
const ModelInfo& model_info,
ErrorReporter* err) {
const auto file_path = JoinPath(module_root, "build.gradle");
CodeWriter code_writer(err);
SetCodeWriterWithModelInfo(&code_writer, model_info);
if (!GenerateBuildGradleContent(&code_writer, model_info)) {
err->Error("Generating build.gradle failed.");
}
const auto content = code_writer.ToString();
return GenerationResult::File{file_path, content};
}
GenerationResult::File GenerateAndroidManifest(const std::string& module_root,
const ModelInfo& model_info,
ErrorReporter* err) {
const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml");
CodeWriter code_writer(err);
SetCodeWriterWithModelInfo(&code_writer, model_info);
if (!GenerateAndroidManifestContent(&code_writer, model_info)) {
err->Error("Generating AndroidManifest.xml failed.");
}
return GenerationResult::File{file_path, code_writer.ToString()};
}
GenerationResult::File GenerateDoc(const std::string& module_root,
const ModelInfo& model_info,
ErrorReporter* err) {
std::string lower = model_info.model_class_name;
for (int i = 0; i < lower.length(); i++) {
lower[i] = std::tolower(lower[i]);
}
const auto file_path = JoinPath(module_root, lower + ".md");
CodeWriter code_writer(err);
SetCodeWriterWithModelInfo(&code_writer, model_info);
if (!GenerateDocContent(&code_writer, model_info)) {
err->Error("Generating doc failed.");
}
return GenerationResult::File{file_path, code_writer.ToString()};
}
} // namespace
AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
: CodeGenerator(), module_root_(module_root) {}
GenerationResult AndroidJavaGenerator::Generate(
const Model* model, const std::string& package_name,
const std::string& model_class_name, const std::string& model_asset_path) {
GenerationResult result;
const ModelMetadata* metadata = GetMetadataFromModel(model);
if (metadata == nullptr) {
err_.Error(
"Cannot find TFLite Metadata in the model. Codegen will generate "
"nothing.");
return result;
}
details_android_java::ModelInfo model_info = CreateModelInfo(
metadata, package_name, model_class_name, model_asset_path, &err_);
result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_));
result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_));
result.files.push_back(
GenerateAndroidManifest(module_root_, model_info, &err_));
result.files.push_back(GenerateDoc(module_root_, model_info, &err_));
return result;
}
GenerationResult AndroidJavaGenerator::Generate(
const char* model_storage, const std::string& package_name,
const std::string& model_class_name, const std::string& model_asset_path) {
const Model* model = GetModel(model_storage);
return Generate(model, package_name, model_class_name, model_asset_path);
}
std::string AndroidJavaGenerator::GetErrorMessage() {
return err_.GetMessage();
}
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -0,0 +1,107 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
namespace details_android_java {
/// The intermediate data structure for generating code from TensorMetadata.
/// Should only be used as const reference when created.
struct TensorInfo {
std::string name;
std::string upper_camel_name;
std::string content_type;
std::string wrapper_type;
std::string processor_type;
bool is_input;
/// Optional. Set to -1 if not applicable.
int normalization_unit;
/// Optional. Set to -1 if associated_axis_label is empty.
int associated_axis_label_index;
/// Optional. Set to -1 if associated_value_label is empty.
int associated_value_label_index;
};
/// The intermediate data structure for generating code from ModelMetadata.
/// Should only be used as const reference when created.
struct ModelInfo {
std::string package_name;
std::string model_asset_path;
std::string model_class_name;
std::string model_versioned_name;
std::vector<TensorInfo> inputs;
std::vector<TensorInfo> outputs;
};
} // namespace details_android_java
constexpr char JAVA_EXT[] = ".java";
/// Generates Android supporting codes and modules (in Java) based on TFLite
/// metadata.
class AndroidJavaGenerator : public CodeGenerator {
public:
/// Creates an AndroidJavaGenerator.
/// Args:
/// - module_root: The root of destination Java module.
explicit AndroidJavaGenerator(const std::string& module_root);
/// Generates files. Returns the file paths and contents.
/// Args:
/// - model: The TFLite model with Metadata filled.
/// - package_name: The name of the Java package which generated classes
/// belong to.
/// - model_class_name: A readable name of the generated wrapper class, such
/// as "ImageClassifier", "MobileNetV2" or "MyModel".
/// - model_asset_path: The relevant path to the model file in the asset.
// TODO(b/141225157): Automatically generate model_class_name.
GenerationResult Generate(const Model* model, const std::string& package_name,
const std::string& model_class_name,
const std::string& model_asset_path);
/// Generates files and returns the file paths and contents.
/// It's mostly identical with the previous one, but the model here is
/// provided as binary flatbuffer content without parsing.
GenerationResult Generate(const char* model_storage,
const std::string& package_name,
const std::string& model_class_name,
const std::string& model_asset_path);
std::string GetErrorMessage();
private:
const std::string module_root_;
ErrorReporter err_;
};
} // namespace codegen
} // namespace support
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_

View File

@ -0,0 +1,179 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
#include <cctype>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
namespace {
void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) {
auto& names = *names_ptr;
std::unordered_map<std::string, int> indexes;
std::unordered_map<std::string, int> first_appearance;
for (int i = 0; i < names.size(); i++) {
if (indexes.find(names[i]) == indexes.end()) {
indexes[names[i]] = 1;
first_appearance[names[i]] = i;
} else {
indexes[names[i]] += 1;
names[i].append(std::to_string(indexes[names[i]]));
}
}
for (const auto it : first_appearance) {
const auto& name = it.first;
const auto i = it.second;
if (indexes[name] > 1) {
names[i].append("1");
}
}
}
} // namespace
CodeGenerator::CodeGenerator() {}
bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata,
ErrorReporter* err) {
if (metadata == nullptr) {
err->Error("Loading nullptr is not allowed");
return false;
}
if (metadata->subgraph_metadata()->size() != 1) {
err->Error("Only exact 1 subgraph is supported");
return false;
}
return true;
}
std::pair<std::vector<std::string>, std::vector<std::string>>
CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs,
const TensorMetadataList* outputs) {
std::vector<std::string> input_names;
std::vector<std::string> output_names;
if (inputs != nullptr) {
input_names.reserve(inputs->size());
for (const auto* tensor : *inputs) {
input_names.push_back(NameTensor(*tensor, "input"));
}
}
if (outputs != nullptr) {
output_names.reserve(outputs->size());
for (const auto* tensor : *outputs) {
output_names.push_back(NameTensor(*tensor, "output"));
}
}
// Solve conflict
ResolveConflictedInputAndOutputNames(&input_names, &output_names);
return std::make_pair(input_names, output_names);
}
std::string CodeGenerator::ConvertToValidName(const std::string& name) {
// lowercase all
std::string result = name;
for (int i = 0; i < result.size(); i++) {
result[i] = std::tolower(result[i]);
}
// replace all non-alpha or non-numeric with underscores, except underscore
// itself
for (int i = 0; i < result.size(); i++) {
if (result[i] != '_' && !std::isalnum(result[i])) {
result[i] = '_';
}
}
// remove leading underscores
int leading_underscores = 0;
while (leading_underscores < result.size() &&
result[leading_underscores] == '_') {
leading_underscores++;
}
result.erase(0, leading_underscores);
if (result.empty()) {
return "";
}
// first char should be alpha
if (std::isalpha(result[0])) {
return result;
}
return "tensor_" + result;
}
std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
const std::string& default_name) {
if (tensor.name() != nullptr && tensor.name()->size() > 0) {
// TODO(b/141225157) Validate tensor name. It should be in lower case.
auto suggested_name = ConvertToValidName(tensor.name()->str());
if (!suggested_name.empty()) {
return suggested_name;
}
}
auto* content = tensor.content();
if (content == nullptr || content->content_properties() == nullptr) {
return default_name;
}
switch (content->content_properties_type()) {
case ContentProperties_ImageProperties:
return "image";
case ContentProperties_FeatureProperties:
return "feature";
default:
return default_name;
}
}
void CodeGenerator::ResolveConflictedInputAndOutputNames(
std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
std::unordered_set<std::string> io_conflict;
auto& input_names = *inputs;
auto& output_names = *outputs;
for (const auto input : input_names) {
if (io_conflict.find(input) != io_conflict.end()) {
continue;
}
for (const auto output : output_names) {
if (input == output) {
io_conflict.insert(input);
break;
}
}
}
for (int i = 0; i < input_names.size(); i++) {
if (io_conflict.find(input_names[i]) != io_conflict.end()) {
input_names[i] = "input_" + input_names[i];
}
}
for (int i = 0; i < output_names.size(); i++) {
if (io_conflict.find(output_names[i]) != io_conflict.end()) {
output_names[i] = "output_" + output_names[i];
}
}
// 2. Second, add index if input[i] == input[j]
ResolveConflictedNamesByAddingIndex(&input_names);
ResolveConflictedNamesByAddingIndex(&output_names);
}
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -0,0 +1,80 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
struct GenerationResult {
struct File {
std::string path;
std::string content;
};
std::vector<File> files;
};
/// Defines language-independent codegen strategies, like class naming, .etc.
/// Should not be used directly.
class CodeGenerator {
public:
CodeGenerator();
using TensorMetadataList =
typename flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>;
virtual ~CodeGenerator() {}
// Strategies.
/// Names all the IO tensors. It's useful when they don't have names, or the
/// names have conflicts. We have to name every tensor for code generation.
// TODO(b/141225157): Add reserved keywords check.
static std::pair<std::vector<std::string>, std::vector<std::string>>
NameInputsAndOutputs(const TensorMetadataList* inputs,
const TensorMetadataList* outputs);
/// Loads a metadata for code generation.
/// Returns false if the metadata is not good for generation.
static bool VerifyMetadata(const ModelMetadata* metadata, ErrorReporter* err);
protected:
/// Converts a name into a valid form. Rules:
/// - lower all letters.
/// - replace all non alphabet nor numeric characters with underscores.
/// - remove prefix underscores.
/// - add prefix if the leading character is a number.
/// Returns empty string if not possible.
static std::string ConvertToValidName(const std::string& name);
static std::string NameTensor(const TensorMetadata& tensor,
const std::string& default_name);
static void ResolveConflictedInputAndOutputNames(
std::vector<std::string>* input, std::vector<std::string>* output);
};
} // namespace codegen
} // namespace support
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_

View File

@ -0,0 +1,126 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace tflite {
namespace support {
namespace codegen {
namespace {
using ::testing::ElementsAreArray;
class CodeGeneratorTest : public ::testing::Test {
public:
class TestingCodeGenerator : public CodeGenerator {
public:
explicit TestingCodeGenerator() : CodeGenerator() {}
// Make tested method public.
static std::string ConvertToValidName(const std::string& name) {
return CodeGenerator::ConvertToValidName(name);
}
static void ResolveConflictedInputAndOutputNames(
std::vector<std::string>* input, std::vector<std::string>* output) {
CodeGenerator::ResolveConflictedInputAndOutputNames(input, output);
}
};
};
TEST_F(CodeGeneratorTest, UpperCasesShouldLower) {
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("AlphaBetCOOL"),
"alphabetcool");
}
TEST_F(CodeGeneratorTest, NonAlphaNumShouldReplace) {
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("A+=B C\t"), "a__b_c_");
}
TEST_F(CodeGeneratorTest, NoLeadingUnderscore) {
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("+KAI Z"), "kai_z");
}
TEST_F(CodeGeneratorTest, NoLeadingNumbers) {
EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("3000 Cool Tensors"),
"tensor_3000_cool_tensors");
}
TEST_F(CodeGeneratorTest, TestSimpleIONames) {
std::vector<std::string> inputs = {"image"};
std::vector<std::string> outputs = {"output"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs, ElementsAreArray({"image"}));
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
}
TEST_F(CodeGeneratorTest, TestIOConflict) {
std::vector<std::string> inputs = {"image"};
std::vector<std::string> outputs = {"image"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs, ElementsAreArray({"input_image"}));
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
}
TEST_F(CodeGeneratorTest, TestInternalConflict) {
std::vector<std::string> inputs = {"image", "image"};
std::vector<std::string> outputs = {"output"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs, ElementsAreArray({"image1", "image2"}));
EXPECT_THAT(outputs, ElementsAreArray({"output"}));
}
TEST_F(CodeGeneratorTest, TestAllConflictNTo1) {
std::vector<std::string> inputs = {"image", "image"};
std::vector<std::string> outputs = {"image"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs, ElementsAreArray({"input_image1", "input_image2"}));
EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
}
TEST_F(CodeGeneratorTest, TestAllConflict) {
std::vector<std::string> inputs = {"image", "audio", "image", "audio",
"audio"};
std::vector<std::string> outputs = {"image", "image", "audio", "feature",
"feature"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs,
ElementsAreArray({"input_image1", "input_audio1", "input_image2",
"input_audio2", "input_audio3"}));
EXPECT_THAT(outputs,
ElementsAreArray({"output_image1", "output_image2",
"output_audio", "feature1", "feature2"}));
}
TEST_F(CodeGeneratorTest, TestAllConflictReversed) {
std::vector<std::string> inputs = {"image", "image", "audio", "feature",
"feature"};
std::vector<std::string> outputs = {"image", "audio", "image", "audio",
"audio"};
TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
EXPECT_THAT(inputs,
ElementsAreArray({"input_image1", "input_image2", "input_audio",
"feature1", "feature2"}));
EXPECT_THAT(outputs, ElementsAreArray({"output_image1", "output_audio1",
"output_image2", "output_audio2",
"output_audio3"}));
}
} // namespace
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -0,0 +1,92 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h"
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
constexpr char BUFFER_KEY[] = "TFLITE_METADATA";
const ModelMetadata* GetMetadataFromModel(const Model* model) {
if (model->metadata() == nullptr) {
return nullptr;
}
for (auto i = 0; i < model->metadata()->size(); i++) {
if (model->metadata()->Get(i)->name()->str() == BUFFER_KEY) {
const auto buffer_index = model->metadata()->Get(i)->buffer();
const auto* buffer = model->buffers()->Get(buffer_index)->data()->data();
return GetModelMetadata(buffer);
}
}
return nullptr;
}
int FindAssociatedFile(const TensorMetadata* metadata,
const AssociatedFileType file_type,
const std::string& tensor_identifier,
ErrorReporter* err) {
int result = -1;
if (metadata->associated_files() == nullptr ||
metadata->associated_files()->size() == 0) {
return result;
}
for (int i = 0; i < metadata->associated_files()->size(); i++) {
const auto* file_metadata = metadata->associated_files()->Get(i);
if (file_metadata->type() == file_type) {
if (result >= 0) {
err->Warning(
"Multiple associated file of type %d found on tensor %s. Only the "
"first one will be used.",
file_type, tensor_identifier.c_str());
continue;
}
result = i;
}
}
return result;
}
int FindNormalizationUnit(const TensorMetadata* metadata,
const std::string& tensor_identifier,
ErrorReporter* err) {
int result = -1;
if (metadata->process_units() == nullptr ||
metadata->process_units()->size() == 0) {
return result;
}
for (int i = 0; i < metadata->process_units()->size(); i++) {
const auto* process_uint = metadata->process_units()->Get(i);
if (process_uint->options_type() ==
ProcessUnitOptions_NormalizationOptions) {
if (result >= 0) {
err->Warning(
"Multiple normalization unit found in tensor %s. Only the first "
"one will be effective.",
tensor_identifier.c_str());
continue;
}
result = i;
}
}
return result;
}
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -0,0 +1,51 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_
#include <string>
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace support {
namespace codegen {
/// Parses a ModelMetadata out from a Model. The returned ModelMetadata's
/// lifetime is scoped by the model.
/// Returns nullptr if we cannot find any metadata.
const ModelMetadata* GetMetadataFromModel(const Model* model);
/// Finds an associated file from a TensorMetadata of certain type. If there're
/// multiple files meet the criteria, only the first one is used. If there's no
/// file meets the criteria, -1 will be returned.
int FindAssociatedFile(const TensorMetadata* metadata,
const AssociatedFileType file_type,
const std::string& tensor_identifier,
ErrorReporter* err);
/// Find the first normalization unit. If none, return -1.
int FindNormalizationUnit(const TensorMetadata* metadata,
const std::string& tensor_identifier,
ErrorReporter* err);
} // namespace codegen
} // namespace support
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_

View File

@ -0,0 +1,38 @@
load("//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
pybind_extension(
name = "_pywrap_codegen",
srcs = [
"codegen_lib.cc",
],
features = ["-use_header_modules"],
module_name = "_pywrap_codegen",
deps = [
"//tensorflow/lite/experimental/support/codegen:android_java_generator",
"//tensorflow/lite/experimental/support/codegen:code_generator",
"//tensorflow/python:pybind11_lib",
"//third_party/python_runtime:headers",
"@pybind11",
],
)
py_binary(
name = "codegen",
srcs = [
"codegen.py",
],
python_version = "PY3",
deps = [
":_pywrap_codegen",
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@absl_py//absl/logging",
],
)

View File

@ -0,0 +1,96 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates Android Java sources from a TFLite model with metadata."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
from absl import app
from absl import flags
from absl import logging
from tensorflow.lite.experimental.support.codegen.python import _pywrap_codegen
FLAGS = flags.FLAGS
flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.')
flags.DEFINE_string('destination', None, 'Path of destination of generation.')
flags.DEFINE_string('package_name', 'org.tensorflow.lite.support',
'Name of generated java package to put the wrapper class.')
flags.DEFINE_string(
'model_class_name', 'MyModel',
'Name of generated wrapper class (should not contain package name).')
flags.DEFINE_string(
'model_asset_path', '',
'(Optional) Path to the model in generated assets/ dir. If not set, '
'generator will use base name of input model.'
)
def get_model_buffer(path):
if not os.path.isfile(path):
logging.error('Cannot find model at path %s.', path)
with open(path, 'rb') as f:
buf = f.read()
return buf
def prepare_directory_for_file(file_path):
target_dir = os.path.dirname(file_path)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
return
if not os.path.isdir(target_dir):
logging.error('Cannot write to %s', target_dir)
def main(argv):
if len(argv) > 1:
logging.error('None flag arguments found: [%s]', ', '.join(argv[1:]))
codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination)
model_buffer = get_model_buffer(FLAGS.model)
model_asset_path = FLAGS.model_asset_path
if not model_asset_path:
model_asset_path = os.path.basename(FLAGS.model)
result = codegen.generate(model_buffer, FLAGS.package_name,
FLAGS.model_class_name, model_asset_path)
error_message = codegen.get_error_message().strip()
if error_message:
logging.error(error_message)
if not result.files:
logging.error('Generation failed!')
return
for each in result.files:
prepare_directory_for_file(each.path)
with open(each.path, 'w') as f:
f.write(each.content)
logging.info('Generation succeeded!')
model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets',
model_asset_path)
prepare_directory_for_file(model_asset_path)
shutil.copy(FLAGS.model, model_asset_path)
logging.info('Model copied into assets!')
if __name__ == '__main__':
flags.mark_flag_as_required('model')
flags.mark_flag_as_required('destination')
app.run(main)

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "include/pybind11/detail/common.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/pytypes.h"
#include "include/pybind11/stl.h"
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
#include "tensorflow/lite/experimental/support/codegen/code_generator.h"
namespace tflite {
namespace support {
namespace codegen {
template <typename... Args>
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
PYBIND11_MODULE(_pywrap_codegen, m) {
pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator")
.def(pybind11::init<const std::string &>())
.def("generate",
overload_cast_<const char *, const std::string &,
const std::string &, const std::string &>()(
&AndroidJavaGenerator::Generate))
.def("get_error_message", &AndroidJavaGenerator::GetErrorMessage);
pybind11::class_<GenerationResult>(m, "GenerationResult")
.def(pybind11::init<>())
.def_readwrite("files", &GenerationResult::files);
pybind11::class_<GenerationResult::File>(m, "GenerationResultFile")
.def(pybind11::init<>())
.def_readwrite("path", &GenerationResult::File::path)
.def_readwrite("content", &GenerationResult::File::content);
}
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -0,0 +1,194 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include <cstdarg>
namespace tflite {
namespace support {
namespace codegen {
int ErrorReporter::Warning(const char* format, ...) {
va_list args;
va_start(args, format);
return Report("[WARN] ", format, args);
}
int ErrorReporter::Error(const char* format, ...) {
va_list args;
va_start(args, format);
return Report("[ERROR] ", format, args);
}
int ErrorReporter::Report(const char* prefix, const char* format,
va_list args) {
char buf[1024];
int formatted = vsnprintf(buf, sizeof(buf), format, args);
buffer_ << prefix << buf << std::endl;
return formatted;
}
std::string ErrorReporter::GetMessage() {
std::string value = buffer_.str();
buffer_.str("");
return value;
}
CodeWriter::CodeWriter(ErrorReporter* err) : indent_(0), err_(err) {}
void CodeWriter::SetTokenValue(const std::string& token,
const std::string& value) {
value_map_[token] = value;
}
const std::string CodeWriter::GetTokenValue(const std::string& token) const {
auto iter = value_map_.find(token);
if (iter == value_map_.end()) {
// Typically only Code Generator's call this function (or `Append`). It's
// their duty to make sure the token is valid, and requesting for an invalid
// token implicits flaws in the code generation logic.
err_->Error("Internal: Cannot find value with token '%s'", token.c_str());
return "";
}
return iter->second;
}
void CodeWriter::SetIndentString(const std::string& indent_str) {
indent_str_ = indent_str;
}
void CodeWriter::Indent() { indent_++; }
void CodeWriter::Outdent() { indent_--; }
std::string CodeWriter::GenerateIndent() const {
std::string res;
res.reserve(indent_str_.size() * indent_);
for (int i = 0; i < indent_; i++) {
res.append(indent_str_);
}
return res;
}
void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); }
void CodeWriter::AppendNoNewLine(const std::string& text) {
AppendInternal(text, false);
}
void CodeWriter::AppendInternal(const std::string& text, bool newline) {
// Prefix indent
if ((buffer_.empty() // nothing in the buffer
|| buffer_.back() == '\n') // is on new line
// is writing on current line
&& (!text.empty() && text[0] != '\n' && text[0] != '\r')) {
buffer_.append(GenerateIndent());
}
// State machine variables
bool in_token = false;
int i = 0;
// Rough memory reserve
buffer_.reserve(buffer_.size() + text.size());
std::string token_buffer;
// A simple LL1 analysis
while (i < text.size()) {
char cur = text[i];
char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian
if (in_token == false) {
if (cur == '{' && cur_next == '{') { // Enter token
in_token = true;
i += 2;
} else if (cur == '\n') { // We need to apply global indent here
buffer_.push_back(cur);
if (cur_next != '\0' && cur_next != '\n' && cur_next != '\r') {
buffer_.append(GenerateIndent());
}
i += 1;
} else {
buffer_.push_back(cur);
i += 1;
}
} else {
if (cur == '}' && cur_next == '}') { // Close token
in_token = false;
const auto value = GetTokenValue(token_buffer);
buffer_.append(value);
token_buffer.clear();
i += 2;
} else {
token_buffer.push_back(cur);
i += 1;
}
}
}
if (!token_buffer.empty()) {
// Typically only Code Generator's call this function. It's
// their duty to make sure the code (or template) has valid syntax, and
// unclosed "{{...}}" implicits severe error in the template.
err_->Error("Internal: Invalid template: {{token}} is not closed.");
}
if (newline) {
buffer_.push_back('\n');
}
}
void CodeWriter::NewLine() { Append(""); }
void CodeWriter::Backspace(int n) {
buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0);
}
std::string CodeWriter::ToString() const { return buffer_; }
bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); }
void CodeWriter::Clear() {
buffer_.clear();
value_map_.clear();
indent_ = 0;
}
std::string SnakeCaseToCamelCase(const std::string& s) {
std::string t;
t.reserve(s.length());
size_t i = 0;
// Note: Use simple string += for simplicity.
bool cap = false;
while (i < s.size()) {
const char c = s[i++];
if (c == '_') {
cap = true;
} else if (cap) {
t += toupper(c);
cap = false;
} else {
t += c;
}
}
return t;
}
std::string JoinPath(const std::string& a, const std::string& b) {
if (a.empty()) return b;
std::string a_fixed = a;
if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back();
std::string b_fixed = b;
if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1);
return a_fixed + "/" + b_fixed;
}
} // namespace codegen
} // namespace support
} // namespace tflite

View File

@ -0,0 +1,127 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_
#include <map>
#include <sstream>
#include <string>
namespace tflite {
namespace support {
namespace codegen {
/// Collects runtime error logs which could be showed later.
// TODO(b/150538286): Consider a better mechanism to simplify callsite code.
class ErrorReporter {
public:
int Warning(const char* format, ...);
int Error(const char* format, ...);
std::string GetMessage();
private:
int Report(const char* prefix, const char* format, va_list args);
std::stringstream buffer_;
};
/// Implements basic code generating with text templates.
///
/// It could accept code templates and concatenate them into complete codes. A
/// template could contain named values.
///
/// Example code:
/// CodeWriter code;
/// code.SetValue("NAME", "Foo");
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
/// code.SetValue("NAME", "Bar");
/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
///
/// Output:
/// void Foo() { printf("%s", "Foo"); }
/// void Bar() { printf("%s", "Bar"); }
class CodeWriter {
public:
explicit CodeWriter(ErrorReporter* err);
/// Sets value to a token. When generating code with template, a string in a
/// pair of {{ and }} will be regarded as a token and replaced with the
/// corresponding value in code generation.
/// It rewrites if the token already has a value.
void SetTokenValue(const std::string& token, const std::string& value);
/// Gets the current value set on the given token.
const std::string GetTokenValue(const std::string& token) const;
/// Sets the unit indent string. For example, in Java it should be " ".
void SetIndentString(const std::string& indent);
/// Increases the indent by a unit (the string set in SetIndentString).
void Indent();
/// Decreases the indent by a unit (the string set in SetIndentString).
void Outdent();
/// Generates the indentation string.
std::string GenerateIndent() const;
/// Appends a piece of template codes to the stream. Every named value will be
/// replaced via the real value. A new line will always be appended at the
/// end.
void Append(const std::string& text);
/// Appends a piece of template codes to the stream. Same with `Append`, but a
/// new line will not be appended at the end.
void AppendNoNewLine(const std::string& text);
/// Appends a new line to the stream.
void NewLine();
/// Deletes the last N charaters in the stream. If the stream has less than N
/// characters, deletes all.
void Backspace(int n);
std::string ToString() const;
/// Checks if the internal string stream is empty. Note: This method has
// overhead.
bool IsStreamEmpty() const;
/// Clears all the internal string stream and value map.
void Clear();
private:
void AppendInternal(const std::string& text, bool newline);
std::string indent_str_;
int indent_;
std::map<std::string, std::string> value_map_;
std::string buffer_;
ErrorReporter* err_;
};
/// Converts foo_bar_name to fooBarName. It's callers duty to make sure given
/// string "s" is already in snake case; or unexpected behavior may occur.
std::string SnakeCaseToCamelCase(const std::string& s);
/// Joins 2 parts of file path into one, connected by unix path seperator '/'.
/// It's callers duty to ensure the two parts are valid.
std::string JoinPath(const std::string& a, const std::string& b);
} // namespace codegen
} // namespace support
} // namespace tflite
#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_

View File

@ -0,0 +1,97 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/experimental/support/codegen/utils.h"
#include <gtest/gtest.h>
namespace tflite {
namespace support {
namespace codegen {
namespace {
TEST(ErrorReporterTest, TestReportError) {
ErrorReporter err;
err.Error("some text");
EXPECT_EQ(err.GetMessage(), "[ERROR] some text\n");
EXPECT_EQ(err.GetMessage(), "");
}
TEST(CodeGeneratorTest, TestExample) {
ErrorReporter err;
CodeWriter writer(&err);
writer.SetTokenValue("NAME", "Foo");
const std::string text = R"(void {{NAME}}() { printf("%s", "{{NAME}}"); })";
writer.Append(text);
writer.SetTokenValue("NAME", "Bar");
writer.Append(text);
EXPECT_EQ(
"void Foo() { printf(\"%s\", \"Foo\"); }\n"
"void Bar() { printf(\"%s\", \"Bar\"); }\n",
writer.ToString());
}
TEST(CodeGeneratorTest, TestInexistentToken) {
ErrorReporter err;
CodeWriter writer(&err);
writer.SetTokenValue("NAME", "Foo");
const std::string text = R"(void {{name}}() {})";
writer.Append(text);
EXPECT_EQ(err.GetMessage(),
"[ERROR] Internal: Cannot find value with token 'name'\n");
}
TEST(CodeGeneratorTest, TestUnclosedToken) {
ErrorReporter err;
CodeWriter writer(&err);
writer.SetTokenValue("NAME", "Foo");
const std::string text = R"(void {{NAME}() {})";
writer.Append(text);
EXPECT_EQ(err.GetMessage(),
"[ERROR] Internal: Invalid template: {{token}} is not closed.\n");
}
TEST(CodeGeneratorTest, TestIndentControl) {
ErrorReporter err;
CodeWriter writer(&err);
writer.SetIndentString(" ");
writer.Indent();
writer.AppendNoNewLine("abcde"); // Will indent
EXPECT_EQ(" abcde", writer.ToString());
writer.Clear();
writer.Indent();
writer.AppendNoNewLine("abc\n\nde");
// The blank line will not indent
EXPECT_EQ(" abc\n\n de", writer.ToString());
writer.Clear();
writer.Indent();
writer.Append("abc");
writer.Outdent();
writer.AppendNoNewLine("def");
EXPECT_EQ(" abc\ndef", writer.ToString());
}
TEST(CaseConversionTest, TestSnakeToCamel) {
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel"));
EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel_"));
EXPECT_EQ("ImACamel", SnakeCaseToCamelCase("_im_a_camel"));
EXPECT_EQ("", SnakeCaseToCamelCase("_"));
EXPECT_EQ("camel", SnakeCaseToCamelCase("camel"));
}
} // namespace
} // namespace codegen
} // namespace support
} // namespace tflite