[tfls.codegen] Update Image Classifier generation to new API.
PiperOrigin-RevId: 316031289 Change-Id: I918dc5563f80e3bd163dec77b72374242484e936
This commit is contained in:
parent
938d22a218
commit
6642441bee
@ -13,6 +13,29 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file contains the logic of android model wrapper generation.
|
||||
//
|
||||
// At the beginning is the helper functions handling metadata and code writer.
|
||||
//
|
||||
// Codes are generated in every `Generate{FOO}` functions. Gradle and Manifest
|
||||
// files are simple. The wrapper file generation is a bit complex so we divided
|
||||
// it into several sub-functions.
|
||||
//
|
||||
// The structure of the wrapper file looks like:
|
||||
//
|
||||
// [ imports ]
|
||||
// [ class ]
|
||||
// [ inner "Outputs" class ]
|
||||
// [ innner "Metadata" class ]
|
||||
// [ APIs ] ( including ctors, public APIs and private APIs )
|
||||
//
|
||||
// We tried to mostly write it in a "template-generation" way. `CodeWriter` does
|
||||
// the job as a template renderer. To avoid repeatedly setting the token values,
|
||||
// helper functions `SetCodeWriterWith{Foo}Info` set the token values with info
|
||||
// structures (`TensorInfo` and `ModelInfo`) - the Info structures are
|
||||
// intermediate datastructures between Metadata (represented in Flatbuffers) and
|
||||
// generated code.
|
||||
|
||||
#include "tensorflow/lite/experimental/support/codegen/android_java_generator.h"
|
||||
|
||||
#include <ctype.h>
|
||||
@ -158,15 +181,32 @@ ModelInfo CreateModelInfo(const ModelMetadata* metadata,
|
||||
graph->input_tensor_metadata(), graph->output_tensor_metadata());
|
||||
std::vector<std::string> input_tensor_names = std::move(names.first);
|
||||
std::vector<std::string> output_tensor_names = std::move(names.second);
|
||||
|
||||
for (int i = 0; i < input_tensor_names.size(); i++) {
|
||||
model_info.inputs.push_back(
|
||||
CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
|
||||
input_tensor_names[i], true, i, err));
|
||||
if (i < input_tensor_names.size() - 1) {
|
||||
model_info.inputs_list += ", ";
|
||||
model_info.input_type_param_list += ", ";
|
||||
}
|
||||
model_info.inputs_list += model_info.inputs[i].name;
|
||||
model_info.input_type_param_list +=
|
||||
model_info.inputs[i].wrapper_type + " " + model_info.inputs[i].name;
|
||||
}
|
||||
for (int i = 0; i < output_tensor_names.size(); i++) {
|
||||
model_info.outputs.push_back(
|
||||
CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
|
||||
output_tensor_names[i], false, i, err));
|
||||
if (i < output_tensor_names.size() - 1) {
|
||||
model_info.postprocessor_type_param_list += ", ";
|
||||
model_info.postprocessors_list += ", ";
|
||||
}
|
||||
model_info.postprocessors_list +=
|
||||
model_info.outputs[i].name + "Postprocessor";
|
||||
model_info.postprocessor_type_param_list +=
|
||||
model_info.outputs[i].processor_type + " " +
|
||||
model_info.outputs[i].name + "Postprocessor";
|
||||
}
|
||||
return model_info;
|
||||
}
|
||||
@ -196,6 +236,14 @@ void SetCodeWriterWithModelInfo(CodeWriter* code_writer,
|
||||
code_writer->SetTokenValue("PACKAGE", model_info.package_name);
|
||||
code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path);
|
||||
code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name);
|
||||
// Extra info, half generated.
|
||||
code_writer->SetTokenValue("INPUT_TYPE_PARAM_LIST",
|
||||
model_info.input_type_param_list);
|
||||
code_writer->SetTokenValue("INPUTS_LIST", model_info.inputs_list);
|
||||
code_writer->SetTokenValue("POSTPROCESSORS_LIST",
|
||||
model_info.postprocessors_list);
|
||||
code_writer->SetTokenValue("POSTPROCESSOR_TYPE_PARAM_LIST",
|
||||
model_info.postprocessor_type_param_list);
|
||||
}
|
||||
|
||||
constexpr char JAVA_DEFAULT_PACKAGE[] = "default";
|
||||
@ -223,6 +271,8 @@ bool IsImageUsed(const ModelInfo& model) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The following functions generates the wrapper Java code for a model.
|
||||
|
||||
bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
code_writer->Append("// Generated by TFLite Support.");
|
||||
@ -253,8 +303,8 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
|
||||
"java.util.HashMap",
|
||||
"java.util.List",
|
||||
"java.util.Map",
|
||||
"org.checkerframework.checker.nullness.qual.Nullable",
|
||||
"org.tensorflow.lite.DataType",
|
||||
"org.tensorflow.lite.Tensor",
|
||||
"org.tensorflow.lite.Tensor.QuantizationParams",
|
||||
support_pkg + "common.FileUtil",
|
||||
support_pkg + "common.TensorProcessor",
|
||||
@ -262,11 +312,11 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
|
||||
support_pkg + "common.ops.DequantizeOp",
|
||||
support_pkg + "common.ops.NormalizeOp",
|
||||
support_pkg + "common.ops.QuantizeOp",
|
||||
support_pkg + "label.Category",
|
||||
support_pkg + "label.TensorLabel",
|
||||
support_pkg + "metadata.MetadataExtractor",
|
||||
support_pkg + "metadata.schema.NormalizationOptions",
|
||||
support_pkg + "model.Model",
|
||||
support_pkg + "model.Model.Device",
|
||||
support_pkg + "tensorbuffer.TensorBuffer",
|
||||
};
|
||||
if (IsImageUsed(model)) {
|
||||
@ -275,7 +325,6 @@ bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
|
||||
"image.ops.ResizeOp.ResizeMethod"}) {
|
||||
imports.push_back(support_pkg + target);
|
||||
}
|
||||
imports.push_back("android.graphics.Bitmap");
|
||||
}
|
||||
|
||||
std::sort(imports.begin(), imports.end());
|
||||
@ -298,26 +347,13 @@ bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
|
||||
code_writer->Append(R"(private final Metadata metadata;
|
||||
private final Model model;
|
||||
private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
|
||||
for (const auto& tensor : model.outputs) {
|
||||
if (tensor.associated_axis_label_index >= 0) {
|
||||
code_writer->SetTokenValue("NAME", tensor.name);
|
||||
code_writer->Append("private final List<String> {{NAME}}Labels;");
|
||||
}
|
||||
}
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(
|
||||
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
|
||||
code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(
|
||||
"@Nullable private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
if (!GenerateWrapperInputs(code_writer, model, err)) {
|
||||
err->Error("Failed to generate input classes");
|
||||
return false;
|
||||
code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
if (!GenerateWrapperOutputs(code_writer, model, err)) {
|
||||
@ -337,92 +373,46 @@ private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateWrapperInputs(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
code_writer->Append("/** Input wrapper of {@link {{MODEL_CLASS_NAME}}} */");
|
||||
auto class_block = AsBlock(code_writer, "public class Inputs");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append("private {{WRAPPER_TYPE}} {{NAME}};");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
// Ctor
|
||||
{
|
||||
auto ctor_block = AsBlock(code_writer, "public Inputs()");
|
||||
code_writer->Append(
|
||||
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
if (tensor.content_type == "image") {
|
||||
code_writer->Append(
|
||||
"{{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());");
|
||||
} else {
|
||||
code_writer->Append(
|
||||
"{{NAME}} = "
|
||||
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
|
||||
"metadata.get{{NAME_U}}Type());");
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto& tensor : model.inputs) {
|
||||
code_writer->NewLine();
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
// Loaders
|
||||
if (tensor.content_type == "image") {
|
||||
{
|
||||
auto bitmap_loader_block =
|
||||
AsBlock(code_writer, "public void load{{NAME_U}}(Bitmap bitmap)");
|
||||
code_writer->Append(R"({{NAME}}.load(bitmap);
|
||||
{{NAME}} = preprocess{{NAME_U}}({{NAME}});)");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
{
|
||||
auto tensor_image_loader_block = AsBlock(
|
||||
code_writer, "public void load{{NAME_U}}(TensorImage tensorImage)");
|
||||
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorImage);");
|
||||
}
|
||||
} else { // content_type == "FEATURE" or "UNKNOWN"
|
||||
auto tensorbuffer_loader_block = AsBlock(
|
||||
code_writer, "public void load{{NAME_U}}(TensorBuffer tensorBuffer)");
|
||||
code_writer->Append("{{NAME}} = preprocess{{NAME_U}}(tensorBuffer);");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
// Processor
|
||||
code_writer->Append(
|
||||
R"(private {{WRAPPER_TYPE}} preprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}}) {
|
||||
if ({{NAME}}Preprocessor == null) {
|
||||
return {{WRAPPER_NAME}};
|
||||
}
|
||||
return {{NAME}}Preprocessor.process({{WRAPPER_NAME}});
|
||||
}
|
||||
)");
|
||||
}
|
||||
{
|
||||
const auto get_buffer_block = AsBlock(code_writer, "Object[] getBuffer()");
|
||||
code_writer->AppendNoNewLine("return new Object[] {");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->AppendNoNewLine("{{NAME}}.getBuffer(), ");
|
||||
}
|
||||
code_writer->Backspace(2);
|
||||
code_writer->Append("};");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
|
||||
ErrorReporter* err) {
|
||||
code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
|
||||
auto class_block = AsBlock(code_writer, "public class Outputs");
|
||||
auto class_block = AsBlock(code_writer, "public static class Outputs");
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
|
||||
if (tensor.associated_axis_label_index >= 0) {
|
||||
code_writer->Append("private final List<String> {{NAME}}Labels;");
|
||||
}
|
||||
code_writer->Append(
|
||||
"private final {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
|
||||
}
|
||||
// Getters
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->NewLine();
|
||||
if (tensor.associated_axis_label_index >= 0) {
|
||||
if (tensor.content_type == "tensor") {
|
||||
code_writer->Append(
|
||||
R"(public List<Category> get{{NAME_U}}AsCategoryList() {
|
||||
return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getCategoryList();
|
||||
})");
|
||||
} else { // image
|
||||
err->Warning(
|
||||
"Axis label for images is not supported. The labels will "
|
||||
"be ignored.");
|
||||
}
|
||||
} else { // no label
|
||||
code_writer->Append(
|
||||
R"(public {{WRAPPER_TYPE}} get{{NAME_U}}As{{WRAPPER_TYPE}}() {
|
||||
return postprocess{{NAME_U}}({{NAME}});
|
||||
})");
|
||||
}
|
||||
}
|
||||
code_writer->NewLine();
|
||||
{
|
||||
const auto ctor_block = AsBlock(code_writer, "public Outputs()");
|
||||
code_writer->Append(
|
||||
"Metadata metadata = {{MODEL_CLASS_NAME}}.this.metadata;");
|
||||
const auto ctor_block = AsBlock(
|
||||
code_writer,
|
||||
"Outputs(Metadata metadata, {{POSTPROCESSOR_TYPE_PARAM_LIST}})");
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
if (tensor.content_type == "image") {
|
||||
@ -435,36 +425,11 @@ bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
|
||||
"TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
|
||||
"metadata.get{{NAME_U}}Type());");
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->NewLine();
|
||||
if (tensor.associated_axis_label_index >= 0) {
|
||||
if (tensor.content_type == "image") {
|
||||
err->Warning(
|
||||
"Axis label for images is not supported. The labels will "
|
||||
"be ignored.");
|
||||
} else {
|
||||
code_writer->Append(R"(public Map<String, Float> get{{NAME_U}}() {
|
||||
return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getMapWithFloatValue();
|
||||
})");
|
||||
if (tensor.associated_axis_label_index >= 0) {
|
||||
code_writer->Append("{{NAME}}Labels = metadata.get{{NAME_U}}Labels();");
|
||||
}
|
||||
} else {
|
||||
code_writer->Append(R"(public {{WRAPPER_TYPE}} get{{NAME_U}}() {
|
||||
return postprocess{{NAME_U}}({{NAME}});
|
||||
})");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
{
|
||||
auto processor_block =
|
||||
AsBlock(code_writer,
|
||||
"private {{WRAPPER_TYPE}} "
|
||||
"postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
|
||||
code_writer->Append(R"(if ({{NAME}}Postprocessor == null) {
|
||||
return {{WRAPPER_NAME}};
|
||||
}
|
||||
return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});)");
|
||||
code_writer->Append(
|
||||
"this.{{NAME}}Postprocessor = {{NAME}}Postprocessor;");
|
||||
}
|
||||
}
|
||||
code_writer->NewLine();
|
||||
@ -479,6 +444,18 @@ return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});)");
|
||||
}
|
||||
code_writer->Append("return outputs;");
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->NewLine();
|
||||
{
|
||||
auto processor_block =
|
||||
AsBlock(code_writer,
|
||||
"private {{WRAPPER_TYPE}} "
|
||||
"postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
|
||||
code_writer->Append(
|
||||
"return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});");
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -522,9 +499,10 @@ private final float[] {{NAME}}Stddev;)");
|
||||
SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
|
||||
code_writer->SetTokenValue("ID", std::to_string(i));
|
||||
code_writer->Append(
|
||||
R"({{NAME}}Shape = extractor.getInputTensorShape({{ID}});
|
||||
{{NAME}}DataType = extractor.getInputTensorType({{ID}});
|
||||
{{NAME}}QuantizationParams = extractor.getInputTensorQuantizationParams({{ID}});)");
|
||||
R"(Tensor {{NAME}}Tensor = model.getInputTensor({{ID}});
|
||||
{{NAME}}Shape = {{NAME}}Tensor.shape();
|
||||
{{NAME}}DataType = {{NAME}}Tensor.dataType();
|
||||
{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
|
||||
if (model.inputs[i].normalization_unit >= 0) {
|
||||
code_writer->Append(
|
||||
R"(NormalizationOptions {{NAME}}NormalizationOptions =
|
||||
@ -541,9 +519,10 @@ FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer(
|
||||
SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
|
||||
code_writer->SetTokenValue("ID", std::to_string(i));
|
||||
code_writer->Append(
|
||||
R"({{NAME}}Shape = model.getOutputTensorShape({{ID}});
|
||||
{{NAME}}DataType = extractor.getOutputTensorType({{ID}});
|
||||
{{NAME}}QuantizationParams = extractor.getOutputTensorQuantizationParams({{ID}});)");
|
||||
R"(Tensor {{NAME}}Tensor = model.getOutputTensor({{ID}});
|
||||
{{NAME}}Shape = {{NAME}}Tensor.shape();
|
||||
{{NAME}}DataType = {{NAME}}Tensor.dataType();
|
||||
{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
|
||||
if (model.outputs[i].normalization_unit >= 0) {
|
||||
code_writer->Append(
|
||||
R"(NormalizationOptions {{NAME}}NormalizationOptions =
|
||||
@ -637,8 +616,8 @@ bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
|
||||
*
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
public {{MODEL_CLASS_NAME}}(Context context) throws IOException {
|
||||
this(context, MODEL_NAME, Device.CPU, 1);
|
||||
public static {{MODEL_CLASS_NAME}} newInstance(Context context) throws IOException {
|
||||
return newInstance(context, MODEL_NAME, new Model.Options.Builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
@ -647,18 +626,17 @@ public {{MODEL_CLASS_NAME}}(Context context) throws IOException {
|
||||
*
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
public {{MODEL_CLASS_NAME}}(Context context, String modelPath) throws IOException {
|
||||
this(context, modelPath, Device.CPU, 1);
|
||||
public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath) throws IOException {
|
||||
return newInstance(context, modelPath, new Model.Options.Builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates interpreter and loads associated files if needed, with device and number of threads
|
||||
* configured.
|
||||
* Creates interpreter and loads associated files if needed, with running options configured.
|
||||
*
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
public {{MODEL_CLASS_NAME}}(Context context, Device device, int numThreads) throws IOException {
|
||||
this(context, MODEL_NAME, device, numThreads);
|
||||
public static {{MODEL_CLASS_NAME}} newInstance(Context context, Model.Options runningOptions) throws IOException {
|
||||
return newInstance(context, MODEL_NAME, runningOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -666,80 +644,124 @@ public {{MODEL_CLASS_NAME}}(Context context, Device device, int numThreads) thro
|
||||
*
|
||||
* @throws IOException if an I/O error occurs when loading the tflite model.
|
||||
*/
|
||||
public {{MODEL_CLASS_NAME}}(Context context, String modelPath, Device device, int numThreads) throws IOException {
|
||||
model = new Model.Builder(context, modelPath).setDevice(device).setNumThreads(numThreads).build();
|
||||
metadata = new Metadata(model.getData(), model);)");
|
||||
public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath, Model.Options runningOptions) throws IOException {
|
||||
Model model = Model.createModel(context, modelPath, runningOptions);
|
||||
Metadata metadata = new Metadata(model.getData(), model);
|
||||
MyImageClassifier instance = new MyImageClassifier(model, metadata);)");
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(
|
||||
{{PROCESSOR_TYPE}}.Builder {{NAME}}PreprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder())");
|
||||
if (tensor.content_type == "image") {
|
||||
code_writer->Append(R"( .add(new ResizeOp(
|
||||
metadata.get{{NAME_U}}Shape()[1],
|
||||
metadata.get{{NAME_U}}Shape()[2],
|
||||
ResizeMethod.NEAREST_NEIGHBOR)))");
|
||||
}
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->Append(
|
||||
R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
|
||||
}
|
||||
code_writer->Append(
|
||||
R"( .add(new QuantizeOp(
|
||||
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
|
||||
metadata.get{{NAME_U}}QuantizationParams().getScale()))
|
||||
.add(new CastOp(metadata.get{{NAME_U}}Type()));
|
||||
{{NAME}}Preprocessor = {{NAME}}PreprocessorBuilder.build();)");
|
||||
"instance.reset{{NAME_U}}Preprocessor(instance.buildDefault{{NAME_U}}"
|
||||
"Preprocessor());");
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->AppendNoNewLine(R"(
|
||||
{{PROCESSOR_TYPE}}.Builder {{NAME}}PostprocessorBuilder = new {{PROCESSOR_TYPE}}.Builder()
|
||||
.add(new DequantizeOp(
|
||||
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
|
||||
metadata.get{{NAME_U}}QuantizationParams().getScale())))");
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->AppendNoNewLine(R"(
|
||||
.add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
|
||||
}
|
||||
code_writer->Append(R"(;
|
||||
{{NAME}}Postprocessor = {{NAME}}PostprocessorBuilder.build();)");
|
||||
if (tensor.associated_axis_label_index >= 0) {
|
||||
code_writer->Append(R"(
|
||||
{{NAME}}Labels = metadata.get{{NAME_U}}Labels();)");
|
||||
}
|
||||
code_writer->Append(
|
||||
"instance.reset{{NAME_U}}Postprocessor(instance.buildDefault{{NAME_U}}"
|
||||
"Postprocessor());");
|
||||
}
|
||||
code_writer->Append("}");
|
||||
code_writer->Append(R"( return instance;
|
||||
}
|
||||
)");
|
||||
|
||||
// Pre, post processor setters
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(
|
||||
public void reset{{NAME_U}}Preprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
|
||||
public void reset{{NAME_U}}Preprocessor({{PROCESSOR_TYPE}} processor) {
|
||||
{{NAME}}Preprocessor = processor;
|
||||
})");
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append(R"(
|
||||
public void reset{{NAME_U}}Postprocessor(@Nullable {{PROCESSOR_TYPE}} processor) {
|
||||
public void reset{{NAME_U}}Postprocessor({{PROCESSOR_TYPE}} processor) {
|
||||
{{NAME}}Postprocessor = processor;
|
||||
})");
|
||||
}
|
||||
// Process method
|
||||
code_writer->Append(R"(
|
||||
/** Creates inputs */
|
||||
public Inputs createInputs() {
|
||||
return new Inputs();
|
||||
}
|
||||
|
||||
/** Triggers the model. */
|
||||
public Outputs run(Inputs inputs) {
|
||||
Outputs outputs = new Outputs();
|
||||
model.run(inputs.getBuffer(), outputs.getBuffer());
|
||||
public Outputs process({{INPUT_TYPE_PARAM_LIST}}) {
|
||||
Outputs outputs = new Outputs(metadata, {{POSTPROCESSORS_LIST}});
|
||||
Object[] inputBuffers = preprocessInputs({{INPUTS_LIST}});
|
||||
model.run(inputBuffers, outputs.getBuffer());
|
||||
return outputs;
|
||||
}
|
||||
|
||||
/** Closes the model. */
|
||||
public void close() {
|
||||
model.close();
|
||||
})");
|
||||
}
|
||||
)");
|
||||
{
|
||||
auto block =
|
||||
AsBlock(code_writer,
|
||||
"private {{MODEL_CLASS_NAME}}(Model model, Metadata metadata)");
|
||||
code_writer->Append(R"(this.model = model;
|
||||
this.metadata = metadata;)");
|
||||
}
|
||||
for (const auto& tensor : model.inputs) {
|
||||
code_writer->NewLine();
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
auto block = AsBlock(
|
||||
code_writer,
|
||||
"private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Preprocessor()");
|
||||
code_writer->Append(
|
||||
"{{PROCESSOR_TYPE}}.Builder builder = new "
|
||||
"{{PROCESSOR_TYPE}}.Builder()");
|
||||
if (tensor.content_type == "image") {
|
||||
code_writer->Append(R"( .add(new ResizeOp(
|
||||
metadata.get{{NAME_U}}Shape()[1],
|
||||
metadata.get{{NAME_U}}Shape()[2],
|
||||
ResizeMethod.NEAREST_NEIGHBOR)))");
|
||||
}
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->Append(
|
||||
R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
|
||||
}
|
||||
code_writer->Append(
|
||||
R"( .add(new QuantizeOp(
|
||||
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
|
||||
metadata.get{{NAME_U}}QuantizationParams().getScale()))
|
||||
.add(new CastOp(metadata.get{{NAME_U}}Type()));
|
||||
return builder.build();)");
|
||||
}
|
||||
for (const auto& tensor : model.outputs) {
|
||||
code_writer->NewLine();
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
auto block = AsBlock(
|
||||
code_writer,
|
||||
"private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Postprocessor()");
|
||||
code_writer->AppendNoNewLine(
|
||||
R"({{PROCESSOR_TYPE}}.Builder builder = new {{PROCESSOR_TYPE}}.Builder()
|
||||
.add(new DequantizeOp(
|
||||
metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
|
||||
metadata.get{{NAME_U}}QuantizationParams().getScale())))");
|
||||
if (tensor.normalization_unit >= 0) {
|
||||
code_writer->AppendNoNewLine(R"(
|
||||
.add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
|
||||
}
|
||||
code_writer->Append(R"(;
|
||||
return builder.build();)");
|
||||
}
|
||||
code_writer->NewLine();
|
||||
{
|
||||
const auto block =
|
||||
AsBlock(code_writer,
|
||||
"private Object[] preprocessInputs({{INPUT_TYPE_PARAM_LIST}})");
|
||||
CodeWriter param_list_gen(err);
|
||||
for (const auto& tensor : model.inputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, tensor);
|
||||
code_writer->Append("{{NAME}} = {{NAME}}Preprocessor.process({{NAME}});");
|
||||
SetCodeWriterWithTensorInfo(¶m_list_gen, tensor);
|
||||
param_list_gen.AppendNoNewLine("{{NAME}}.getBuffer(), ");
|
||||
}
|
||||
param_list_gen.Backspace(2);
|
||||
code_writer->AppendNoNewLine("return new Object[] {");
|
||||
code_writer->AppendNoNewLine(param_list_gen.ToString());
|
||||
code_writer->Append("};");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -823,6 +845,7 @@ bool GenerateAndroidManifestContent(CodeWriter* code_writer,
|
||||
|
||||
bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) {
|
||||
code_writer->Append("# {{MODEL_CLASS_NAME}} Usage");
|
||||
// TODO(b/158651848) Generate imports for TFLS util types like TensorImage.
|
||||
code_writer->AppendNoNewLine(R"(
|
||||
```
|
||||
import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
|
||||
@ -831,9 +854,7 @@ import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
|
||||
{{MODEL_CLASS_NAME}} model = null;
|
||||
|
||||
try {
|
||||
model = new {{MODEL_CLASS_NAME}}(context); // android.content.Context
|
||||
// Create the input container.
|
||||
{{MODEL_CLASS_NAME}}.Inputs inputs = model.createInputs();
|
||||
model = {{MODEL_CLASS_NAME}}.newInstance(context); // android.content.Context
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
@ -845,33 +866,42 @@ if (model != null) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, t);
|
||||
if (t.content_type == "image") {
|
||||
code_writer->Append(R"(
|
||||
// Load input tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
|
||||
// Prepare tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
|
||||
Bitmap bitmap = ...;
|
||||
inputs.load{{NAME_U}}(bitmap);
|
||||
// Alternatively, load the input tensor "{{NAME}}" from a TensorImage.
|
||||
TensorImage {{MAME}} = TensorImage.fromBitmap(bitmap);
|
||||
// Alternatively, load the input tensor "{{NAME}}" from pixel values.
|
||||
// Check out TensorImage documentation to load other image data structures.
|
||||
// TensorImage tensorImage = ...;
|
||||
// inputs.load{{NAME_U}}(tensorImage);)");
|
||||
// int[] pixelValues = ...;
|
||||
// int[] shape = ...;
|
||||
// TensorImage {{NAME}} = new TensorImage();
|
||||
// {{NAME}}.load(pixelValues, shape);)");
|
||||
} else {
|
||||
code_writer->Append(R"(
|
||||
// Load input tensor "{{NAME}}" from a TensorBuffer.
|
||||
// Prepare input tensor "{{NAME}}" from an array.
|
||||
// Check out TensorBuffer documentation to load other data structures.
|
||||
TensorBuffer tensorBuffer = ...;
|
||||
inputs.load{{NAME_U}}(tensorBuffer);)");
|
||||
TensorBuffer {{NAME}} = ...;
|
||||
int[] values = ...;
|
||||
int[] shape = ...;
|
||||
{{NAME}}.load(values, shape);)");
|
||||
}
|
||||
}
|
||||
code_writer->Append(R"(
|
||||
// 3. Run the model
|
||||
{{MODEL_CLASS_NAME}}.Outputs outputs = model.run(inputs);)");
|
||||
{{MODEL_CLASS_NAME}}.Outputs outputs = model.process({{INPUTS_LIST}});)");
|
||||
code_writer->Append(R"(
|
||||
// 4. Retrieve the results)");
|
||||
for (const auto& t : model_info.outputs) {
|
||||
SetCodeWriterWithTensorInfo(code_writer, t);
|
||||
if (t.associated_axis_label_index >= 0) {
|
||||
code_writer->SetTokenValue("WRAPPER_TYPE", "Map<String, Float>");
|
||||
code_writer->SetTokenValue("WRAPPER_TYPE", "List<Category>");
|
||||
code_writer->Append(
|
||||
" List<Category> {{NAME}} = "
|
||||
"outputs.get{{NAME_U}}AsCategoryList();");
|
||||
} else {
|
||||
code_writer->Append(
|
||||
" {{WRAPPER_TYPE}} {{NAME}} = "
|
||||
"outputs.get{{NAME_U}}As{{WRAPPER_TYPE}}();");
|
||||
}
|
||||
code_writer->Append(
|
||||
R"( {{WRAPPER_TYPE}} {{NAME}} = outputs.get{{NAME_U}}();)");
|
||||
}
|
||||
code_writer->Append(R"(}
|
||||
```)");
|
||||
|
@ -57,6 +57,15 @@ struct ModelInfo {
|
||||
std::string model_versioned_name;
|
||||
std::vector<TensorInfo> inputs;
|
||||
std::vector<TensorInfo> outputs;
|
||||
// Extra helper fields. For models with inputs "a", "b" and outputs "x", "y":
|
||||
std::string input_type_param_list;
|
||||
// e.g. "TensorImage a, TensorBuffer b"
|
||||
std::string inputs_list;
|
||||
// e.g. "a, b"
|
||||
std::string postprocessor_type_param_list;
|
||||
// e.g. "ImageProcessor xPostprocessor, TensorProcessor yPostprocessor"
|
||||
std::string postprocessors_list;
|
||||
// e.g. "xPostprocessor, yPostprocessor"
|
||||
};
|
||||
|
||||
} // namespace details_android_java
|
||||
|
@ -106,7 +106,7 @@ void CodeWriter::AppendInternal(const std::string& text, bool newline) {
|
||||
while (i < text.size()) {
|
||||
char cur = text[i];
|
||||
char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian
|
||||
if (in_token == false) {
|
||||
if (!in_token) {
|
||||
if (cur == '{' && cur_next == '{') { // Enter token
|
||||
in_token = true;
|
||||
i += 2;
|
||||
|
Loading…
Reference in New Issue
Block a user