OSS TFLite Metadata library
PiperOrigin-RevId: 299966503 Change-Id: I08e6e7d424822c5f80f48b0f0c670f3c9c75f403
This commit is contained in:
parent
c645992e15
commit
9ebf4a223b
87
tensorflow/lite/experimental/support/metadata/BUILD
Normal file
87
tensorflow/lite/experimental/support/metadata/BUILD
Normal file
@ -0,0 +1,87 @@
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library", "flatbuffer_py_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["metadata_schema.fbs"])
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "schema_py",
|
||||
srcs = ["//tensorflow/lite/schema:schema.fbs"],
|
||||
)
|
||||
|
||||
# Generic schema for inference on device.
|
||||
flatbuffer_android_library(
|
||||
name = "schema_fbs_android",
|
||||
srcs = ["//tensorflow/lite/schema:schema.fbs"],
|
||||
custom_package = "org.tensorflow.lite.schema",
|
||||
)
|
||||
|
||||
flatbuffer_java_library(
|
||||
name = "schema_fbs_java",
|
||||
srcs = ["//tensorflow/lite/schema:schema.fbs"],
|
||||
custom_package = "org.tensorflow.lite.schema",
|
||||
)
|
||||
|
||||
# Generic schema for model metadata.
|
||||
flatbuffer_cc_library(
|
||||
name = "metadata_schema_cc",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "metadata_schema_py",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_java_library(
|
||||
name = "metadata_schema_java",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
custom_package = "org.tensorflow.lite.support.metadata.schema",
|
||||
)
|
||||
|
||||
flatbuffer_android_library(
|
||||
name = "metadata_schema_fbs_android",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
custom_package = "org.tensorflow.lite.support.metadata.schema",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "metadata",
|
||||
srcs = ["metadata.py"],
|
||||
data = [
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs",
|
||||
"@flatbuffers//:flatc",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":metadata_schema_py",
|
||||
":schema_py",
|
||||
"//tensorflow/python:platform",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "metadata_test",
|
||||
srcs = ["metadata_test.py"],
|
||||
data = ["testdata/golden_json.json"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":metadata",
|
||||
":metadata_schema_py",
|
||||
":schema_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@flatbuffers//:runtime_py",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="org.tensorflow.lite.support">
|
||||
<uses-sdk android:minSdkVersion="19" />
|
||||
</manifest>
|
||||
|
36
tensorflow/lite/experimental/support/metadata/java/BUILD
Normal file
36
tensorflow/lite/experimental/support/metadata/java/BUILD
Normal file
@ -0,0 +1,36 @@
|
||||
# Description:
|
||||
# TensorFlow Lite Support API in Java for metadata.
|
||||
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "tensorflow-lite-support-metadata",
|
||||
srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]),
|
||||
manifest = "AndroidManifest.xml",
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/support/java:tensorflow-lite-support-precondition-lib-android",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_fbs_android",
|
||||
"//tensorflow/lite/experimental/support/metadata:schema_fbs_android",
|
||||
"//tensorflow/lite/java:tensorflowlite",
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
||||
|
||||
java_library(
|
||||
name = "tensorflow-lite-support-metadata-lib",
|
||||
srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]),
|
||||
javacopts = JAVACOPTS,
|
||||
deps = [
|
||||
"//tensorflow/lite/experimental/support/java:tensorflow-lite-support-precondition",
|
||||
"//tensorflow/lite/experimental/support/metadata:metadata_schema_java",
|
||||
"//tensorflow/lite/experimental/support/metadata:schema_fbs_java",
|
||||
"//tensorflow/lite/java:tensorflowlitelib",
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
)
|
@ -0,0 +1,116 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
|
||||
import static org.tensorflow.lite.support.common.SupportPreconditions.checkElementIndex;
|
||||
import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
* An {@link InputStream} that wraps a section of a {@link SeekableByteChannelCompat}.
|
||||
*
|
||||
* <p><b>WARNING:</b> Similar as {@link InputStream}, instances of an {@link BoundedInputStream} are
|
||||
* <b>not</b> thread-safe. If multiple threads concurrently reading from the same {@link
|
||||
* BoundedInputStream}, it must be synchronized externally. Also, if multiple instances of {@link
|
||||
* BoundedInputStream} are created on the same {@link SeekableByteChannelCompat}, it must be
|
||||
* synchronized as well.
|
||||
*/
|
||||
final class BoundedInputStream extends InputStream {
|
||||
private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
|
||||
private final long end; // The valid data for the stream is between [start, end).
|
||||
private long position;
|
||||
private final SeekableByteChannelCompat channel;
|
||||
|
||||
/**
|
||||
* Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
|
||||
*
|
||||
* @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
|
||||
* BoundedInputStream}
|
||||
* @param start the starting position of this {@link BoundedInputStream} in the given {@link
|
||||
* SeekableByteChannelCompat}
|
||||
* @param remaining the length of this {@link BoundedInputStream}
|
||||
* @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
|
||||
*/
|
||||
BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
|
||||
checkArgument(
|
||||
remaining >= 0 && start >= 0,
|
||||
String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
|
||||
|
||||
end = start + remaining;
|
||||
this.channel = channel;
|
||||
position = start;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int available() throws IOException {
|
||||
return (int) (Math.min(end, channel.size()) - position);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int read() throws IOException {
|
||||
if (position >= end) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
singleByteBuffer.rewind();
|
||||
int count = read(position, singleByteBuffer);
|
||||
if (count < 0) {
|
||||
return count;
|
||||
}
|
||||
|
||||
position++;
|
||||
return singleByteBuffer.get() & 0xff;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int read(byte[] b, int off, int len) throws IOException {
|
||||
checkNotNull(b);
|
||||
checkElementIndex(off, b.length, "The start offset");
|
||||
checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
|
||||
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (len > end - position) {
|
||||
if (position >= end) {
|
||||
return -1;
|
||||
}
|
||||
len = (int) (end - position);
|
||||
}
|
||||
|
||||
ByteBuffer buf = ByteBuffer.wrap(b, off, len);
|
||||
int count = read(position, buf);
|
||||
if (count > 0) {
|
||||
position += count;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
private int read(long position, ByteBuffer buf) throws IOException {
|
||||
int count;
|
||||
synchronized (channel) {
|
||||
channel.position(position);
|
||||
count = channel.read(buf);
|
||||
}
|
||||
buf.flip();
|
||||
return count;
|
||||
}
|
||||
}
|
@ -0,0 +1,130 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static java.lang.Math.min;
|
||||
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
|
||||
import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.NonWritableChannelException;
|
||||
|
||||
/** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */
|
||||
final class ByteBufferChannel implements SeekableByteChannelCompat {
|
||||
|
||||
/** The ByteBuffer that holds the data. */
|
||||
private final ByteBuffer buffer;
|
||||
|
||||
/**
|
||||
* Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
|
||||
*
|
||||
* @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
|
||||
* @throws NullPointerException if {@code buffer} is null
|
||||
*/
|
||||
public ByteBufferChannel(ByteBuffer buffer) {
|
||||
checkNotNull(buffer, "The ByteBuffer cannot be null.");
|
||||
this.buffer = buffer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
|
||||
@Override
|
||||
public boolean isOpen() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long position() {
|
||||
return buffer.position();
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets this channel's position.
|
||||
*
|
||||
* @param newPosition the new position, a non-negative integer counting the number of bytes from
|
||||
* the beginning of the entity
|
||||
* @return this channel
|
||||
* @throws IllegalArgumentException if the new position is negative, or greater than the size of
|
||||
* the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
|
||||
*/
|
||||
@Override
|
||||
public synchronized ByteBufferChannel position(long newPosition) {
|
||||
checkArgument(
|
||||
(newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
|
||||
"The new position should be non-negative and be less than Integer.MAX_VALUE.");
|
||||
buffer.position((int) newPosition);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>Bytes are read starting at this channel's current position, and then the position is updated
|
||||
* with the number of bytes actually read. Otherwise this method behaves exactly as specified in
|
||||
* the {@link ReadableByteChannel} interface.
|
||||
*/
|
||||
@Override
|
||||
public synchronized int read(ByteBuffer dst) {
|
||||
if (buffer.remaining() == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
int count = min(dst.remaining(), buffer.remaining());
|
||||
if (count > 0) {
|
||||
ByteBuffer tempBuffer = buffer.slice();
|
||||
tempBuffer.order(buffer.order()).limit(count);
|
||||
dst.put(tempBuffer);
|
||||
buffer.position(buffer.position() + count);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long size() {
|
||||
return buffer.limit();
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized ByteBufferChannel truncate(long size) {
|
||||
checkArgument(
|
||||
(size >= 0 && size <= Integer.MAX_VALUE),
|
||||
"The new size should be non-negative and be less than Integer.MAX_VALUE.");
|
||||
|
||||
if (size < buffer.limit()) {
|
||||
buffer.limit((int) size);
|
||||
if (buffer.position() > size) {
|
||||
buffer.position((int) size);
|
||||
}
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized int write(ByteBuffer src) {
|
||||
if (buffer.isReadOnly()) {
|
||||
throw new NonWritableChannelException();
|
||||
}
|
||||
|
||||
int count = min(src.remaining(), buffer.remaining());
|
||||
if (count > 0) {
|
||||
ByteBuffer tempBuffer = src.slice();
|
||||
tempBuffer.order(buffer.order()).limit(count);
|
||||
buffer.put(tempBuffer);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
}
|
@ -0,0 +1,247 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.zip.ZipException;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.Tensor.QuantizationParams;
|
||||
import org.tensorflow.lite.schema.Tensor;
|
||||
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
|
||||
|
||||
/**
|
||||
* Loads metadata from TFLite Model FlatBuffer.
|
||||
*
|
||||
* <p>TFLite Model FlatBuffer can be generated using the <a
|
||||
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
|
||||
* Model schema file.</a>
|
||||
*
|
||||
* <p>Some models contain a TFLite Metadata Flatbuffer, which records more information about what
|
||||
* the model does and how to interprete the model. TFLite Metadata Flatbuffer can be generated using
|
||||
* the <a
|
||||
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/metadata_schema.fbs">TFLite
|
||||
* Metadata schema file.</a>
|
||||
*
|
||||
* <p>It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking methods
|
||||
* that read from TFLite metadata will cause runtime errors.
|
||||
*
|
||||
* <p>Similarly, it is allowed to pass in a model FlatBuffer without associated files. However,
|
||||
* invoking methods that read the associated files will cause runtime errors.
|
||||
*
|
||||
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports a
|
||||
* single subgraph so far. See the <a
|
||||
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
|
||||
* of how to specify subgraph during convertion for more information.</a> Therefore, {@link
|
||||
* MetadataExtractor} omits subgraph index as an input in its methods.
|
||||
*/
|
||||
public class MetadataExtractor {
|
||||
/** The helper class to load metadata from TFLite model FlatBuffer. */
|
||||
private final ModelInfo modelInfo;
|
||||
|
||||
/** The helper class to load metadata from TFLite metadata FlatBuffer. */
|
||||
@Nullable private final ModelMetadataInfo metadataInfo;
|
||||
|
||||
/** The handler to load associated files through zip. */
|
||||
@Nullable private final ZipFile zipFile;
|
||||
|
||||
/**
|
||||
* Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
|
||||
*
|
||||
* @param buffer the TFLite model FlatBuffer
|
||||
* @throws IllegalArgumentException if the number of input or output tensors in the model does not
|
||||
* match that in the metadata
|
||||
* @throws IOException if an error occurs while reading the model as a Zip file
|
||||
*/
|
||||
public MetadataExtractor(ByteBuffer buffer) throws IOException {
|
||||
modelInfo = new ModelInfo(buffer);
|
||||
ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
|
||||
if (metadataBuffer != null) {
|
||||
metadataInfo = new ModelMetadataInfo(metadataBuffer);
|
||||
checkArgument(
|
||||
modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
|
||||
String.format(
|
||||
"The number of input tensors in the model is %d. The number of input tensors that"
|
||||
+ " recorded in the metadata is %d. These two values does not match.",
|
||||
modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
|
||||
checkArgument(
|
||||
modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
|
||||
String.format(
|
||||
"The number of output tensors in the model is %d. The number of output tensors that"
|
||||
+ " recorded in the metadata is %d. These two values does not match.",
|
||||
modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
|
||||
} else {
|
||||
// It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking
|
||||
// methods that read from TFLite metadata will cause runtime errors.
|
||||
metadataInfo = null;
|
||||
}
|
||||
|
||||
zipFile = createZipFile(buffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the packed associated file with the specified {@code fileName}.
|
||||
*
|
||||
* @param fileName the name of the associated file
|
||||
* @return the raw input stream containing specified file
|
||||
* @throws IllegalStateException if the model is not a zip file
|
||||
* @throws IllegalArgumentException if the specified file does not exist in the model
|
||||
*/
|
||||
public InputStream getAssociatedFile(String fileName) {
|
||||
assertZipFile();
|
||||
return zipFile.getRawInputStream(fileName);
|
||||
}
|
||||
|
||||
/** Gets the count of input tensors in the model. */
|
||||
public int getInputTensorCount() {
|
||||
return modelInfo.getInputTensorCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the metadata for the input tensor specified by {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex the index of the desired input tensor
|
||||
* @throws IllegalStateException if this model does not contain model metadata
|
||||
*/
|
||||
@Nullable
|
||||
public TensorMetadata getInputTensorMetadata(int inputIndex) {
|
||||
assertMetadataInfo();
|
||||
return metadataInfo.getInputTensorMetadata(inputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex the index of the desired input tensor
|
||||
*/
|
||||
public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
|
||||
Tensor tensor = modelInfo.getInputTensor(inputIndex);
|
||||
return modelInfo.getQuantizationParams(tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex the index of the desired input tensor
|
||||
*/
|
||||
public int[] getInputTensorShape(int inputIndex) {
|
||||
return modelInfo.getInputTensorShape(inputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the {@link DataType} of the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex the index of the desired input tensor
|
||||
*/
|
||||
public DataType getInputTensorType(int inputIndex) {
|
||||
return modelInfo.getInputTensorType(inputIndex);
|
||||
}
|
||||
|
||||
/** Gets the count of output tensors in the model. */
|
||||
public int getOutputTensorCount() {
|
||||
return modelInfo.getOutputTensorCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the metadata for the output tensor specified by {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex the index of the desired output tensor
|
||||
* @throws IllegalStateException if this model does not contain model metadata
|
||||
*/
|
||||
@Nullable
|
||||
public TensorMetadata getOutputTensorMetadata(int outputIndex) {
|
||||
assertMetadataInfo();
|
||||
return metadataInfo.getOutputTensorMetadata(outputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex the index of the desired output tensor
|
||||
*/
|
||||
public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
|
||||
Tensor tensor = modelInfo.getOutputTensor(outputIndex);
|
||||
return modelInfo.getQuantizationParams(tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of the output tensor with {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex the index of the desired output tensor
|
||||
*/
|
||||
public int[] getOutputTensorShape(int outputIndex) {
|
||||
return modelInfo.getOutputTensorShape(outputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the {@link DataType} of the output tensor with {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex the index of the desired output tensor
|
||||
*/
|
||||
public DataType getOutputTensorType(int outputIndex) {
|
||||
return modelInfo.getOutputTensorType(outputIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts if {@link metdadataInfo} is not initialized. Some models may not have metadata and this
|
||||
* is allowed. However, invoking methods that reads the metadata is not allowed.
|
||||
*
|
||||
* @throws IllegalStateException if this model does not contain model metadata
|
||||
*/
|
||||
private void assertMetadataInfo() {
|
||||
if (metadataInfo == null) {
|
||||
throw new IllegalStateException("This model does not contain model metadata.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus
|
||||
* are not Zip files. This is allowed. However, invoking methods that reads those associated files
|
||||
* is not allowed.
|
||||
*
|
||||
* @throws IllegalStateException if this model is not a Zip file
|
||||
*/
|
||||
private void assertZipFile() {
|
||||
if (zipFile == null) {
|
||||
throw new IllegalStateException(
|
||||
"This model does not contain associated files, and is not a Zip file.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
|
||||
* it does not have associated files, return a null handler.
|
||||
*
|
||||
* @param buffer the TFLite model FlatBuffer
|
||||
* @throws IOException if an error occurs while reading the model as a Zip file
|
||||
*/
|
||||
@Nullable
|
||||
private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
|
||||
try {
|
||||
// Creates the handler to hold the associated files through the Zip.
|
||||
ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
|
||||
return ZipFile.createFrom(byteBufferChannel);
|
||||
} catch (ZipException e) {
|
||||
// Some models may not have associate files. Therefore, Those models are not zip files.
|
||||
// However, invoking methods that read associated files later will lead into errors.
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,281 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import org.tensorflow.lite.DataType;
|
||||
import org.tensorflow.lite.Tensor.QuantizationParams;
|
||||
import org.tensorflow.lite.schema.Buffer;
|
||||
import org.tensorflow.lite.schema.Metadata;
|
||||
import org.tensorflow.lite.schema.Model;
|
||||
import org.tensorflow.lite.schema.QuantizationParameters;
|
||||
import org.tensorflow.lite.schema.SubGraph;
|
||||
import org.tensorflow.lite.schema.Tensor;
|
||||
import org.tensorflow.lite.schema.TensorType;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
|
||||
/** Extracts model information out of TFLite model FLatBuffer. */
|
||||
final class ModelInfo {
|
||||
/** The model that is loaded from TFLite model FlatBuffer. */
|
||||
private final Model model;
|
||||
|
||||
/** A list of input tensors. */
|
||||
private final List</* @Nullable */ Tensor> inputTensors;
|
||||
|
||||
/** A list of output tensors. */
|
||||
private final List</* @Nullable */ Tensor> outputTensors;
|
||||
|
||||
/** Identifier of the TFLite model metadata in the Metadata array. */
|
||||
static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
|
||||
|
||||
/** Maps from TensorType in TFlite FlatBuffer to {@link DataType} in Java. */
|
||||
private final Map<Byte, DataType> tensorTypeToDataTypeMap;
|
||||
|
||||
/**
|
||||
* Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
|
||||
*
|
||||
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports
|
||||
* single subgraph so far. See the <a
|
||||
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
|
||||
* of how to specify subgraph during convertion for more information.</a> Therefore, all methods
|
||||
* in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
|
||||
*
|
||||
* @param buffer The TFLite model FlatBuffer.
|
||||
* @throws NullPointerException if {@code buffer} is null.
|
||||
* @throws IllegalArgumentException if the model does not contain any subgraph.
|
||||
*/
|
||||
ModelInfo(ByteBuffer buffer) {
|
||||
SupportPreconditions.checkNotNull(buffer, "Model flatbuffer cannot be null.");
|
||||
|
||||
model = Model.getRootAsModel(buffer);
|
||||
SupportPreconditions.checkArgument(
|
||||
model.subgraphsLength() > 0, "The model does not contain any subgraph.");
|
||||
|
||||
inputTensors = getInputTensors(model);
|
||||
outputTensors = getOutputTensors(model);
|
||||
tensorTypeToDataTypeMap = createTensorTypeToDataTypeMap();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex The index of the desired input tensor.
|
||||
* @throws IllegalArgumentException if the inputIndex specified is invalid.
|
||||
*/
|
||||
@Nullable
|
||||
Tensor getInputTensor(int inputIndex) {
|
||||
SupportPreconditions.checkArgument(
|
||||
inputIndex >= 0 && inputIndex < inputTensors.size(),
|
||||
"The inputIndex specified is invalid.");
|
||||
return inputTensors.get(inputIndex);
|
||||
}
|
||||
|
||||
int getInputTensorCount() {
|
||||
return inputTensors.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets shape of the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex The index of the desired intput tensor.
|
||||
*/
|
||||
int[] getInputTensorShape(int inputIndex) {
|
||||
Tensor tensor = getInputTensor(inputIndex);
|
||||
return getShape(tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets {@link DataType} of the input tensor with {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex The index of the desired intput tensor.
|
||||
*/
|
||||
DataType getInputTensorType(int inputIndex) {
|
||||
Tensor tensor = getInputTensor(inputIndex);
|
||||
return getDataType(tensor.type());
|
||||
}
|
||||
|
||||
/** Gets the metadata FlatBuffer from the model FlatBuffer. */
|
||||
@Nullable
|
||||
ByteBuffer getMetadataBuffer() {
|
||||
// Some models may not have metadata, and this is allowed.
|
||||
if (model.metadataLength() == 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
for (int i = 0; i < model.metadataLength(); i++) {
|
||||
Metadata meta = model.metadata(i);
|
||||
if (METADATA_FIELD_NAME.equals(meta.name())) {
|
||||
long bufferIndex = meta.buffer();
|
||||
Buffer metadataBuf = model.buffers((int) bufferIndex);
|
||||
return metadataBuf.dataAsByteBuffer();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the output tensor with {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex The index of the desired outtput tensor.
|
||||
* @throws IllegalArgumentException if the outputIndex specified is invalid.
|
||||
*/
|
||||
@Nullable
|
||||
Tensor getOutputTensor(int outputIndex) {
|
||||
SupportPreconditions.checkArgument(
|
||||
outputIndex >= 0 && outputIndex < outputTensors.size(),
|
||||
"The outputIndex specified is invalid.");
|
||||
return outputTensors.get(outputIndex);
|
||||
}
|
||||
|
||||
int getOutputTensorCount() {
|
||||
return outputTensors.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets shape of the output tensor with {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex The index of the desired outtput tensor.
|
||||
*/
|
||||
int[] getOutputTensorShape(int outputIndex) {
|
||||
Tensor tensor = getOutputTensor(outputIndex);
|
||||
return getShape(tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets {@link DataType} of the output tensor {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex The index of the desired outtput tensor.
|
||||
*/
|
||||
DataType getOutputTensorType(int outputIndex) {
|
||||
Tensor tensor = getOutputTensor(outputIndex);
|
||||
return getDataType(tensor.type());
|
||||
}
|
||||
|
||||
private static Map<Byte, DataType> createTensorTypeToDataTypeMap() {
|
||||
Map<Byte, DataType> map = new HashMap<>();
|
||||
map.put(TensorType.FLOAT32, DataType.FLOAT32);
|
||||
map.put(TensorType.INT32, DataType.INT32);
|
||||
map.put(TensorType.UINT8, DataType.UINT8);
|
||||
map.put(TensorType.INT64, DataType.INT64);
|
||||
map.put(TensorType.STRING, DataType.STRING);
|
||||
return Collections.unmodifiableMap(map);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the quantization parameters of a tensor.
|
||||
*
|
||||
* <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
|
||||
* quantized, the values of scale and zero_point are both 0.
|
||||
*
|
||||
* @param tensor The tensor whoes quantization parameters is desired.
|
||||
* @throws NullPointerException if the tensor is null.
|
||||
* @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link
|
||||
* QuantizationParameters} are not single values.
|
||||
*/
|
||||
QuantizationParams getQuantizationParams(Tensor tensor) {
|
||||
SupportPreconditions.checkNotNull(tensor, "Tensor cannot be null.");
|
||||
|
||||
float scale;
|
||||
int zeroPoint;
|
||||
QuantizationParameters quantization = tensor.quantization();
|
||||
|
||||
// Tensors that are not quantized do not have quantization parameters, which can be null when
|
||||
// being extracted from the flatbuffer.
|
||||
if (quantization == null) {
|
||||
scale = 0.0f;
|
||||
zeroPoint = 0;
|
||||
return new QuantizationParams(scale, zeroPoint);
|
||||
}
|
||||
|
||||
// Tensors that are not quantized do not have quantization parameters.
|
||||
// quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
|
||||
SupportPreconditions.checkArgument(
|
||||
quantization.scaleLength() <= 1,
|
||||
"Input and output tensors do not support per-channel quantization.");
|
||||
SupportPreconditions.checkArgument(
|
||||
quantization.zeroPointLength() <= 1,
|
||||
"Input and output tensors do not support per-channel quantization.");
|
||||
|
||||
// For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will
|
||||
// both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++
|
||||
// runtime.
|
||||
scale = quantization.scale(0);
|
||||
// zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it
|
||||
// consistent with the C++ runtime.
|
||||
zeroPoint = (int) quantization.zeroPoint(0);
|
||||
|
||||
return new QuantizationParams(scale, zeroPoint);
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms from TensorType in TFlite FlatBuffer to {@link DataType} in Java.
|
||||
*
|
||||
* @param tensorType The tensor type to be converted.
|
||||
* @throws IllegalArgumentException if the tensor type is not supported.
|
||||
*/
|
||||
private DataType getDataType(byte tensorType) {
|
||||
SupportPreconditions.checkArgument(
|
||||
tensorTypeToDataTypeMap.containsKey(tensorType),
|
||||
String.format("Tensor type %d is not supported.", tensorType));
|
||||
return tensorTypeToDataTypeMap.get(tensorType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the shape of a tensor.
|
||||
*
|
||||
* @param tensor The tensor whoes shape is desired.
|
||||
* @throws NullPointerException if the tensor is null.
|
||||
*/
|
||||
private static int[] getShape(Tensor tensor) {
|
||||
SupportPreconditions.checkNotNull(tensor, "Tensor cannot be null.");
|
||||
int shapeDim = tensor.shapeLength();
|
||||
int[] tensorShape = new int[shapeDim];
|
||||
for (int i = 0; i < shapeDim; i++) {
|
||||
tensorShape[i] = tensor.shape(i);
|
||||
}
|
||||
return tensorShape;
|
||||
}
|
||||
|
||||
/** Gets input tensors from a model. */
|
||||
private static List<Tensor> getInputTensors(Model model) {
|
||||
// TFLite only support one subgraph currently.
|
||||
SubGraph subgraph = model.subgraphs(0);
|
||||
int tensorNum = subgraph.inputsLength();
|
||||
ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
|
||||
for (int i = 0; i < tensorNum; i++) {
|
||||
inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
|
||||
}
|
||||
return Collections.unmodifiableList(inputTensors);
|
||||
}
|
||||
|
||||
/** Gets output tensors from a model. */
|
||||
private static List<Tensor> getOutputTensors(Model model) {
|
||||
// TFLite only support one subgraph currently.
|
||||
SubGraph subgraph = model.subgraphs(0);
|
||||
int tensorNum = subgraph.outputsLength();
|
||||
ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
|
||||
for (int i = 0; i < tensorNum; i++) {
|
||||
outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
|
||||
}
|
||||
return Collections.unmodifiableList(outputTensors);
|
||||
}
|
||||
}
|
@ -0,0 +1,114 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||
import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
|
||||
import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
|
||||
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
|
||||
|
||||
/** Extracts model metadata information out of TFLite metadata FlatBuffer. */
|
||||
final class ModelMetadataInfo {
|
||||
/** Metadata array of input tensors. */
|
||||
private final List</* @Nullable */ TensorMetadata> inputsMetadata;
|
||||
|
||||
/** Metadata array of output tensors. */
|
||||
private final List</* @Nullable */ TensorMetadata> outputsMetadata;
|
||||
|
||||
/**
|
||||
* Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
|
||||
*
|
||||
* @param buffer The TFLite metadata FlatBuffer.
|
||||
* @throws NullPointerException if {@code buffer} is null.
|
||||
* @throws IllegalArgumentException if the metadata does not contain any subgraph metadata.
|
||||
*/
|
||||
ModelMetadataInfo(ByteBuffer buffer) {
|
||||
SupportPreconditions.checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
|
||||
|
||||
ModelMetadata modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
|
||||
SupportPreconditions.checkArgument(
|
||||
modelMetadata.subgraphMetadataLength() > 0,
|
||||
"The metadata flatbuffer does not contain any subgraph metadata.");
|
||||
|
||||
inputsMetadata = getInputsMetadata(modelMetadata);
|
||||
outputsMetadata = getOutputsMetadata(modelMetadata);
|
||||
}
|
||||
|
||||
/** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
|
||||
int getInputTensorCount() {
|
||||
return inputsMetadata.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the metadata for the input tensor specified by {@code inputIndex}.
|
||||
*
|
||||
* @param inputIndex The index of the desired intput tensor.
|
||||
* @throws IllegalArgumentException if the inputIndex specified is invalid.
|
||||
*/
|
||||
@Nullable
|
||||
TensorMetadata getInputTensorMetadata(int inputIndex) {
|
||||
SupportPreconditions.checkArgument(
|
||||
inputIndex >= 0 && inputIndex < inputsMetadata.size(),
|
||||
"The inputIndex specified is invalid.");
|
||||
return inputsMetadata.get(inputIndex);
|
||||
}
|
||||
|
||||
/** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
|
||||
int getOutputTensorCount() {
|
||||
return outputsMetadata.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the metadata for the output tensor specified by {@code outputIndex}.
|
||||
*
|
||||
* @param outputIndex The index of the desired output tensor.
|
||||
* @throws IllegalArgumentException if the outputIndex specified is invalid.
|
||||
*/
|
||||
@Nullable
|
||||
TensorMetadata getOutputTensorMetadata(int outputIndex) {
|
||||
SupportPreconditions.checkArgument(
|
||||
outputIndex >= 0 && outputIndex < outputsMetadata.size(),
|
||||
"The outputIndex specified is invalid.");
|
||||
return outputsMetadata.get(outputIndex);
|
||||
}
|
||||
|
||||
/** Gets metadata for all input tensors. */
|
||||
private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
|
||||
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
|
||||
int tensorNum = subgraphMetadata.inputTensorMetadataLength();
|
||||
ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
|
||||
for (int i = 0; i < tensorNum; i++) {
|
||||
inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
|
||||
}
|
||||
return Collections.unmodifiableList(inputsMetadata);
|
||||
}
|
||||
|
||||
/** Gets metadata for all output tensors. */
|
||||
private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
|
||||
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
|
||||
int tensorNum = subgraphMetadata.outputTensorMetadataLength();
|
||||
ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
|
||||
for (int i = 0; i < tensorNum; i++) {
|
||||
outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
|
||||
}
|
||||
return Collections.unmodifiableList(outputsMetadata);
|
||||
}
|
||||
}
|
@ -0,0 +1,107 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.Channel;
|
||||
|
||||
/**
|
||||
* A byte channel that maintains a current <i>position</i> and allows the position to be changed.
|
||||
* {@link SeekableByteChannelCompat} is compatible with {@link
|
||||
* java.nio.channels.SeekableByteChannel}.
|
||||
*
|
||||
* <p>{@link java.nio.channels.SeekableByteChannel} is not available in Android API 23 and under.
|
||||
* Therefore, {@link SeekableByteChannelCompat} is introduced here to make the interfaces used in
|
||||
* the MetadtaExtractor library consistent with the common used Java libraries.
|
||||
*/
|
||||
interface SeekableByteChannelCompat extends Channel {
|
||||
/**
|
||||
* Reads a sequence of bytes from this channel into the given buffer.
|
||||
*
|
||||
* @param dst The buffer into which bytes are to be transferred
|
||||
* @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
|
||||
* end-of-stream
|
||||
* @throws NonReadableChannelException If this channel was not opened for reading
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws AsynchronousCloseException If another thread closes this channel while the read
|
||||
* operation is in progress
|
||||
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
|
||||
* read operation is in progress, thereby closing the channel and setting the current thread's
|
||||
* interrupt status
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
int read(ByteBuffer dst) throws IOException;
|
||||
|
||||
/**
|
||||
* Writes a sequence of bytes to this channel from the given buffer.
|
||||
*
|
||||
* @param src The buffer from which bytes are to be retrieved
|
||||
* @return The number of bytes written, possibly zero
|
||||
* @throws NonWritableChannelException If this channel was not opened for writing
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws AsynchronousCloseException If another thread closes this channel while the write
|
||||
* operation is in progress
|
||||
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
|
||||
* write operation is in progress, thereby closing the channel and setting the current
|
||||
* thread's interrupt status
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
int write(ByteBuffer src) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns this channel's position.
|
||||
*
|
||||
* @return This channel's position, a non-negative integer counting the number of bytes from the
|
||||
* beginning of the entity to the current position
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
long position() throws IOException;
|
||||
|
||||
/**
|
||||
* Sets this channel's position.
|
||||
*
|
||||
* @param newPosition The new position, a non-negative integer counting the number of bytes from
|
||||
* the beginning of the entity
|
||||
* @return This channel
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws IllegalArgumentException If the new position is negative
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
SeekableByteChannelCompat position(long newPosition) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns the current size of entity to which this channel is connected.
|
||||
*
|
||||
* @return The current size, measured in bytes
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
long size() throws IOException;
|
||||
|
||||
/**
|
||||
* Truncates the entity, to which this channel is connected, to the given size.
|
||||
*
|
||||
* @param size The new size, a non-negative byte count
|
||||
* @return This channel
|
||||
* @throws NonWritableChannelException If this channel was not opened for writing
|
||||
* @throws ClosedChannelException If this channel is closed
|
||||
* @throws IllegalArgumentException If the new size is negative
|
||||
* @throws IOException If some other I/O error occurs
|
||||
*/
|
||||
SeekableByteChannelCompat truncate(long size) throws IOException;
|
||||
}
|
@ -0,0 +1,427 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.metadata;
|
||||
|
||||
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
|
||||
import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.EOFException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.zip.ZipException;
|
||||
|
||||
/**
|
||||
* Reads uncompressed files from the TFLite model, a zip file.
|
||||
*
|
||||
* <p>TODO(b/150237111): add a link to the webpage of MetadataPopulator once it's available.
|
||||
*
|
||||
* <p>A TFLite model file becomes a zip file when it contains associated files. The associated files
|
||||
* can be packed to a TFLite model file using the MetadataPopulator. The associated files are not
|
||||
* compressed when being added to the model file.
|
||||
*
|
||||
* <p>{@link ZipFile} does not support Zip64 format, because TFLite models are much smaller than the
|
||||
* size limit for Zip64, which is 4GB.
|
||||
*/
|
||||
final class ZipFile implements Closeable {
|
||||
/** Maps String to list of ZipEntrys, name -> actual entries. */
|
||||
private final Map<String, List<ZipEntry>> nameMap;
|
||||
|
||||
/** The actual data source. */
|
||||
private final ByteBufferChannel archive;
|
||||
|
||||
/**
|
||||
* Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
|
||||
* ZipFile} does not synchronized over the buffer that is passed into it.
|
||||
*
|
||||
* @param channel the archive
|
||||
* @throws IOException if an error occurs while creating this {@link ZipFile}
|
||||
* @throws ZipException if the channel is not a zip archive
|
||||
* @throws NullPointerException if the archive is null
|
||||
*/
|
||||
public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
|
||||
checkNotNull(channel);
|
||||
ZipParser zipParser = new ZipParser(channel);
|
||||
Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
|
||||
return new ZipFile(channel, nameMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
archive.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Exposes the raw stream of the archive entry.
|
||||
*
|
||||
* <p>Since the associated files will not be compressed when being packed to the zip file, the raw
|
||||
* stream represents the non-compressed files.
|
||||
*
|
||||
* <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
|
||||
* threads concurrently reading from the returned {@link InputStream}, it must be synchronized
|
||||
* externally.
|
||||
*
|
||||
* @param name name of the entry to get the stream for
|
||||
* @return the raw input stream containing data
|
||||
* @throws IllegalArgumentException if the specified file does not exist in the zip file
|
||||
*/
|
||||
public InputStream getRawInputStream(String name) {
|
||||
checkArgument(
|
||||
nameMap.containsKey(name),
|
||||
String.format("The file, %s, does not exist in the zip file.", name));
|
||||
|
||||
List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
|
||||
ZipEntry entry = entriesWithTheSameName.get(0);
|
||||
long start = entry.getDataOffset();
|
||||
long remaining = entry.getSize();
|
||||
return new BoundedInputStream(archive, start, remaining);
|
||||
}
|
||||
|
||||
private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
|
||||
archive = channel;
|
||||
this.nameMap = nameMap;
|
||||
}
|
||||
|
||||
/* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
|
||||
private static class ZipParser {
|
||||
private final ByteBufferChannel archive;
|
||||
|
||||
// Cached buffers that will only be used locally in the class to reduce garbage collection.
|
||||
private final ByteBuffer longBuffer =
|
||||
ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
|
||||
private final ByteBuffer intBuffer =
|
||||
ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
|
||||
private final ByteBuffer shortBuffer =
|
||||
ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
private ZipParser(ByteBufferChannel archive) {
|
||||
this.archive = archive;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses the underlying {@code archive} and returns the information as a list of {@link
|
||||
* ZipEntry}.
|
||||
*/
|
||||
private Map<String, List<ZipEntry>> parseEntries() throws IOException {
|
||||
List<ZipEntry> entries = parseCentralDirectory();
|
||||
return parseLocalFileHeaderData(entries);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the current position contains a central file header signature, {@link
|
||||
* ZipConstants#CENSIG}.
|
||||
*/
|
||||
private boolean foundCentralFileheaderSignature() {
|
||||
long signature = (long) getInt();
|
||||
return signature == ZipConstants.CENSIG;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the value as a Java int from two bytes starting at the current position of the archive.
|
||||
*/
|
||||
private int getShort() {
|
||||
shortBuffer.rewind();
|
||||
archive.read(shortBuffer);
|
||||
shortBuffer.flip();
|
||||
return (int) shortBuffer.getShort();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the value as a Java long from four bytes starting at the current position of the
|
||||
* archive.
|
||||
*/
|
||||
private int getInt() {
|
||||
intBuffer.rewind();
|
||||
archive.read(intBuffer);
|
||||
intBuffer.flip();
|
||||
return intBuffer.getInt();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the value as a Java long from four bytes starting at the current position of the
|
||||
* archive.
|
||||
*/
|
||||
private long getLong() {
|
||||
longBuffer.rewind();
|
||||
archive.read(longBuffer);
|
||||
longBuffer.flip();
|
||||
return longBuffer.getLong();
|
||||
}
|
||||
|
||||
/**
|
||||
* Positions the archive at the start of the central directory.
|
||||
*
|
||||
* <p>First, it searches for the signature of the "end of central directory record", {@link
|
||||
* ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
|
||||
* record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG}
|
||||
* should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file.
|
||||
*
|
||||
* <p>Then, parse the "end of central dir record" and position the archive at the start of the
|
||||
* central directory.
|
||||
*/
|
||||
private void locateCentralDirectory() throws IOException {
|
||||
if (archive.size() < ZipConstants.ENDHDR) {
|
||||
throw new ZipException("The archive is not a ZIP archive.");
|
||||
}
|
||||
|
||||
// Positions the archive at the start of the "end of central directory record".
|
||||
long offsetRecord = archive.size() - ZipConstants.ENDHDR;
|
||||
archive.position(offsetRecord);
|
||||
|
||||
// Checks for the signature, {@link ZipConstants#ENDSIG}.
|
||||
long endSig = getLong();
|
||||
if (endSig != ZipConstants.ENDSIG) {
|
||||
throw new ZipException("The archive is not a ZIP archive.");
|
||||
}
|
||||
|
||||
// Positions the archive at the “offset of central directory”.
|
||||
skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
|
||||
// Gets the offset to central directory
|
||||
long offsetDirectory = getInt();
|
||||
// Goes to the central directory.
|
||||
archive.position(offsetDirectory);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads the central directory of the given archive and populates the internal tables with
|
||||
* {@link ZipEntry} instances.
|
||||
*/
|
||||
private List<ZipEntry> parseCentralDirectory() throws IOException {
|
||||
/** List of entries in the order they appear inside the central directory. */
|
||||
List<ZipEntry> entries = new ArrayList<>();
|
||||
locateCentralDirectory();
|
||||
|
||||
while (foundCentralFileheaderSignature()) {
|
||||
ZipEntry entry = parseCentralDirectoryEntry();
|
||||
entries.add(entry);
|
||||
}
|
||||
|
||||
return entries;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to
|
||||
* the global maps.
|
||||
*/
|
||||
private ZipEntry parseCentralDirectoryEntry() throws IOException {
|
||||
// Positions the archive at the "compressed size" and read the value.
|
||||
skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
|
||||
long compressSize = getInt();
|
||||
|
||||
// Positions the archive at the "filename length" and read the value.
|
||||
skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
|
||||
int fileNameLen = getShort();
|
||||
|
||||
// Reads the extra field length and the comment length.
|
||||
int extraLen = getShort();
|
||||
int commentLen = getShort();
|
||||
|
||||
// Positions the archive at the "local file header offset" and read the value.
|
||||
skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
|
||||
long localHeaderOffset = getInt();
|
||||
|
||||
// Reads the file name.
|
||||
byte[] fileNameBuf = new byte[fileNameLen];
|
||||
archive.read(ByteBuffer.wrap(fileNameBuf));
|
||||
String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
|
||||
|
||||
// Skips the extra field and the comment.
|
||||
skipBytes(extraLen + commentLen);
|
||||
|
||||
ZipEntry entry = new ZipEntry();
|
||||
entry.setSize(compressSize);
|
||||
entry.setLocalHeaderOffset(localHeaderOffset);
|
||||
entry.setName(fileName);
|
||||
|
||||
return entry;
|
||||
}
|
||||
|
||||
/** Walks through all recorded entries and records the offsets for the entry data. */
|
||||
private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
|
||||
/** Maps String to list of ZipEntrys, name -> actual entries. */
|
||||
Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
|
||||
|
||||
for (ZipEntry entry : entries) {
|
||||
long offset = entry.getLocalHeaderOffset();
|
||||
archive.position(offset + ZipConstants.LOCNAM);
|
||||
|
||||
// Gets the data offset of this entry.
|
||||
int fileNameLen = getShort();
|
||||
int extraFieldLen = getShort();
|
||||
long dataOffset =
|
||||
offset
|
||||
+ ZipConstants.LOCEXT
|
||||
+ ZipConstants.SHORT_BYTE_SIZE
|
||||
+ fileNameLen
|
||||
+ extraFieldLen;
|
||||
entry.setDataOffset(dataOffset);
|
||||
|
||||
// Puts the entry into the nameMap.
|
||||
String name = entry.getName();
|
||||
List<ZipEntry> entriesWithTheSameName;
|
||||
if (nameMap.containsKey(name)) {
|
||||
entriesWithTheSameName = nameMap.get(name);
|
||||
} else {
|
||||
entriesWithTheSameName = new ArrayList<>();
|
||||
nameMap.put(name, entriesWithTheSameName);
|
||||
}
|
||||
entriesWithTheSameName.add(entry);
|
||||
}
|
||||
|
||||
return nameMap;
|
||||
}
|
||||
|
||||
/** Skips the given number of bytes or throws an EOFException if skipping failed. */
|
||||
private void skipBytes(int count) throws IOException {
|
||||
long currentPosition = archive.position();
|
||||
long newPosition = currentPosition + count;
|
||||
if (newPosition > archive.size()) {
|
||||
throw new EOFException();
|
||||
}
|
||||
archive.position(newPosition);
|
||||
}
|
||||
}
|
||||
|
||||
/** Stores the data offset and the size of an entry in the archive. */
|
||||
private static class ZipEntry {
|
||||
|
||||
private String name;
|
||||
private long dataOffset = -1;
|
||||
private long size = -1;
|
||||
private long localHeaderOffset = -1;
|
||||
|
||||
public long getSize() {
|
||||
return size;
|
||||
}
|
||||
|
||||
public long getDataOffset() {
|
||||
return dataOffset;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public long getLocalHeaderOffset() {
|
||||
return localHeaderOffset;
|
||||
}
|
||||
|
||||
public void setSize(long size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
public void setDataOffset(long dataOffset) {
|
||||
this.dataOffset = dataOffset;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public void setLocalHeaderOffset(long localHeaderOffset) {
|
||||
this.localHeaderOffset = localHeaderOffset;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Various constants for this {@link ZipFile}.
|
||||
*
|
||||
* <p>Referenced from {@link java.util.zip.ZipConstants}.
|
||||
*/
|
||||
private static class ZipConstants {
|
||||
/** length of Java short in bytes. */
|
||||
static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
|
||||
|
||||
/** length of Java int in bytes. */
|
||||
static final int INT_BYTE_SIZE = Integer.SIZE / 8;
|
||||
|
||||
/** length of Java long in bytes. */
|
||||
static final int LONG_BYTE_SIZE = Long.SIZE / 8;
|
||||
|
||||
/*
|
||||
* Header signatures
|
||||
*/
|
||||
static final long LOCSIG = 0x04034b50L; // "PK\003\004"
|
||||
static final long EXTSIG = 0x08074b50L; // "PK\007\008"
|
||||
static final long CENSIG = 0x02014b50L; // "PK\001\002"
|
||||
static final long ENDSIG = 0x06054b50L; // "PK\005\006"
|
||||
|
||||
/*
|
||||
* Header sizes in bytes (including signatures)
|
||||
*/
|
||||
static final int LOCHDR = 30; // LOC header size
|
||||
static final int EXTHDR = 16; // EXT header size
|
||||
static final int CENHDR = 46; // CEN header size
|
||||
static final int ENDHDR = 22; // END header size
|
||||
|
||||
/*
|
||||
* Local file (LOC) header field offsets
|
||||
*/
|
||||
static final int LOCVER = 4; // version needed to extract
|
||||
static final int LOCFLG = 6; // general purpose bit flag
|
||||
static final int LOCHOW = 8; // compression method
|
||||
static final int LOCTIM = 10; // modification time
|
||||
static final int LOCCRC = 14; // uncompressed file crc-32 value
|
||||
static final int LOCSIZ = 18; // compressed size
|
||||
static final int LOCLEN = 22; // uncompressed size
|
||||
static final int LOCNAM = 26; // filename length
|
||||
static final int LOCEXT = 28; // extra field length
|
||||
|
||||
/*
|
||||
* Extra local (EXT) header field offsets
|
||||
*/
|
||||
static final int EXTCRC = 4; // uncompressed file crc-32 value
|
||||
static final int EXTSIZ = 8; // compressed size
|
||||
static final int EXTLEN = 12; // uncompressed size
|
||||
|
||||
/*
|
||||
* Central directory (CEN) header field offsets
|
||||
*/
|
||||
static final int CENVEM = 4; // version made by
|
||||
static final int CENVER = 6; // version needed to extract
|
||||
static final int CENFLG = 8; // encrypt, decrypt flags
|
||||
static final int CENHOW = 10; // compression method
|
||||
static final int CENTIM = 12; // modification time
|
||||
static final int CENCRC = 16; // uncompressed file crc-32 value
|
||||
static final int CENSIZ = 20; // compressed size
|
||||
static final int CENLEN = 24; // uncompressed size
|
||||
static final int CENNAM = 28; // filename length
|
||||
static final int CENEXT = 30; // extra field length
|
||||
static final int CENCOM = 32; // comment length
|
||||
static final int CENDSK = 34; // disk number start
|
||||
static final int CENATT = 36; // internal file attributes
|
||||
static final int CENATX = 38; // external file attributes
|
||||
static final int CENOFF = 42; // LOC header offset
|
||||
|
||||
/*
|
||||
* End of central directory (END) header field offsets
|
||||
*/
|
||||
static final int ENDSUB = 8; // number of entries on this disk
|
||||
static final int ENDTOT = 10; // total number of entries
|
||||
static final int ENDSIZ = 12; // central directory size in bytes
|
||||
static final int ENDOFF = 16; // offset of first CEN header
|
||||
static final int ENDCOM = 20; // zip file comment length
|
||||
|
||||
private ZipConstants() {}
|
||||
}
|
||||
}
|
542
tensorflow/lite/experimental/support/metadata/metadata.py
Normal file
542
tensorflow/lite/experimental/support/metadata/metadata.py
Normal file
@ -0,0 +1,542 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorFlow Lite metadata tools."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
from flatbuffers.python import flatbuffers
|
||||
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_FLATC_BINARY_PATH = resource_loader.get_path_to_datafile(
|
||||
"../../../../../external/flatbuffers/flatc")
|
||||
_FLATC_TFLITE_METADATA_SCHEMA_FILE = resource_loader.get_path_to_datafile(
|
||||
"metadata_schema.fbs")
|
||||
|
||||
|
||||
# TODO(b/141467403): add delete method for associated files.
|
||||
class MetadataPopulator(object):
|
||||
"""Packs metadata and associated files into TensorFlow Lite model file.
|
||||
|
||||
MetadataPopulator can be used to populate metadata and model associated files
|
||||
into a model file or a model buffer (in bytearray). It can also help to
|
||||
inspect list of files that have been packed into the model or are supposed to
|
||||
be packed into the model.
|
||||
|
||||
The metadata file (or buffer) should be generated based on the metadata
|
||||
schema:
|
||||
third_party/tensorflow/lite/schema/metadata_schema.fbs
|
||||
|
||||
Example usage:
|
||||
Populate matadata and label file into an image classifier model.
|
||||
|
||||
First, based on metadata_schema.fbs, generate the metadata for this image
|
||||
classifer model using Flatbuffers API. Attach the label file onto the ouput
|
||||
tensor (the tensor of probabilities) in the metadata.
|
||||
|
||||
Then, pack the metadata and lable file into the model as follows.
|
||||
|
||||
```python
|
||||
# Populating a metadata file (or a metadta buffer) and associated files to
|
||||
a model file:
|
||||
populator = MetadataPopulator.with_model_file(model_file)
|
||||
# For metadata buffer (bytearray read from the metadata file), use:
|
||||
# populator.load_metadata_buffer(metadata_buf)
|
||||
populator.load_metadata_file(metadata_file)
|
||||
populator.load_associated_files([label.txt])
|
||||
populator.populate()
|
||||
|
||||
# Populating a metadata file (or a metadta buffer) and associated files to
|
||||
a model buffer:
|
||||
populator = MetadataPopulator.with_model_buffer(model_buf)
|
||||
populator.load_metadata_file(metadata_file)
|
||||
populator.load_associated_files([label.txt])
|
||||
populator.populate()
|
||||
# Writing the updated model buffer into a file.
|
||||
updated_model_buf = populator.get_model_buffer()
|
||||
with open("updated_model.tflite", "wb") as f:
|
||||
f.write(updated_model_buf)
|
||||
```
|
||||
"""
|
||||
# As Zip API is used to concatenate associated files after tflite model file,
|
||||
# the populating operation is developed based on a model file. For in-memory
|
||||
# model buffer, we create a tempfile to serve the populating operation.
|
||||
# Creating the deleting such a tempfile is handled by the class,
|
||||
# _MetadataPopulatorWithBuffer.
|
||||
|
||||
METADATA_FIELD_NAME = "TFLITE_METADATA"
|
||||
TFLITE_FILE_IDENTIFIER = b"TFL3"
|
||||
METADATA_FILE_IDENTIFIER = b"M001"
|
||||
|
||||
def __init__(self, model_file):
|
||||
"""Constructor for MetadataPopulator.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
"""
|
||||
_assert_exist(model_file)
|
||||
self._model_file = model_file
|
||||
self._metadata_buf = None
|
||||
self._associated_files = set()
|
||||
|
||||
@classmethod
|
||||
def with_model_file(cls, model_file):
|
||||
"""Creates a MetadataPopulator object that populates data to a model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Returns:
|
||||
MetadataPopulator object.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
"""
|
||||
return cls(model_file)
|
||||
|
||||
# TODO(b/141468993): investigate if type check can be applied to model_buf for
|
||||
# FB.
|
||||
@classmethod
|
||||
def with_model_buffer(cls, model_buf):
|
||||
"""Creates a MetadataPopulator object that populates data to a model buffer.
|
||||
|
||||
Args:
|
||||
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Returns:
|
||||
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
|
||||
"""
|
||||
return _MetadataPopulatorWithBuffer(model_buf)
|
||||
|
||||
def get_model_buffer(self):
|
||||
"""Gets the buffer of the model with packed metadata and associated files.
|
||||
|
||||
Returns:
|
||||
Model buffer (in bytearray).
|
||||
"""
|
||||
with open(self._model_file, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def get_packed_associated_file_list(self):
|
||||
"""Gets a list of associated files packed to the model file.
|
||||
|
||||
Returns:
|
||||
List of packed associated files.
|
||||
"""
|
||||
if not zipfile.is_zipfile(self._model_file):
|
||||
return []
|
||||
|
||||
with zipfile.ZipFile(self._model_file, "r") as zf:
|
||||
return zf.namelist()
|
||||
|
||||
def get_recorded_associated_file_list(self):
|
||||
"""Gets a list of associated files recorded in metadata of the model file.
|
||||
|
||||
Associated files may be attached to a model, a subgraph, or an input/output
|
||||
tensor.
|
||||
|
||||
Returns:
|
||||
List of recorded associated files.
|
||||
"""
|
||||
recorded_files = []
|
||||
|
||||
if not self._metadata_buf:
|
||||
return recorded_files
|
||||
|
||||
metadata = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
|
||||
self._metadata_buf, 0)
|
||||
|
||||
# Add associated files attached to ModelMetadata
|
||||
self._get_associated_files_from_metadata_struct(metadata, recorded_files)
|
||||
|
||||
# Add associated files attached to each SubgraphMetadata
|
||||
for j in range(metadata.SubgraphMetadataLength()):
|
||||
subgraph = metadata.SubgraphMetadata(j)
|
||||
self._get_associated_files_from_metadata_struct(subgraph, recorded_files)
|
||||
|
||||
# Add associated files attached to each input tensor
|
||||
for k in range(subgraph.InputTensorMetadataLength()):
|
||||
tensor = subgraph.InputTensorMetadata(k)
|
||||
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
|
||||
|
||||
# Add associated files attached to each output tensor
|
||||
for k in range(subgraph.OutputTensorMetadataLength()):
|
||||
tensor = subgraph.OutputTensorMetadata(k)
|
||||
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
|
||||
|
||||
return recorded_files
|
||||
|
||||
def load_associated_files(self, associated_files):
|
||||
"""Loads associated files that to be concatenated after the model file.
|
||||
|
||||
Args:
|
||||
associated_files: list of file paths.
|
||||
|
||||
Raises:
|
||||
IOError:
|
||||
File not found.
|
||||
"""
|
||||
for af in associated_files:
|
||||
_assert_exist(af)
|
||||
self._associated_files.add(af)
|
||||
|
||||
def load_metadata_buffer(self, metadata_buf):
|
||||
"""Loads the metadata buffer (in bytearray) to be populated.
|
||||
|
||||
Args:
|
||||
metadata_buf: metadata buffer (in bytearray) to be populated.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
The metadata to be populated is empty.
|
||||
"""
|
||||
if not metadata_buf:
|
||||
raise ValueError("The metadata to be populated is empty.")
|
||||
|
||||
self._metadata_buf = metadata_buf
|
||||
|
||||
def load_metadata_file(self, metadata_file):
|
||||
"""Loads the metadata file to be populated.
|
||||
|
||||
Args:
|
||||
metadata_file: path to the metadata file to be populated.
|
||||
|
||||
Raises:
|
||||
IOError:
|
||||
File not found.
|
||||
"""
|
||||
_assert_exist(metadata_file)
|
||||
with open(metadata_file, "rb") as f:
|
||||
metadata_buf = f.read()
|
||||
self.load_metadata_buffer(bytearray(metadata_buf))
|
||||
|
||||
def populate(self):
|
||||
"""Populates loaded metadata and associated files into the model file."""
|
||||
self._assert_validate()
|
||||
self._populate_metadata_buffer()
|
||||
self._populate_associated_files()
|
||||
|
||||
def _assert_validate(self):
|
||||
"""Validates the metadata and associated files to be populated.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
File is recorded in the metadata, but is not going to be populated.
|
||||
File has already been packed.
|
||||
"""
|
||||
# Gets files that are recorded in metadata.
|
||||
recorded_files = self.get_recorded_associated_file_list()
|
||||
|
||||
# Gets files that have been packed to self._model_file.
|
||||
packed_files = self.get_packed_associated_file_list()
|
||||
|
||||
# Gets the file name of those associated files to be populated.
|
||||
to_be_populated_files = []
|
||||
for af in self._associated_files:
|
||||
to_be_populated_files.append(os.path.basename(af))
|
||||
|
||||
# Checks all files recorded in the metadata will be populated.
|
||||
for rf in recorded_files:
|
||||
if rf not in to_be_populated_files and rf not in packed_files:
|
||||
raise ValueError("File, '{0}', is recorded in the metadata, but has "
|
||||
"not been loaded into the populator.".format(rf))
|
||||
|
||||
for f in to_be_populated_files:
|
||||
if f in packed_files:
|
||||
raise ValueError("File, '{0}', has already been packed.".format(f))
|
||||
|
||||
if f not in recorded_files:
|
||||
warnings.warn(
|
||||
"File, '{0}', does not exsit in the metadata. But packing it to "
|
||||
"tflite model is still allowed.".format(f))
|
||||
|
||||
def _copy_archived_files(self, src_zip, dst_zip, file_list):
|
||||
"""Copy archieved files in file_list from src_zip ro dst_zip."""
|
||||
|
||||
if not zipfile.is_zipfile(src_zip):
|
||||
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
|
||||
|
||||
with zipfile.ZipFile(src_zip,
|
||||
"r") as src_zf, zipfile.ZipFile(dst_zip,
|
||||
"a") as dst_zf:
|
||||
src_list = src_zf.namelist()
|
||||
for f in file_list:
|
||||
if f not in src_list:
|
||||
raise ValueError(
|
||||
"File, '{0}', does not exist in the zipfile, {1}.".format(
|
||||
f, src_zip))
|
||||
file_buffer = src_zf.read(f)
|
||||
dst_zf.writestr(f, file_buffer)
|
||||
|
||||
def _get_associated_files_from_metadata_struct(self, file_holder, file_list):
|
||||
for j in range(file_holder.AssociatedFilesLength()):
|
||||
file_list.append(file_holder.AssociatedFiles(j).Name().decode("utf-8"))
|
||||
|
||||
def _populate_associated_files(self):
|
||||
"""Concatenates associated files after TensorFlow Lite model file.
|
||||
|
||||
If the MetadataPopulator object is created using the method,
|
||||
with_model_file(model_file), the model file will be updated.
|
||||
"""
|
||||
# Opens up the model file in "appending" mode.
|
||||
# If self._model_file already has pack files, zipfile will concatenate
|
||||
# addition files after self._model_file. For example, suppose we have
|
||||
# self._model_file = old_tflite_file | label1.txt | label2.txt
|
||||
# Then after trigger populate() to add label3.txt, self._model_file becomes
|
||||
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
|
||||
with zipfile.ZipFile(self._model_file, "a") as zf:
|
||||
for af in self._associated_files:
|
||||
filename = os.path.basename(af)
|
||||
zf.write(af, filename)
|
||||
|
||||
def _populate_metadata_buffer(self):
|
||||
"""Populates the metadata buffer (in bytearray) into the model file.
|
||||
|
||||
Inserts metadata_buf into the metadata field of schema.Model. If the
|
||||
MetadataPopulator object is created using the method,
|
||||
with_model_file(model_file), the model file will be updated.
|
||||
"""
|
||||
|
||||
with open(self._model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
|
||||
model = _schema_fb.ModelT.InitFromObj(
|
||||
_schema_fb.Model.GetRootAsModel(model_buf, 0))
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
buffer_field.data = self._metadata_buf
|
||||
|
||||
is_populated = False
|
||||
if not model.metadata:
|
||||
model.metadata = []
|
||||
else:
|
||||
# Check if metadata has already been populated.
|
||||
for meta in model.metadata:
|
||||
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
|
||||
is_populated = True
|
||||
model.buffers[meta.buffer] = buffer_field
|
||||
|
||||
if not is_populated:
|
||||
if not model.buffers:
|
||||
model.buffers = []
|
||||
model.buffers.append(buffer_field)
|
||||
# Creates a new metadata field.
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
metadata_field.name = self.METADATA_FIELD_NAME
|
||||
metadata_field.buffer = len(model.buffers) - 1
|
||||
model.metadata.append(metadata_field)
|
||||
|
||||
# Packs model back to a flatbuffer binaray file.
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
|
||||
model_buf = b.Output()
|
||||
|
||||
# Saves the updated model buffer to model file.
|
||||
# Gets files that have been packed to self._model_file.
|
||||
packed_files = self.get_packed_associated_file_list()
|
||||
if packed_files:
|
||||
# Writes the updated model buffer and associated files into a new model
|
||||
# file. Then overwrites the original model file.
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
new_file = temp.name
|
||||
with open(new_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
self._copy_archived_files(self._model_file, new_file, packed_files)
|
||||
shutil.copy(new_file, self._model_file)
|
||||
os.remove(new_file)
|
||||
else:
|
||||
with open(self._model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
|
||||
class _MetadataPopulatorWithBuffer(MetadataPopulator):
|
||||
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
|
||||
|
||||
This class is used to populate metadata into a in-memory model buffer. As we
|
||||
use Zip API to concatenate associated files after tflite model file, the
|
||||
populating operation is developed based on a model file. For in-memory model
|
||||
buffer, we create a tempfile to serve the populating operation. This class is
|
||||
then used to generate this tempfile, and delete the file when the
|
||||
MetadataPopulator object is deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, model_buf):
|
||||
"""Constructor for _MetadataPopulatorWithBuffer.
|
||||
|
||||
Args:
|
||||
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||
|
||||
Raises:
|
||||
ValueError: model_buf is empty.
|
||||
"""
|
||||
if not model_buf:
|
||||
raise ValueError("model_buf cannot be empty.")
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
model_file = temp.name
|
||||
|
||||
with open(model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
MetadataPopulator.__init__(self, model_file)
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor of _MetadataPopulatorWithBuffer.
|
||||
|
||||
Deletes the tempfile.
|
||||
"""
|
||||
if os.path.exists(self._model_file):
|
||||
os.remove(self._model_file)
|
||||
|
||||
|
||||
class MetadataDisplayer(object):
|
||||
"""Displays metadata and associated file info in human-readable format."""
|
||||
|
||||
def __init__(self, model_file, metadata_file, associated_file_list):
|
||||
"""Constructor for MetadataDisplayer.
|
||||
|
||||
Args:
|
||||
model_file: valid path to the model file.
|
||||
metadata_file: valid path to the metadata file.
|
||||
associated_file_list: list of associate files in the model file.
|
||||
"""
|
||||
self._model_file = model_file
|
||||
self._metadata_file = metadata_file
|
||||
self._associated_file_list = associated_file_list
|
||||
|
||||
@classmethod
|
||||
def with_model_file(cls, model_file):
|
||||
"""Creates a MetadataDisplayer object for the model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to a TensorFlow Lite model file.
|
||||
|
||||
Returns:
|
||||
MetadataDisplayer object.
|
||||
|
||||
Raises:
|
||||
IOError: File not found.
|
||||
ValueError: The model does not have metadata.
|
||||
"""
|
||||
_assert_exist(model_file)
|
||||
metadata_file = cls._save_temporary_metadata_file(model_file)
|
||||
associated_file_list = cls._parse_packed_associted_file_list(model_file)
|
||||
return cls(model_file, metadata_file, associated_file_list)
|
||||
|
||||
def export_metadata_json_file(self, export_dir):
|
||||
"""Converts the metadata into a json file.
|
||||
|
||||
Args:
|
||||
export_dir: the directory that the json file will be exported to. The json
|
||||
file will be named after the model file, but with ".json" as extension.
|
||||
"""
|
||||
subprocess.check_call([
|
||||
_FLATC_BINARY_PATH, "-o", export_dir, "--json",
|
||||
_FLATC_TFLITE_METADATA_SCHEMA_FILE, "--", self._metadata_file,
|
||||
"--strict-json"
|
||||
])
|
||||
temp_name = os.path.join(
|
||||
export_dir,
|
||||
os.path.splitext(os.path.basename(self._metadata_file))[0] + ".json")
|
||||
expected_name = os.path.join(
|
||||
export_dir,
|
||||
os.path.splitext(os.path.basename(self._model_file))[0] + ".json")
|
||||
os.rename(temp_name, expected_name)
|
||||
|
||||
def get_packed_associated_file_list(self):
|
||||
"""Returns a list of associated files that are packed in the model.
|
||||
|
||||
Returns:
|
||||
A name list of associated files.
|
||||
"""
|
||||
return copy.deepcopy(self._associated_file_list)
|
||||
|
||||
@staticmethod
|
||||
def _save_temporary_metadata_file(model_file):
|
||||
"""Saves the metadata in the model file to a temporary file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to the model file.
|
||||
|
||||
Returns:
|
||||
Path to the metadata temporary file.
|
||||
|
||||
Raises:
|
||||
ValueError: The model does not have metadata.
|
||||
"""
|
||||
with open(model_file, "rb") as f:
|
||||
model_buf = f.read()
|
||||
|
||||
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
|
||||
|
||||
# Gets metadata from the model file.
|
||||
for i in range(tflite_model.MetadataLength()):
|
||||
meta = tflite_model.Metadata(i)
|
||||
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
|
||||
buffer_index = meta.Buffer()
|
||||
metadata = tflite_model.Buffers(buffer_index)
|
||||
metadata_buf = metadata.DataAsNumpy().tobytes()
|
||||
# Creates a temporary file to store the metadata.
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
metadata_file = temp.name
|
||||
# Saves the metadata into the temporary file.
|
||||
with open(metadata_file, "wb") as f:
|
||||
f.write(metadata_buf)
|
||||
return metadata_file
|
||||
|
||||
raise ValueError("The model does not have metadata.")
|
||||
|
||||
@staticmethod
|
||||
def _parse_packed_associted_file_list(model_file):
|
||||
"""Gets a list of associated files packed to the model file.
|
||||
|
||||
Args:
|
||||
model_file: valid path to the model file.
|
||||
|
||||
Returns:
|
||||
List of packed associated files.
|
||||
"""
|
||||
if not zipfile.is_zipfile(model_file):
|
||||
return []
|
||||
|
||||
with zipfile.ZipFile(model_file, "r") as zf:
|
||||
return zf.namelist()
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor of MetadataDisplayer.
|
||||
|
||||
Deletes the tempfile.
|
||||
"""
|
||||
if os.path.exists(self._metadata_file):
|
||||
os.remove(self._metadata_file)
|
||||
|
||||
|
||||
def _assert_exist(filename):
|
||||
"""Checks if a file exists."""
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("File, '{0}', does not exist.".format(filename))
|
@ -0,0 +1,499 @@
|
||||
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
namespace tflite;
|
||||
|
||||
// TFLite metadata contains both human readable and machine readable information
|
||||
// about what the model does and how to use the model. It can be used as a
|
||||
// README file, which elaborates the details of the model, each input/ouput
|
||||
// tensor, and each associated file.
|
||||
//
|
||||
// An important use case of TFLite metadata is the TFLite codegen tool, which
|
||||
// automatically generates the model interface based on the properties of the
|
||||
// model and the tensors. The model interface provides high-level APIs to
|
||||
// interact with the model, such as preprocessing the input data and running
|
||||
// inferences.
|
||||
//
|
||||
// Entries marked with "<Codegen usage>" are used in TFLite codegen tool to
|
||||
// generate the model interface. It is recommended to fill in at least those
|
||||
// enties to boost the codegen performance.
|
||||
|
||||
// This corresponds to the schema version.
|
||||
file_identifier "M001";
|
||||
// File extension of any written files.
|
||||
file_extension "tflitemeta";
|
||||
|
||||
enum AssociatedFileType : byte {
|
||||
UNKNOWN = 0,
|
||||
// Files such as readme.txt
|
||||
DESCRIPTIONS = 1,
|
||||
|
||||
// Contains labels that annotate certain axis of the tensor. For example,
|
||||
// the label file in image classification. Those labels annotate the
|
||||
// the output tensor, such that each value in the output tensor is the
|
||||
// probability of that corresponding category specified by the label.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// If an output tensor has an associated file as TENSOR_AXIS_LABELS, return
|
||||
// the output as a mapping between the labels and probability in the model
|
||||
// interface.
|
||||
// If multiple files of the same type are present, the first one is used by
|
||||
// default; additional ones are to be distinguished from one another by their
|
||||
// specified locale.
|
||||
TENSOR_AXIS_LABELS = 2,
|
||||
|
||||
// Contains labels that tensor values correspond to. For example, in
|
||||
// the object detection model, one of the output tensors is the detected
|
||||
// classes. And each value in the tensor refers to the index of label in the
|
||||
// category label file.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// If an output tensor has an associated file as TENSOR_VALUE_LABELS, convert
|
||||
// the tensor values into labels, and return a list of string as the output.
|
||||
// If multiple files of the same type are present, the first one is used by
|
||||
// default; additional ones are to be distinguished from one another by their
|
||||
// specified locale.
|
||||
TENSOR_VALUE_LABELS = 3,
|
||||
|
||||
// Contains sigmoid-based score calibration parameters, formatted as CSV.
|
||||
// Lines contain for each index of an output tensor the scale, slope, offset
|
||||
// and min_score parameters to be used for sigmoid fitting (in this order and
|
||||
// in `strtof`-compatible [1] format).
|
||||
// A line may be left empty to default calibrated scores for this index to
|
||||
// default_score. See documentation for ScoreCalibrationOptions for details.
|
||||
//
|
||||
// [1]: https://en.cppreference.com/w/c/string/byte/strtof
|
||||
TENSOR_AXIS_SCORE_CALIBRATION = 4,
|
||||
}
|
||||
|
||||
table AssociatedFile {
|
||||
// Name of this file. Need to be exact the same as the name of the actual file
|
||||
// packed into the TFLite model as a zip file.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Locates to the actual file in the TFLite model.
|
||||
name:string;
|
||||
|
||||
// A description of what the file is.
|
||||
description:string;
|
||||
|
||||
// Type of the associated file. There may be special pre/post processing for
|
||||
// some types. For example in image classification, a label file of the output
|
||||
// will be used to convert object index into string.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to process the corresponding tensor.
|
||||
type:AssociatedFileType;
|
||||
|
||||
// An optional locale for this associated file (if applicable). It is
|
||||
// recommended to use an ISO 639-1 letter code (e.g. "en" for English),
|
||||
// optionally completed by a two letter region code (e.g. "en-US" for US
|
||||
// English and "en-CA" for Canadian English).
|
||||
// Leverage this in order to specify e.g multiple label files translated in
|
||||
// different languages.
|
||||
locale:string;
|
||||
}
|
||||
|
||||
// The basic content type for all tensors.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Input feature tensors:
|
||||
// 1. Generates the method to load data from a TensorBuffer.
|
||||
// 2. Creates the preprocessing logic. The default processing pipeline is:
|
||||
// [NormalizeOp, QuantizeOp].
|
||||
// Output feature tensors:
|
||||
// 1. Generates the method to return the output data to a TensorBuffer.
|
||||
// 2. Creates the post-processing logic. The default processing pipeline is:
|
||||
// [DeQuantizeOp].
|
||||
table FeatureProperties {
|
||||
}
|
||||
|
||||
// The type of color space of an image.
|
||||
enum ColorSpaceType : byte {
|
||||
UNKNOWN = 0,
|
||||
RGB = 1,
|
||||
GRAYSCALE = 2,
|
||||
}
|
||||
|
||||
table ImageSize {
|
||||
width:uint;
|
||||
height:uint;
|
||||
}
|
||||
|
||||
// The properties for image tensors.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Input image tensors:
|
||||
// 1. Generates the method to load an image from a TensorImage.
|
||||
// 2. Creates the preprocessing logic. The default processing pipeline is:
|
||||
// [ResizeOp, NormalizeOp, QuantizeOp].
|
||||
// Output image tensors:
|
||||
// 1. Generates the method to return the output data to a TensorImage.
|
||||
// 2. Creates the post-processing logic. The default processing pipeline is:
|
||||
// [DeQuantizeOp].
|
||||
table ImageProperties {
|
||||
// The color space of the image.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to convert the color space of a given image from users.
|
||||
color_space:ColorSpaceType;
|
||||
|
||||
// Indicates the default value of image width and height if the tensor shape
|
||||
// is dynamic. For fixed-size tensor, this size will be consistent with the
|
||||
// expected size.
|
||||
default_size:ImageSize;
|
||||
}
|
||||
|
||||
// The properties for tensors representing bounding boxes.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Input image tensors: NA.
|
||||
// Output image tensors: parses the values into a data stucture that represents
|
||||
// bounding boxes. For example, in the generated wrapper for Android, it returns
|
||||
// the output as android.graphics.Rect objects.
|
||||
enum BoundingBoxType : byte {
|
||||
UNKNOWN = 0,
|
||||
// Represents the bounding box by using the combination of boundaries,
|
||||
// {left, top, right, bottom}.
|
||||
// The default order is {left, top, right, bottom}. Other orders can be
|
||||
// indicated by BoundingBoxProperties.index.
|
||||
BOUNDARIES = 1,
|
||||
|
||||
// Represents the bounding box by using the upper_left corner, width and
|
||||
// height.
|
||||
// The default order is {upper_left_x, upper_left_y, width, height}. Other
|
||||
// orders can be indicated by BoundingBoxProperties.index.
|
||||
UPPER_LEFT = 2,
|
||||
|
||||
// Represents the bounding box by using the center of the box, width and
|
||||
// height. The default order is {center_x, center_y, width, height}. Other
|
||||
// orders can be indicated by BoundingBoxProperties.index.
|
||||
CENTER = 3,
|
||||
|
||||
}
|
||||
|
||||
enum CoordinateType : byte {
|
||||
// The coordinates are float values from 0 to 1.
|
||||
RATIO = 0,
|
||||
// The coordinates are integers.
|
||||
PIXEL = 1,
|
||||
}
|
||||
|
||||
table BoundingBoxProperties {
|
||||
// Denotes the order of the elements defined in each bounding box type. An
|
||||
// empty index array represent the defualt order of each bounding box type.
|
||||
// For example, to denote the default order of BOUNDARIES, {left, top, right,
|
||||
// bottom}, the index should be {0, 1, 2, 3}. To denote the order {left,
|
||||
// right, top, bottom}, the order should be {0, 2, 1, 3}.
|
||||
//
|
||||
// The index array can be applied to all bounding box types to adjust the
|
||||
// order of their corresponding underlying elements.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Indicates how to parse the bounding box values.
|
||||
index:[uint];
|
||||
|
||||
// <Codegen usage>:
|
||||
// Indicates how to parse the bounding box values.
|
||||
type:BoundingBoxType;
|
||||
|
||||
// <Codegen usage>:
|
||||
// Indicates how to convert the bounding box back to the original image in
|
||||
// pixels.
|
||||
coordinate_type:CoordinateType;
|
||||
}
|
||||
|
||||
union ContentProperties {
|
||||
FeatureProperties,
|
||||
ImageProperties,
|
||||
BoundingBoxProperties,
|
||||
}
|
||||
|
||||
table ValueRange {
|
||||
min:int;
|
||||
max:int;
|
||||
}
|
||||
|
||||
table Content {
|
||||
// The properties that the content may have, indicating the type of the
|
||||
// Content.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Indicates how to process the tensor.
|
||||
content_properties:ContentProperties;
|
||||
|
||||
// The range of dimensions that the content corresponds to. A NULL
|
||||
// "range" indicates that the content uses up all dimensions,
|
||||
// except the batch axis if applied.
|
||||
//
|
||||
// Here are all the possible situations of how a tensor is composed.
|
||||
// Case 1: The tensor is a single object, such as an image.
|
||||
// For example, the input of an image classifier
|
||||
// (https://www.tensorflow.org/lite/models/image_classification/overview),
|
||||
// a tensor of shape [1, 224, 224, 3]. Dimensions 1 to 3 correspond to the
|
||||
// image. Since dimension 0 is a batch axis, which can be ignored,
|
||||
// "range" can be left as NULL.
|
||||
//
|
||||
// Case 2: The tensor contains multiple instances of the same object.
|
||||
// For example, the output tensor of detected bounding boxes of an object
|
||||
// detection model
|
||||
// (https://www.tensorflow.org/lite/models/object_detection/overview).
|
||||
// The tensor shape is [1, 10, 4]. Here is the what the three dimensions
|
||||
// represent for:
|
||||
// dimension 0: the batch axis.
|
||||
// dimension 1: the 10 objects detected with the highest confidence.
|
||||
// dimension 2: the bounding boxes of the 10 detected objects.
|
||||
// The tensor is essentially 10 bounding boxes. In this case,
|
||||
// "range" should be {min=2; max=2;}.
|
||||
// Another example is the pose estimation model
|
||||
// (https://www.tensorflow.org/lite/models/pose_estimation/overview).
|
||||
// The output tensor of heatmaps is in the shape of [1, 9, 9, 17].
|
||||
// Here is the what the four dimensions represent for:
|
||||
// dimension 0: the batch axis.
|
||||
// dimension 1/2: the heatmap image.
|
||||
// dimension 3: 17 body parts of a person.
|
||||
// Even though the last axis is body part, the real content of this tensor is
|
||||
// the heatmap. "range" should be [min=2; max=3].
|
||||
//
|
||||
// Case 3: The tensor contains multiple different objects. (Not supported by
|
||||
// Content at this point).
|
||||
// Sometimes a tensor may contain multiple different objects, thus different
|
||||
// contents. It is very common for regression models. For example, a model
|
||||
// to predict the fuel efficiency
|
||||
// (https://www.tensorflow.org/tutorials/keras/regression).
|
||||
// The input tensor has shape [1, 9], consisting of 9 features, such as
|
||||
// "Cylinders", "Displacement", "Weight", etc. In this case, dimension 1
|
||||
// contains 9 different contents. However, since these sub-dimension objects
|
||||
// barely need to be specifically processed, their contents are not recorded
|
||||
// in the metadata. Through, the name of each dimension can be set through
|
||||
// TensorMetadata.dimension_names.
|
||||
//
|
||||
// Note that if it is not case 3, a tensor can only have one content type.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Case 1: return a processed single object of certain content type.
|
||||
// Case 2: return a list of processed objects of certain content type. The
|
||||
// generated model interface have API to random access those objects from
|
||||
// the output.
|
||||
range:ValueRange;
|
||||
}
|
||||
|
||||
// Parameters that are used when normalizing the tensor.
|
||||
table NormalizationOptions{
|
||||
// mean and std are normalization parameters. Tensor values are normailzed
|
||||
// per-channelly by,
|
||||
// (x - mean) / std.
|
||||
// For example, a float MobileNet model will have
|
||||
// mean = 127.5f and std = 127.5f.
|
||||
// A quantized MobileNet model will have
|
||||
// mean = 0.0f and std = 1.0f.
|
||||
// If there is only one value in mean or std, we'll propogate the value to
|
||||
// all channels.
|
||||
|
||||
// Per-channel mean of the possible values used in normalization.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Apply normalization to input tensors accordingly.
|
||||
mean:[float];
|
||||
|
||||
// Per-channel standard dev. of the possible values used in normalization.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Apply normalization to input tensors accordingly.
|
||||
std:[float];
|
||||
}
|
||||
|
||||
// The different possible score transforms to apply to uncalibrated scores
|
||||
// before applying score calibration.
|
||||
enum ScoreTransformationType : byte {
|
||||
// Identity function: g(x) = x.
|
||||
IDENTITY = 0,
|
||||
// Log function: g(x) = log(x).
|
||||
LOG = 1,
|
||||
// Inverse logistic function: g(x) = log(x) - log(1-x).
|
||||
INVERSE_LOGISTIC = 2,
|
||||
}
|
||||
|
||||
// Options to perform score calibration on an output tensor through sigmoid
|
||||
// functions. One of the main purposes of score calibration is to make scores
|
||||
// across classes comparable, so that a common threshold can be used for all
|
||||
// output classes. This is meant for models producing class predictions as
|
||||
// output, e.g. image classification or detection models.
|
||||
//
|
||||
// For each index in the output tensor, this applies:
|
||||
// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score`,
|
||||
// * `f(x) = default_score` otherwise or if no scale, slope, offset and
|
||||
// min_score have been specified.
|
||||
// Where:
|
||||
// * scale, slope, offset and min_score are index-specific parameters
|
||||
// * g(x) is an index-independent transform among those defined in
|
||||
// ScoreTransformationType
|
||||
// * default_score is an index-independent parameter.
|
||||
// An AssociatedFile with type TANSOR_AXIS_SCORE_CALIBRATION specifying the
|
||||
// index-specific parameters must be associated with the corresponding
|
||||
// TensorMetadata for score calibration be applied.
|
||||
table ScoreCalibrationOptions {
|
||||
// The function to use for transforming the uncalibrated score before
|
||||
// applying score calibration.
|
||||
score_transformation:ScoreTransformationType;
|
||||
|
||||
// The default calibrated score to apply if the uncalibrated score is
|
||||
// below min_score or if no parameters were specified for a given index.
|
||||
default_score:float;
|
||||
}
|
||||
|
||||
// Performs thresholding on output tensor values, in order to filter out
|
||||
// low-confidence results.
|
||||
table ScoreThresholdingOptions {
|
||||
// The recommended global threshold below which results are considered
|
||||
// low-confidence and should be filtered out.
|
||||
global_score_threshold:float;
|
||||
}
|
||||
|
||||
// Options that are used when processing the tensor.
|
||||
union ProcessUnitOptions {
|
||||
NormalizationOptions,
|
||||
ScoreCalibrationOptions,
|
||||
ScoreThresholdingOptions,
|
||||
}
|
||||
|
||||
// A process unit that is used to process the tensor out-of-graph.
|
||||
table ProcessUnit {
|
||||
options:ProcessUnitOptions;
|
||||
}
|
||||
|
||||
|
||||
// Statistics to describe a tensor.
|
||||
table Stats {
|
||||
// Max and min are not currently used in tflite.support codegen. They mainly
|
||||
// serve as references for users to better understand the model. They can also
|
||||
// be used to validate model pre/post processing results.
|
||||
// If there is only one value in max or min, we'll propogate the value to
|
||||
// all channels.
|
||||
|
||||
// Per-channel maximum value of the tensor.
|
||||
max:[float];
|
||||
|
||||
// Per-channel minimum value of the tensor.
|
||||
min:[float];
|
||||
}
|
||||
|
||||
// Detailed information of an input or output tensor.
|
||||
table TensorMetadata {
|
||||
// Name of the tensor.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// The name of this tensor in the generated model interface.
|
||||
name:string;
|
||||
|
||||
// A description of the tensor.
|
||||
description:string;
|
||||
|
||||
// A list of names of the dimensions in this tentor. The length of
|
||||
// dimension_names need to match the number of dimensions in this tensor.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// The name of each dimension in the generated model interface. See "Case 2"
|
||||
// in the comments of Content.range.
|
||||
dimension_names:[string];
|
||||
|
||||
// The content that represents this tensor.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to process this tensor. See each item in ContentProperties
|
||||
// for the default process units that will be applied to the tensor.
|
||||
content:Content;
|
||||
|
||||
// The process units that are used to process the tensor out-of-graph.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Contains the parameters of the default processing pipeline for each content
|
||||
// type, such as the normalization parameters in all content types. See the
|
||||
// items under ContentProperties for the details of the default processing
|
||||
// pipeline.
|
||||
process_units:[ProcessUnit];
|
||||
|
||||
// The statistics of the tensor values.
|
||||
stats:Stats;
|
||||
|
||||
// A list of associated files of this tensor.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Contains processing parameters of this tensor, such as normalization.
|
||||
associated_files:[AssociatedFile];
|
||||
}
|
||||
|
||||
table SubGraphMetadata {
|
||||
// Name of the subgraph.
|
||||
//
|
||||
// Note that, since TFLite only support one subgraph at this moment, the
|
||||
// Codegen tool will use the name in ModelMetadata in the generated model
|
||||
// interface.
|
||||
name:string;
|
||||
|
||||
// A description explains details about what the subgraph does.
|
||||
description:string;
|
||||
|
||||
// Metadata of all input tensors used in this subgraph.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to process the inputs.
|
||||
input_tensor_metadata:[TensorMetadata];
|
||||
|
||||
// Metadata of all output tensors used in this subgraph.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to process the outputs.
|
||||
output_tensor_metadata:[TensorMetadata];
|
||||
|
||||
// A list of associated files of this subgraph.
|
||||
associated_files:[AssociatedFile];
|
||||
}
|
||||
|
||||
table ModelMetadata {
|
||||
// Name of the model.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// The name of the model in the generated model interface.
|
||||
name:string;
|
||||
|
||||
// Model description in schema.
|
||||
description:string;
|
||||
|
||||
// Version of the model that specified by model creators.
|
||||
version:string;
|
||||
|
||||
// Noted that, the minimum required TFLite runtime version that the model is
|
||||
// compatible with, has already been added as a metadata entry in tflite
|
||||
// schema. We'll decide later if we want to move it here, and keep it with
|
||||
// other metadata entries.
|
||||
|
||||
// Metadata of all the subgraphs of the model. The 0th is assumed to be the
|
||||
// main subgraph.
|
||||
//
|
||||
// <Codegen usage>:
|
||||
// Determines how to process the inputs and outputs.
|
||||
subgraph_metadata:[SubGraphMetadata];
|
||||
|
||||
// The person who creates this model.
|
||||
author:string;
|
||||
|
||||
// Licenses that may apply to this model.
|
||||
license:string;
|
||||
|
||||
// A list of associated files of this model.
|
||||
associated_files:[AssociatedFile];
|
||||
}
|
||||
|
||||
root_type ModelMetadata;
|
381
tensorflow/lite/experimental/support/metadata/metadata_test.py
Normal file
381
tensorflow/lite/experimental/support/metadata/metadata_test.py
Normal file
@ -0,0 +1,381 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.lite.experimental.support.metadata.metadata."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import six
|
||||
|
||||
from flatbuffers.python import flatbuffers
|
||||
from tensorflow.lite.experimental.support.metadata import metadata as _metadata
|
||||
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MetadataTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MetadataTest, self).setUp()
|
||||
self._invalid_model_buf = None
|
||||
self._invalid_file = "not_existed_file"
|
||||
self._empty_model_buf = self._create_empty_model_buf()
|
||||
self._empty_model_file = self.create_tempfile().full_path
|
||||
with open(self._empty_model_file, "wb") as f:
|
||||
f.write(self._empty_model_buf)
|
||||
self._model_file = self._create_model_file_with_metadata_and_buf_fields()
|
||||
self._metadata_file = self._create_metadata_file()
|
||||
self._file1 = self.create_tempfile("file1").full_path
|
||||
self._file2 = self.create_tempfile("file2").full_path
|
||||
self._file3 = self.create_tempfile("file3").full_path
|
||||
|
||||
def _create_empty_model_buf(self):
|
||||
model = _schema_fb.ModelT()
|
||||
model_builder = flatbuffers.Builder(0)
|
||||
model_builder.Finish(
|
||||
model.Pack(model_builder),
|
||||
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||
return model_builder.Output()
|
||||
|
||||
def _create_model_file_with_metadata_and_buf_fields(self):
|
||||
metadata_field = _schema_fb.MetadataT()
|
||||
metadata_field.name = "meta"
|
||||
buffer_field = _schema_fb.BufferT()
|
||||
model = _schema_fb.ModelT()
|
||||
model.metadata = [metadata_field, metadata_field]
|
||||
model.buffers = [buffer_field, buffer_field, buffer_field]
|
||||
model_builder = flatbuffers.Builder(0)
|
||||
model_builder.Finish(
|
||||
model.Pack(model_builder),
|
||||
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||
|
||||
mnodel_file = self.create_tempfile().full_path
|
||||
with open(mnodel_file, "wb") as f:
|
||||
f.write(model_builder.Output())
|
||||
|
||||
return mnodel_file
|
||||
|
||||
def _create_metadata_file(self):
|
||||
associated_file1 = _metadata_fb.AssociatedFileT()
|
||||
associated_file1.name = b"file1"
|
||||
associated_file2 = _metadata_fb.AssociatedFileT()
|
||||
associated_file2.name = b"file2"
|
||||
self.expected_recorded_files = [
|
||||
six.ensure_str(associated_file1.name),
|
||||
six.ensure_str(associated_file2.name)
|
||||
]
|
||||
|
||||
output_meta = _metadata_fb.TensorMetadataT()
|
||||
output_meta.associatedFiles = [associated_file2]
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
subgraph.outputTensorMetadata = [output_meta]
|
||||
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.name = "Mobilenet_quantized"
|
||||
model_meta.associatedFiles = [associated_file1]
|
||||
model_meta.subgraphMetadata = [subgraph]
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
model_meta.Pack(b),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
|
||||
metadata_file = self.create_tempfile().full_path
|
||||
with open(metadata_file, "wb") as f:
|
||||
f.write(b.Output())
|
||||
return metadata_file
|
||||
|
||||
|
||||
class MetadataPopulatorTest(MetadataTest):
|
||||
|
||||
def testToValidModelFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(
|
||||
self._empty_model_file)
|
||||
self.assertIsInstance(populator, _metadata.MetadataPopulator)
|
||||
|
||||
def testToInvalidModelFile(self):
|
||||
with self.assertRaises(IOError) as error:
|
||||
_metadata.MetadataPopulator.with_model_file(self._invalid_file)
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testToValidModelBuffer(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
self.assertIsInstance(populator, _metadata.MetadataPopulator)
|
||||
|
||||
def testToInvalidModelBuffer(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
|
||||
self.assertEqual("model_buf cannot be empty.", str(error.exception))
|
||||
|
||||
def testSinglePopulateAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
expected_packed_files = [os.path.basename(self._file1)]
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
def testRepeatedPopulateAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(
|
||||
self._empty_model_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
# Loads file2 multiple times.
|
||||
populator.load_associated_files([self._file2])
|
||||
populator.populate()
|
||||
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
expected_packed_files = [
|
||||
os.path.basename(self._file1),
|
||||
os.path.basename(self._file2)
|
||||
]
|
||||
self.assertEqual(len(packed_files), 2)
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
# Check if the model buffer read from file is the same as that read from
|
||||
# get_model_buffer().
|
||||
with open(self._empty_model_file, "rb") as f:
|
||||
model_buf_from_file = f.read()
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testPopulateInvalidAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
with self.assertRaises(IOError) as error:
|
||||
populator.load_associated_files([self._invalid_file])
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testPopulatePackedAssociatedFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_associated_files([self._file1])
|
||||
populator.populate()
|
||||
self.assertEqual(
|
||||
"File, '{0}', has already been packed.".format(
|
||||
os.path.basename(self._file1)), str(error.exception))
|
||||
|
||||
def testGetPackedAssociatedFileList(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
packed_files = populator.get_packed_associated_file_list()
|
||||
self.assertEqual(packed_files, [])
|
||||
|
||||
def testPopulateMetadataFileToEmptyModelFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(
|
||||
self._empty_model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
populator.populate()
|
||||
|
||||
with open(self._empty_model_file, "rb") as f:
|
||||
model_buf_from_file = f.read()
|
||||
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
|
||||
metadata_field = model.Metadata(0)
|
||||
self.assertEqual(
|
||||
six.ensure_str(metadata_field.Name()),
|
||||
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
|
||||
|
||||
buffer_index = metadata_field.Buffer()
|
||||
buffer_data = model.Buffers(buffer_index)
|
||||
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||
metadata_buf = metadata_buf_np.tobytes()
|
||||
with open(self._metadata_file, "rb") as f:
|
||||
expected_metadata_buf = bytearray(f.read())
|
||||
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||
|
||||
recorded_files = populator.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||
|
||||
# Up to now, we've proved the correctness of the model buffer that read from
|
||||
# file. Then we'll test if get_model_buffer() gives the same model buffer.
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testPopulateMetadataFileWithoutAssociatedFiles(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(
|
||||
self._empty_model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1])
|
||||
# Suppose to populate self._file2, because it is recorded in the metadta.
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.populate()
|
||||
self.assertEqual(("File, '{0}', is recorded in the metadata, but has "
|
||||
"not been loaded into the populator.").format(
|
||||
os.path.basename(self._file2)), str(error.exception))
|
||||
|
||||
def _assert_golden_metadata(self, model_file):
|
||||
with open(model_file, "rb") as f:
|
||||
model_buf_from_file = f.read()
|
||||
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
|
||||
# There are two elements in model.Metadata array before the population.
|
||||
# Metadata should be packed to the third element in the array.
|
||||
metadata_field = model.Metadata(2)
|
||||
self.assertEqual(
|
||||
six.ensure_str(metadata_field.Name()),
|
||||
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
|
||||
|
||||
buffer_index = metadata_field.Buffer()
|
||||
buffer_data = model.Buffers(buffer_index)
|
||||
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||
metadata_buf = metadata_buf_np.tobytes()
|
||||
with open(self._metadata_file, "rb") as f:
|
||||
expected_metadata_buf = bytearray(f.read())
|
||||
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||
|
||||
def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self):
|
||||
# First, creates a dummy metadata. Populates it and the associated files
|
||||
# into the model.
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.name = "Mobilenet_quantized"
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(
|
||||
model_meta.Pack(b),
|
||||
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf = b.Output()
|
||||
|
||||
populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator1.load_metadata_buffer(metadata_buf)
|
||||
populator1.load_associated_files([self._file1, self._file2])
|
||||
populator1.populate()
|
||||
|
||||
# Then, populates the metadata again.
|
||||
populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator2.load_metadata_file(self._metadata_file)
|
||||
populator2.populate()
|
||||
|
||||
# Tests if the metadata is populated correctly.
|
||||
self._assert_golden_metadata(self._model_file)
|
||||
|
||||
def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
populator.populate()
|
||||
|
||||
# Tests if the metadata is populated correctly.
|
||||
self._assert_golden_metadata(self._model_file)
|
||||
|
||||
recorded_files = populator.get_recorded_associated_file_list()
|
||||
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||
|
||||
# Up to now, we've proved the correctness of the model buffer that read from
|
||||
# file. Then we'll test if get_model_buffer() gives the same model buffer.
|
||||
with open(self._model_file, "rb") as f:
|
||||
model_buf_from_file = f.read()
|
||||
model_buf_from_getter = populator.get_model_buffer()
|
||||
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||
|
||||
def testPopulateInvalidMetadataFile(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
with self.assertRaises(IOError) as error:
|
||||
populator.load_metadata_file(self._invalid_file)
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def testPopulateInvalidMetadataBuffer(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
populator.load_metadata_buffer([])
|
||||
self.assertEqual("The metadata to be populated is empty.",
|
||||
str(error.exception))
|
||||
|
||||
def testGetModelBufferBeforePopulatingData(self):
|
||||
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||
self._empty_model_buf)
|
||||
model_buf = populator.get_model_buffer()
|
||||
expected_model_buf = self._empty_model_buf
|
||||
self.assertEqual(model_buf, expected_model_buf)
|
||||
|
||||
|
||||
class MetadataDisplayerTest(MetadataTest):
|
||||
|
||||
def setUp(self):
|
||||
super(MetadataDisplayerTest, self).setUp()
|
||||
self._model_file = self._create_model_with_metadata_and_associated_files()
|
||||
|
||||
def _create_model_with_metadata_and_associated_files(self):
|
||||
model_buf = self._create_empty_model_buf()
|
||||
model_file = self.create_tempfile().full_path
|
||||
with open(model_file, "wb") as f:
|
||||
f.write(model_buf)
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_file(model_file)
|
||||
populator.load_metadata_file(self._metadata_file)
|
||||
populator.load_associated_files([self._file1, self._file2])
|
||||
populator.populate()
|
||||
return model_file
|
||||
|
||||
def test_load_model_file_invalidModelFile_throwsException(self):
|
||||
with self.assertRaises(IOError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
|
||||
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||
str(error.exception))
|
||||
|
||||
def test_load_model_file_modelWithoutMetadata_throwsException(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
_metadata.MetadataDisplayer.with_model_file(self._empty_model_file)
|
||||
self.assertEqual("The model does not have metadata.", str(error.exception))
|
||||
|
||||
def test_load_model_file_modelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
|
||||
|
||||
def test_export_metadata_json_file_modelWithMetadata(self):
|
||||
export_dir = self.create_tempdir().full_path
|
||||
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||
displayer.export_metadata_json_file(export_dir)
|
||||
|
||||
# Verifies the generated json file.
|
||||
golden_json_file_path = resource_loader.get_path_to_datafile(
|
||||
"testdata/golden_json.json")
|
||||
json_file_path = os.path.join(
|
||||
export_dir,
|
||||
os.path.splitext(os.path.basename(self._model_file))[0] + ".json")
|
||||
with open(json_file_path, "r") as json_file, open(golden_json_file_path,
|
||||
"r") as golden_json_file:
|
||||
json_contents = json_file.read()
|
||||
golden_json_contents = golden_json_file.read()
|
||||
self.assertEqual(json_contents, golden_json_contents)
|
||||
|
||||
def test_get_packed_associated_file_list_modelWithMetadata(self):
|
||||
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||
packed_files = displayer.get_packed_associated_file_list()
|
||||
|
||||
expected_packed_files = [
|
||||
os.path.basename(self._file1),
|
||||
os.path.basename(self._file2)
|
||||
]
|
||||
self.assertEqual(len(packed_files), 2)
|
||||
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
21
tensorflow/lite/experimental/support/metadata/testdata/golden_json.json
vendored
Normal file
21
tensorflow/lite/experimental/support/metadata/testdata/golden_json.json
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "Mobilenet_quantized",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "file2"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "file1"
|
||||
}
|
||||
]
|
||||
}
|
@ -206,4 +206,12 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "zip_files",
|
||||
srcs = ["zip_files.py"],
|
||||
python_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["@absl_py//absl:app"],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite()
|
||||
|
41
tensorflow/lite/tools/zip_files.py
Normal file
41
tensorflow/lite/tools/zip_files.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Lint as: python3
|
||||
"""Creates a zip package of the files passed in."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("export_zip_path", None, "Path to zip file.")
|
||||
flags.DEFINE_string("file_directory", None, "Path to the files to be zipped.")
|
||||
|
||||
|
||||
def main(_):
|
||||
with zipfile.ZipFile(FLAGS.export_zip_path, mode="w") as zf:
|
||||
for root, _, files in os.walk(FLAGS.file_directory):
|
||||
for f in files:
|
||||
if f.endswith(".java"):
|
||||
zf.write(os.path.join(root, f))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
21
third_party/flatbuffers/BUILD.bazel
vendored
21
third_party/flatbuffers/BUILD.bazel
vendored
@ -1,3 +1,7 @@
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE.txt"])
|
||||
@ -116,3 +120,20 @@ py_library(
|
||||
srcs = [":runtime_py_srcs"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "runtime_java_srcs",
|
||||
srcs = glob(["java/com/google/flatbuffers/**/*.java"]),
|
||||
)
|
||||
|
||||
java_library(
|
||||
name = "runtime_java",
|
||||
srcs = [":runtime_java_srcs"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "runtime_android",
|
||||
srcs = [":runtime_java_srcs"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
187
third_party/flatbuffers/build_defs.bzl
vendored
187
third_party/flatbuffers/build_defs.bzl
vendored
@ -1,6 +1,15 @@
|
||||
"""BUILD rules for generating flatbuffer files."""
|
||||
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
|
||||
flatc_path = "@flatbuffers//:flatc"
|
||||
zip_files = "//tensorflow/lite/tools:zip_files"
|
||||
|
||||
DEFAULT_INCLUDE_PATHS = [
|
||||
"./",
|
||||
"$(GENDIR)",
|
||||
"$(BINDIR)",
|
||||
]
|
||||
|
||||
DEFAULT_FLATC_ARGS = [
|
||||
"--no-union-value-namespacing",
|
||||
@ -422,3 +431,181 @@ def flatbuffer_py_library(
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
def flatbuffer_java_library(
|
||||
name,
|
||||
srcs,
|
||||
custom_package = "",
|
||||
package_prefix = "",
|
||||
include_paths = DEFAULT_INCLUDE_PATHS,
|
||||
flatc_args = DEFAULT_FLATC_ARGS,
|
||||
visibility = None):
|
||||
"""A java library with the generated reader/writers for the given flatbuffer definitions.
|
||||
|
||||
Args:
|
||||
name: Rule name. (required)
|
||||
srcs: List of source .fbs files including all includes. (required)
|
||||
custom_package: Package name of generated Java files. If not specified
|
||||
namespace in the schema files will be used. (optional)
|
||||
package_prefix: like custom_package, but prefixes to the existing
|
||||
namespace. (optional)
|
||||
include_paths: List of paths that includes files can be found in. (optional)
|
||||
flatc_args: List of additional arguments to pass to flatc. (optional)
|
||||
visibility: Visibility setting for the java_library rule. (optional)
|
||||
"""
|
||||
out_srcjar = "java_%s_all.srcjar" % name
|
||||
flatbuffer_java_srcjar(
|
||||
name = "%s_srcjar" % name,
|
||||
srcs = srcs,
|
||||
out = out_srcjar,
|
||||
custom_package = custom_package,
|
||||
flatc_args = flatc_args,
|
||||
include_paths = include_paths,
|
||||
package_prefix = package_prefix,
|
||||
)
|
||||
|
||||
native.filegroup(
|
||||
name = "%s.srcjar" % name,
|
||||
srcs = [out_srcjar],
|
||||
)
|
||||
|
||||
native.java_library(
|
||||
name = name,
|
||||
srcs = [out_srcjar],
|
||||
deps = [
|
||||
"@flatbuffers//:runtime_java",
|
||||
],
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
def flatbuffer_java_srcjar(
|
||||
name,
|
||||
srcs,
|
||||
out,
|
||||
custom_package = "",
|
||||
package_prefix = "",
|
||||
include_paths = DEFAULT_INCLUDE_PATHS,
|
||||
flatc_args = DEFAULT_FLATC_ARGS):
|
||||
"""Generate flatbuffer Java source files.
|
||||
|
||||
Args:
|
||||
name: Rule name. (required)
|
||||
srcs: List of source .fbs files including all includes. (required)
|
||||
out: Output file name. (required)
|
||||
custom_package: Package name of generated Java files. If not specified
|
||||
namespace in the schema files will be used. (optional)
|
||||
package_prefix: like custom_package, but prefixes to the existing
|
||||
namespace. (optional)
|
||||
include_paths: List of paths that includes files can be found in. (optional)
|
||||
flatc_args: List of additional arguments to pass to flatc. (optional)
|
||||
"""
|
||||
command_fmt = """set -e
|
||||
tmpdir=$(@D)
|
||||
schemas=$$tmpdir/schemas
|
||||
java_root=$$tmpdir/java
|
||||
rm -rf $$schemas
|
||||
rm -rf $$java_root
|
||||
mkdir -p $$schemas
|
||||
mkdir -p $$java_root
|
||||
|
||||
for src in $(SRCS); do
|
||||
dest=$$schemas/$$src
|
||||
rm -rf $$(dirname $$dest)
|
||||
mkdir -p $$(dirname $$dest)
|
||||
if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then
|
||||
cp -f $$src $$dest
|
||||
else
|
||||
if [ -z "{package_prefix}" ]; then
|
||||
sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest
|
||||
else
|
||||
sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
flatc_arg_I="-I $$tmpdir/schemas"
|
||||
for include_path in {include_paths}; do
|
||||
flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path"
|
||||
done
|
||||
|
||||
flatc_additional_args=
|
||||
for arg in {flatc_args}; do
|
||||
flatc_additional_args="$$flatc_additional_args $$arg"
|
||||
done
|
||||
|
||||
for src in $(SRCS); do
|
||||
$(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src
|
||||
done
|
||||
|
||||
$(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root
|
||||
"""
|
||||
genrule_cmd = command_fmt.format(
|
||||
package_name = native.package_name(),
|
||||
custom_package = custom_package,
|
||||
package_prefix = package_prefix,
|
||||
flatc_path = flatc_path,
|
||||
zip_files = zip_files,
|
||||
include_paths = " ".join(include_paths),
|
||||
flatc_args = " ".join(flatc_args),
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
outs = [out],
|
||||
tools = [flatc_path, zip_files],
|
||||
cmd = genrule_cmd,
|
||||
)
|
||||
|
||||
def flatbuffer_android_library(
|
||||
name,
|
||||
srcs,
|
||||
custom_package = "",
|
||||
package_prefix = "",
|
||||
javacopts = None,
|
||||
include_paths = DEFAULT_INCLUDE_PATHS,
|
||||
flatc_args = DEFAULT_FLATC_ARGS,
|
||||
visibility = None):
|
||||
"""An android_library with the generated reader/writers for the given flatbuffer definitions.
|
||||
|
||||
Args:
|
||||
name: Rule name. (required)
|
||||
srcs: List of source .fbs files including all includes. (required)
|
||||
custom_package: Package name of generated Java files. If not specified
|
||||
namespace in the schema files will be used. (optional)
|
||||
package_prefix: like custom_package, but prefixes to the existing
|
||||
namespace. (optional)
|
||||
javacopts: List of options to pass to javac.
|
||||
include_paths: List of paths that includes files can be found in. (optional)
|
||||
flatc_args: List of additional arguments to pass to flatc. (optional)
|
||||
visibility: Visibility setting for the android_library rule. (optional)
|
||||
"""
|
||||
out_srcjar = "android_%s_all.srcjar" % name
|
||||
flatbuffer_java_srcjar(
|
||||
name = "%s_srcjar" % name,
|
||||
srcs = srcs,
|
||||
out = out_srcjar,
|
||||
custom_package = custom_package,
|
||||
flatc_args = flatc_args,
|
||||
include_paths = include_paths,
|
||||
package_prefix = package_prefix,
|
||||
)
|
||||
|
||||
native.filegroup(
|
||||
name = "%s.srcjar" % name,
|
||||
srcs = [out_srcjar],
|
||||
)
|
||||
|
||||
# To support org.checkerframework.dataflow.qual.Pure.
|
||||
checkerframework_annotations = [
|
||||
"@org_checkerframework_qual",
|
||||
] if "--java-checkerframework" in flatc_args else []
|
||||
|
||||
android_library(
|
||||
name = name,
|
||||
srcs = [out_srcjar],
|
||||
visibility = visibility,
|
||||
deps = [
|
||||
"@flatbuffers//:runtime_android",
|
||||
] + checkerframework_annotations,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user