OSS TFLite Metadata library

PiperOrigin-RevId: 299966503
Change-Id: I08e6e7d424822c5f80f48b0f0c670f3c9c75f403
This commit is contained in:
Lu Wang 2020-03-09 16:40:25 -07:00 committed by TensorFlower Gardener
parent c645992e15
commit 9ebf4a223b
18 changed files with 3251 additions and 0 deletions

View File

@ -0,0 +1,87 @@
load("//tensorflow:tensorflow.bzl", "py_test")
load("@flatbuffers//:build_defs.bzl", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library", "flatbuffer_py_library")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["metadata_schema.fbs"])
flatbuffer_py_library(
name = "schema_py",
srcs = ["//tensorflow/lite/schema:schema.fbs"],
)
# Generic schema for inference on device.
flatbuffer_android_library(
name = "schema_fbs_android",
srcs = ["//tensorflow/lite/schema:schema.fbs"],
custom_package = "org.tensorflow.lite.schema",
)
flatbuffer_java_library(
name = "schema_fbs_java",
srcs = ["//tensorflow/lite/schema:schema.fbs"],
custom_package = "org.tensorflow.lite.schema",
)
# Generic schema for model metadata.
flatbuffer_cc_library(
name = "metadata_schema_cc",
srcs = ["metadata_schema.fbs"],
)
flatbuffer_py_library(
name = "metadata_schema_py",
srcs = ["metadata_schema.fbs"],
)
flatbuffer_java_library(
name = "metadata_schema_java",
srcs = ["metadata_schema.fbs"],
custom_package = "org.tensorflow.lite.support.metadata.schema",
)
flatbuffer_android_library(
name = "metadata_schema_fbs_android",
srcs = ["metadata_schema.fbs"],
custom_package = "org.tensorflow.lite.support.metadata.schema",
)
py_library(
name = "metadata",
srcs = ["metadata.py"],
data = [
"//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs",
"@flatbuffers//:flatc",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":metadata_schema_py",
":schema_py",
"//tensorflow/python:platform",
"@flatbuffers//:runtime_py",
],
)
py_test(
name = "metadata_test",
srcs = ["metadata_test.py"],
data = ["testdata/golden_json.json"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":metadata",
":metadata_schema_py",
":schema_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"@flatbuffers//:runtime_py",
"@six_archive//:six",
],
)

View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.support">
<uses-sdk android:minSdkVersion="19" />
</manifest>

View File

@ -0,0 +1,36 @@
# Description:
# TensorFlow Lite Support API in Java for metadata.
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
android_library(
name = "tensorflow-lite-support-metadata",
srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]),
manifest = "AndroidManifest.xml",
deps = [
"//tensorflow/lite/experimental/support/java:tensorflow-lite-support-precondition-lib-android",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_fbs_android",
"//tensorflow/lite/experimental/support/metadata:schema_fbs_android",
"//tensorflow/lite/java:tensorflowlite",
"@org_checkerframework_qual",
],
)
java_library(
name = "tensorflow-lite-support-metadata-lib",
srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]),
javacopts = JAVACOPTS,
deps = [
"//tensorflow/lite/experimental/support/java:tensorflow-lite-support-precondition",
"//tensorflow/lite/experimental/support/metadata:metadata_schema_java",
"//tensorflow/lite/experimental/support/metadata:schema_fbs_java",
"//tensorflow/lite/java:tensorflowlitelib",
"@org_checkerframework_qual",
],
)

View File

@ -0,0 +1,116 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkElementIndex;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
/**
* An {@link InputStream} that wraps a section of a {@link SeekableByteChannelCompat}.
*
* <p><b>WARNING:</b> Similar as {@link InputStream}, instances of an {@link BoundedInputStream} are
* <b>not</b> thread-safe. If multiple threads concurrently reading from the same {@link
* BoundedInputStream}, it must be synchronized externally. Also, if multiple instances of {@link
* BoundedInputStream} are created on the same {@link SeekableByteChannelCompat}, it must be
* synchronized as well.
*/
final class BoundedInputStream extends InputStream {
private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
private final long end; // The valid data for the stream is between [start, end).
private long position;
private final SeekableByteChannelCompat channel;
/**
* Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
*
* @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
* BoundedInputStream}
* @param start the starting position of this {@link BoundedInputStream} in the given {@link
* SeekableByteChannelCompat}
* @param remaining the length of this {@link BoundedInputStream}
* @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
*/
BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
checkArgument(
remaining >= 0 && start >= 0,
String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
end = start + remaining;
this.channel = channel;
position = start;
}
@Override
public int available() throws IOException {
return (int) (Math.min(end, channel.size()) - position);
}
@Override
public int read() throws IOException {
if (position >= end) {
return -1;
}
singleByteBuffer.rewind();
int count = read(position, singleByteBuffer);
if (count < 0) {
return count;
}
position++;
return singleByteBuffer.get() & 0xff;
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
checkNotNull(b);
checkElementIndex(off, b.length, "The start offset");
checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
if (len == 0) {
return 0;
}
if (len > end - position) {
if (position >= end) {
return -1;
}
len = (int) (end - position);
}
ByteBuffer buf = ByteBuffer.wrap(b, off, len);
int count = read(position, buf);
if (count > 0) {
position += count;
}
return count;
}
private int read(long position, ByteBuffer buf) throws IOException {
int count;
synchronized (channel) {
channel.position(position);
count = channel.read(buf);
}
buf.flip();
return count;
}
}

View File

