From b57d910db53d8f91c6a611b00a184e55fcaee06a Mon Sep 17 00:00:00 2001 From: Xunkai Zhang Date: Fri, 20 Mar 2020 05:52:10 -0700 Subject: [PATCH] Opensource TFLite Support codegen. PiperOrigin-RevId: 302011153 Change-Id: Idb2f649dc48fdc449fac2d6e9009719d29afb2ad --- .../lite/experimental/support/codegen/BUILD | 87 ++ .../experimental/support/codegen/README.md | 13 + .../support/codegen/android_java_generator.cc | 978 ++++++++++++++++++ .../support/codegen/android_java_generator.h | 107 ++ .../support/codegen/code_generator.cc | 179 ++++ .../support/codegen/code_generator.h | 80 ++ .../support/codegen/code_generator_test.cc | 126 +++ .../support/codegen/metadata_helper.cc | 92 ++ .../support/codegen/metadata_helper.h | 51 + .../experimental/support/codegen/python/BUILD | 38 + .../support/codegen/python/codegen.py | 96 ++ .../support/codegen/python/codegen_lib.cc | 49 + .../experimental/support/codegen/utils.cc | 194 ++++ .../lite/experimental/support/codegen/utils.h | 127 +++ .../support/codegen/utils_test.cc | 97 ++ 15 files changed, 2314 insertions(+) create mode 100644 tensorflow/lite/experimental/support/codegen/BUILD create mode 100644 tensorflow/lite/experimental/support/codegen/README.md create mode 100644 tensorflow/lite/experimental/support/codegen/android_java_generator.cc create mode 100644 tensorflow/lite/experimental/support/codegen/android_java_generator.h create mode 100644 tensorflow/lite/experimental/support/codegen/code_generator.cc create mode 100644 tensorflow/lite/experimental/support/codegen/code_generator.h create mode 100644 tensorflow/lite/experimental/support/codegen/code_generator_test.cc create mode 100644 tensorflow/lite/experimental/support/codegen/metadata_helper.cc create mode 100644 tensorflow/lite/experimental/support/codegen/metadata_helper.h create mode 100644 tensorflow/lite/experimental/support/codegen/python/BUILD create mode 100644 tensorflow/lite/experimental/support/codegen/python/codegen.py create mode 100644 tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc create mode 100644 tensorflow/lite/experimental/support/codegen/utils.cc create mode 100644 tensorflow/lite/experimental/support/codegen/utils.h create mode 100644 tensorflow/lite/experimental/support/codegen/utils_test.cc diff --git a/tensorflow/lite/experimental/support/codegen/BUILD b/tensorflow/lite/experimental/support/codegen/BUILD new file mode 100644 index 00000000000..96bb3e35952 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/BUILD @@ -0,0 +1,87 @@ +# The tools for generating wrapper classes for a TFLite model with metadata. + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "utils", + srcs = [ + "utils.cc", + ], + hdrs = [ + "utils.h", + ], + deps = [ + ], +) + +cc_library( + name = "code_generator", + srcs = [ + "code_generator.cc", + ], + hdrs = [ + "code_generator.h", + ], + deps = [ + ":utils", + "//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", + ], +) + +cc_library( + name = "metadata_helper", + srcs = [ + "metadata_helper.cc", + ], + hdrs = [ + "metadata_helper.h", + ], + deps = [ + ":utils", + "//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", + "//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_library( + name = "android_java_generator", + srcs = [ + "android_java_generator.cc", + ], + hdrs = [ + "android_java_generator.h", + ], + deps = [ + ":code_generator", + ":metadata_helper", + ":utils", + "//tensorflow/core/platform:logging", + "//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", + "//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "code_generator_test", + size = "small", + srcs = ["code_generator_test.cc"], + data = ["//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs"], + deps = [ + ":code_generator", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + deps = [ + ":utils", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/support/codegen/README.md b/tensorflow/lite/experimental/support/codegen/README.md new file mode 100644 index 00000000000..425dab37b04 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/README.md @@ -0,0 +1,13 @@ +# TensorFlow Lite Android Wrapper Code Generator + +For TensorFlow Lite model enhanced with [metadata](https://www.tensorflow.org/lite/convert/metadata.md), +developers can use the TensorFlow Lite Android wrapper code generator to create +platform specific wrapper code. The wrapper code removes the need to interact +directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow +Lite model with typed objects such as `Bitmap` and `Rect`. + +The usefulness of the code generator depend on the completeness of the +TensorFlow Lite model's metadata entry. Refer to the `` section +under relevant fields in +[metadata_schema.fbs](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs), +to see how the codegen tool parses each field. diff --git a/tensorflow/lite/experimental/support/codegen/android_java_generator.cc b/tensorflow/lite/experimental/support/codegen/android_java_generator.cc new file mode 100644 index 00000000000..b16db570aaa --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/android_java_generator.cc @@ -0,0 +1,978 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h" + +#include + +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" +#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h" +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +namespace { + +using details_android_java::ModelInfo; +using details_android_java::TensorInfo; + +// Helper class to organize the C++ code block as a generated code block. +// Using ctor and dtor to simulate an enter/exit schema like `with` in Python. +class AsBlock { + public: + AsBlock(CodeWriter* code_writer, const std::string& before, + bool trailing_blank_line = false) + : code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) { + code_writer_->AppendNoNewLine(before); + code_writer_->Append(" {"); + code_writer_->Indent(); + } + ~AsBlock() { + code_writer_->Outdent(); + code_writer_->Append("}"); + if (trailing_blank_line_) { + code_writer_->NewLine(); + } + } + + private: + CodeWriter* code_writer_; + bool trailing_blank_line_; +}; + +// Declare the functions first, so that the functions can follow a logical +// order. +bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*); + +std::string GetModelVersionedName(const ModelMetadata* metadata) { + std::string model_name = "MyModel"; + if (metadata->name() != nullptr && !(metadata->name()->str().empty())) { + model_name = metadata->name()->str(); + } + std::string model_version = "unknown"; + if (metadata->version() != nullptr && !(metadata->version()->str().empty())) { + model_version = metadata->version()->str(); + } + return model_name + " (Version: " + model_version + ")"; +} + +TensorInfo CreateTensorInfo(const TensorMetadata* metadata, + const std::string& name, bool is_input, int index, + ErrorReporter* err) { + TensorInfo tensor_info; + std::string tensor_identifier = is_input ? "input" : "output"; + tensor_identifier += " " + std::to_string(index); + tensor_info.associated_axis_label_index = FindAssociatedFile( + metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err); + tensor_info.associated_value_label_index = FindAssociatedFile( + metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err); + if (is_input && (tensor_info.associated_axis_label_index >= 0 || + tensor_info.associated_value_label_index >= 0)) { + err->Warning( + "Found label file on input tensor (%s). Label file for input " + "tensor is not supported yet. The " + "file will be ignored.", + tensor_identifier.c_str()); + } + if (tensor_info.associated_axis_label_index >= 0 && + tensor_info.associated_value_label_index >= 0) { + err->Warning( + "Found both axis label file and value label file for tensor (%s), " + "which is not supported. Only the axis label file will be used.", + tensor_identifier.c_str()); + } + tensor_info.is_input = is_input; + tensor_info.name = SnakeCaseToCamelCase(name); + tensor_info.upper_camel_name = tensor_info.name; + tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]); + tensor_info.normalization_unit = + FindNormalizationUnit(metadata, tensor_identifier, err); + if (metadata->content()->content_properties_type() == + ContentProperties_ImageProperties) { + if (metadata->content() + ->content_properties_as_ImageProperties() + ->color_space() == ColorSpaceType_RGB) { + tensor_info.content_type = "image"; + tensor_info.wrapper_type = "TensorImage"; + tensor_info.processor_type = "ImageProcessor"; + return tensor_info; + } else { + err->Warning( + "Found Non-RGB image on tensor (%s). Codegen currently does not " + "support it, and regard it as a plain numeric tensor.", + tensor_identifier.c_str()); + } + } + tensor_info.content_type = "tensor"; + tensor_info.wrapper_type = "TensorBuffer"; + tensor_info.processor_type = "TensorProcessor"; + return tensor_info; +} + +ModelInfo CreateModelInfo(const ModelMetadata* metadata, + const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path, + ErrorReporter* err) { + ModelInfo model_info; + if (!CodeGenerator::VerifyMetadata(metadata, err)) { + // TODO(b/150116380): Create dummy model info. + err->Error("Validating metadata failed."); + return model_info; + } + model_info.package_name = package_name; + model_info.model_class_name = model_class_name; + model_info.model_asset_path = model_asset_path; + model_info.model_versioned_name = GetModelVersionedName(metadata); + const auto* graph = metadata->subgraph_metadata()->Get(0); + auto names = CodeGenerator::NameInputsAndOutputs( + graph->input_tensor_metadata(), graph->output_tensor_metadata()); + std::vector input_tensor_names = std::move(names.first); + std::vector output_tensor_names = std::move(names.second); + for (int i = 0; i < graph->input_tensor_metadata()->size(); i++) { + model_info.inputs.push_back( + CreateTensorInfo(graph->input_tensor_metadata()->Get(i), + input_tensor_names[i], true, i, err)); + } + for (int i = 0; i < graph->output_tensor_metadata()->size(); i++) { + model_info.outputs.push_back( + CreateTensorInfo(graph->output_tensor_metadata()->Get(i), + output_tensor_names[i], false, i, err)); + } + return model_info; +} + +void SetCodeWriterWithTensorInfo(CodeWriter* code_writer, + const TensorInfo& tensor_info) { + code_writer->SetTokenValue("NAME", tensor_info.name); + code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name); + code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type); + code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type); + std::string wrapper_name = tensor_info.wrapper_type; + wrapper_name[0] = tolower(wrapper_name[0]); + code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name); + code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type); + code_writer->SetTokenValue("NORMALIZATION_UNIT", + std::to_string(tensor_info.normalization_unit)); + code_writer->SetTokenValue( + "ASSOCIATED_AXIS_LABEL_INDEX", + std::to_string(tensor_info.associated_axis_label_index)); + code_writer->SetTokenValue( + "ASSOCIATED_VALUE_LABEL_INDEX", + std::to_string(tensor_info.associated_value_label_index)); +} + +void SetCodeWriterWithModelInfo(CodeWriter* code_writer, + const ModelInfo& model_info) { + code_writer->SetTokenValue("PACKAGE", model_info.package_name); + code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path); + code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name); +} + +constexpr char JAVA_DEFAULT_PACKAGE[] = "default"; + +std::string ConvertPackageToPath(const std::string& package) { + if (package == JAVA_DEFAULT_PACKAGE) { + return ""; + } + std::string path = package; + std::replace(path.begin(), path.end(), '.', '/'); + return path; +} + +bool IsImageUsed(const ModelInfo& model) { + for (const auto& input : model.inputs) { + if (input.content_type == "image") { + return true; + } + } + for (const auto& output : model.outputs) { + if (output.content_type == "image") { + return true; + } + } + return false; +} + +bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("// Generated by TFLite Support."); + code_writer->Append("package {{PACKAGE}};"); + code_writer->NewLine(); + + if (!GenerateWrapperImports(code_writer, model, err)) { + err->Error("Fail to generate imports for wrapper class."); + return false; + } + if (!GenerateWrapperClass(code_writer, model, err)) { + err->Error("Fail to generate wrapper class."); + return false; + } + code_writer->NewLine(); + return true; +} + +bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + const std::string support_pkg = "org.tensorflow.lite.support."; + std::vector imports{ + "android.content.Context", + "java.io.IOException", + "java.nio.ByteBuffer", + "java.nio.FloatBuffer", + "java.util.Arrays", + "java.util.HashMap", + "java.util.List", + "java.util.Map", + "org.checkerframework.checker.nullness.qual.Nullable", + "org.tensorflow.lite.DataType", + "org.tensorflow.lite.Tensor.QuantizationParams", + support_pkg + "common.FileUtil", + support_pkg + "common.TensorProcessor", + support_pkg + "common.ops.CastOp", + support_pkg + "common.ops.DequantizeOp", + support_pkg + "common.ops.NormalizeOp", + support_pkg + "common.ops.QuantizeOp", + support_pkg + "label.TensorLabel", + support_pkg + "metadata.MetadataExtractor", + support_pkg + "metadata.schema.NormalizationOptions", + support_pkg + "model.Model", + support_pkg + "model.Model.Device", + support_pkg + "tensorbuffer.TensorBuffer", + }; + if (IsImageUsed(model)) { + for (const auto& target : + {"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp", + "image.ops.ResizeOp.ResizeMethod"}) { + imports.push_back(support_pkg + target); + } + imports.push_back("android.graphics.Bitmap"); + } + + std::sort(imports.begin(), imports.end()); + for (const auto target : imports) { + code_writer->SetTokenValue("TARGET", target); + code_writer->Append("import {{TARGET}};"); + } + code_writer->NewLine(); + return true; +} + +bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->SetTokenValue("MODEL_VERSIONED_NAME", + model.model_versioned_name); + code_writer->Append( + R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)"); + const auto code_block = + AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}"); + code_writer->Append(R"(private final Metadata metadata; +private final Model model; +private static final String MODEL_NAME = "{{MODEL_PATH}}";)"); + for (const auto& tensor : model.outputs) { + if (tensor.associated_axis_label_index >= 0) { + code_writer->SetTokenValue("NAME", tensor.name); + code_writer->Append("private final List {{NAME}}Labels;"); + } + } + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append( + "@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append( + "@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;"); + } + code_writer->NewLine(); + if (!GenerateWrapperInputs(code_writer, model, err)) { + err->Error("Failed to generate input classes"); + return false; + } + code_writer->NewLine(); + if (!GenerateWrapperOutputs(code_writer, model, err)) { + err->Error("Failed to generate output classes"); + return false; + } + code_writer->NewLine(); + if (!GenerateWrapperMetadata(code_writer, model, err)) { + err->Error("Failed to generate the metadata class"); + return false; + } + code_writer->NewLine(); + if (!GenerateWrapperAPI(code_writer, model, err)) { + err->Error("Failed to generate the common APIs"); + return false; + } + return true; +} + +bool GenerateWrapperInputs(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("/** Input wrapper of {@link {{MODEL_CLASS_NAME}}} */"); + auto class_block = AsBlock(code_writer, "public class Inputs"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append("private {{WRAPPER_TYPE}} {{NAME}};"); + } + code_writer->NewLine(); + // Ctor + { + auto ctor_block = AsBlock(code_writer, "public Inputs()"); + code_writer->Append( + "Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + if (tensor.content_type == "image") { + code_writer->Append( + "{{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());"); + } else { + code_writer->Append( + "{{NAME}} = " + "TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), " + "metadata.get{{NAME_U}}Type());"); + } + } + } + for (const auto& tensor : model.inputs) { + code_writer->NewLine(); + SetCodeWriterWithTensorInfo(code_writer, tensor); + // Loaders + if (tensor.content_type == "image") { + { + auto bitmap_loader_block = + AsBlock(code_writer, "public void load{{NAME_U}}(Bitmap bitmap)"); + code_writer->Append(R"({{NAME}}.load(bitmap); +{{NAME}} = preprocess{{NAME_U}}({{NAME}});)"); + } + code_writer->NewLine(); + { + auto tensor_image_loader_block = AsBlock( + code_writer, "public void load{{NAME_U}}(TensorImage tensorImage)"); + code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorImage);"); + } + } else { // content_type == "FEATURE" or "UNKNOWN" + auto tensorbuffer_loader_block = AsBlock( + code_writer, "public void load{{NAME_U}}(TensorBuffer tensorBuffer)"); + code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorBuffer);"); + } + code_writer->NewLine(); + // Processor + code_writer->Append( + R"(private {{WRAPPER_TYPE}} preprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}}) { + if ({{NAME}}Preprocessor == null) { + return {{WRAPPER_NAME}}; + } + return {{NAME}}Preprocessor.process({{WRAPPER_NAME}}); +} +)"); + } + { + const auto get_buffer_block = AsBlock(code_writer, "Object[] getBuffer()"); + code_writer->AppendNoNewLine("return new Object[] {"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->AppendNoNewLine("{{NAME}}.getBuffer(), "); + } + code_writer->Backspace(2); + code_writer->Append("};"); + } + return true; +} + +bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */"); + auto class_block = AsBlock(code_writer, "public class Outputs"); + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};"); + } + code_writer->NewLine(); + { + const auto ctor_block = AsBlock(code_writer, "public Outputs()"); + code_writer->Append( + "Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;"); + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + if (tensor.content_type == "image") { + code_writer->Append( + R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type()); +{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)"); + } else { // FEATURE, UNKNOWN + code_writer->Append( + "{{NAME}} = " + "TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), " + "metadata.get{{NAME_U}}Type());"); + } + } + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->NewLine(); + if (tensor.associated_axis_label_index >= 0) { + if (tensor.content_type == "image") { + err->Warning( + "Axis label for images is not supported. The labels will " + "be ignored."); + } else { + code_writer->Append(R"(public Map get{{NAME_U}}() { + return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getMapWithFloatValue(); +})"); + } + } else { + code_writer->Append(R"(public {{WRAPPER_TYPE}} get{{NAME_U}}() { + return postprocess{{NAME_U}}({{NAME}}); +})"); + } + code_writer->NewLine(); + { + auto processor_block = + AsBlock(code_writer, + "private {{WRAPPER_TYPE}} " + "postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})"); + code_writer->Append(R"(if ({{NAME}}Postprocessor == null) { + return {{WRAPPER_NAME}}; +} +return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});)"); + } + } + code_writer->NewLine(); + { + const auto get_buffer_block = + AsBlock(code_writer, "Map getBuffer()"); + code_writer->Append("Map outputs = new HashMap<>();"); + for (int i = 0; i < model.outputs.size(); i++) { + SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]); + code_writer->SetTokenValue("ID", std::to_string(i)); + code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());"); + } + code_writer->Append("return outputs;"); + } + return true; +} + +bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append( + "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */"); + const auto class_block = AsBlock(code_writer, "public static class Metadata"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"(private final int[] {{NAME}}Shape; +private final DataType {{NAME}}DataType; +private final QuantizationParams {{NAME}}QuantizationParams;)"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"(private final float[] {{NAME}}Mean; +private final float[] {{NAME}}Stddev;)"); + } + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"(private final int[] {{NAME}}Shape; +private final DataType {{NAME}}DataType; +private final QuantizationParams {{NAME}}QuantizationParams;)"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"(private final float[] {{NAME}}Mean; +private final float[] {{NAME}}Stddev;)"); + } + if (tensor.associated_axis_label_index >= 0 || + tensor.associated_value_label_index >= 0) { + code_writer->Append("private final List {{NAME}}Labels;"); + } + } + code_writer->NewLine(); + { + const auto ctor_block = AsBlock( + code_writer, + "public Metadata(ByteBuffer buffer, Model model) throws IOException"); + code_writer->Append( + "MetadataExtractor extractor = new MetadataExtractor(buffer);"); + for (int i = 0; i < model.inputs.size(); i++) { + SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]); + code_writer->SetTokenValue("ID", std::to_string(i)); + code_writer->Append( + R"({{NAME}}Shape = extractor.getInputTensorShape({{ID}}); +{{NAME}}DataType = extractor.getInputTensorType({{ID}}); +{{NAME}}QuantizationParams = extractor.getInputTensorQuantizationParams({{ID}});)"); + if (model.inputs[i].normalization_unit >= 0) { + code_writer->Append( + R"(NormalizationOptions {{NAME}}NormalizationOptions = + (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions()); +FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer(); +{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()]; +{{NAME}}MeanBuffer.get({{NAME}}Mean); +FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer(); +{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()]; +{{NAME}}StddevBuffer.get({{NAME}}Stddev);)"); + } + } + for (int i = 0; i < model.outputs.size(); i++) { + SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]); + code_writer->SetTokenValue("ID", std::to_string(i)); + code_writer->Append( + R"({{NAME}}Shape = model.getOutputTensorShape({{ID}}); +{{NAME}}DataType = extractor.getOutputTensorType({{ID}}); +{{NAME}}QuantizationParams = extractor.getOutputTensorQuantizationParams({{ID}});)"); + if (model.outputs[i].normalization_unit >= 0) { + code_writer->Append( + R"(NormalizationOptions {{NAME}}NormalizationOptions = + (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions()); +FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer(); +{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()]; +{{NAME}}MeanBuffer.get({{NAME}}Mean); +FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer(); +{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()]; +{{NAME}}StddevBuffer.get({{NAME}}Stddev);)"); + } + if (model.outputs[i].associated_axis_label_index >= 0) { + code_writer->Append(R"(String {{NAME}}LabelsFileName = + extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name(); +{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)"); + } else if (model.outputs[i].associated_value_label_index >= 0) { + code_writer->Append(R"(String {{NAME}}LabelsFileName = + extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name(); +{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)"); + } + } + } + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public int[] get{{NAME_U}}Shape() { + return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length); +} + +public DataType get{{NAME_U}}Type() { + return {{NAME}}DataType; +} + +public QuantizationParams get{{NAME_U}}QuantizationParams() { + return {{NAME}}QuantizationParams; +})"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"( +public float[] get{{NAME_U}}Mean() { + return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length); +} + +public float[] get{{NAME_U}}Stddev() { + return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length); +})"); + } + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public int[] get{{NAME_U}}Shape() { + return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length); +} + +public DataType get{{NAME_U}}Type() { + return {{NAME}}DataType; +} + +public QuantizationParams get{{NAME_U}}QuantizationParams() { + return {{NAME}}QuantizationParams; +})"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"( +public float[] get{{NAME_U}}Mean() { + return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length); +} + +public float[] get{{NAME_U}}Stddev() { + return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length); +})"); + } + if (tensor.associated_axis_label_index >= 0 || + tensor.associated_value_label_index >= 0) { + code_writer->Append(R"( +public List get{{NAME_U}}Labels() { + return {{NAME}}Labels; +})"); + } + } + return true; +} + +bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append(R"(public Metadata getMetadata() { + return metadata; +} +)"); + code_writer->Append(R"(/** + * Creates interpreter and loads associated files if needed. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public {{MODEL_CLASS_NAME}}(Context context) throws IOException { + this(context, MODEL_NAME, Device.CPU, 1); +} + +/** + * Creates interpreter and loads associated files if needed, but loading another model in the same + * input / output structure with the original one. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public {{MODEL_CLASS_NAME}}(Context context, String modelPath) throws IOException { + this(context, modelPath, Device.CPU, 1); +} + +/** + * Creates interpreter and loads associated files if needed, with device and number of threads + * configured. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public {{MODEL_CLASS_NAME}}(Context context, Device device, int numThreads) throws IOException { + this(context, MODEL_NAME, device, numThreads); +} + +/** + * Creates interpreter for a user-specified model. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public {{MODEL_CLASS_NAME}}(Context context, String modelPath, Device device, int numThreads) throws IOException { + model = new Model.Builder(context, modelPath).setDevice(device).setNumThreads(numThreads).build(); + metadata = new Metadata(model.getData(), model);)"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( + {{PROCESSOR_TYPE}}.Builder {{NAME}}PreprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder())"); + if (tensor.content_type == "image") { + code_writer->Append(R"( .add(new ResizeOp( + metadata.get{{NAME_U}}Shape()[1], + metadata.get{{NAME_U}}Shape()[2], + ResizeMethod.NEAREST_NEIGHBOR)))"); + } + if (tensor.normalization_unit >= 0) { + code_writer->Append( + R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))"); + } + code_writer->Append( + R"( .add(new QuantizeOp( + metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(), + metadata.get{{NAME_U}}QuantizationParams().getScale())) + .add(new CastOp(metadata.get{{NAME_U}}Type())); + {{NAME}}Preprocessor = {{NAME}}PreprocessorBuilder.build();)"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->AppendNoNewLine(R"( + {{PROCESSOR_TYPE}}.Builder {{NAME}}PostprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder() + .add(new DequantizeOp( + metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(), + metadata.get{{NAME_U}}QuantizationParams().getScale())))"); + if (tensor.normalization_unit >= 0) { + code_writer->AppendNoNewLine(R"( + .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))"); + } + code_writer->Append(R"(; + {{NAME}}Postprocessor = {{NAME}}PostprocessorBuilder.build();)"); + if (tensor.associated_axis_label_index >= 0) { + code_writer->Append(R"( + {{NAME}}Labels = metadata.get{{NAME_U}}Labels();)"); + } + } + code_writer->Append("}"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public void reset{{NAME_U}}Preprocessor(@Nullable {{PROCESSOR_TYPE}} processor) { + {{NAME}}Preprocessor = processor; +})"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public void reset{{NAME_U}}Postprocessor(@Nullable {{PROCESSOR_TYPE}} processor) { + {{NAME}}Postprocessor = processor; +})"); + } + code_writer->Append(R"( +/** Creates inputs */ +public Inputs createInputs() { + return new Inputs(); +} + +/** Triggers the model. */ +public Outputs run(Inputs inputs) { + Outputs outputs = new Outputs(); + model.run(inputs.getBuffer(), outputs.getBuffer()); + return outputs; +} + +/** Closes the model. */ +public void close() { + model.close(); +})"); + return true; +} + +bool GenerateBuildGradleContent(CodeWriter* code_writer, + const ModelInfo& model_info) { + code_writer->Append(R"(buildscript { + repositories { + google() + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:3.2.1' + } +} + +allprojects { + repositories { + google() + jcenter() + flatDir { + dirs 'libs' + } + } +} + +apply plugin: 'com.android.library' + +android { + compileSdkVersion 29 + defaultConfig { + targetSdkVersion 29 + versionCode 1 + versionName "1.0" + } + aaptOptions { + noCompress "tflite" + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + lintOptions { + abortOnError false + } +} + +configurations { + libMetadata +} + +dependencies { + libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic' +} + +task downloadLibs(type: Sync) { + from configurations.libMetadata + into "$buildDir/libs" + rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar" +} + +preBuild.dependsOn downloadLibs + +dependencies { + compileOnly 'org.checkerframework:checker-qual:2.5.8' + api 'org.tensorflow:tensorflow-lite:0.0.0-nightly' + api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly' + api files("$buildDir/libs/tensorflow-lite-support-metadata.jar") + implementation 'org.apache.commons:commons-compress:1.19' +})"); + return true; +} + +bool GenerateAndroidManifestContent(CodeWriter* code_writer, + const ModelInfo& model_info) { + code_writer->Append(R"( + +)"); + return true; +} + +bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) { + code_writer->Append("# {{MODEL_CLASS_NAME}} Usage"); + code_writer->AppendNoNewLine(R"( +``` +import {{PACKAGE}}.{{MODEL_CLASS_NAME}}; + +// 1. Initialize the Model +{{MODEL_CLASS_NAME}} model = null; + +try { + model = new {{MODEL_CLASS_NAME}}(context); // android.content.Context + // Create the input container. + {{MODEL_CLASS_NAME}}.Inputs inputs = model.createInputs(); +} catch (IOException e) { + e.printStackTrace(); +} + +if (model != null) { + + // 2. Set the inputs)"); + for (const auto& t : model_info.inputs) { + SetCodeWriterWithTensorInfo(code_writer, t); + if (t.content_type == "image") { + code_writer->Append(R"( + // Load input tensor "{{NAME}}" from a Bitmap with ARGB_8888 format. + Bitmap bitmap = ...; + inputs.load{{NAME_U}}(bitmap); + // Alternatively, load the input tensor "{{NAME}}" from a TensorImage. + // Check out TensorImage documentation to load other image data structures. + // TensorImage tensorImage = ...; + // inputs.load{{NAME_U}}(tensorImage);)"); + } else { + code_writer->Append(R"( + // Load input tensor "{{NAME}}" from a TensorBuffer. + // Check out TensorBuffer documentation to load other data structures. + TensorBuffer tensorBuffer = ...; + inputs.load{{NAME_U}}(tensorBuffer);)"); + } + } + code_writer->Append(R"( + // 3. Run the model + {{MODEL_CLASS_NAME}}.Outputs outputs = model.run(inputs);)"); + code_writer->Append(R"( + // 4. Retrieve the results)"); + for (const auto& t : model_info.outputs) { + SetCodeWriterWithTensorInfo(code_writer, t); + if (t.associated_axis_label_index >= 0) { + code_writer->SetTokenValue("WRAPPER_TYPE", "Map"); + } + code_writer->Append( + R"( {{WRAPPER_TYPE}} {{NAME}} = outputs.get{{NAME_U}}();)"); + } + code_writer->Append(R"(} +```)"); + return true; +} + +GenerationResult::File GenerateWrapperFile(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + const auto java_path = JoinPath(module_root, "src/main/java"); + const auto package_path = + JoinPath(java_path, ConvertPackageToPath(model_info.package_name)); + const auto file_path = + JoinPath(package_path, model_info.model_class_name + JAVA_EXT); + + CodeWriter code_writer(err); + code_writer.SetIndentString(" "); + SetCodeWriterWithModelInfo(&code_writer, model_info); + + if (!GenerateWrapperFileContent(&code_writer, model_info, err)) { + err->Error("Generating Java wrapper content failed."); + } + + const auto java_file = code_writer.ToString(); + return GenerationResult::File{file_path, java_file}; +} + +GenerationResult::File GenerateBuildGradle(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + const auto file_path = JoinPath(module_root, "build.gradle"); + CodeWriter code_writer(err); + SetCodeWriterWithModelInfo(&code_writer, model_info); + if (!GenerateBuildGradleContent(&code_writer, model_info)) { + err->Error("Generating build.gradle failed."); + } + const auto content = code_writer.ToString(); + return GenerationResult::File{file_path, content}; +} + +GenerationResult::File GenerateAndroidManifest(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml"); + CodeWriter code_writer(err); + SetCodeWriterWithModelInfo(&code_writer, model_info); + if (!GenerateAndroidManifestContent(&code_writer, model_info)) { + err->Error("Generating AndroidManifest.xml failed."); + } + return GenerationResult::File{file_path, code_writer.ToString()}; +} + +GenerationResult::File GenerateDoc(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + std::string lower = model_info.model_class_name; + for (int i = 0; i < lower.length(); i++) { + lower[i] = std::tolower(lower[i]); + } + const auto file_path = JoinPath(module_root, lower + ".md"); + CodeWriter code_writer(err); + SetCodeWriterWithModelInfo(&code_writer, model_info); + if (!GenerateDocContent(&code_writer, model_info)) { + err->Error("Generating doc failed."); + } + return GenerationResult::File{file_path, code_writer.ToString()}; +} + +} // namespace + +AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root) + : CodeGenerator(), module_root_(module_root) {} + +GenerationResult AndroidJavaGenerator::Generate( + const Model* model, const std::string& package_name, + const std::string& model_class_name, const std::string& model_asset_path) { + GenerationResult result; + const ModelMetadata* metadata = GetMetadataFromModel(model); + if (metadata == nullptr) { + err_.Error( + "Cannot find TFLite Metadata in the model. Codegen will generate " + "nothing."); + return result; + } + details_android_java::ModelInfo model_info = CreateModelInfo( + metadata, package_name, model_class_name, model_asset_path, &err_); + result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_)); + result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_)); + result.files.push_back( + GenerateAndroidManifest(module_root_, model_info, &err_)); + result.files.push_back(GenerateDoc(module_root_, model_info, &err_)); + return result; +} + +GenerationResult AndroidJavaGenerator::Generate( + const char* model_storage, const std::string& package_name, + const std::string& model_class_name, const std::string& model_asset_path) { + const Model* model = GetModel(model_storage); + return Generate(model, package_name, model_class_name, model_asset_path); +} + +std::string AndroidJavaGenerator::GetErrorMessage() { + return err_.GetMessage(); +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/android_java_generator.h b/tensorflow/lite/experimental/support/codegen/android_java_generator.h new file mode 100644 index 00000000000..f8821a0de70 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/android_java_generator.h @@ -0,0 +1,107 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ + +#include +#include +#include + +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +namespace details_android_java { + +/// The intermediate data structure for generating code from TensorMetadata. +/// Should only be used as const reference when created. +struct TensorInfo { + std::string name; + std::string upper_camel_name; + std::string content_type; + std::string wrapper_type; + std::string processor_type; + bool is_input; + /// Optional. Set to -1 if not applicable. + int normalization_unit; + /// Optional. Set to -1 if associated_axis_label is empty. + int associated_axis_label_index; + /// Optional. Set to -1 if associated_value_label is empty. + int associated_value_label_index; +}; + +/// The intermediate data structure for generating code from ModelMetadata. +/// Should only be used as const reference when created. +struct ModelInfo { + std::string package_name; + std::string model_asset_path; + std::string model_class_name; + std::string model_versioned_name; + std::vector inputs; + std::vector outputs; +}; + +} // namespace details_android_java + +constexpr char JAVA_EXT[] = ".java"; + +/// Generates Android supporting codes and modules (in Java) based on TFLite +/// metadata. +class AndroidJavaGenerator : public CodeGenerator { + public: + /// Creates an AndroidJavaGenerator. + /// Args: + /// - module_root: The root of destination Java module. + explicit AndroidJavaGenerator(const std::string& module_root); + + /// Generates files. Returns the file paths and contents. + /// Args: + /// - model: The TFLite model with Metadata filled. + /// - package_name: The name of the Java package which generated classes + /// belong to. + /// - model_class_name: A readable name of the generated wrapper class, such + /// as "ImageClassifier", "MobileNetV2" or "MyModel". + /// - model_asset_path: The relevant path to the model file in the asset. + // TODO(b/141225157): Automatically generate model_class_name. + GenerationResult Generate(const Model* model, const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path); + + /// Generates files and returns the file paths and contents. + /// It's mostly identical with the previous one, but the model here is + /// provided as binary flatbuffer content without parsing. + GenerationResult Generate(const char* model_storage, + const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path); + + std::string GetErrorMessage(); + + private: + const std::string module_root_; + ErrorReporter err_; +}; + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ diff --git a/tensorflow/lite/experimental/support/codegen/code_generator.cc b/tensorflow/lite/experimental/support/codegen/code_generator.cc new file mode 100644 index 00000000000..687724815ef --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/code_generator.cc @@ -0,0 +1,179 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" + +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +namespace { + +void ResolveConflictedNamesByAddingIndex(std::vector* names_ptr) { + auto& names = *names_ptr; + std::unordered_map indexes; + std::unordered_map first_appearance; + for (int i = 0; i < names.size(); i++) { + if (indexes.find(names[i]) == indexes.end()) { + indexes[names[i]] = 1; + first_appearance[names[i]] = i; + } else { + indexes[names[i]] += 1; + names[i].append(std::to_string(indexes[names[i]])); + } + } + for (const auto it : first_appearance) { + const auto& name = it.first; + const auto i = it.second; + if (indexes[name] > 1) { + names[i].append("1"); + } + } +} + +} // namespace + +CodeGenerator::CodeGenerator() {} + +bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata, + ErrorReporter* err) { + if (metadata == nullptr) { + err->Error("Loading nullptr is not allowed"); + return false; + } + if (metadata->subgraph_metadata()->size() != 1) { + err->Error("Only exact 1 subgraph is supported"); + return false; + } + return true; +} + +std::pair, std::vector> +CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs, + const TensorMetadataList* outputs) { + std::vector input_names; + std::vector output_names; + if (inputs != nullptr) { + input_names.reserve(inputs->size()); + for (const auto* tensor : *inputs) { + input_names.push_back(NameTensor(*tensor, "input")); + } + } + if (outputs != nullptr) { + output_names.reserve(outputs->size()); + for (const auto* tensor : *outputs) { + output_names.push_back(NameTensor(*tensor, "output")); + } + } + // Solve conflict + ResolveConflictedInputAndOutputNames(&input_names, &output_names); + return std::make_pair(input_names, output_names); +} + +std::string CodeGenerator::ConvertToValidName(const std::string& name) { + // lowercase all + std::string result = name; + for (int i = 0; i < result.size(); i++) { + result[i] = std::tolower(result[i]); + } + // replace all non-alpha or non-numeric with underscores, except underscore + // itself + for (int i = 0; i < result.size(); i++) { + if (result[i] != '_' && !std::isalnum(result[i])) { + result[i] = '_'; + } + } + // remove leading underscores + int leading_underscores = 0; + while (leading_underscores < result.size() && + result[leading_underscores] == '_') { + leading_underscores++; + } + result.erase(0, leading_underscores); + if (result.empty()) { + return ""; + } + // first char should be alpha + if (std::isalpha(result[0])) { + return result; + } + return "tensor_" + result; +} + +std::string CodeGenerator::NameTensor(const TensorMetadata& tensor, + const std::string& default_name) { + if (tensor.name() != nullptr && tensor.name()->size() > 0) { + // TODO(b/141225157) Validate tensor name. It should be in lower case. + auto suggested_name = ConvertToValidName(tensor.name()->str()); + if (!suggested_name.empty()) { + return suggested_name; + } + } + auto* content = tensor.content(); + if (content == nullptr || content->content_properties() == nullptr) { + return default_name; + } + switch (content->content_properties_type()) { + case ContentProperties_ImageProperties: + return "image"; + case ContentProperties_FeatureProperties: + return "feature"; + default: + return default_name; + } +} + +void CodeGenerator::ResolveConflictedInputAndOutputNames( + std::vector* inputs, std::vector* outputs) { + std::unordered_set io_conflict; + auto& input_names = *inputs; + auto& output_names = *outputs; + for (const auto input : input_names) { + if (io_conflict.find(input) != io_conflict.end()) { + continue; + } + for (const auto output : output_names) { + if (input == output) { + io_conflict.insert(input); + break; + } + } + } + for (int i = 0; i < input_names.size(); i++) { + if (io_conflict.find(input_names[i]) != io_conflict.end()) { + input_names[i] = "input_" + input_names[i]; + } + } + for (int i = 0; i < output_names.size(); i++) { + if (io_conflict.find(output_names[i]) != io_conflict.end()) { + output_names[i] = "output_" + output_names[i]; + } + } + // 2. Second, add index if input[i] == input[j] + ResolveConflictedNamesByAddingIndex(&input_names); + ResolveConflictedNamesByAddingIndex(&output_names); +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/code_generator.h b/tensorflow/lite/experimental/support/codegen/code_generator.h new file mode 100644 index 00000000000..5bb151e50a0 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/code_generator.h @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_ + +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +struct GenerationResult { + struct File { + std::string path; + std::string content; + }; + std::vector files; +}; + +/// Defines language-independent codegen strategies, like class naming, .etc. +/// Should not be used directly. +class CodeGenerator { + public: + CodeGenerator(); + + using TensorMetadataList = + typename flatbuffers::Vector>; + + virtual ~CodeGenerator() {} + + // Strategies. + /// Names all the IO tensors. It's useful when they don't have names, or the + /// names have conflicts. We have to name every tensor for code generation. + // TODO(b/141225157): Add reserved keywords check. + static std::pair, std::vector> + NameInputsAndOutputs(const TensorMetadataList* inputs, + const TensorMetadataList* outputs); + + /// Loads a metadata for code generation. + /// Returns false if the metadata is not good for generation. + static bool VerifyMetadata(const ModelMetadata* metadata, ErrorReporter* err); + + protected: + /// Converts a name into a valid form. Rules: + /// - lower all letters. + /// - replace all non alphabet nor numeric characters with underscores. + /// - remove prefix underscores. + /// - add prefix if the leading character is a number. + /// Returns empty string if not possible. + static std::string ConvertToValidName(const std::string& name); + static std::string NameTensor(const TensorMetadata& tensor, + const std::string& default_name); + static void ResolveConflictedInputAndOutputNames( + std::vector* input, std::vector* output); +}; + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_CODE_GENERATOR_H_ diff --git a/tensorflow/lite/experimental/support/codegen/code_generator_test.cc b/tensorflow/lite/experimental/support/codegen/code_generator_test.cc new file mode 100644 index 00000000000..57c5cec60e4 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/code_generator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" + +#include +#include + +namespace tflite { +namespace support { +namespace codegen { +namespace { + +using ::testing::ElementsAreArray; + +class CodeGeneratorTest : public ::testing::Test { + public: + class TestingCodeGenerator : public CodeGenerator { + public: + explicit TestingCodeGenerator() : CodeGenerator() {} + + // Make tested method public. + static std::string ConvertToValidName(const std::string& name) { + return CodeGenerator::ConvertToValidName(name); + } + static void ResolveConflictedInputAndOutputNames( + std::vector* input, std::vector* output) { + CodeGenerator::ResolveConflictedInputAndOutputNames(input, output); + } + }; +}; + +TEST_F(CodeGeneratorTest, UpperCasesShouldLower) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("AlphaBetCOOL"), + "alphabetcool"); +} + +TEST_F(CodeGeneratorTest, NonAlphaNumShouldReplace) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("A+=B C\t"), "a__b_c_"); +} + +TEST_F(CodeGeneratorTest, NoLeadingUnderscore) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("+KAI Z"), "kai_z"); +} + +TEST_F(CodeGeneratorTest, NoLeadingNumbers) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("3000 Cool Tensors"), + "tensor_3000_cool_tensors"); +} + +TEST_F(CodeGeneratorTest, TestSimpleIONames) { + std::vector inputs = {"image"}; + std::vector outputs = {"output"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"image"})); + EXPECT_THAT(outputs, ElementsAreArray({"output"})); +} + +TEST_F(CodeGeneratorTest, TestIOConflict) { + std::vector inputs = {"image"}; + std::vector outputs = {"image"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"input_image"})); + EXPECT_THAT(outputs, ElementsAreArray({"output_image"})); +} + +TEST_F(CodeGeneratorTest, TestInternalConflict) { + std::vector inputs = {"image", "image"}; + std::vector outputs = {"output"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"image1", "image2"})); + EXPECT_THAT(outputs, ElementsAreArray({"output"})); +} + +TEST_F(CodeGeneratorTest, TestAllConflictNTo1) { + std::vector inputs = {"image", "image"}; + std::vector outputs = {"image"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"input_image1", "input_image2"})); + EXPECT_THAT(outputs, ElementsAreArray({"output_image"})); +} + +TEST_F(CodeGeneratorTest, TestAllConflict) { + std::vector inputs = {"image", "audio", "image", "audio", + "audio"}; + std::vector outputs = {"image", "image", "audio", "feature", + "feature"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, + ElementsAreArray({"input_image1", "input_audio1", "input_image2", + "input_audio2", "input_audio3"})); + EXPECT_THAT(outputs, + ElementsAreArray({"output_image1", "output_image2", + "output_audio", "feature1", "feature2"})); +} + +TEST_F(CodeGeneratorTest, TestAllConflictReversed) { + std::vector inputs = {"image", "image", "audio", "feature", + "feature"}; + std::vector outputs = {"image", "audio", "image", "audio", + "audio"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, + ElementsAreArray({"input_image1", "input_image2", "input_audio", + "feature1", "feature2"})); + EXPECT_THAT(outputs, ElementsAreArray({"output_image1", "output_audio1", + "output_image2", "output_audio2", + "output_audio3"})); +} + +} // namespace +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/metadata_helper.cc b/tensorflow/lite/experimental/support/codegen/metadata_helper.cc new file mode 100644 index 00000000000..3fcc7aee3bf --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/metadata_helper.cc @@ -0,0 +1,92 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/metadata_helper.h" + +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +constexpr char BUFFER_KEY[] = "TFLITE_METADATA"; +const ModelMetadata* GetMetadataFromModel(const Model* model) { + if (model->metadata() == nullptr) { + return nullptr; + } + for (auto i = 0; i < model->metadata()->size(); i++) { + if (model->metadata()->Get(i)->name()->str() == BUFFER_KEY) { + const auto buffer_index = model->metadata()->Get(i)->buffer(); + const auto* buffer = model->buffers()->Get(buffer_index)->data()->data(); + return GetModelMetadata(buffer); + } + } + return nullptr; +} + +int FindAssociatedFile(const TensorMetadata* metadata, + const AssociatedFileType file_type, + const std::string& tensor_identifier, + ErrorReporter* err) { + int result = -1; + if (metadata->associated_files() == nullptr || + metadata->associated_files()->size() == 0) { + return result; + } + for (int i = 0; i < metadata->associated_files()->size(); i++) { + const auto* file_metadata = metadata->associated_files()->Get(i); + if (file_metadata->type() == file_type) { + if (result >= 0) { + err->Warning( + "Multiple associated file of type %d found on tensor %s. Only the " + "first one will be used.", + file_type, tensor_identifier.c_str()); + continue; + } + result = i; + } + } + return result; +} + +int FindNormalizationUnit(const TensorMetadata* metadata, + const std::string& tensor_identifier, + ErrorReporter* err) { + int result = -1; + if (metadata->process_units() == nullptr || + metadata->process_units()->size() == 0) { + return result; + } + for (int i = 0; i < metadata->process_units()->size(); i++) { + const auto* process_uint = metadata->process_units()->Get(i); + if (process_uint->options_type() == + ProcessUnitOptions_NormalizationOptions) { + if (result >= 0) { + err->Warning( + "Multiple normalization unit found in tensor %s. Only the first " + "one will be effective.", + tensor_identifier.c_str()); + continue; + } + result = i; + } + } + return result; +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/metadata_helper.h b/tensorflow/lite/experimental/support/codegen/metadata_helper.h new file mode 100644 index 00000000000..0d5e06b4506 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/metadata_helper.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_ + +#include + +#include "tensorflow/lite/experimental/support/codegen/utils.h" +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +/// Parses a ModelMetadata out from a Model. The returned ModelMetadata's +/// lifetime is scoped by the model. +/// Returns nullptr if we cannot find any metadata. +const ModelMetadata* GetMetadataFromModel(const Model* model); + +/// Finds an associated file from a TensorMetadata of certain type. If there're +/// multiple files meet the criteria, only the first one is used. If there's no +/// file meets the criteria, -1 will be returned. +int FindAssociatedFile(const TensorMetadata* metadata, + const AssociatedFileType file_type, + const std::string& tensor_identifier, + ErrorReporter* err); + +/// Find the first normalization unit. If none, return -1. +int FindNormalizationUnit(const TensorMetadata* metadata, + const std::string& tensor_identifier, + ErrorReporter* err); + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_METADATA_HELPER_H_ diff --git a/tensorflow/lite/experimental/support/codegen/python/BUILD b/tensorflow/lite/experimental/support/codegen/python/BUILD new file mode 100644 index 00000000000..d364d82eaeb --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/python/BUILD @@ -0,0 +1,38 @@ +load("//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "_pywrap_codegen", + srcs = [ + "codegen_lib.cc", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_codegen", + deps = [ + "//tensorflow/lite/experimental/support/codegen:android_java_generator", + "//tensorflow/lite/experimental/support/codegen:code_generator", + "//tensorflow/python:pybind11_lib", + "//third_party/python_runtime:headers", + "@pybind11", + ], +) + +py_binary( + name = "codegen", + srcs = [ + "codegen.py", + ], + python_version = "PY3", + deps = [ + ":_pywrap_codegen", + "@absl_py//absl:app", + "@absl_py//absl/flags", + "@absl_py//absl/logging", + ], +) diff --git a/tensorflow/lite/experimental/support/codegen/python/codegen.py b/tensorflow/lite/experimental/support/codegen/python/codegen.py new file mode 100644 index 00000000000..f28bafe5cff --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/python/codegen.py @@ -0,0 +1,96 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generates Android Java sources from a TFLite model with metadata.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +from absl import app +from absl import flags +from absl import logging + +from tensorflow.lite.experimental.support.codegen.python import _pywrap_codegen + +FLAGS = flags.FLAGS + +flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.') +flags.DEFINE_string('destination', None, 'Path of destination of generation.') +flags.DEFINE_string('package_name', 'org.tensorflow.lite.support', + 'Name of generated java package to put the wrapper class.') +flags.DEFINE_string( + 'model_class_name', 'MyModel', + 'Name of generated wrapper class (should not contain package name).') +flags.DEFINE_string( + 'model_asset_path', '', + '(Optional) Path to the model in generated assets/ dir. If not set, ' + 'generator will use base name of input model.' +) + + +def get_model_buffer(path): + if not os.path.isfile(path): + logging.error('Cannot find model at path %s.', path) + with open(path, 'rb') as f: + buf = f.read() + return buf + + +def prepare_directory_for_file(file_path): + target_dir = os.path.dirname(file_path) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + return + if not os.path.isdir(target_dir): + logging.error('Cannot write to %s', target_dir) + + +def main(argv): + if len(argv) > 1: + logging.error('None flag arguments found: [%s]', ', '.join(argv[1:])) + + codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination) + model_buffer = get_model_buffer(FLAGS.model) + model_asset_path = FLAGS.model_asset_path + if not model_asset_path: + model_asset_path = os.path.basename(FLAGS.model) + result = codegen.generate(model_buffer, FLAGS.package_name, + FLAGS.model_class_name, model_asset_path) + error_message = codegen.get_error_message().strip() + if error_message: + logging.error(error_message) + if not result.files: + logging.error('Generation failed!') + return + + for each in result.files: + prepare_directory_for_file(each.path) + with open(each.path, 'w') as f: + f.write(each.content) + + logging.info('Generation succeeded!') + model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets', + model_asset_path) + prepare_directory_for_file(model_asset_path) + shutil.copy(FLAGS.model, model_asset_path) + logging.info('Model copied into assets!') + + +if __name__ == '__main__': + flags.mark_flag_as_required('model') + flags.mark_flag_as_required('destination') + app.run(main) diff --git a/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc b/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc new file mode 100644 index 00000000000..e3db29b1959 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/python/codegen_lib.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "include/pybind11/detail/common.h" +#include "include/pybind11/pybind11.h" +#include "include/pybind11/pytypes.h" +#include "include/pybind11/stl.h" +#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h" +#include "tensorflow/lite/experimental/support/codegen/code_generator.h" + +namespace tflite { +namespace support { +namespace codegen { + +template +using overload_cast_ = pybind11::detail::overload_cast_impl; + +PYBIND11_MODULE(_pywrap_codegen, m) { + pybind11::class_(m, "AndroidJavaGenerator") + .def(pybind11::init()) + .def("generate", + overload_cast_()( + &AndroidJavaGenerator::Generate)) + .def("get_error_message", &AndroidJavaGenerator::GetErrorMessage); + pybind11::class_(m, "GenerationResult") + .def(pybind11::init<>()) + .def_readwrite("files", &GenerationResult::files); + pybind11::class_(m, "GenerationResultFile") + .def(pybind11::init<>()) + .def_readwrite("path", &GenerationResult::File::path) + .def_readwrite("content", &GenerationResult::File::content); +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/utils.cc b/tensorflow/lite/experimental/support/codegen/utils.cc new file mode 100644 index 00000000000..394c147a33f --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/utils.cc @@ -0,0 +1,194 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/support/codegen/utils.h" + +#include + +namespace tflite { +namespace support { +namespace codegen { + +int ErrorReporter::Warning(const char* format, ...) { + va_list args; + va_start(args, format); + return Report("[WARN] ", format, args); +} + +int ErrorReporter::Error(const char* format, ...) { + va_list args; + va_start(args, format); + return Report("[ERROR] ", format, args); +} + +int ErrorReporter::Report(const char* prefix, const char* format, + va_list args) { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << prefix << buf << std::endl; + return formatted; +} + +std::string ErrorReporter::GetMessage() { + std::string value = buffer_.str(); + buffer_.str(""); + return value; +} + +CodeWriter::CodeWriter(ErrorReporter* err) : indent_(0), err_(err) {} + +void CodeWriter::SetTokenValue(const std::string& token, + const std::string& value) { + value_map_[token] = value; +} + +const std::string CodeWriter::GetTokenValue(const std::string& token) const { + auto iter = value_map_.find(token); + if (iter == value_map_.end()) { + // Typically only Code Generator's call this function (or `Append`). It's + // their duty to make sure the token is valid, and requesting for an invalid + // token implicits flaws in the code generation logic. + err_->Error("Internal: Cannot find value with token '%s'", token.c_str()); + return ""; + } + return iter->second; +} + +void CodeWriter::SetIndentString(const std::string& indent_str) { + indent_str_ = indent_str; +} + +void CodeWriter::Indent() { indent_++; } + +void CodeWriter::Outdent() { indent_--; } + +std::string CodeWriter::GenerateIndent() const { + std::string res; + res.reserve(indent_str_.size() * indent_); + for (int i = 0; i < indent_; i++) { + res.append(indent_str_); + } + return res; +} + +void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); } + +void CodeWriter::AppendNoNewLine(const std::string& text) { + AppendInternal(text, false); +} + +void CodeWriter::AppendInternal(const std::string& text, bool newline) { + // Prefix indent + if ((buffer_.empty() // nothing in the buffer + || buffer_.back() == '\n') // is on new line + // is writing on current line + && (!text.empty() && text[0] != '\n' && text[0] != '\r')) { + buffer_.append(GenerateIndent()); + } + // State machine variables + bool in_token = false; + int i = 0; + // Rough memory reserve + buffer_.reserve(buffer_.size() + text.size()); + std::string token_buffer; + // A simple LL1 analysis + while (i < text.size()) { + char cur = text[i]; + char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian + if (in_token == false) { + if (cur == '{' && cur_next == '{') { // Enter token + in_token = true; + i += 2; + } else if (cur == '\n') { // We need to apply global indent here + buffer_.push_back(cur); + if (cur_next != '\0' && cur_next != '\n' && cur_next != '\r') { + buffer_.append(GenerateIndent()); + } + i += 1; + } else { + buffer_.push_back(cur); + i += 1; + } + } else { + if (cur == '}' && cur_next == '}') { // Close token + in_token = false; + const auto value = GetTokenValue(token_buffer); + buffer_.append(value); + token_buffer.clear(); + i += 2; + } else { + token_buffer.push_back(cur); + i += 1; + } + } + } + if (!token_buffer.empty()) { + // Typically only Code Generator's call this function. It's + // their duty to make sure the code (or template) has valid syntax, and + // unclosed "{{...}}" implicits severe error in the template. + err_->Error("Internal: Invalid template: {{token}} is not closed."); + } + if (newline) { + buffer_.push_back('\n'); + } +} + +void CodeWriter::NewLine() { Append(""); } + +void CodeWriter::Backspace(int n) { + buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0); +} + +std::string CodeWriter::ToString() const { return buffer_; } + +bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); } + +void CodeWriter::Clear() { + buffer_.clear(); + value_map_.clear(); + indent_ = 0; +} + +std::string SnakeCaseToCamelCase(const std::string& s) { + std::string t; + t.reserve(s.length()); + size_t i = 0; + // Note: Use simple string += for simplicity. + bool cap = false; + while (i < s.size()) { + const char c = s[i++]; + if (c == '_') { + cap = true; + } else if (cap) { + t += toupper(c); + cap = false; + } else { + t += c; + } + } + return t; +} + +std::string JoinPath(const std::string& a, const std::string& b) { + if (a.empty()) return b; + std::string a_fixed = a; + if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back(); + std::string b_fixed = b; + if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1); + return a_fixed + "/" + b_fixed; +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/codegen/utils.h b/tensorflow/lite/experimental/support/codegen/utils.h new file mode 100644 index 00000000000..17153bd6ad0 --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/utils.h @@ -0,0 +1,127 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_ + +#include +#include +#include + +namespace tflite { +namespace support { +namespace codegen { + +/// Collects runtime error logs which could be showed later. +// TODO(b/150538286): Consider a better mechanism to simplify callsite code. +class ErrorReporter { + public: + int Warning(const char* format, ...); + int Error(const char* format, ...); + std::string GetMessage(); + + private: + int Report(const char* prefix, const char* format, va_list args); + std::stringstream buffer_; +}; + +/// Implements basic code generating with text templates. +/// +/// It could accept code templates and concatenate them into complete codes. A +/// template could contain named values. +/// +/// Example code: +/// CodeWriter code; +/// code.SetValue("NAME", "Foo"); +/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }"); +/// code.SetValue("NAME", "Bar"); +/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }"); +/// +/// Output: +/// void Foo() { printf("%s", "Foo"); } +/// void Bar() { printf("%s", "Bar"); } +class CodeWriter { + public: + explicit CodeWriter(ErrorReporter* err); + /// Sets value to a token. When generating code with template, a string in a + /// pair of {{ and }} will be regarded as a token and replaced with the + /// corresponding value in code generation. + /// It rewrites if the token already has a value. + void SetTokenValue(const std::string& token, const std::string& value); + + /// Gets the current value set on the given token. + const std::string GetTokenValue(const std::string& token) const; + + /// Sets the unit indent string. For example, in Java it should be " ". + void SetIndentString(const std::string& indent); + + /// Increases the indent by a unit (the string set in SetIndentString). + void Indent(); + + /// Decreases the indent by a unit (the string set in SetIndentString). + void Outdent(); + + /// Generates the indentation string. + std::string GenerateIndent() const; + + /// Appends a piece of template codes to the stream. Every named value will be + /// replaced via the real value. A new line will always be appended at the + /// end. + void Append(const std::string& text); + + /// Appends a piece of template codes to the stream. Same with `Append`, but a + /// new line will not be appended at the end. + void AppendNoNewLine(const std::string& text); + + /// Appends a new line to the stream. + void NewLine(); + + /// Deletes the last N charaters in the stream. If the stream has less than N + /// characters, deletes all. + void Backspace(int n); + + std::string ToString() const; + + /// Checks if the internal string stream is empty. Note: This method has + // overhead. + bool IsStreamEmpty() const; + + /// Clears all the internal string stream and value map. + void Clear(); + + private: + void AppendInternal(const std::string& text, bool newline); + + std::string indent_str_; + int indent_; + + std::map value_map_; + std::string buffer_; + + ErrorReporter* err_; +}; + +/// Converts foo_bar_name to fooBarName. It's callers duty to make sure given +/// string "s" is already in snake case; or unexpected behavior may occur. +std::string SnakeCaseToCamelCase(const std::string& s); + +/// Joins 2 parts of file path into one, connected by unix path seperator '/'. +/// It's callers duty to ensure the two parts are valid. +std::string JoinPath(const std::string& a, const std::string& b); + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_CODEGEN_UTILS_H_ diff --git a/tensorflow/lite/experimental/support/codegen/utils_test.cc b/tensorflow/lite/experimental/support/codegen/utils_test.cc new file mode 100644 index 00000000000..8cdb838129c --- /dev/null +++ b/tensorflow/lite/experimental/support/codegen/utils_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/codegen/utils.h" + +#include + +namespace tflite { +namespace support { +namespace codegen { +namespace { + +TEST(ErrorReporterTest, TestReportError) { + ErrorReporter err; + err.Error("some text"); + EXPECT_EQ(err.GetMessage(), "[ERROR] some text\n"); + EXPECT_EQ(err.GetMessage(), ""); +} + +TEST(CodeGeneratorTest, TestExample) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetTokenValue("NAME", "Foo"); + const std::string text = R"(void {{NAME}}() { printf("%s", "{{NAME}}"); })"; + writer.Append(text); + writer.SetTokenValue("NAME", "Bar"); + writer.Append(text); + EXPECT_EQ( + "void Foo() { printf(\"%s\", \"Foo\"); }\n" + "void Bar() { printf(\"%s\", \"Bar\"); }\n", + writer.ToString()); +} + +TEST(CodeGeneratorTest, TestInexistentToken) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetTokenValue("NAME", "Foo"); + const std::string text = R"(void {{name}}() {})"; + writer.Append(text); + EXPECT_EQ(err.GetMessage(), + "[ERROR] Internal: Cannot find value with token 'name'\n"); +} + +TEST(CodeGeneratorTest, TestUnclosedToken) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetTokenValue("NAME", "Foo"); + const std::string text = R"(void {{NAME}() {})"; + writer.Append(text); + EXPECT_EQ(err.GetMessage(), + "[ERROR] Internal: Invalid template: {{token}} is not closed.\n"); +} + +TEST(CodeGeneratorTest, TestIndentControl) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetIndentString(" "); + writer.Indent(); + writer.AppendNoNewLine("abcde"); // Will indent + EXPECT_EQ(" abcde", writer.ToString()); + writer.Clear(); + writer.Indent(); + writer.AppendNoNewLine("abc\n\nde"); + // The blank line will not indent + EXPECT_EQ(" abc\n\n de", writer.ToString()); + writer.Clear(); + writer.Indent(); + writer.Append("abc"); + writer.Outdent(); + writer.AppendNoNewLine("def"); + EXPECT_EQ(" abc\ndef", writer.ToString()); +} + +TEST(CaseConversionTest, TestSnakeToCamel) { + EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel")); + EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel_")); + EXPECT_EQ("ImACamel", SnakeCaseToCamelCase("_im_a_camel")); + EXPECT_EQ("", SnakeCaseToCamelCase("_")); + EXPECT_EQ("camel", SnakeCaseToCamelCase("camel")); +} + +} // namespace +} // namespace codegen +} // namespace support +} // namespace tflite