[tfls.codegen] Update Image Classifier generation to new API.

PiperOrigin-RevId: 316031289
Change-Id: I918dc5563f80e3bd163dec77b72374242484e936
This commit is contained in:
Xunkai Zhang 2020-06-11 20:15:49 -07:00 committed by TensorFlower Gardener
parent 938d22a218
commit 6642441bee
3 changed files with 240 additions and 201 deletions

View File

@ -13,6 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains the logic of android model wrapper generation.
//
// At the beginning is the helper functions handling metadata and code writer.
//
// Codes are generated in every `Generate{FOO}` functions. Gradle and Manifest
// files are simple. The wrapper file generation is a bit complex so we divided
// it into several sub-functions.
//
// The structure of the wrapper file looks like:
//
// [ imports ]
// [ class ]
// [ inner "Outputs" class ]
// [ innner "Metadata" class ]
// [ APIs ] ( including ctors, public APIs and private APIs )
//
// We tried to mostly write it in a "template-generation" way. `CodeWriter` does
// the job as a template renderer. To avoid repeatedly setting the token values,
// helper functions `SetCodeWriterWith{Foo}Info` set the token values with info
// structures (`TensorInfo` and `ModelInfo`) - the Info structures are
// intermediate datastructures between Metadata (represented in Flatbuffers) and
// generated code.
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
#include <ctype.h>
@ -158,15 +181,32 @@ ModelInfo CreateModelInfo(const ModelMetadata* metadata,
graph->input_tensor_metadata(), graph->output_tensor_metadata());
std::vector<std::string> input_tensor_names = std::move(names.first);
std::vector<std::string> output_tensor_names = std::move(names.second);
for (int i = 0; i < input_tensor_names.size(); i++) {
model_info.inputs.push_back(
CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
input_tensor_names[i], true, i, err));
if (i < input_tensor_names.size() - 1) {
model_info.inputs_list += ", ";
model_info.input_type_param_list += ", ";
}
model_info.inputs_list += model_info.inputs[i].name;
model_info.input_type_param_list +=
model_info.inputs[i].wrapper_type + " " + model_info.inputs[i].name;
}
for (int i = 0; i < output_tensor_names.size(); i++) {
model_info.outputs.push_back(
CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
output_tensor_names[i], false, i, err));
if (i < output_tensor_names.size() - 1) {
model_info.postprocessor_type_param_list += ", ";
model_info.postprocessors_list += ", ";
}
model_info.postprocessors_list +=
model_info.outputs[i].name + "Postprocessor";
model_info.postprocessor_type_param_list +=
model_info.outputs[i].processor_type + " " +
model_info.outputs[i].name + "Postprocessor";
}
return model_info;
}
@ -196,6 +236,14 @@ void SetCodeWriterWithModelInfo(CodeWriter* code_writer,
code_writer->SetTokenValue("PACKAGE", model_info.package_name);
code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path);
code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name);
// Extra info, half generated.
code_writer->SetTokenValue("INPUT_TYPE_PARAM_LIST",
model_info.input_type_param_list);
code_writer->SetTokenValue("INPUTS_LIST", model_info.inputs_list);
code_writer->SetTokenValue("POSTPROCESSORS_LIST",
model_info.postprocessors_list);
code_writer->SetTokenValue("POSTPROCESSOR_TYPE_PARAM_LIST",
model_info.postprocessor_type_param_list);
}
constexpr char JAVA_DEFAULT_PACKAGE[] = "default";
@ -223,6 +271,8 @@ bool IsImageUsed(const ModelInfo& model) {
return false;
}
// The following functions generates the wrapper Java code for a model.
bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("// Generated by TFLite Support.");
@ -253,8 +303,8 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
"java.util.HashMap",
"java.util.List",
"java.util.Map",
"org.checkerframework.checker.nullness.qual.Nullable",
"org.tensorflow.lite.DataType",
"org.tensorflow.lite.Tensor",
"org.tensorflow.lite.Tensor.QuantizationParams",
support_pkg + "common.FileUtil",
support_pkg + "common.TensorProcessor",
@ -262,11 +312,11 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
support_pkg + "common.ops.DequantizeOp",
support_pkg + "common.ops.NormalizeOp",
support_pkg + "common.ops.QuantizeOp",
support_pkg + "label.Category",
support_pkg + "label.TensorLabel",
support_pkg + "metadata.MetadataExtractor",
support_pkg + "metadata.schema.NormalizationOptions",
support_pkg + "model.Model",
support_pkg + "model.Model.Device",
support_pkg + "tensorbuffer.TensorBuffer",
};
if (IsImageUsed(model)) {
@ -275,7 +325,6 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
"image.ops.ResizeOp.ResizeMethod"}) {
imports.push_back(support_pkg + target);
}
imports.push_back("android.graphics.Bitmap");
}
std::sort(imports.begin(), imports.end());
@ -298,26 +347,13 @@ bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
code_writer->Append(R"(private final Metadata metadata;
private final Model model;
private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
for (const auto& tensor : model.outputs) {
if (tensor.associated_axis_label_index >= 0) {
code_writer->SetTokenValue("NAME", tensor.name);
code_writer->Append("private final List<String> {{NAME}}Labels;");
}
}
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
}
code_writer->NewLine();
if (!GenerateWrapperInputs(code_writer, model, err)) {
err->Error("Failed to generate input classes");
return false;
code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
}
code_writer->NewLine();
if (!GenerateWrapperOutputs(code_writer, model, err)) {
@ -337,92 +373,46 @@ private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
return true;
}
bool GenerateWrapperInputs(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("/** Input wrapper of {@link {{MODEL_CLASS_NAME}}} */");
auto class_block = AsBlock(code_writer, "public class Inputs");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("private {{WRAPPER_TYPE}} {{NAME}};");
}
code_writer->NewLine();
// Ctor
{
auto ctor_block = AsBlock(code_writer, "public Inputs()");
code_writer->Append(
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
if (tensor.content_type == "image") {
code_writer->Append(
"{{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());");
} else {
code_writer->Append(
"{{NAME}} = "
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
"metadata.get{{NAME_U}}Type());");
}
}
}
for (const auto& tensor : model.inputs) {
code_writer->NewLine();
SetCodeWriterWithTensorInfo(code_writer, tensor);
// Loaders
if (tensor.content_type == "image") {
{
auto bitmap_loader_block =
AsBlock(code_writer, "public void load{{NAME_U}}(Bitmap bitmap)");
code_writer->Append(R"({{NAME}}.load(bitmap);
{{NAME}} = preprocess{{NAME_U}}({{NAME}});)");
}
code_writer->NewLine();
{
auto tensor_image_loader_block = AsBlock(
code_writer, "public void load{{NAME_U}}(TensorImage tensorImage)");
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorImage);");
}
} else { // content_type == "FEATURE" or "UNKNOWN"
auto tensorbuffer_loader_block = AsBlock(
code_writer, "public void load{{NAME_U}}(TensorBuffer tensorBuffer)");
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorBuffer);");
}
code_writer->NewLine();
// Processor
code_writer->Append(
R"(private {{WRAPPER_TYPE}} preprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}}) {
if ({{NAME}}Preprocessor == null) {
return {{WRAPPER_NAME}};
}
return {{NAME}}Preprocessor.process({{WRAPPER_NAME}});
}
)");
}
{
const auto get_buffer_block = AsBlock(code_writer, "Object[] getBuffer()");
code_writer->AppendNoNewLine("return new Object[] {");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->AppendNoNewLine("{{NAME}}.getBuffer(), ");
}
code_writer->Backspace(2);
code_writer->Append("};");
}
return true;
}
bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
ErrorReporter* err) {
code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
auto class_block = AsBlock(code_writer, "public class Outputs");
auto class_block = AsBlock(code_writer, "public static class Outputs");
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
if (tensor.associated_axis_label_index >= 0) {
code_writer->Append("private final List<String> {{NAME}}Labels;");
}
code_writer->Append(
"private final {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
}
// Getters
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->NewLine();
if (tensor.associated_axis_label_index >= 0) {
if (tensor.content_type == "tensor") {
code_writer->Append(
R"(public List<Category> get{{NAME_U}}AsCategoryList() {
return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getCategoryList();
})");
} else { // image
err->Warning(
"Axis label for images is not supported. The labels will "
"be ignored.");
}
} else { // no label
code_writer->Append(
R"(public {{WRAPPER_TYPE}} get{{NAME_U}}As{{WRAPPER_TYPE}}() {
return postprocess{{NAME_U}}({{NAME}});
})");
}
}
code_writer->NewLine();
{
const auto ctor_block = AsBlock(code_writer, "public Outputs()");
code_writer->Append(
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
const auto ctor_block = AsBlock(
code_writer,
"Outputs(Metadata metadata, {{POSTPROCESSOR_TYPE_PARAM_LIST}})");
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
if (tensor.content_type == "image") {
@ -435,36 +425,11 @@ bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
"metadata.get{{NAME_U}}Type());");
}
}
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->NewLine();
if (tensor.associated_axis_label_index >= 0) {
if (tensor.content_type == "image") {
err->Warning(
"Axis label for images is not supported. The labels will "
"be ignored.");
} else {
code_writer->Append(R"(public Map<String, Float> get{{NAME_U}}() {
return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getMapWithFloatValue();
})");
if (tensor.associated_axis_label_index >= 0) {
code_writer->Append("{{NAME}}Labels = metadata.get{{NAME_U}}Labels();");
}
} else {
code_writer->Append(R"(public {{WRAPPER_TYPE}} get{{NAME_U}}() {
return postprocess{{NAME_U}}({{NAME}});
})");
}
code_writer->NewLine();
{
auto processor_block =
AsBlock(code_writer,
"private {{WRAPPER_TYPE}} "
"postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
code_writer->Append(R"(if ({{NAME}}Postprocessor == null) {
return {{WRAPPER_NAME}};
}
return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});)");
code_writer->Append(
"this.{{NAME}}Postprocessor = {{NAME}}Postprocessor;");
}
}
code_writer->NewLine();
@ -479,6 +444,18 @@ return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});)");
}
code_writer->Append("return outputs;");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->NewLine();
{
auto processor_block =
AsBlock(code_writer,
"private {{WRAPPER_TYPE}} "
"postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
code_writer->Append(
"return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});");
}
}
return true;
}
@ -522,9 +499,10 @@ private final float[] {{NAME}}Stddev;)");
SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
code_writer->SetTokenValue("ID", std::to_string(i));
code_writer->Append(
R"({{NAME}}Shape = extractor.getInputTensorShape({{ID}});
{{NAME}}DataType = extractor.getInputTensorType({{ID}});
{{NAME}}QuantizationParams = extractor.getInputTensorQuantizationParams({{ID}});)");
R"(Tensor {{NAME}}Tensor = model.getInputTensor({{ID}});
{{NAME}}Shape = {{NAME}}Tensor.shape();
{{NAME}}DataType = {{NAME}}Tensor.dataType();
{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
if (model.inputs[i].normalization_unit >= 0) {
code_writer->Append(
R"(NormalizationOptions {{NAME}}NormalizationOptions =
@ -541,9 +519,10 @@ FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer(
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
code_writer->SetTokenValue("ID", std::to_string(i));
code_writer->Append(
R"({{NAME}}Shape = model.getOutputTensorShape({{ID}});
{{NAME}}DataType = extractor.getOutputTensorType({{ID}});
{{NAME}}QuantizationParams = extractor.getOutputTensorQuantizationParams({{ID}});)");
R"(Tensor {{NAME}}Tensor = model.getOutputTensor({{ID}});
{{NAME}}Shape = {{NAME}}Tensor.shape();
{{NAME}}DataType = {{NAME}}Tensor.dataType();
{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
if (model.outputs[i].normalization_unit >= 0) {
code_writer->Append(
R"(NormalizationOptions {{NAME}}NormalizationOptions =
@ -637,8 +616,8 @@ bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public {{MODEL_CLASS_NAME}}(Context context) throws IOException {
this(context, MODEL_NAME, Device.CPU, 1);
public static {{MODEL_CLASS_NAME}} newInstance(Context context) throws IOException {
return newInstance(context, MODEL_NAME, new Model.Options.Builder().build());
}
/**
@ -647,18 +626,17 @@ public {{MODEL_CLASS_NAME}}(Context context) throws IOException {
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public {{MODEL_CLASS_NAME}}(Context context, String modelPath) throws IOException {
this(context, modelPath, Device.CPU, 1);
public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath) throws IOException {
return newInstance(context, modelPath, new Model.Options.Builder().build());
}
/**
* Creates interpreter and loads associated files if needed, with device and number of threads
* configured.
* Creates interpreter and loads associated files if needed, with running options configured.
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public {{MODEL_CLASS_NAME}}(Context context, Device device, int numThreads) throws IOException {
this(context, MODEL_NAME, device, numThreads);
public static {{MODEL_CLASS_NAME}} newInstance(Context context, Model.Options runningOptions) throws IOException {
return newInstance(context, MODEL_NAME, runningOptions);
}
/**
@ -666,80 +644,124 @@ public {{MODEL_CLASS_NAME}}(Context context, Device device, int numThreads) thro
*
* @throws IOException if an I/O error occurs when loading the tflite model.
*/
public {{MODEL_CLASS_NAME}}(Context context, String modelPath, Device device, int numThreads) throws IOException {
model = new Model.Builder(context, modelPath).setDevice(device).setNumThreads(numThreads).build();
metadata = new Metadata(model.getData(), model);)");
public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath, Model.Options runningOptions) throws IOException {
Model model = Model.createModel(context, modelPath, runningOptions);
Metadata metadata = new Metadata(model.getData(), model);
MyImageClassifier instance = new MyImageClassifier(model, metadata);)");
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
{{PROCESSOR_TYPE}}.Builder {{NAME}}PreprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder())");
if (tensor.content_type == "image") {
code_writer->Append(R"( .add(new ResizeOp(
metadata.get{{NAME_U}}Shape()[1],
metadata.get{{NAME_U}}Shape()[2],
ResizeMethod.NEAREST_NEIGHBOR)))");
}
if (tensor.normalization_unit >= 0) {
code_writer->Append(
R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
}
code_writer->Append(
R"( .add(new QuantizeOp(
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
metadata.get{{NAME_U}}QuantizationParams().getScale()))
.add(new CastOp(metadata.get{{NAME_U}}Type()));
{{NAME}}Preprocessor = {{NAME}}PreprocessorBuilder.build();)");
"instance.reset{{NAME_U}}Preprocessor(instance.buildDefault{{NAME_U}}"
"Preprocessor());");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->AppendNoNewLine(R"(
{{PROCESSOR_TYPE}}.Builder {{NAME}}PostprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder()
.add(new DequantizeOp(
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
metadata.get{{NAME_U}}QuantizationParams().getScale())))");
if (tensor.normalization_unit >= 0) {
code_writer->AppendNoNewLine(R"(
.add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
}
code_writer->Append(R"(;
{{NAME}}Postprocessor = {{NAME}}PostprocessorBuilder.build();)");
if (tensor.associated_axis_label_index >= 0) {
code_writer->Append(R"(
{{NAME}}Labels = metadata.get{{NAME_U}}Labels();)");
}
code_writer->Append(
"instance.reset{{NAME_U}}Postprocessor(instance.buildDefault{{NAME_U}}"
"Postprocessor());");
}
code_writer->Append("}");
code_writer->Append(R"( return instance;
}
)");
// Pre, post processor setters
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public void reset{{NAME_U}}Preprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
public void reset{{NAME_U}}Preprocessor({{PROCESSOR_TYPE}} processor) {
{{NAME}}Preprocessor = processor;
})");
}
for (const auto& tensor : model.outputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append(R"(
public void reset{{NAME_U}}Postprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
public void reset{{NAME_U}}Postprocessor({{PROCESSOR_TYPE}} processor) {
{{NAME}}Postprocessor = processor;
})");
}
// Process method
code_writer->Append(R"(
/** Creates inputs */
public Inputs createInputs() {
return new Inputs();
}
/** Triggers the model. */
public Outputs run(Inputs inputs) {
Outputs outputs = new Outputs();
model.run(inputs.getBuffer(), outputs.getBuffer());
public Outputs process({{INPUT_TYPE_PARAM_LIST}}) {
Outputs outputs = new Outputs(metadata, {{POSTPROCESSORS_LIST}});
Object[] inputBuffers = preprocessInputs({{INPUTS_LIST}});
model.run(inputBuffers, outputs.getBuffer());
return outputs;
}
/** Closes the model. */
public void close() {
model.close();
})");
}
)");
{
auto block =
AsBlock(code_writer,
"private {{MODEL_CLASS_NAME}}(Model model, Metadata metadata)");
code_writer->Append(R"(this.model = model;
this.metadata = metadata;)");
}
for (const auto& tensor : model.inputs) {
code_writer->NewLine();
SetCodeWriterWithTensorInfo(code_writer, tensor);
auto block = AsBlock(
code_writer,
"private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Preprocessor()");
code_writer->Append(
"{{PROCESSOR_TYPE}}.Builder builder = new "
"{{PROCESSOR_TYPE}}.Builder()");
if (tensor.content_type == "image") {
code_writer->Append(R"( .add(new ResizeOp(
metadata.get{{NAME_U}}Shape()[1],
metadata.get{{NAME_U}}Shape()[2],
ResizeMethod.NEAREST_NEIGHBOR)))");
}
if (tensor.normalization_unit >= 0) {
code_writer->Append(
R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
}
code_writer->Append(
R"( .add(new QuantizeOp(
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
metadata.get{{NAME_U}}QuantizationParams().getScale()))
.add(new CastOp(metadata.get{{NAME_U}}Type()));
return builder.build();)");
}
for (const auto& tensor : model.outputs) {
code_writer->NewLine();
SetCodeWriterWithTensorInfo(code_writer, tensor);
auto block = AsBlock(
code_writer,
"private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Postprocessor()");
code_writer->AppendNoNewLine(
R"({{PROCESSOR_TYPE}}.Builder builder = new {{PROCESSOR_TYPE}}.Builder()
.add(new DequantizeOp(
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
metadata.get{{NAME_U}}QuantizationParams().getScale())))");
if (tensor.normalization_unit >= 0) {
code_writer->AppendNoNewLine(R"(
.add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
}
code_writer->Append(R"(;
return builder.build();)");
}
code_writer->NewLine();
{
const auto block =
AsBlock(code_writer,
"private Object[] preprocessInputs({{INPUT_TYPE_PARAM_LIST}})");
CodeWriter param_list_gen(err);
for (const auto& tensor : model.inputs) {
SetCodeWriterWithTensorInfo(code_writer, tensor);
code_writer->Append("{{NAME}} = {{NAME}}Preprocessor.process({{NAME}});");
SetCodeWriterWithTensorInfo(&param_list_gen, tensor);
param_list_gen.AppendNoNewLine("{{NAME}}.getBuffer(), ");
}
param_list_gen.Backspace(2);
code_writer->AppendNoNewLine("return new Object[] {");
code_writer->AppendNoNewLine(param_list_gen.ToString());
code_writer->Append("};");
}
return true;
}
@ -823,6 +845,7 @@ bool GenerateAndroidManifestContent(CodeWriter* code_writer,
bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) {
code_writer->Append("# {{MODEL_CLASS_NAME}} Usage");
// TODO(b/158651848) Generate imports for TFLS util types like TensorImage.
code_writer->AppendNoNewLine(R"(
```
import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
@ -831,9 +854,7 @@ import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
{{MODEL_CLASS_NAME}} model = null;
try {
model = new {{MODEL_CLASS_NAME}}(context); // android.content.Context
// Create the input container.
{{MODEL_CLASS_NAME}}.Inputs inputs = model.createInputs();
model = {{MODEL_CLASS_NAME}}.newInstance(context); // android.content.Context
} catch (IOException e) {
e.printStackTrace();
}
@ -845,33 +866,42 @@ if (model != null) {
SetCodeWriterWithTensorInfo(code_writer, t);
if (t.content_type == "image") {
code_writer->Append(R"(
// Load input tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
// Prepare tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
Bitmap bitmap = ...;
inputs.load{{NAME_U}}(bitmap);
// Alternatively, load the input tensor "{{NAME}}" from a TensorImage.
TensorImage {{MAME}} = TensorImage.fromBitmap(bitmap);
// Alternatively, load the input tensor "{{NAME}}" from pixel values.
// Check out TensorImage documentation to load other image data structures.
// TensorImage tensorImage = ...;
// inputs.load{{NAME_U}}(tensorImage);)");
// int[] pixelValues = ...;
// int[] shape = ...;
// TensorImage {{NAME}} = new TensorImage();
// {{NAME}}.load(pixelValues, shape);)");
} else {
code_writer->Append(R"(
// Load input tensor "{{NAME}}" from a TensorBuffer.
// Prepare input tensor "{{NAME}}" from an array.
// Check out TensorBuffer documentation to load other data structures.
TensorBuffer tensorBuffer = ...;
inputs.load{{NAME_U}}(tensorBuffer);)");
TensorBuffer {{NAME}} = ...;
int[] values = ...;
int[] shape = ...;
{{NAME}}.load(values, shape);)");
}
}
code_writer->Append(R"(
// 3. Run the model
{{MODEL_CLASS_NAME}}.Outputs outputs = model.run(inputs);)");
{{MODEL_CLASS_NAME}}.Outputs outputs = model.process({{INPUTS_LIST}});)");
code_writer->Append(R"(
// 4. Retrieve the results)");
for (const auto& t : model_info.outputs) {
SetCodeWriterWithTensorInfo(code_writer, t);
if (t.associated_axis_label_index >= 0) {
code_writer->SetTokenValue("WRAPPER_TYPE", "Map<String, Float>");
code_writer->SetTokenValue("WRAPPER_TYPE", "List<Category>");
code_writer->Append(
" List<Category> {{NAME}} = "
"outputs.get{{NAME_U}}AsCategoryList();");
} else {
code_writer->Append(
" {{WRAPPER_TYPE}} {{NAME}} = "
"outputs.get{{NAME_U}}As{{WRAPPER_TYPE}}();");
}
code_writer->Append(
R"( {{WRAPPER_TYPE}} {{NAME}} = outputs.get{{NAME_U}}();)");
}
code_writer->Append(R"(}
```)");

View File

@ -57,6 +57,15 @@ struct ModelInfo {
std::string model_versioned_name;
std::vector<TensorInfo> inputs;
std::vector<TensorInfo> outputs;
// Extra helper fields. For models with inputs "a", "b" and outputs "x", "y":
std::string input_type_param_list;
// e.g. "TensorImage a, TensorBuffer b"
std::string inputs_list;
// e.g. "a, b"
std::string postprocessor_type_param_list;
// e.g. "ImageProcessor xPostprocessor, TensorProcessor yPostprocessor"
std::string postprocessors_list;
// e.g. "xPostprocessor, yPostprocessor"
};
} // namespace details_android_java

View File

@ -106,7 +106,7 @@ void CodeWriter::AppendInternal(const std::string& text, bool newline) {
while (i < text.size()) {
char cur = text[i];
char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian
if (in_token == false) {
if (!in_token) {
if (cur == '{' && cur_next == '{') { // Enter token
in_token = true;
i += 2;