@ -0,0 +1,130 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static java.lang.Math.min;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
import java.nio.ByteBuffer;
import java.nio.channels.NonWritableChannelException;
/** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */
final class ByteBufferChannel implements SeekableByteChannelCompat {
/** The ByteBuffer that holds the data. */
private final ByteBuffer buffer;
/**
* Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
*
* @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
* @throws NullPointerException if {@code buffer} is null
*/
public ByteBufferChannel(ByteBuffer buffer) {
checkNotNull(buffer, "The ByteBuffer cannot be null.");
this.buffer = buffer;
}
@Override
public void close() {}
@Override
public boolean isOpen() {
return true;
}
@Override
public long position() {
return buffer.position();
}
/**
* Sets this channel's position.
*
* @param newPosition the new position, a non-negative integer counting the number of bytes from
* the beginning of the entity
* @return this channel
* @throws IllegalArgumentException if the new position is negative, or greater than the size of
* the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
*/
@Override
public synchronized ByteBufferChannel position(long newPosition) {
checkArgument(
(newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
"The new position should be non-negative and be less than Integer.MAX_VALUE.");
buffer.position((int) newPosition);
return this;
}
/**
* {@inheritDoc}
*
* <p>Bytes are read starting at this channel's current position, and then the position is updated
* with the number of bytes actually read. Otherwise this method behaves exactly as specified in
* the {@link ReadableByteChannel} interface.
*/
@Override
public synchronized int read(ByteBuffer dst) {
if (buffer.remaining() == 0) {
return -1;
}
int count = min(dst.remaining(), buffer.remaining());
if (count > 0) {
ByteBuffer tempBuffer = buffer.slice();
tempBuffer.order(buffer.order()).limit(count);
dst.put(tempBuffer);
buffer.position(buffer.position() + count);
}
return count;
}
@Override
public long size() {
return buffer.limit();
}
@Override
public synchronized ByteBufferChannel truncate(long size) {
checkArgument(
(size >= 0 && size <= Integer.MAX_VALUE),
"The new size should be non-negative and be less than Integer.MAX_VALUE.");
if (size < buffer.limit()) {
buffer.limit((int) size);
if (buffer.position() > size) {
buffer.position((int) size);
}
}
return this;
}
@Override
public synchronized int write(ByteBuffer src) {
if (buffer.isReadOnly()) {
throw new NonWritableChannelException();
}
int count = min(src.remaining(), buffer.remaining());
if (count > 0) {
ByteBuffer tempBuffer = src.slice();
tempBuffer.order(buffer.order()).limit(count);
buffer.put(tempBuffer);
}
return count;
}
}

View File

@ -0,0 +1,247 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.zip.ZipException;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Tensor.QuantizationParams;
import org.tensorflow.lite.schema.Tensor;
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
/**
* Loads metadata from TFLite Model FlatBuffer.
*
* <p>TFLite Model FlatBuffer can be generated using the <a
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
* Model schema file.</a>
*
* <p>Some models contain a TFLite Metadata Flatbuffer, which records more information about what
* the model does and how to interprete the model. TFLite Metadata Flatbuffer can be generated using
* the <a
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/metadata_schema.fbs">TFLite
* Metadata schema file.</a>
*
* <p>It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking methods
* that read from TFLite metadata will cause runtime errors.
*
* <p>Similarly, it is allowed to pass in a model FlatBuffer without associated files. However,
* invoking methods that read the associated files will cause runtime errors.
*
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports a
* single subgraph so far. See the <a
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
* of how to specify subgraph during convertion for more information.</a> Therefore, {@link
* MetadataExtractor} omits subgraph index as an input in its methods.
*/
public class MetadataExtractor {
/** The helper class to load metadata from TFLite model FlatBuffer. */
private final ModelInfo modelInfo;
/** The helper class to load metadata from TFLite metadata FlatBuffer. */
@Nullable private final ModelMetadataInfo metadataInfo;
/** The handler to load associated files through zip. */
@Nullable private final ZipFile zipFile;
/**
* Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
*
* @param buffer the TFLite model FlatBuffer
* @throws IllegalArgumentException if the number of input or output tensors in the model does not
* match that in the metadata
* @throws IOException if an error occurs while reading the model as a Zip file
*/
public MetadataExtractor(ByteBuffer buffer) throws IOException {
modelInfo = new ModelInfo(buffer);
ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
if (metadataBuffer != null) {
metadataInfo = new ModelMetadataInfo(metadataBuffer);
checkArgument(
modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
String.format(
"The number of input tensors in the model is %d. The number of input tensors that"
+ " recorded in the metadata is %d. These two values does not match.",
modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
checkArgument(
modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
String.format(
"The number of output tensors in the model is %d. The number of output tensors that"
+ " recorded in the metadata is %d. These two values does not match.",
modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
} else {
// It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking
// methods that read from TFLite metadata will cause runtime errors.
metadataInfo = null;
}
zipFile = createZipFile(buffer);
}
/**
* Gets the packed associated file with the specified {@code fileName}.
*
* @param fileName the name of the associated file
* @return the raw input stream containing specified file
* @throws IllegalStateException if the model is not a zip file
* @throws IllegalArgumentException if the specified file does not exist in the model
*/
public InputStream getAssociatedFile(String fileName) {
assertZipFile();
return zipFile.getRawInputStream(fileName);
}
/** Gets the count of input tensors in the model. */
public int getInputTensorCount() {
return modelInfo.getInputTensorCount();
}
/**
* Gets the metadata for the input tensor specified by {@code inputIndex}.
*
* @param inputIndex the index of the desired input tensor
* @throws IllegalStateException if this model does not contain model metadata
*/
@Nullable
public TensorMetadata getInputTensorMetadata(int inputIndex) {
assertMetadataInfo();
return metadataInfo.getInputTensorMetadata(inputIndex);
}
/**
* Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
*
* @param inputIndex the index of the desired input tensor
*/
public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
Tensor tensor = modelInfo.getInputTensor(inputIndex);
return modelInfo.getQuantizationParams(tensor);
}
/**
* Gets the shape of the input tensor with {@code inputIndex}.
*
* @param inputIndex the index of the desired input tensor
*/
public int[] getInputTensorShape(int inputIndex) {
return modelInfo.getInputTensorShape(inputIndex);
}
/**
* Gets the {@link DataType} of the input tensor with {@code inputIndex}.
*
* @param inputIndex the index of the desired input tensor
*/
public DataType getInputTensorType(int inputIndex) {
return modelInfo.getInputTensorType(inputIndex);
}
/** Gets the count of output tensors in the model. */
public int getOutputTensorCount() {
return modelInfo.getOutputTensorCount();
}
/**
* Gets the metadata for the output tensor specified by {@code outputIndex}.
*
* @param outputIndex the index of the desired output tensor
* @throws IllegalStateException if this model does not contain model metadata
*/
@Nullable
public TensorMetadata getOutputTensorMetadata(int outputIndex) {
assertMetadataInfo();
return metadataInfo.getOutputTensorMetadata(outputIndex);
}
/**
* Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
*
* @param outputIndex the index of the desired output tensor
*/
public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
Tensor tensor = modelInfo.getOutputTensor(outputIndex);
return modelInfo.getQuantizationParams(tensor);
}
/**
* Gets the shape of the output tensor with {@code outputIndex}.
*
* @param outputIndex the index of the desired output tensor
*/
public int[] getOutputTensorShape(int outputIndex) {
return modelInfo.getOutputTensorShape(outputIndex);
}
/**
* Gets the {@link DataType} of the output tensor with {@code outputIndex}.
*
* @param outputIndex the index of the desired output tensor
*/
public DataType getOutputTensorType(int outputIndex) {
return modelInfo.getOutputTensorType(outputIndex);
}
/**
* Asserts if {@link metdadataInfo} is not initialized. Some models may not have metadata and this
* is allowed. However, invoking methods that reads the metadata is not allowed.
*
* @throws IllegalStateException if this model does not contain model metadata
*/
private void assertMetadataInfo() {
if (metadataInfo == null) {
throw new IllegalStateException("This model does not contain model metadata.");
}
}
/**
* Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus
* are not Zip files. This is allowed. However, invoking methods that reads those associated files
* is not allowed.
*
* @throws IllegalStateException if this model is not a Zip file
*/
private void assertZipFile() {
if (zipFile == null) {
throw new IllegalStateException(
"This model does not contain associated files, and is not a Zip file.");
}
}
/**
* Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
* it does not have associated files, return a null handler.
*
* @param buffer the TFLite model FlatBuffer
* @throws IOException if an error occurs while reading the model as a Zip file
*/
@Nullable
private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
try {
// Creates the handler to hold the associated files through the Zip.
ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
return ZipFile.createFrom(byteBufferChannel);
} catch (ZipException e) {
// Some models may not have associate files. Therefore, Those models are not zip files.
// However, invoking methods that read associated files later will lead into errors.
return null;
}
}
}

View File

@ -0,0 +1,281 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Tensor.QuantizationParams;
import org.tensorflow.lite.schema.Buffer;
import org.tensorflow.lite.schema.Metadata;
import org.tensorflow.lite.schema.Model;
import org.tensorflow.lite.schema.QuantizationParameters;
import org.tensorflow.lite.schema.SubGraph;
import org.tensorflow.lite.schema.Tensor;
import org.tensorflow.lite.schema.TensorType;
import org.tensorflow.lite.support.common.SupportPreconditions;
/** Extracts model information out of TFLite model FLatBuffer. */
final class ModelInfo {
/** The model that is loaded from TFLite model FlatBuffer. */
private final Model model;
/** A list of input tensors. */
private final List</* @Nullable */ Tensor> inputTensors;
/** A list of output tensors. */
private final List</* @Nullable */ Tensor> outputTensors;
/** Identifier of the TFLite model metadata in the Metadata array. */
static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
/** Maps from TensorType in TFlite FlatBuffer to {@link DataType} in Java. */
private final Map<Byte, DataType> tensorTypeToDataTypeMap;
/**
* Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
*
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports
* single subgraph so far. See the <a
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
* of how to specify subgraph during convertion for more information.</a> Therefore, all methods
* in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
*
* @param buffer The TFLite model FlatBuffer.
* @throws NullPointerException if {@code buffer} is null.
* @throws IllegalArgumentException if the model does not contain any subgraph.
*/
ModelInfo(ByteBuffer buffer) {
SupportPreconditions.checkNotNull(buffer, "Model flatbuffer cannot be null.");
model = Model.getRootAsModel(buffer);
SupportPreconditions.checkArgument(
model.subgraphsLength() > 0, "The model does not contain any subgraph.");
inputTensors = getInputTensors(model);
outputTensors = getOutputTensors(model);
tensorTypeToDataTypeMap = createTensorTypeToDataTypeMap();
}
/**
* Gets the input tensor with {@code inputIndex}.
*
* @param inputIndex The index of the desired input tensor.
* @throws IllegalArgumentException if the inputIndex specified is invalid.
*/
@Nullable
Tensor getInputTensor(int inputIndex) {
SupportPreconditions.checkArgument(
inputIndex >= 0 && inputIndex < inputTensors.size(),
"The inputIndex specified is invalid.");
return inputTensors.get(inputIndex);
}
int getInputTensorCount() {
return inputTensors.size();
}
/**
* Gets shape of the input tensor with {@code inputIndex}.
*
* @param inputIndex The index of the desired intput tensor.
*/
int[] getInputTensorShape(int inputIndex) {
Tensor tensor = getInputTensor(inputIndex);
return getShape(tensor);
}
/**
* Gets {@link DataType} of the input tensor with {@code inputIndex}.
*
* @param inputIndex The index of the desired intput tensor.
*/
DataType getInputTensorType(int inputIndex) {
Tensor tensor = getInputTensor(inputIndex);
return getDataType(tensor.type());
}
/** Gets the metadata FlatBuffer from the model FlatBuffer. */
@Nullable
ByteBuffer getMetadataBuffer() {
// Some models may not have metadata, and this is allowed.
if (model.metadataLength() == 0) {
return null;
}
for (int i = 0; i < model.metadataLength(); i++) {
Metadata meta = model.metadata(i);
if (METADATA_FIELD_NAME.equals(meta.name())) {
long bufferIndex = meta.buffer();
Buffer metadataBuf = model.buffers((int) bufferIndex);
return metadataBuf.dataAsByteBuffer();
}
}
return null;
}
/**
* Gets the output tensor with {@code outputIndex}.
*
* @param outputIndex The index of the desired outtput tensor.
* @throws IllegalArgumentException if the outputIndex specified is invalid.
*/
@Nullable
Tensor getOutputTensor(int outputIndex) {
SupportPreconditions.checkArgument(
outputIndex >= 0 && outputIndex < outputTensors.size(),
"The outputIndex specified is invalid.");
return outputTensors.get(outputIndex);
}
int getOutputTensorCount() {
return outputTensors.size();
}
/**
* Gets shape of the output tensor with {@code outputIndex}.
*
* @param outputIndex The index of the desired outtput tensor.
*/
int[] getOutputTensorShape(int outputIndex) {
Tensor tensor = getOutputTensor(outputIndex);
return getShape(tensor);
}
/**
* Gets {@link DataType} of the output tensor {@code outputIndex}.
*
* @param outputIndex The index of the desired outtput tensor.
*/
DataType getOutputTensorType(int outputIndex) {
Tensor tensor = getOutputTensor(outputIndex);
return getDataType(tensor.type());
}
private static Map<Byte, DataType> createTensorTypeToDataTypeMap() {
Map<Byte, DataType> map = new HashMap<>();
map.put(TensorType.FLOAT32, DataType.FLOAT32);
map.put(TensorType.INT32, DataType.INT32);
map.put(TensorType.UINT8, DataType.UINT8);
map.put(TensorType.INT64, DataType.INT64);
map.put(TensorType.STRING, DataType.STRING);
return Collections.unmodifiableMap(map);
}
/**
* Gets the quantization parameters of a tensor.
*
* <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
* quantized, the values of scale and zero_point are both 0.
*
* @param tensor The tensor whoes quantization parameters is desired.
* @throws NullPointerException if the tensor is null.
* @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link
* QuantizationParameters} are not single values.
*/
QuantizationParams getQuantizationParams(Tensor tensor) {
SupportPreconditions.checkNotNull(tensor, "Tensor cannot be null.");
float scale;
int zeroPoint;
QuantizationParameters quantization = tensor.quantization();
// Tensors that are not quantized do not have quantization parameters, which can be null when
// being extracted from the flatbuffer.
if (quantization == null) {
scale = 0.0f;
zeroPoint = 0;
return new QuantizationParams(scale, zeroPoint);
}
// Tensors that are not quantized do not have quantization parameters.
// quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
SupportPreconditions.checkArgument(
quantization.scaleLength() <= 1,
"Input and output tensors do not support per-channel quantization.");
SupportPreconditions.checkArgument(
quantization.zeroPointLength() <= 1,
"Input and output tensors do not support per-channel quantization.");
// For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will
// both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++
// runtime.
scale = quantization.scale(0);
// zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it
// consistent with the C++ runtime.
zeroPoint = (int) quantization.zeroPoint(0);
return new QuantizationParams(scale, zeroPoint);
}
/**
* Transforms from TensorType in TFlite FlatBuffer to {@link DataType} in Java.
*
* @param tensorType The tensor type to be converted.
* @throws IllegalArgumentException if the tensor type is not supported.
*/
private DataType getDataType(byte tensorType) {
SupportPreconditions.checkArgument(
tensorTypeToDataTypeMap.containsKey(tensorType),
String.format("Tensor type %d is not supported.", tensorType));
return tensorTypeToDataTypeMap.get(tensorType);
}
/**
* Gets the shape of a tensor.
*
* @param tensor The tensor whoes shape is desired.
* @throws NullPointerException if the tensor is null.
*/
private static int[] getShape(Tensor tensor) {
SupportPreconditions.checkNotNull(tensor, "Tensor cannot be null.");
int shapeDim = tensor.shapeLength();
int[] tensorShape = new int[shapeDim];
for (int i = 0; i < shapeDim; i++) {
tensorShape[i] = tensor.shape(i);
}
return tensorShape;
}
/** Gets input tensors from a model. */
private static List<Tensor> getInputTensors(Model model) {
// TFLite only support one subgraph currently.
SubGraph subgraph = model.subgraphs(0);
int tensorNum = subgraph.inputsLength();
ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
for (int i = 0; i < tensorNum; i++) {
inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
}
return Collections.unmodifiableList(inputTensors);
}
/** Gets output tensors from a model. */
private static List<Tensor> getOutputTensors(Model model) {
// TFLite only support one subgraph currently.
SubGraph subgraph = model.subgraphs(0);
int tensorNum = subgraph.outputsLength();
ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
for (int i = 0; i < tensorNum; i++) {
outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
}
return Collections.unmodifiableList(outputTensors);
}
}

View File

@ -0,0 +1,114 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.support.common.SupportPreconditions;
import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
/** Extracts model metadata information out of TFLite metadata FlatBuffer. */
final class ModelMetadataInfo {
/** Metadata array of input tensors. */
private final List</* @Nullable */ TensorMetadata> inputsMetadata;
/** Metadata array of output tensors. */
private final List</* @Nullable */ TensorMetadata> outputsMetadata;
/**
* Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
*
* @param buffer The TFLite metadata FlatBuffer.
* @throws NullPointerException if {@code buffer} is null.
* @throws IllegalArgumentException if the metadata does not contain any subgraph metadata.
*/
ModelMetadataInfo(ByteBuffer buffer) {
SupportPreconditions.checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
ModelMetadata modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
SupportPreconditions.checkArgument(
modelMetadata.subgraphMetadataLength() > 0,
"The metadata flatbuffer does not contain any subgraph metadata.");
inputsMetadata = getInputsMetadata(modelMetadata);
outputsMetadata = getOutputsMetadata(modelMetadata);
}
/** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
int getInputTensorCount() {
return inputsMetadata.size();
}
/**
* Gets the metadata for the input tensor specified by {@code inputIndex}.
*
* @param inputIndex The index of the desired intput tensor.
* @throws IllegalArgumentException if the inputIndex specified is invalid.
*/
@Nullable
TensorMetadata getInputTensorMetadata(int inputIndex) {
SupportPreconditions.checkArgument(
inputIndex >= 0 && inputIndex < inputsMetadata.size(),
"The inputIndex specified is invalid.");
return inputsMetadata.get(inputIndex);
}
/** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
int getOutputTensorCount() {
return outputsMetadata.size();
}
/**
* Gets the metadata for the output tensor specified by {@code outputIndex}.
*
* @param outputIndex The index of the desired output tensor.
* @throws IllegalArgumentException if the outputIndex specified is invalid.
*/
@Nullable
TensorMetadata getOutputTensorMetadata(int outputIndex) {
SupportPreconditions.checkArgument(
outputIndex >= 0 && outputIndex < outputsMetadata.size(),
"The outputIndex specified is invalid.");
return outputsMetadata.get(outputIndex);
}
/** Gets metadata for all input tensors. */
private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
int tensorNum = subgraphMetadata.inputTensorMetadataLength();
ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
for (int i = 0; i < tensorNum; i++) {
inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
}
return Collections.unmodifiableList(inputsMetadata);
}
/** Gets metadata for all output tensors. */
private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
int tensorNum = subgraphMetadata.outputTensorMetadataLength();
ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
for (int i = 0; i < tensorNum; i++) {
outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
}
return Collections.unmodifiableList(outputsMetadata);
}
}

View File

@ -0,0 +1,107 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channel;
/**
* A byte channel that maintains a current <i>position</i> and allows the position to be changed.
* {@link SeekableByteChannelCompat} is compatible with {@link
* java.nio.channels.SeekableByteChannel}.
*
* <p>{@link java.nio.channels.SeekableByteChannel} is not available in Android API 23 and under.
* Therefore, {@link SeekableByteChannelCompat} is introduced here to make the interfaces used in
* the MetadtaExtractor library consistent with the common used Java libraries.
*/
interface SeekableByteChannelCompat extends Channel {
/**
* Reads a sequence of bytes from this channel into the given buffer.
*
* @param dst The buffer into which bytes are to be transferred
* @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
* end-of-stream
* @throws NonReadableChannelException If this channel was not opened for reading
* @throws ClosedChannelException If this channel is closed
* @throws AsynchronousCloseException If another thread closes this channel while the read
* operation is in progress
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
* read operation is in progress, thereby closing the channel and setting the current thread's
* interrupt status
* @throws IOException If some other I/O error occurs
*/
int read(ByteBuffer dst) throws IOException;
/**
* Writes a sequence of bytes to this channel from the given buffer.
*
* @param src The buffer from which bytes are to be retrieved
* @return The number of bytes written, possibly zero
* @throws NonWritableChannelException If this channel was not opened for writing
* @throws ClosedChannelException If this channel is closed
* @throws AsynchronousCloseException If another thread closes this channel while the write
* operation is in progress
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
* write operation is in progress, thereby closing the channel and setting the current
* thread's interrupt status
* @throws IOException If some other I/O error occurs
*/
int write(ByteBuffer src) throws IOException;
/**
* Returns this channel's position.
*
* @return This channel's position, a non-negative integer counting the number of bytes from the
* beginning of the entity to the current position
* @throws ClosedChannelException If this channel is closed
* @throws IOException If some other I/O error occurs
*/
long position() throws IOException;
/**
* Sets this channel's position.
*
* @param newPosition The new position, a non-negative integer counting the number of bytes from
* the beginning of the entity
* @return This channel
* @throws ClosedChannelException If this channel is closed
* @throws IllegalArgumentException If the new position is negative
* @throws IOException If some other I/O error occurs
*/
SeekableByteChannelCompat position(long newPosition) throws IOException;
/**
* Returns the current size of entity to which this channel is connected.
*
* @return The current size, measured in bytes
* @throws ClosedChannelException If this channel is closed
* @throws IOException If some other I/O error occurs
*/
long size() throws IOException;
/**
* Truncates the entity, to which this channel is connected, to the given size.
*
* @param size The new size, a non-negative byte count
* @return This channel
* @throws NonWritableChannelException If this channel was not opened for writing
* @throws ClosedChannelException If this channel is closed
* @throws IllegalArgumentException If the new size is negative
* @throws IOException If some other I/O error occurs
*/
SeekableByteChannelCompat truncate(long size) throws IOException;
}

View File

@ -0,0 +1,427 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.metadata;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.ZipException;
/**
* Reads uncompressed files from the TFLite model, a zip file.
*
* <p>TODO(b/150237111): add a link to the webpage of MetadataPopulator once it's available.
*
* <p>A TFLite model file becomes a zip file when it contains associated files. The associated files
* can be packed to a TFLite model file using the MetadataPopulator. The associated files are not
* compressed when being added to the model file.
*
* <p>{@link ZipFile} does not support Zip64 format, because TFLite models are much smaller than the
* size limit for Zip64, which is 4GB.
*/
final class ZipFile implements Closeable {
/** Maps String to list of ZipEntrys, name -> actual entries. */
private final Map<String, List<ZipEntry>> nameMap;
/** The actual data source. */
private final ByteBufferChannel archive;
/**
* Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
* ZipFile} does not synchronized over the buffer that is passed into it.
*
* @param channel the archive
* @throws IOException if an error occurs while creating this {@link ZipFile}
* @throws ZipException if the channel is not a zip archive
* @throws NullPointerException if the archive is null
*/
public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
checkNotNull(channel);
ZipParser zipParser = new ZipParser(channel);
Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
return new ZipFile(channel, nameMap);
}
@Override
public void close() {
archive.close();
}
/**
* Exposes the raw stream of the archive entry.
*
* <p>Since the associated files will not be compressed when being packed to the zip file, the raw
* stream represents the non-compressed files.
*
* <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
* threads concurrently reading from the returned {@link InputStream}, it must be synchronized
* externally.
*
* @param name name of the entry to get the stream for
* @return the raw input stream containing data
* @throws IllegalArgumentException if the specified file does not exist in the zip file
*/
public InputStream getRawInputStream(String name) {
checkArgument(
nameMap.containsKey(name),
String.format("The file, %s, does not exist in the zip file.", name));
List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
ZipEntry entry = entriesWithTheSameName.get(0);
long start = entry.getDataOffset();
long remaining = entry.getSize();
return new BoundedInputStream(archive, start, remaining);
}
private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
archive = channel;
this.nameMap = nameMap;
}
/* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
private static class ZipParser {
private final ByteBufferChannel archive;
// Cached buffers that will only be used locally in the class to reduce garbage collection.
private final ByteBuffer longBuffer =
ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer intBuffer =
ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
private final ByteBuffer shortBuffer =
ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
private ZipParser(ByteBufferChannel archive) {
this.archive = archive;
}
/**
* Parses the underlying {@code archive} and returns the information as a list of {@link
* ZipEntry}.
*/
private Map<String, List<ZipEntry>> parseEntries() throws IOException {
List<ZipEntry> entries = parseCentralDirectory();
return parseLocalFileHeaderData(entries);
}
/**
* Checks if the current position contains a central file header signature, {@link
* ZipConstants#CENSIG}.
*/
private boolean foundCentralFileheaderSignature() {
long signature = (long) getInt();
return signature == ZipConstants.CENSIG;
}
/**
* Gets the value as a Java int from two bytes starting at the current position of the archive.
*/
private int getShort() {
shortBuffer.rewind();
archive.read(shortBuffer);
shortBuffer.flip();
return (int) shortBuffer.getShort();
}
/**
* Gets the value as a Java long from four bytes starting at the current position of the
* archive.
*/
private int getInt() {
intBuffer.rewind();
archive.read(intBuffer);
intBuffer.flip();
return intBuffer.getInt();
}
/**
* Gets the value as a Java long from four bytes starting at the current position of the
* archive.
*/
private long getLong() {
longBuffer.rewind();
archive.read(longBuffer);
longBuffer.flip();
return longBuffer.getLong();
}
/**
* Positions the archive at the start of the central directory.
*
* <p>First, it searches for the signature of the "end of central directory record", {@link
* ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
* record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG}
* should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file.
*
* <p>Then, parse the "end of central dir record" and position the archive at the start of the
* central directory.
*/
private void locateCentralDirectory() throws IOException {
if (archive.size() < ZipConstants.ENDHDR) {
throw new ZipException("The archive is not a ZIP archive.");
}
// Positions the archive at the start of the "end of central directory record".
long offsetRecord = archive.size() - ZipConstants.ENDHDR;
archive.position(offsetRecord);
// Checks for the signature, {@link ZipConstants#ENDSIG}.
long endSig = getLong();
if (endSig != ZipConstants.ENDSIG) {
throw new ZipException("The archive is not a ZIP archive.");
}
// Positions the archive at the offset of central directory.
skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
// Gets the offset to central directory
long offsetDirectory = getInt();
// Goes to the central directory.
archive.position(offsetDirectory);
}
/**
* Reads the central directory of the given archive and populates the internal tables with
* {@link ZipEntry} instances.
*/
private List<ZipEntry> parseCentralDirectory() throws IOException {
/** List of entries in the order they appear inside the central directory. */
List<ZipEntry> entries = new ArrayList<>();
locateCentralDirectory();
while (foundCentralFileheaderSignature()) {
ZipEntry entry = parseCentralDirectoryEntry();
entries.add(entry);
}
return entries;
}
/**
* Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to
* the global maps.
*/
private ZipEntry parseCentralDirectoryEntry() throws IOException {
// Positions the archive at the "compressed size" and read the value.
skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
long compressSize = getInt();
// Positions the archive at the "filename length" and read the value.
skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
int fileNameLen = getShort();
// Reads the extra field length and the comment length.
int extraLen = getShort();
int commentLen = getShort();
// Positions the archive at the "local file header offset" and read the value.
skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
long localHeaderOffset = getInt();
// Reads the file name.
byte[] fileNameBuf = new byte[fileNameLen];
archive.read(ByteBuffer.wrap(fileNameBuf));
String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
// Skips the extra field and the comment.
skipBytes(extraLen + commentLen);
ZipEntry entry = new ZipEntry();
entry.setSize(compressSize);
entry.setLocalHeaderOffset(localHeaderOffset);
entry.setName(fileName);
return entry;
}
/** Walks through all recorded entries and records the offsets for the entry data. */
private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
/** Maps String to list of ZipEntrys, name -> actual entries. */
Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
for (ZipEntry entry : entries) {
long offset = entry.getLocalHeaderOffset();
archive.position(offset + ZipConstants.LOCNAM);
// Gets the data offset of this entry.
int fileNameLen = getShort();
int extraFieldLen = getShort();
long dataOffset =
offset
+ ZipConstants.LOCEXT
+ ZipConstants.SHORT_BYTE_SIZE
+ fileNameLen
+ extraFieldLen;
entry.setDataOffset(dataOffset);
// Puts the entry into the nameMap.
String name = entry.getName();
List<ZipEntry> entriesWithTheSameName;
if (nameMap.containsKey(name)) {
entriesWithTheSameName = nameMap.get(name);
} else {
entriesWithTheSameName = new ArrayList<>();
nameMap.put(name, entriesWithTheSameName);
}
entriesWithTheSameName.add(entry);
}
return nameMap;
}
/** Skips the given number of bytes or throws an EOFException if skipping failed. */
private void skipBytes(int count) throws IOException {
long currentPosition = archive.position();
long newPosition = currentPosition + count;
if (newPosition > archive.size()) {
throw new EOFException();
}
archive.position(newPosition);
}
}
/** Stores the data offset and the size of an entry in the archive. */
private static class ZipEntry {
private String name;
private long dataOffset = -1;
private long size = -1;
private long localHeaderOffset = -1;
public long getSize() {
return size;
}
public long getDataOffset() {
return dataOffset;
}
public String getName() {
return name;
}
public long getLocalHeaderOffset() {
return localHeaderOffset;
}
public void setSize(long size) {
this.size = size;
}
public void setDataOffset(long dataOffset) {
this.dataOffset = dataOffset;
}
public void setName(String name) {
this.name = name;
}
public void setLocalHeaderOffset(long localHeaderOffset) {
this.localHeaderOffset = localHeaderOffset;
}
}
/**
* Various constants for this {@link ZipFile}.
*
* <p>Referenced from {@link java.util.zip.ZipConstants}.
*/
private static class ZipConstants {
/** length of Java short in bytes. */
static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
/** length of Java int in bytes. */
static final int INT_BYTE_SIZE = Integer.SIZE / 8;
/** length of Java long in bytes. */
static final int LONG_BYTE_SIZE = Long.SIZE / 8;
/*
* Header signatures
*/
static final long LOCSIG = 0x04034b50L; // "PK\003\004"
static final long EXTSIG = 0x08074b50L; // "PK\007\008"
static final long CENSIG = 0x02014b50L; // "PK\001\002"
static final long ENDSIG = 0x06054b50L; // "PK\005\006"
/*
* Header sizes in bytes (including signatures)
*/
static final int LOCHDR = 30; // LOC header size
static final int EXTHDR = 16; // EXT header size
static final int CENHDR = 46; // CEN header size
static final int ENDHDR = 22; // END header size
/*
* Local file (LOC) header field offsets
*/
static final int LOCVER = 4; // version needed to extract
static final int LOCFLG = 6; // general purpose bit flag
static final int LOCHOW = 8; // compression method
static final int LOCTIM = 10; // modification time
static final int LOCCRC = 14; // uncompressed file crc-32 value
static final int LOCSIZ = 18; // compressed size
static final int LOCLEN = 22; // uncompressed size
static final int LOCNAM = 26; // filename length
static final int LOCEXT = 28; // extra field length
/*
* Extra local (EXT) header field offsets
*/
static final int EXTCRC = 4; // uncompressed file crc-32 value
static final int EXTSIZ = 8; // compressed size
static final int EXTLEN = 12; // uncompressed size
/*
* Central directory (CEN) header field offsets
*/
static final int CENVEM = 4; // version made by
static final int CENVER = 6; // version needed to extract
static final int CENFLG = 8; // encrypt, decrypt flags
static final int CENHOW = 10; // compression method
static final int CENTIM = 12; // modification time
static final int CENCRC = 16; // uncompressed file crc-32 value
static final int CENSIZ = 20; // compressed size
static final int CENLEN = 24; // uncompressed size
static final int CENNAM = 28; // filename length
static final int CENEXT = 30; // extra field length
static final int CENCOM = 32; // comment length
static final int CENDSK = 34; // disk number start
static final int CENATT = 36; // internal file attributes
static final int CENATX = 38; // external file attributes
static final int CENOFF = 42; // LOC header offset
/*
* End of central directory (END) header field offsets
*/
static final int ENDSUB = 8; // number of entries on this disk
static final int ENDTOT = 10; // total number of entries
static final int ENDSIZ = 12; // central directory size in bytes
static final int ENDOFF = 16; // offset of first CEN header
static final int ENDCOM = 20; // zip file comment length
private ZipConstants() {}
}
}

View File

@ -0,0 +1,542 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite metadata tools."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
import shutil
import subprocess
import tempfile
import warnings
import zipfile
from flatbuffers.python import flatbuffers
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
from tensorflow.python.platform import resource_loader
_FLATC_BINARY_PATH = resource_loader.get_path_to_datafile(
"../../../../../external/flatbuffers/flatc")
_FLATC_TFLITE_METADATA_SCHEMA_FILE = resource_loader.get_path_to_datafile(
"metadata_schema.fbs")
# TODO(b/141467403): add delete method for associated files.
class MetadataPopulator(object):
"""Packs metadata and associated files into TensorFlow Lite model file.
MetadataPopulator can be used to populate metadata and model associated files
into a model file or a model buffer (in bytearray). It can also help to
inspect list of files that have been packed into the model or are supposed to
be packed into the model.
The metadata file (or buffer) should be generated based on the metadata
schema:
third_party/tensorflow/lite/schema/metadata_schema.fbs
Example usage:
Populate matadata and label file into an image classifier model.
First, based on metadata_schema.fbs, generate the metadata for this image
classifer model using Flatbuffers API. Attach the label file onto the ouput
tensor (the tensor of probabilities) in the metadata.
Then, pack the metadata and lable file into the model as follows.
```python
# Populating a metadata file (or a metadta buffer) and associated files to
a model file:
populator = MetadataPopulator.with_model_file(model_file)
# For metadata buffer (bytearray read from the metadata file), use:
# populator.load_metadata_buffer(metadata_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
populator.populate()
# Populating a metadata file (or a metadta buffer) and associated files to
a model buffer:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
populator.populate()
# Writing the updated model buffer into a file.
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
```
"""
# As Zip API is used to concatenate associated files after tflite model file,
# the populating operation is developed based on a model file. For in-memory
# model buffer, we create a tempfile to serve the populating operation.
# Creating the deleting such a tempfile is handled by the class,
# _MetadataPopulatorWithBuffer.
METADATA_FIELD_NAME = "TFLITE_METADATA"
TFLITE_FILE_IDENTIFIER = b"TFL3"
METADATA_FILE_IDENTIFIER = b"M001"
def __init__(self, model_file):
"""Constructor for MetadataPopulator.
Args:
model_file: valid path to a TensorFlow Lite model file.
Raises:
IOError: File not found.
"""
_assert_exist(model_file)
self._model_file = model_file
self._metadata_buf = None
self._associated_files = set()
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataPopulator object that populates data to a model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataPopulator object.
Raises:
IOError: File not found.
"""
return cls(model_file)
# TODO(b/141468993): investigate if type check can be applied to model_buf for
# FB.
@classmethod
def with_model_buffer(cls, model_buf):
"""Creates a MetadataPopulator object that populates data to a model buffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Returns:
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
"""
return _MetadataPopulatorWithBuffer(model_buf)
def get_model_buffer(self):
"""Gets the buffer of the model with packed metadata and associated files.
Returns:
Model buffer (in bytearray).
"""
with open(self._model_file, "rb") as f:
return f.read()
def get_packed_associated_file_list(self):
"""Gets a list of associated files packed to the model file.
Returns:
List of packed associated files.
"""
if not zipfile.is_zipfile(self._model_file):
return []
with zipfile.ZipFile(self._model_file, "r") as zf:
return zf.namelist()
def get_recorded_associated_file_list(self):
"""Gets a list of associated files recorded in metadata of the model file.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Returns:
List of recorded associated files.
"""
recorded_files = []
if not self._metadata_buf:
return recorded_files
metadata = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
self._metadata_buf, 0)
# Add associated files attached to ModelMetadata
self._get_associated_files_from_metadata_struct(metadata, recorded_files)
# Add associated files attached to each SubgraphMetadata
for j in range(metadata.SubgraphMetadataLength()):
subgraph = metadata.SubgraphMetadata(j)
self._get_associated_files_from_metadata_struct(subgraph, recorded_files)
# Add associated files attached to each input tensor
for k in range(subgraph.InputTensorMetadataLength()):
tensor = subgraph.InputTensorMetadata(k)
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
# Add associated files attached to each output tensor
for k in range(subgraph.OutputTensorMetadataLength()):
tensor = subgraph.OutputTensorMetadata(k)
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
return recorded_files
def load_associated_files(self, associated_files):
"""Loads associated files that to be concatenated after the model file.
Args:
associated_files: list of file paths.
Raises:
IOError:
File not found.
"""
for af in associated_files:
_assert_exist(af)
self._associated_files.add(af)
def load_metadata_buffer(self, metadata_buf):
"""Loads the metadata buffer (in bytearray) to be populated.
Args:
metadata_buf: metadata buffer (in bytearray) to be populated.
Raises:
ValueError:
The metadata to be populated is empty.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")
self._metadata_buf = metadata_buf
def load_metadata_file(self, metadata_file):
"""Loads the metadata file to be populated.
Args:
metadata_file: path to the metadata file to be populated.
Raises:
IOError:
File not found.
"""
_assert_exist(metadata_file)
with open(metadata_file, "rb") as f:
metadata_buf = f.read()
self.load_metadata_buffer(bytearray(metadata_buf))
def populate(self):
"""Populates loaded metadata and associated files into the model file."""
self._assert_validate()
self._populate_metadata_buffer()
self._populate_associated_files()
def _assert_validate(self):
"""Validates the metadata and associated files to be populated.
Raises:
ValueError:
File is recorded in the metadata, but is not going to be populated.
File has already been packed.
"""
# Gets files that are recorded in metadata.
recorded_files = self.get_recorded_associated_file_list()
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
# Gets the file name of those associated files to be populated.
to_be_populated_files = []
for af in self._associated_files:
to_be_populated_files.append(os.path.basename(af))
# Checks all files recorded in the metadata will be populated.
for rf in recorded_files:
if rf not in to_be_populated_files and rf not in packed_files:
raise ValueError("File, '{0}', is recorded in the metadata, but has "
"not been loaded into the populator.".format(rf))
for f in to_be_populated_files:
if f in packed_files:
raise ValueError("File, '{0}', has already been packed.".format(f))
if f not in recorded_files:
warnings.warn(
"File, '{0}', does not exsit in the metadata. But packing it to "
"tflite model is still allowed.".format(f))
def _copy_archived_files(self, src_zip, dst_zip, file_list):
"""Copy archieved files in file_list from src_zip ro dst_zip."""
if not zipfile.is_zipfile(src_zip):
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
with zipfile.ZipFile(src_zip,
"r") as src_zf, zipfile.ZipFile(dst_zip,
"a") as dst_zf:
src_list = src_zf.namelist()
for f in file_list:
if f not in src_list:
raise ValueError(
"File, '{0}', does not exist in the zipfile, {1}.".format(
f, src_zip))
file_buffer = src_zf.read(f)
dst_zf.writestr(f, file_buffer)
def _get_associated_files_from_metadata_struct(self, file_holder, file_list):
for j in range(file_holder.AssociatedFilesLength()):
file_list.append(file_holder.AssociatedFiles(j).Name().decode("utf-8"))
def _populate_associated_files(self):
"""Concatenates associated files after TensorFlow Lite model file.
If the MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
"""
# Opens up the model file in "appending" mode.
# If self._model_file already has pack files, zipfile will concatenate
# addition files after self._model_file. For example, suppose we have
# self._model_file = old_tflite_file | label1.txt | label2.txt
# Then after trigger populate() to add label3.txt, self._model_file becomes
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
with zipfile.ZipFile(self._model_file, "a") as zf:
for af in self._associated_files:
filename = os.path.basename(af)
zf.write(af, filename)
def _populate_metadata_buffer(self):
"""Populates the metadata buffer (in bytearray) into the model file.
Inserts metadata_buf into the metadata field of schema.Model. If the
MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
"""
with open(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.ModelT.InitFromObj(
_schema_fb.Model.GetRootAsModel(model_buf, 0))
buffer_field = _schema_fb.BufferT()
buffer_field.data = self._metadata_buf
is_populated = False
if not model.metadata:
model.metadata = []
else:
# Check if metadata has already been populated.
for meta in model.metadata:
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
is_populated = True
model.buffers[meta.buffer] = buffer_field
if not is_populated:
if not model.buffers:
model.buffers = []
model.buffers.append(buffer_field)
# Creates a new metadata field.
metadata_field = _schema_fb.MetadataT()
metadata_field.name = self.METADATA_FIELD_NAME
metadata_field.buffer = len(model.buffers) - 1
model.metadata.append(metadata_field)
# Packs model back to a flatbuffer binaray file.
b = flatbuffers.Builder(0)
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
model_buf = b.Output()
# Saves the updated model buffer to model file.
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
if packed_files:
# Writes the updated model buffer and associated files into a new model
# file. Then overwrites the original model file.
with tempfile.NamedTemporaryFile() as temp:
new_file = temp.name
with open(new_file, "wb") as f:
f.write(model_buf)
self._copy_archived_files(self._model_file, new_file, packed_files)
shutil.copy(new_file, self._model_file)
os.remove(new_file)
else:
with open(self._model_file, "wb") as f:
f.write(model_buf)
class _MetadataPopulatorWithBuffer(MetadataPopulator):
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
This class is used to populate metadata into a in-memory model buffer. As we
use Zip API to concatenate associated files after tflite model file, the
populating operation is developed based on a model file. For in-memory model
buffer, we create a tempfile to serve the populating operation. This class is
then used to generate this tempfile, and delete the file when the
MetadataPopulator object is deleted.
"""
def __init__(self, model_buf):
"""Constructor for _MetadataPopulatorWithBuffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Raises:
ValueError: model_buf is empty.
"""
if not model_buf:
raise ValueError("model_buf cannot be empty.")
with tempfile.NamedTemporaryFile() as temp:
model_file = temp.name
with open(model_file, "wb") as f:
f.write(model_buf)
MetadataPopulator.__init__(self, model_file)
def __del__(self):
"""Destructor of _MetadataPopulatorWithBuffer.
Deletes the tempfile.
"""
if os.path.exists(self._model_file):
os.remove(self._model_file)
class MetadataDisplayer(object):
"""Displays metadata and associated file info in human-readable format."""
def __init__(self, model_file, metadata_file, associated_file_list):
"""Constructor for MetadataDisplayer.
Args:
model_file: valid path to the model file.
metadata_file: valid path to the metadata file.
associated_file_list: list of associate files in the model file.
"""
self._model_file = model_file
self._metadata_file = metadata_file
self._associated_file_list = associated_file_list
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataDisplayer object for the model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataDisplayer object.
Raises:
IOError: File not found.
ValueError: The model does not have metadata.
"""
_assert_exist(model_file)
metadata_file = cls._save_temporary_metadata_file(model_file)
associated_file_list = cls._parse_packed_associted_file_list(model_file)
return cls(model_file, metadata_file, associated_file_list)
def export_metadata_json_file(self, export_dir):
"""Converts the metadata into a json file.
Args:
export_dir: the directory that the json file will be exported to. The json
file will be named after the model file, but with ".json" as extension.
"""
subprocess.check_call([
_FLATC_BINARY_PATH, "-o", export_dir, "--json",
_FLATC_TFLITE_METADATA_SCHEMA_FILE, "--", self._metadata_file,
"--strict-json"
])
temp_name = os.path.join(
export_dir,
os.path.splitext(os.path.basename(self._metadata_file))[0] + ".json")
expected_name = os.path.join(
export_dir,
os.path.splitext(os.path.basename(self._model_file))[0] + ".json")
os.rename(temp_name, expected_name)
def get_packed_associated_file_list(self):
"""Returns a list of associated files that are packed in the model.
Returns:
A name list of associated files.
"""
return copy.deepcopy(self._associated_file_list)
@staticmethod
def _save_temporary_metadata_file(model_file):
"""Saves the metadata in the model file to a temporary file.
Args:
model_file: valid path to the model file.
Returns:
Path to the metadata temporary file.
Raises:
ValueError: The model does not have metadata.
"""
with open(model_file, "rb") as f:
model_buf = f.read()
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
# Gets metadata from the model file.
for i in range(tflite_model.MetadataLength()):
meta = tflite_model.Metadata(i)
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
buffer_index = meta.Buffer()
metadata = tflite_model.Buffers(buffer_index)
metadata_buf = metadata.DataAsNumpy().tobytes()
# Creates a temporary file to store the metadata.
with tempfile.NamedTemporaryFile() as temp:
metadata_file = temp.name
# Saves the metadata into the temporary file.
with open(metadata_file, "wb") as f:
f.write(metadata_buf)
return metadata_file
raise ValueError("The model does not have metadata.")
@staticmethod
def _parse_packed_associted_file_list(model_file):
"""Gets a list of associated files packed to the model file.
Args:
model_file: valid path to the model file.
Returns:
List of packed associated files.
"""
if not zipfile.is_zipfile(model_file):
return []
with zipfile.ZipFile(model_file, "r") as zf:
return zf.namelist()
def __del__(self):
"""Destructor of MetadataDisplayer.
Deletes the tempfile.
"""
if os.path.exists(self._metadata_file):
os.remove(self._metadata_file)
def _assert_exist(filename):
"""Checks if a file exists."""
if not os.path.exists(filename):
raise IOError("File, '{0}', does not exist.".format(filename))

View File

@ -0,0 +1,499 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace tflite;
// TFLite metadata contains both human readable and machine readable information
// about what the model does and how to use the model. It can be used as a
// README file, which elaborates the details of the model, each input/ouput
// tensor, and each associated file.
//
// An important use case of TFLite metadata is the TFLite codegen tool, which
// automatically generates the model interface based on the properties of the
// model and the tensors. The model interface provides high-level APIs to
// interact with the model, such as preprocessing the input data and running
// inferences.
//
// Entries marked with "<Codegen usage>" are used in TFLite codegen tool to
// generate the model interface. It is recommended to fill in at least those
// enties to boost the codegen performance.
// This corresponds to the schema version.
file_identifier "M001";
// File extension of any written files.
file_extension "tflitemeta";
enum AssociatedFileType : byte {
UNKNOWN = 0,
// Files such as readme.txt
DESCRIPTIONS = 1,
// Contains labels that annotate certain axis of the tensor. For example,
// the label file in image classification. Those labels annotate the
// the output tensor, such that each value in the output tensor is the
// probability of that corresponding category specified by the label.
//
// <Codegen usage>:
// If an output tensor has an associated file as TENSOR_AXIS_LABELS, return
// the output as a mapping between the labels and probability in the model
// interface.
// If multiple files of the same type are present, the first one is used by
// default; additional ones are to be distinguished from one another by their
// specified locale.
TENSOR_AXIS_LABELS = 2,
// Contains labels that tensor values correspond to. For example, in
// the object detection model, one of the output tensors is the detected
// classes. And each value in the tensor refers to the index of label in the
// category label file.
//
// <Codegen usage>:
// If an output tensor has an associated file as TENSOR_VALUE_LABELS, convert
// the tensor values into labels, and return a list of string as the output.
// If multiple files of the same type are present, the first one is used by
// default; additional ones are to be distinguished from one another by their
// specified locale.
TENSOR_VALUE_LABELS = 3,
// Contains sigmoid-based score calibration parameters, formatted as CSV.
// Lines contain for each index of an output tensor the scale, slope, offset
// and min_score parameters to be used for sigmoid fitting (in this order and
// in `strtof`-compatible [1] format).
// A line may be left empty to default calibrated scores for this index to
// default_score. See documentation for ScoreCalibrationOptions for details.
//
// [1]: https://en.cppreference.com/w/c/string/byte/strtof
TENSOR_AXIS_SCORE_CALIBRATION = 4,
}
table AssociatedFile {
// Name of this file. Need to be exact the same as the name of the actual file
// packed into the TFLite model as a zip file.
//
// <Codegen usage>:
// Locates to the actual file in the TFLite model.
name:string;
// A description of what the file is.
description:string;
// Type of the associated file. There may be special pre/post processing for
// some types. For example in image classification, a label file of the output
// will be used to convert object index into string.
//
// <Codegen usage>:
// Determines how to process the corresponding tensor.
type:AssociatedFileType;
// An optional locale for this associated file (if applicable). It is
// recommended to use an ISO 639-1 letter code (e.g. "en" for English),
// optionally completed by a two letter region code (e.g. "en-US" for US
// English and "en-CA" for Canadian English).
// Leverage this in order to specify e.g multiple label files translated in
// different languages.
locale:string;
}
// The basic content type for all tensors.
//
// <Codegen usage>:
// Input feature tensors:
// 1. Generates the method to load data from a TensorBuffer.
// 2. Creates the preprocessing logic. The default processing pipeline is:
// [NormalizeOp, QuantizeOp].
// Output feature tensors:
// 1. Generates the method to return the output data to a TensorBuffer.
// 2. Creates the post-processing logic. The default processing pipeline is:
// [DeQuantizeOp].
table FeatureProperties {
}
// The type of color space of an image.
enum ColorSpaceType : byte {
UNKNOWN = 0,
RGB = 1,
GRAYSCALE = 2,
}
table ImageSize {
width:uint;
height:uint;
}
// The properties for image tensors.
//
// <Codegen usage>:
// Input image tensors:
// 1. Generates the method to load an image from a TensorImage.
// 2. Creates the preprocessing logic. The default processing pipeline is:
// [ResizeOp, NormalizeOp, QuantizeOp].
// Output image tensors:
// 1. Generates the method to return the output data to a TensorImage.
// 2. Creates the post-processing logic. The default processing pipeline is:
// [DeQuantizeOp].
table ImageProperties {
// The color space of the image.
//
// <Codegen usage>:
// Determines how to convert the color space of a given image from users.
color_space:ColorSpaceType;
// Indicates the default value of image width and height if the tensor shape
// is dynamic. For fixed-size tensor, this size will be consistent with the
// expected size.
default_size:ImageSize;
}
// The properties for tensors representing bounding boxes.
//
// <Codegen usage>:
// Input image tensors: NA.
// Output image tensors: parses the values into a data stucture that represents
// bounding boxes. For example, in the generated wrapper for Android, it returns
// the output as android.graphics.Rect objects.
enum BoundingBoxType : byte {
UNKNOWN = 0,
// Represents the bounding box by using the combination of boundaries,
// {left, top, right, bottom}.
// The default order is {left, top, right, bottom}. Other orders can be
// indicated by BoundingBoxProperties.index.
BOUNDARIES = 1,
// Represents the bounding box by using the upper_left corner, width and
// height.
// The default order is {upper_left_x, upper_left_y, width, height}. Other
// orders can be indicated by BoundingBoxProperties.index.
UPPER_LEFT = 2,
// Represents the bounding box by using the center of the box, width and
// height. The default order is {center_x, center_y, width, height}. Other
// orders can be indicated by BoundingBoxProperties.index.
CENTER = 3,
}
enum CoordinateType : byte {
// The coordinates are float values from 0 to 1.
RATIO = 0,
// The coordinates are integers.
PIXEL = 1,
}
table BoundingBoxProperties {
// Denotes the order of the elements defined in each bounding box type. An
// empty index array represent the defualt order of each bounding box type.
// For example, to denote the default order of BOUNDARIES, {left, top, right,
// bottom}, the index should be {0, 1, 2, 3}. To denote the order {left,
// right, top, bottom}, the order should be {0, 2, 1, 3}.
//
// The index array can be applied to all bounding box types to adjust the
// order of their corresponding underlying elements.
//
// <Codegen usage>:
// Indicates how to parse the bounding box values.
index:[uint];
// <Codegen usage>:
// Indicates how to parse the bounding box values.
type:BoundingBoxType;
// <Codegen usage>:
// Indicates how to convert the bounding box back to the original image in
// pixels.
coordinate_type:CoordinateType;
}
union ContentProperties {
FeatureProperties,
ImageProperties,
BoundingBoxProperties,
}
table ValueRange {
min:int;
max:int;
}
table Content {
// The properties that the content may have, indicating the type of the
// Content.
//
// <Codegen usage>:
// Indicates how to process the tensor.
content_properties:ContentProperties;
// The range of dimensions that the content corresponds to. A NULL
// "range" indicates that the content uses up all dimensions,
// except the batch axis if applied.
//
// Here are all the possible situations of how a tensor is composed.
// Case 1: The tensor is a single object, such as an image.
// For example, the input of an image classifier
// (https://www.tensorflow.org/lite/models/image_classification/overview),
// a tensor of shape [1, 224, 224, 3]. Dimensions 1 to 3 correspond to the
// image. Since dimension 0 is a batch axis, which can be ignored,
// "range" can be left as NULL.
//
// Case 2: The tensor contains multiple instances of the same object.
// For example, the output tensor of detected bounding boxes of an object
// detection model
// (https://www.tensorflow.org/lite/models/object_detection/overview).
// The tensor shape is [1, 10, 4]. Here is the what the three dimensions
// represent for:
// dimension 0: the batch axis.
// dimension 1: the 10 objects detected with the highest confidence.
// dimension 2: the bounding boxes of the 10 detected objects.
// The tensor is essentially 10 bounding boxes. In this case,
// "range" should be {min=2; max=2;}.
// Another example is the pose estimation model
// (https://www.tensorflow.org/lite/models/pose_estimation/overview).
// The output tensor of heatmaps is in the shape of [1, 9, 9, 17].
// Here is the what the four dimensions represent for:
// dimension 0: the batch axis.
// dimension 1/2: the heatmap image.
// dimension 3: 17 body parts of a person.
// Even though the last axis is body part, the real content of this tensor is
// the heatmap. "range" should be [min=2; max=3].
//
// Case 3: The tensor contains multiple different objects. (Not supported by
// Content at this point).
// Sometimes a tensor may contain multiple different objects, thus different
// contents. It is very common for regression models. For example, a model
// to predict the fuel efficiency
// (https://www.tensorflow.org/tutorials/keras/regression).
// The input tensor has shape [1, 9], consisting of 9 features, such as
// "Cylinders", "Displacement", "Weight", etc. In this case, dimension 1
// contains 9 different contents. However, since these sub-dimension objects
// barely need to be specifically processed, their contents are not recorded
// in the metadata. Through, the name of each dimension can be set through
// TensorMetadata.dimension_names.
//
// Note that if it is not case 3, a tensor can only have one content type.
//
// <Codegen usage>:
// Case 1: return a processed single object of certain content type.
// Case 2: return a list of processed objects of certain content type. The
// generated model interface have API to random access those objects from
// the output.
range:ValueRange;
}
// Parameters that are used when normalizing the tensor.
table NormalizationOptions{
// mean and std are normalization parameters. Tensor values are normailzed
// per-channelly by,
// (x - mean) / std.
// For example, a float MobileNet model will have
// mean = 127.5f and std = 127.5f.
// A quantized MobileNet model will have
// mean = 0.0f and std = 1.0f.
// If there is only one value in mean or std, we'll propogate the value to
// all channels.
// Per-channel mean of the possible values used in normalization.
//
// <Codegen usage>:
// Apply normalization to input tensors accordingly.
mean:[float];
// Per-channel standard dev. of the possible values used in normalization.
//
// <Codegen usage>:
// Apply normalization to input tensors accordingly.
std:[float];
}
// The different possible score transforms to apply to uncalibrated scores
// before applying score calibration.
enum ScoreTransformationType : byte {
// Identity function: g(x) = x.
IDENTITY = 0,
// Log function: g(x) = log(x).
LOG = 1,
// Inverse logistic function: g(x) = log(x) - log(1-x).
INVERSE_LOGISTIC = 2,
}
// Options to perform score calibration on an output tensor through sigmoid
// functions. One of the main purposes of score calibration is to make scores
// across classes comparable, so that a common threshold can be used for all
// output classes. This is meant for models producing class predictions as
// output, e.g. image classification or detection models.
//
// For each index in the output tensor, this applies:
// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score`,
// * `f(x) = default_score` otherwise or if no scale, slope, offset and
// min_score have been specified.
// Where:
// * scale, slope, offset and min_score are index-specific parameters
// * g(x) is an index-independent transform among those defined in
// ScoreTransformationType
// * default_score is an index-independent parameter.
// An AssociatedFile with type TANSOR_AXIS_SCORE_CALIBRATION specifying the
// index-specific parameters must be associated with the corresponding
// TensorMetadata for score calibration be applied.
table ScoreCalibrationOptions {
// The function to use for transforming the uncalibrated score before
// applying score calibration.
score_transformation:ScoreTransformationType;
// The default calibrated score to apply if the uncalibrated score is
// below min_score or if no parameters were specified for a given index.
default_score:float;
}
// Performs thresholding on output tensor values, in order to filter out
// low-confidence results.
table ScoreThresholdingOptions {
// The recommended global threshold below which results are considered
// low-confidence and should be filtered out.
global_score_threshold:float;
}
// Options that are used when processing the tensor.
union ProcessUnitOptions {
NormalizationOptions,
ScoreCalibrationOptions,
ScoreThresholdingOptions,
}
// A process unit that is used to process the tensor out-of-graph.
table ProcessUnit {
options:ProcessUnitOptions;
}
// Statistics to describe a tensor.
table Stats {
// Max and min are not currently used in tflite.support codegen. They mainly
// serve as references for users to better understand the model. They can also
// be used to validate model pre/post processing results.
// If there is only one value in max or min, we'll propogate the value to
// all channels.
// Per-channel maximum value of the tensor.
max:[float];
// Per-channel minimum value of the tensor.
min:[float];
}
// Detailed information of an input or output tensor.
table TensorMetadata {
// Name of the tensor.
//
// <Codegen usage>:
// The name of this tensor in the generated model interface.
name:string;
// A description of the tensor.
description:string;
// A list of names of the dimensions in this tentor. The length of
// dimension_names need to match the number of dimensions in this tensor.
//
// <Codegen usage>:
// The name of each dimension in the generated model interface. See "Case 2"
// in the comments of Content.range.
dimension_names:[string];
// The content that represents this tensor.
//
// <Codegen usage>:
// Determines how to process this tensor. See each item in ContentProperties
// for the default process units that will be applied to the tensor.
content:Content;
// The process units that are used to process the tensor out-of-graph.
//
// <Codegen usage>:
// Contains the parameters of the default processing pipeline for each content
// type, such as the normalization parameters in all content types. See the
// items under ContentProperties for the details of the default processing
// pipeline.
process_units:[ProcessUnit];
// The statistics of the tensor values.
stats:Stats;
// A list of associated files of this tensor.
//
// <Codegen usage>:
// Contains processing parameters of this tensor, such as normalization.
associated_files:[AssociatedFile];
}
table SubGraphMetadata {
// Name of the subgraph.
//
// Note that, since TFLite only support one subgraph at this moment, the
// Codegen tool will use the name in ModelMetadata in the generated model
// interface.
name:string;
// A description explains details about what the subgraph does.
description:string;
// Metadata of all input tensors used in this subgraph.
//
// <Codegen usage>:
// Determines how to process the inputs.
input_tensor_metadata:[TensorMetadata];
// Metadata of all output tensors used in this subgraph.
//
// <Codegen usage>:
// Determines how to process the outputs.
output_tensor_metadata:[TensorMetadata];
// A list of associated files of this subgraph.
associated_files:[AssociatedFile];
}
table ModelMetadata {
// Name of the model.
//
// <Codegen usage>:
// The name of the model in the generated model interface.
name:string;
// Model description in schema.
description:string;
// Version of the model that specified by model creators.
version:string;
// Noted that, the minimum required TFLite runtime version that the model is
// compatible with, has already been added as a metadata entry in tflite
// schema. We'll decide later if we want to move it here, and keep it with
// other metadata entries.
// Metadata of all the subgraphs of the model. The 0th is assumed to be the
// main subgraph.
//
// <Codegen usage>:
// Determines how to process the inputs and outputs.
subgraph_metadata:[SubGraphMetadata];
// The person who creates this model.
author:string;
// Licenses that may apply to this model.
license:string;
// A list of associated files of this model.
associated_files:[AssociatedFile];
}
root_type ModelMetadata;

View File

@ -0,0 +1,381 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.lite.experimental.support.metadata.metadata."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import six
from flatbuffers.python import flatbuffers
from tensorflow.lite.experimental.support.metadata import metadata as _metadata
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
class MetadataTest(test_util.TensorFlowTestCase):
def setUp(self):
super(MetadataTest, self).setUp()
self._invalid_model_buf = None
self._invalid_file = "not_existed_file"
self._empty_model_buf = self._create_empty_model_buf()
self._empty_model_file = self.create_tempfile().full_path
with open(self._empty_model_file, "wb") as f:
f.write(self._empty_model_buf)
self._model_file = self._create_model_file_with_metadata_and_buf_fields()
self._metadata_file = self._create_metadata_file()
self._file1 = self.create_tempfile("file1").full_path
self._file2 = self.create_tempfile("file2").full_path
self._file3 = self.create_tempfile("file3").full_path
def _create_empty_model_buf(self):
model = _schema_fb.ModelT()
model_builder = flatbuffers.Builder(0)
model_builder.Finish(
model.Pack(model_builder),
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
return model_builder.Output()
def _create_model_file_with_metadata_and_buf_fields(self):
metadata_field = _schema_fb.MetadataT()
metadata_field.name = "meta"
buffer_field = _schema_fb.BufferT()
model = _schema_fb.ModelT()
model.metadata = [metadata_field, metadata_field]
model.buffers = [buffer_field, buffer_field, buffer_field]
model_builder = flatbuffers.Builder(0)
model_builder.Finish(
model.Pack(model_builder),
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
mnodel_file = self.create_tempfile().full_path
with open(mnodel_file, "wb") as f:
f.write(model_builder.Output())
return mnodel_file
def _create_metadata_file(self):
associated_file1 = _metadata_fb.AssociatedFileT()
associated_file1.name = b"file1"
associated_file2 = _metadata_fb.AssociatedFileT()
associated_file2.name = b"file2"
self.expected_recorded_files = [
six.ensure_str(associated_file1.name),
six.ensure_str(associated_file2.name)
]
output_meta = _metadata_fb.TensorMetadataT()
output_meta.associatedFiles = [associated_file2]
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.outputTensorMetadata = [output_meta]
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Mobilenet_quantized"
model_meta.associatedFiles = [associated_file1]
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_file = self.create_tempfile().full_path
with open(metadata_file, "wb") as f:
f.write(b.Output())
return metadata_file
class MetadataPopulatorTest(MetadataTest):
def testToValidModelFile(self):
populator = _metadata.MetadataPopulator.with_model_file(
self._empty_model_file)
self.assertIsInstance(populator, _metadata.MetadataPopulator)
def testToInvalidModelFile(self):
with self.assertRaises(IOError) as error:
_metadata.MetadataPopulator.with_model_file(self._invalid_file)
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testToValidModelBuffer(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
self.assertIsInstance(populator, _metadata.MetadataPopulator)
def testToInvalidModelBuffer(self):
with self.assertRaises(ValueError) as error:
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
self.assertEqual("model_buf cannot be empty.", str(error.exception))
def testSinglePopulateAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
populator.load_associated_files([self._file1])
populator.populate()
packed_files = populator.get_packed_associated_file_list()
expected_packed_files = [os.path.basename(self._file1)]
self.assertEqual(set(packed_files), set(expected_packed_files))
def testRepeatedPopulateAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_file(
self._empty_model_file)
populator.load_associated_files([self._file1, self._file2])
# Loads file2 multiple times.
populator.load_associated_files([self._file2])
populator.populate()
packed_files = populator.get_packed_associated_file_list()
expected_packed_files = [
os.path.basename(self._file1),
os.path.basename(self._file2)
]
self.assertEqual(len(packed_files), 2)
self.assertEqual(set(packed_files), set(expected_packed_files))
# Check if the model buffer read from file is the same as that read from
# get_model_buffer().
with open(self._empty_model_file, "rb") as f:
model_buf_from_file = f.read()
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testPopulateInvalidAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
with self.assertRaises(IOError) as error:
populator.load_associated_files([self._invalid_file])
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testPopulatePackedAssociatedFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
populator.load_associated_files([self._file1])
populator.populate()
with self.assertRaises(ValueError) as error:
populator.load_associated_files([self._file1])
populator.populate()
self.assertEqual(
"File, '{0}', has already been packed.".format(
os.path.basename(self._file1)), str(error.exception))
def testGetPackedAssociatedFileList(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
packed_files = populator.get_packed_associated_file_list()
self.assertEqual(packed_files, [])
def testPopulateMetadataFileToEmptyModelFile(self):
populator = _metadata.MetadataPopulator.with_model_file(
self._empty_model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1, self._file2])
populator.populate()
with open(self._empty_model_file, "rb") as f:
model_buf_from_file = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
metadata_field = model.Metadata(0)
self.assertEqual(
six.ensure_str(metadata_field.Name()),
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
buffer_index = metadata_field.Buffer()
buffer_data = model.Buffers(buffer_index)
metadata_buf_np = buffer_data.DataAsNumpy()
metadata_buf = metadata_buf_np.tobytes()
with open(self._metadata_file, "rb") as f:
expected_metadata_buf = bytearray(f.read())
self.assertEqual(metadata_buf, expected_metadata_buf)
recorded_files = populator.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
# Up to now, we've proved the correctness of the model buffer that read from
# file. Then we'll test if get_model_buffer() gives the same model buffer.
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testPopulateMetadataFileWithoutAssociatedFiles(self):
populator = _metadata.MetadataPopulator.with_model_file(
self._empty_model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1])
# Suppose to populate self._file2, because it is recorded in the metadta.
with self.assertRaises(ValueError) as error:
populator.populate()
self.assertEqual(("File, '{0}', is recorded in the metadata, but has "
"not been loaded into the populator.").format(
os.path.basename(self._file2)), str(error.exception))
def _assert_golden_metadata(self, model_file):
with open(model_file, "rb") as f:
model_buf_from_file = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
# There are two elements in model.Metadata array before the population.
# Metadata should be packed to the third element in the array.
metadata_field = model.Metadata(2)
self.assertEqual(
six.ensure_str(metadata_field.Name()),
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
buffer_index = metadata_field.Buffer()
buffer_data = model.Buffers(buffer_index)
metadata_buf_np = buffer_data.DataAsNumpy()
metadata_buf = metadata_buf_np.tobytes()
with open(self._metadata_file, "rb") as f:
expected_metadata_buf = bytearray(f.read())
self.assertEqual(metadata_buf, expected_metadata_buf)
def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self):
# First, creates a dummy metadata. Populates it and the associated files
# into the model.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Mobilenet_quantized"
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator1.load_metadata_buffer(metadata_buf)
populator1.load_associated_files([self._file1, self._file2])
populator1.populate()
# Then, populates the metadata again.
populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator2.load_metadata_file(self._metadata_file)
populator2.populate()
# Tests if the metadata is populated correctly.
self._assert_golden_metadata(self._model_file)
def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self):
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1, self._file2])
populator.populate()
# Tests if the metadata is populated correctly.
self._assert_golden_metadata(self._model_file)
recorded_files = populator.get_recorded_associated_file_list()
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
# Up to now, we've proved the correctness of the model buffer that read from
# file. Then we'll test if get_model_buffer() gives the same model buffer.
with open(self._model_file, "rb") as f:
model_buf_from_file = f.read()
model_buf_from_getter = populator.get_model_buffer()
self.assertEqual(model_buf_from_file, model_buf_from_getter)
def testPopulateInvalidMetadataFile(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
with self.assertRaises(IOError) as error:
populator.load_metadata_file(self._invalid_file)
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def testPopulateInvalidMetadataBuffer(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
with self.assertRaises(ValueError) as error:
populator.load_metadata_buffer([])
self.assertEqual("The metadata to be populated is empty.",
str(error.exception))
def testGetModelBufferBeforePopulatingData(self):
populator = _metadata.MetadataPopulator.with_model_buffer(
self._empty_model_buf)
model_buf = populator.get_model_buffer()
expected_model_buf = self._empty_model_buf
self.assertEqual(model_buf, expected_model_buf)
class MetadataDisplayerTest(MetadataTest):
def setUp(self):
super(MetadataDisplayerTest, self).setUp()
self._model_file = self._create_model_with_metadata_and_associated_files()
def _create_model_with_metadata_and_associated_files(self):
model_buf = self._create_empty_model_buf()
model_file = self.create_tempfile().full_path
with open(model_file, "wb") as f:
f.write(model_buf)
populator = _metadata.MetadataPopulator.with_model_file(model_file)
populator.load_metadata_file(self._metadata_file)
populator.load_associated_files([self._file1, self._file2])
populator.populate()
return model_file
def test_load_model_file_invalidModelFile_throwsException(self):
with self.assertRaises(IOError) as error:
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
str(error.exception))
def test_load_model_file_modelWithoutMetadata_throwsException(self):
with self.assertRaises(ValueError) as error:
_metadata.MetadataDisplayer.with_model_file(self._empty_model_file)
self.assertEqual("The model does not have metadata.", str(error.exception))
def test_load_model_file_modelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
def test_export_metadata_json_file_modelWithMetadata(self):
export_dir = self.create_tempdir().full_path
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
displayer.export_metadata_json_file(export_dir)
# Verifies the generated json file.
golden_json_file_path = resource_loader.get_path_to_datafile(
"testdata/golden_json.json")
json_file_path = os.path.join(
export_dir,
os.path.splitext(os.path.basename(self._model_file))[0] + ".json")
with open(json_file_path, "r") as json_file, open(golden_json_file_path,
"r") as golden_json_file:
json_contents = json_file.read()
golden_json_contents = golden_json_file.read()
self.assertEqual(json_contents, golden_json_contents)
def test_get_packed_associated_file_list_modelWithMetadata(self):
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
packed_files = displayer.get_packed_associated_file_list()
expected_packed_files = [
os.path.basename(self._file1),
os.path.basename(self._file2)
]
self.assertEqual(len(packed_files), 2)
self.assertEqual(set(packed_files), set(expected_packed_files))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,21 @@
{
"name": "Mobilenet_quantized",
"subgraph_metadata": [
{
"output_tensor_metadata": [
{
"associated_files": [
{
"name": "file2"
}
]
}
]
}
],
"associated_files": [
{
"name": "file1"
}
]
}

View File

@ -206,4 +206,12 @@ cc_test(
],
)
py_binary(
name = "zip_files",
srcs = ["zip_files.py"],
python_version = "PY3",
visibility = ["//visibility:public"],
deps = ["@absl_py//absl:app"],
)
tflite_portable_test_suite()

View File

@ -0,0 +1,41 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Creates a zip package of the files passed in."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import zipfile
from absl import app
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_string("export_zip_path", None, "Path to zip file.")
flags.DEFINE_string("file_directory", None, "Path to the files to be zipped.")
def main(_):
with zipfile.ZipFile(FLAGS.export_zip_path, mode="w") as zf:
for root, _, files in os.walk(FLAGS.file_directory):
for f in files:
if f.endswith(".java"):
zf.write(os.path.join(root, f))
if __name__ == "__main__":
app.run(main)

View File

@ -1,3 +1,7 @@
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE.txt"])
@ -116,3 +120,20 @@ py_library(
srcs = [":runtime_py_srcs"],
visibility = ["//visibility:public"],
)
filegroup(
name = "runtime_java_srcs",
srcs = glob(["java/com/google/flatbuffers/**/*.java"]),
)
java_library(
name = "runtime_java",
srcs = [":runtime_java_srcs"],
visibility = ["//visibility:public"],
)
android_library(
name = "runtime_android",
srcs = [":runtime_java_srcs"],
visibility = ["//visibility:public"],
)

View File

@ -1,6 +1,15 @@
"""BUILD rules for generating flatbuffer files."""
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
flatc_path = "@flatbuffers//:flatc"
zip_files = "//tensorflow/lite/tools:zip_files"
DEFAULT_INCLUDE_PATHS = [
"./",
"$(GENDIR)",
"$(BINDIR)",
]
DEFAULT_FLATC_ARGS = [
"--no-union-value-namespacing",
@ -422,3 +431,181 @@ def flatbuffer_py_library(
"@flatbuffers//:runtime_py",
],
)
def flatbuffer_java_library(
name,
srcs,
custom_package = "",
package_prefix = "",
include_paths = DEFAULT_INCLUDE_PATHS,
flatc_args = DEFAULT_FLATC_ARGS,
visibility = None):
"""A java library with the generated reader/writers for the given flatbuffer definitions.
Args:
name: Rule name. (required)
srcs: List of source .fbs files including all includes. (required)
custom_package: Package name of generated Java files. If not specified
namespace in the schema files will be used. (optional)
package_prefix: like custom_package, but prefixes to the existing
namespace. (optional)
include_paths: List of paths that includes files can be found in. (optional)
flatc_args: List of additional arguments to pass to flatc. (optional)
visibility: Visibility setting for the java_library rule. (optional)
"""
out_srcjar = "java_%s_all.srcjar" % name
flatbuffer_java_srcjar(
name = "%s_srcjar" % name,
srcs = srcs,
out = out_srcjar,
custom_package = custom_package,
flatc_args = flatc_args,
include_paths = include_paths,
package_prefix = package_prefix,
)
native.filegroup(
name = "%s.srcjar" % name,
srcs = [out_srcjar],
)
native.java_library(
name = name,
srcs = [out_srcjar],
deps = [
"@flatbuffers//:runtime_java",
],
visibility = visibility,
)
def flatbuffer_java_srcjar(
name,
srcs,
out,
custom_package = "",
package_prefix = "",
include_paths = DEFAULT_INCLUDE_PATHS,
flatc_args = DEFAULT_FLATC_ARGS):
"""Generate flatbuffer Java source files.
Args:
name: Rule name. (required)
srcs: List of source .fbs files including all includes. (required)
out: Output file name. (required)
custom_package: Package name of generated Java files. If not specified
namespace in the schema files will be used. (optional)
package_prefix: like custom_package, but prefixes to the existing
namespace. (optional)
include_paths: List of paths that includes files can be found in. (optional)
flatc_args: List of additional arguments to pass to flatc. (optional)
"""
command_fmt = """set -e
tmpdir=$(@D)
schemas=$$tmpdir/schemas
java_root=$$tmpdir/java
rm -rf $$schemas
rm -rf $$java_root
mkdir -p $$schemas
mkdir -p $$java_root
for src in $(SRCS); do
dest=$$schemas/$$src
rm -rf $$(dirname $$dest)
mkdir -p $$(dirname $$dest)
if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then
cp -f $$src $$dest
else
if [ -z "{package_prefix}" ]; then
sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest
else
sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest
fi
fi
done
flatc_arg_I="-I $$tmpdir/schemas"
for include_path in {include_paths}; do
flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path"
done
flatc_additional_args=
for arg in {flatc_args}; do
flatc_additional_args="$$flatc_additional_args $$arg"
done
for src in $(SRCS); do
$(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src
done
$(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root
"""
genrule_cmd = command_fmt.format(
package_name = native.package_name(),
custom_package = custom_package,
package_prefix = package_prefix,
flatc_path = flatc_path,
zip_files = zip_files,
include_paths = " ".join(include_paths),
flatc_args = " ".join(flatc_args),
)
native.genrule(
name = name,
srcs = srcs,
outs = [out],
tools = [flatc_path, zip_files],
cmd = genrule_cmd,
)
def flatbuffer_android_library(
name,
srcs,
custom_package = "",
package_prefix = "",
javacopts = None,
include_paths = DEFAULT_INCLUDE_PATHS,
flatc_args = DEFAULT_FLATC_ARGS,
visibility = None):
"""An android_library with the generated reader/writers for the given flatbuffer definitions.
Args:
name: Rule name. (required)
srcs: List of source .fbs files including all includes. (required)
custom_package: Package name of generated Java files. If not specified
namespace in the schema files will be used. (optional)
package_prefix: like custom_package, but prefixes to the existing
namespace. (optional)
javacopts: List of options to pass to javac.
include_paths: List of paths that includes files can be found in. (optional)
flatc_args: List of additional arguments to pass to flatc. (optional)
visibility: Visibility setting for the android_library rule. (optional)
"""
out_srcjar = "android_%s_all.srcjar" % name
flatbuffer_java_srcjar(
name = "%s_srcjar" % name,
srcs = srcs,
out = out_srcjar,
custom_package = custom_package,
flatc_args = flatc_args,
include_paths = include_paths,
package_prefix = package_prefix,
)
native.filegroup(
name = "%s.srcjar" % name,
srcs = [out_srcjar],
)
# To support org.checkerframework.dataflow.qual.Pure.
checkerframework_annotations = [
"@org_checkerframework_qual",
] if "--java-checkerframework" in flatc_args else []
android_library(
name = name,
srcs = [out_srcjar],
visibility = visibility,
deps = [
"@flatbuffers//:runtime_android",
] + checkerframework_annotations,
)