OSS TFLite Metadata library
PiperOrigin-RevId: 299966503 Change-Id: I08e6e7d424822c5f80f48b0f0c670f3c9c75f403
This commit is contained in:
parent
c645992e15
commit
9ebf4a223b
tensorflow/lite
experimental/support/metadata
BUILD
java
metadata.pymetadata_schema.fbsmetadata_test.pytestdata
tools
third_party/flatbuffers
87
tensorflow/lite/experimental/support/metadata/BUILD
Normal file
87
tensorflow/lite/experimental/support/metadata/BUILD
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
load("@flatbuffers//:build_defs.bzl", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library", "flatbuffer_py_library")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
exports_files(["metadata_schema.fbs"])
|
||||||
|
|
||||||
|
flatbuffer_py_library(
|
||||||
|
name = "schema_py",
|
||||||
|
srcs = ["//tensorflow/lite/schema:schema.fbs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generic schema for inference on device.
|
||||||
|
flatbuffer_android_library(
|
||||||
|
name = "schema_fbs_android",
|
||||||
|
srcs = ["//tensorflow/lite/schema:schema.fbs"],
|
||||||
|
custom_package = "org.tensorflow.lite.schema",
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_java_library(
|
||||||
|
name = "schema_fbs_java",
|
||||||
|
srcs = ["//tensorflow/lite/schema:schema.fbs"],
|
||||||
|
custom_package = "org.tensorflow.lite.schema",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generic schema for model metadata.
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "metadata_schema_cc",
|
||||||
|
srcs = ["metadata_schema.fbs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_py_library(
|
||||||
|
name = "metadata_schema_py",
|
||||||
|
srcs = ["metadata_schema.fbs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_java_library(
|
||||||
|
name = "metadata_schema_java",
|
||||||
|
srcs = ["metadata_schema.fbs"],
|
||||||
|
custom_package = "org.tensorflow.lite.support.metadata.schema",
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_android_library(
|
||||||
|
name = "metadata_schema_fbs_android",
|
||||||
|
srcs = ["metadata_schema.fbs"],
|
||||||
|
custom_package = "org.tensorflow.lite.support.metadata.schema",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "metadata",
|
||||||
|
srcs = ["metadata.py"],
|
||||||
|
data = [
|
||||||
|
"//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs",
|
||||||
|
"@flatbuffers//:flatc",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":metadata_schema_py",
|
||||||
|
":schema_py",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"@flatbuffers//:runtime_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "metadata_test",
|
||||||
|
srcs = ["metadata_test.py"],
|
||||||
|
data = ["testdata/golden_json.json"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":metadata",
|
||||||
|
":metadata_schema_py",
|
||||||
|
":schema_py",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"@flatbuffers//:runtime_py",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="org.tensorflow.lite.support">
|
||||||
|
<uses-sdk android:minSdkVersion="19" />
|
||||||
|
</manifest>
|
||||||
|
|
36
tensorflow/lite/experimental/support/metadata/java/BUILD
Normal file
36
tensorflow/lite/experimental/support/metadata/java/BUILD
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# Description:
|
||||||
|
# TensorFlow Lite Support API in Java for metadata.
|
||||||
|
|
||||||
|
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||||
|
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//visibility:public"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "tensorflow-lite-support-metadata",
|
||||||
|
srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]),
|
||||||
|
manifest = "AndroidManifest.xml",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/experimental/support/java:tensorflow-lite-support-precondition-lib-android",
|
||||||
|
"//tensorflow/lite/experimental/support/metadata:metadata_schema_fbs_android",
|
||||||
|
"//tensorflow/lite/experimental/support/metadata:schema_fbs_android",
|
||||||
|
"//tensorflow/lite/java:tensorflowlite",
|
||||||
|
"@org_checkerframework_qual",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
java_library(
|
||||||
|
name = "tensorflow-lite-support-metadata-lib",
|
||||||
|
srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]),
|
||||||
|
javacopts = JAVACOPTS,
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/experimental/support/java:tensorflow-lite-support-precondition",
|
||||||
|
"//tensorflow/lite/experimental/support/metadata:metadata_schema_java",
|
||||||
|
"//tensorflow/lite/experimental/support/metadata:schema_fbs_java",
|
||||||
|
"//tensorflow/lite/java:tensorflowlitelib",
|
||||||
|
"@org_checkerframework_qual",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,116 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.lite.support.metadata;
|
||||||
|
|
||||||
|
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
|
||||||
|
import static org.tensorflow.lite.support.common.SupportPreconditions.checkElementIndex;
|
||||||
|
import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An {@link InputStream} that wraps a section of a {@link SeekableByteChannelCompat}.
|
||||||
|
*
|
||||||
|
* <p><b>WARNING:</b> Similar as {@link InputStream}, instances of an {@link BoundedInputStream} are
|
||||||
|
* <b>not</b> thread-safe. If multiple threads concurrently reading from the same {@link
|
||||||
|
* BoundedInputStream}, it must be synchronized externally. Also, if multiple instances of {@link
|
||||||
|
* BoundedInputStream} are created on the same {@link SeekableByteChannelCompat}, it must be
|
||||||
|
* synchronized as well.
|
||||||
|
*/
|
||||||
|
final class BoundedInputStream extends InputStream {
|
||||||
|
private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1);
|
||||||
|
private final long end; // The valid data for the stream is between [start, end).
|
||||||
|
private long position;
|
||||||
|
private final SeekableByteChannelCompat channel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}.
|
||||||
|
*
|
||||||
|
* @param channel the {@link SeekableByteChannelCompat} that backs up this {@link
|
||||||
|
* BoundedInputStream}
|
||||||
|
* @param start the starting position of this {@link BoundedInputStream} in the given {@link
|
||||||
|
* SeekableByteChannelCompat}
|
||||||
|
* @param remaining the length of this {@link BoundedInputStream}
|
||||||
|
* @throws IllegalArgumentException if {@code start} or {@code remaining} is negative
|
||||||
|
*/
|
||||||
|
BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) {
|
||||||
|
checkArgument(
|
||||||
|
remaining >= 0 && start >= 0,
|
||||||
|
String.format("Invalid length of stream at offset=%d, length=%d", start, remaining));
|
||||||
|
|
||||||
|
end = start + remaining;
|
||||||
|
this.channel = channel;
|
||||||
|
position = start;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int available() throws IOException {
|
||||||
|
return (int) (Math.min(end, channel.size()) - position);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int read() throws IOException {
|
||||||
|
if (position >= end) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
singleByteBuffer.rewind();
|
||||||
|
int count = read(position, singleByteBuffer);
|
||||||
|
if (count < 0) {
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
position++;
|
||||||
|
return singleByteBuffer.get() & 0xff;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int read(byte[] b, int off, int len) throws IOException {
|
||||||
|
checkNotNull(b);
|
||||||
|
checkElementIndex(off, b.length, "The start offset");
|
||||||
|
checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read");
|
||||||
|
|
||||||
|
if (len == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (len > end - position) {
|
||||||
|
if (position >= end) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
len = (int) (end - position);
|
||||||
|
}
|
||||||
|
|
||||||
|
ByteBuffer buf = ByteBuffer.wrap(b, off, len);
|
||||||
|
int count = read(position, buf);
|
||||||
|
if (count > 0) {
|
||||||
|
position += count;
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int read(long position, ByteBuffer buf) throws IOException {
|
||||||
|
int count;
|
||||||
|
synchronized (channel) {
|
||||||
|
channel.position(position);
|
||||||
|
count = channel.read(buf);
|
||||||
|
}
|
||||||
|
buf.flip();
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,130 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.lite.support.metadata;
|
||||||
|
|
||||||
|
import static java.lang.Math.min;
|
||||||
|
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
|
||||||
|
import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.channels.NonWritableChannelException;
|
||||||
|
|
||||||
|
/** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */
|
||||||
|
final class ByteBufferChannel implements SeekableByteChannelCompat {
|
||||||
|
|
||||||
|
/** The ByteBuffer that holds the data. */
|
||||||
|
private final ByteBuffer buffer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}.
|
||||||
|
*
|
||||||
|
* @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel}
|
||||||
|
* @throws NullPointerException if {@code buffer} is null
|
||||||
|
*/
|
||||||
|
public ByteBufferChannel(ByteBuffer buffer) {
|
||||||
|
checkNotNull(buffer, "The ByteBuffer cannot be null.");
|
||||||
|
this.buffer = buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isOpen() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long position() {
|
||||||
|
return buffer.position();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets this channel's position.
|
||||||
|
*
|
||||||
|
* @param newPosition the new position, a non-negative integer counting the number of bytes from
|
||||||
|
* the beginning of the entity
|
||||||
|
* @return this channel
|
||||||
|
* @throws IllegalArgumentException if the new position is negative, or greater than the size of
|
||||||
|
* the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public synchronized ByteBufferChannel position(long newPosition) {
|
||||||
|
checkArgument(
|
||||||
|
(newPosition >= 0 && newPosition <= Integer.MAX_VALUE),
|
||||||
|
"The new position should be non-negative and be less than Integer.MAX_VALUE.");
|
||||||
|
buffer.position((int) newPosition);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*
|
||||||
|
* <p>Bytes are read starting at this channel's current position, and then the position is updated
|
||||||
|
* with the number of bytes actually read. Otherwise this method behaves exactly as specified in
|
||||||
|
* the {@link ReadableByteChannel} interface.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public synchronized int read(ByteBuffer dst) {
|
||||||
|
if (buffer.remaining() == 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int count = min(dst.remaining(), buffer.remaining());
|
||||||
|
if (count > 0) {
|
||||||
|
ByteBuffer tempBuffer = buffer.slice();
|
||||||
|
tempBuffer.order(buffer.order()).limit(count);
|
||||||
|
dst.put(tempBuffer);
|
||||||
|
buffer.position(buffer.position() + count);
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long size() {
|
||||||
|
return buffer.limit();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized ByteBufferChannel truncate(long size) {
|
||||||
|
checkArgument(
|
||||||
|
(size >= 0 && size <= Integer.MAX_VALUE),
|
||||||
|
"The new size should be non-negative and be less than Integer.MAX_VALUE.");
|
||||||
|
|
||||||
|
if (size < buffer.limit()) {
|
||||||
|
buffer.limit((int) size);
|
||||||
|
if (buffer.position() > size) {
|
||||||
|
buffer.position((int) size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized int write(ByteBuffer src) {
|
||||||
|
if (buffer.isReadOnly()) {
|
||||||
|
throw new NonWritableChannelException();
|
||||||
|
}
|
||||||
|
|
||||||
|
int count = min(src.remaining(), buffer.remaining());
|
||||||
|
if (count > 0) {
|
||||||
|
ByteBuffer tempBuffer = src.slice();
|
||||||
|
tempBuffer.order(buffer.order()).limit(count);
|
||||||
|
buffer.put(tempBuffer);
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,247 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.lite.support.metadata;
|
||||||
|
|
||||||
|
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.zip.ZipException;
|
||||||
|
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||||
|
import org.tensorflow.lite.DataType;
|
||||||
|
import org.tensorflow.lite.Tensor.QuantizationParams;
|
||||||
|
import org.tensorflow.lite.schema.Tensor;
|
||||||
|
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads metadata from TFLite Model FlatBuffer.
|
||||||
|
*
|
||||||
|
* <p>TFLite Model FlatBuffer can be generated using the <a
|
||||||
|
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
|
||||||
|
* Model schema file.</a>
|
||||||
|
*
|
||||||
|
* <p>Some models contain a TFLite Metadata Flatbuffer, which records more information about what
|
||||||
|
* the model does and how to interprete the model. TFLite Metadata Flatbuffer can be generated using
|
||||||
|
* the <a
|
||||||
|
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/metadata_schema.fbs">TFLite
|
||||||
|
* Metadata schema file.</a>
|
||||||
|
*
|
||||||
|
* <p>It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking methods
|
||||||
|
* that read from TFLite metadata will cause runtime errors.
|
||||||
|
*
|
||||||
|
* <p>Similarly, it is allowed to pass in a model FlatBuffer without associated files. However,
|
||||||
|
* invoking methods that read the associated files will cause runtime errors.
|
||||||
|
*
|
||||||
|
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports a
|
||||||
|
* single subgraph so far. See the <a
|
||||||
|
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
|
||||||
|
* of how to specify subgraph during convertion for more information.</a> Therefore, {@link
|
||||||
|
* MetadataExtractor} omits subgraph index as an input in its methods.
|
||||||
|
*/
|
||||||
|
public class MetadataExtractor {
|
||||||
|
/** The helper class to load metadata from TFLite model FlatBuffer. */
|
||||||
|
private final ModelInfo modelInfo;
|
||||||
|
|
||||||
|
/** The helper class to load metadata from TFLite metadata FlatBuffer. */
|
||||||
|
@Nullable private final ModelMetadataInfo metadataInfo;
|
||||||
|
|
||||||
|
/** The handler to load associated files through zip. */
|
||||||
|
@Nullable private final ZipFile zipFile;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link MetadataExtractor} with TFLite model FlatBuffer.
|
||||||
|
*
|
||||||
|
* @param buffer the TFLite model FlatBuffer
|
||||||
|
* @throws IllegalArgumentException if the number of input or output tensors in the model does not
|
||||||
|
* match that in the metadata
|
||||||
|
* @throws IOException if an error occurs while reading the model as a Zip file
|
||||||
|
*/
|
||||||
|
public MetadataExtractor(ByteBuffer buffer) throws IOException {
|
||||||
|
modelInfo = new ModelInfo(buffer);
|
||||||
|
ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer();
|
||||||
|
if (metadataBuffer != null) {
|
||||||
|
metadataInfo = new ModelMetadataInfo(metadataBuffer);
|
||||||
|
checkArgument(
|
||||||
|
modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(),
|
||||||
|
String.format(
|
||||||
|
"The number of input tensors in the model is %d. The number of input tensors that"
|
||||||
|
+ " recorded in the metadata is %d. These two values does not match.",
|
||||||
|
modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount()));
|
||||||
|
checkArgument(
|
||||||
|
modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(),
|
||||||
|
String.format(
|
||||||
|
"The number of output tensors in the model is %d. The number of output tensors that"
|
||||||
|
+ " recorded in the metadata is %d. These two values does not match.",
|
||||||
|
modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount()));
|
||||||
|
} else {
|
||||||
|
// It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking
|
||||||
|
// methods that read from TFLite metadata will cause runtime errors.
|
||||||
|
metadataInfo = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
zipFile = createZipFile(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the packed associated file with the specified {@code fileName}.
|
||||||
|
*
|
||||||
|
* @param fileName the name of the associated file
|
||||||
|
* @return the raw input stream containing specified file
|
||||||
|
* @throws IllegalStateException if the model is not a zip file
|
||||||
|
* @throws IllegalArgumentException if the specified file does not exist in the model
|
||||||
|
*/
|
||||||
|
public InputStream getAssociatedFile(String fileName) {
|
||||||
|
assertZipFile();
|
||||||
|
return zipFile.getRawInputStream(fileName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the count of input tensors in the model. */
|
||||||
|
public int getInputTensorCount() {
|
||||||
|
return modelInfo.getInputTensorCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the metadata for the input tensor specified by {@code inputIndex}.
|
||||||
|
*
|
||||||
|
* @param inputIndex the index of the desired input tensor
|
||||||
|
* @throws IllegalStateException if this model does not contain model metadata
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
public TensorMetadata getInputTensorMetadata(int inputIndex) {
|
||||||
|
assertMetadataInfo();
|
||||||
|
return metadataInfo.getInputTensorMetadata(inputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the quantization parameters for the input tensor specified by {@code inputIndex}.
|
||||||
|
*
|
||||||
|
* @param inputIndex the index of the desired input tensor
|
||||||
|
*/
|
||||||
|
public QuantizationParams getInputTensorQuantizationParams(int inputIndex) {
|
||||||
|
Tensor tensor = modelInfo.getInputTensor(inputIndex);
|
||||||
|
return modelInfo.getQuantizationParams(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the shape of the input tensor with {@code inputIndex}.
|
||||||
|
*
|
||||||
|
* @param inputIndex the index of the desired input tensor
|
||||||
|
*/
|
||||||
|
public int[] getInputTensorShape(int inputIndex) {
|
||||||
|
return modelInfo.getInputTensorShape(inputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the {@link DataType} of the input tensor with {@code inputIndex}.
|
||||||
|
*
|
||||||
|
* @param inputIndex the index of the desired input tensor
|
||||||
|
*/
|
||||||
|
public DataType getInputTensorType(int inputIndex) {
|
||||||
|
return modelInfo.getInputTensorType(inputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the count of output tensors in the model. */
|
||||||
|
public int getOutputTensorCount() {
|
||||||
|
return modelInfo.getOutputTensorCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the metadata for the output tensor specified by {@code outputIndex}.
|
||||||
|
*
|
||||||
|
* @param outputIndex the index of the desired output tensor
|
||||||
|
* @throws IllegalStateException if this model does not contain model metadata
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
public TensorMetadata getOutputTensorMetadata(int outputIndex) {
|
||||||
|
assertMetadataInfo();
|
||||||
|
return metadataInfo.getOutputTensorMetadata(outputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the quantization parameters for the output tensor specified by {@code outputIndex}.
|
||||||
|
*
|
||||||
|
* @param outputIndex the index of the desired output tensor
|
||||||
|
*/
|
||||||
|
public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) {
|
||||||
|
Tensor tensor = modelInfo.getOutputTensor(outputIndex);
|
||||||
|
return modelInfo.getQuantizationParams(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the shape of the output tensor with {@code outputIndex}.
|
||||||
|
*
|
||||||
|
* @param outputIndex the index of the desired output tensor
|
||||||
|
*/
|
||||||
|
public int[] getOutputTensorShape(int outputIndex) {
|
||||||
|
return modelInfo.getOutputTensorShape(outputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the {@link DataType} of the output tensor with {@code outputIndex}.
|
||||||
|
*
|
||||||
|
* @param outputIndex the index of the desired output tensor
|
||||||
|
*/
|
||||||
|
public DataType getOutputTensorType(int outputIndex) {
|
||||||
|
return modelInfo.getOutputTensorType(outputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Asserts if {@link metdadataInfo} is not initialized. Some models may not have metadata and this
|
||||||
|
* is allowed. However, invoking methods that reads the metadata is not allowed.
|
||||||
|
*
|
||||||
|
* @throws IllegalStateException if this model does not contain model metadata
|
||||||
|
*/
|
||||||
|
private void assertMetadataInfo() {
|
||||||
|
if (metadataInfo == null) {
|
||||||
|
throw new IllegalStateException("This model does not contain model metadata.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus
|
||||||
|
* are not Zip files. This is allowed. However, invoking methods that reads those associated files
|
||||||
|
* is not allowed.
|
||||||
|
*
|
||||||
|
* @throws IllegalStateException if this model is not a Zip file
|
||||||
|
*/
|
||||||
|
private void assertZipFile() {
|
||||||
|
if (zipFile == null) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"This model does not contain associated files, and is not a Zip file.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e.
|
||||||
|
* it does not have associated files, return a null handler.
|
||||||
|
*
|
||||||
|
* @param buffer the TFLite model FlatBuffer
|
||||||
|
* @throws IOException if an error occurs while reading the model as a Zip file
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
private static ZipFile createZipFile(ByteBuffer buffer) throws IOException {
|
||||||
|
try {
|
||||||
|
// Creates the handler to hold the associated files through the Zip.
|
||||||
|
ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer);
|
||||||
|
return ZipFile.createFrom(byteBufferChannel);
|
||||||
|
} catch (ZipException e) {
|
||||||
|
// Some models may not have associate files. Therefore, Those models are not zip files.
|
||||||
|
// However, invoking methods that read associated files later will lead into errors.
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,281 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.lite.support.metadata;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||||
|
import org.tensorflow.lite.DataType;
|
||||||
|
import org.tensorflow.lite.Tensor.QuantizationParams;
|
||||||
|
import org.tensorflow.lite.schema.Buffer;
|
||||||
|
import org.tensorflow.lite.schema.Metadata;
|
||||||
|
import org.tensorflow.lite.schema.Model;
|
||||||
|
import org.tensorflow.lite.schema.QuantizationParameters;
|
||||||
|
import org.tensorflow.lite.schema.SubGraph;
|
||||||
|
import org.tensorflow.lite.schema.Tensor;
|
||||||
|
import org.tensorflow.lite.schema.TensorType;
|
||||||
|
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||||
|
|
||||||
|
/** Extracts model information out of TFLite model FLatBuffer. */
|
||||||
|
final class ModelInfo {
|
||||||
|
/** The model that is loaded from TFLite model FlatBuffer. */
|
||||||
|
private final Model model;
|
||||||
|
|
||||||
|
/** A list of input tensors. */
|
||||||
|
private final List</* @Nullable */ Tensor> inputTensors;
|
||||||
|
|
||||||
|
/** A list of output tensors. */
|
||||||
|
private final List</* @Nullable */ Tensor> outputTensors;
|
||||||
|
|
||||||
|
/** Identifier of the TFLite model metadata in the Metadata array. */
|
||||||
|
static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
|
||||||
|
|
||||||
|
/** Maps from TensorType in TFlite FlatBuffer to {@link DataType} in Java. */
|
||||||
|
private final Map<Byte, DataType> tensorTypeToDataTypeMap;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
|
||||||
|
*
|
||||||
|
* <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports
|
||||||
|
* single subgraph so far. See the <a
|
||||||
|
* href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction
|
||||||
|
* of how to specify subgraph during convertion for more information.</a> Therefore, all methods
|
||||||
|
* in {@link ModelInfo} retrieves metadata of the first subgrpah as default.
|
||||||
|
*
|
||||||
|
* @param buffer The TFLite model FlatBuffer.
|
||||||
|
* @throws NullPointerException if {@code buffer} is null.
|
||||||
|
* @throws IllegalArgumentException if the model does not contain any subgraph.
|
||||||
|
*/
|
||||||
|
ModelInfo(ByteBuffer buffer) {
|
||||||
|
SupportPreconditions.checkNotNull(buffer, "Model flatbuffer cannot be null.");
|
||||||
|
|
||||||
|
model = Model.getRootAsModel(buffer);
|
||||||
|
SupportPreconditions.checkArgument(
|
||||||
|
model.subgraphsLength() > 0, "The model does not contain any subgraph.");
|
||||||
|
|
||||||
|
inputTensors = getInputTensors(model);
|
||||||
|
outputTensors = getOutputTensors(model);
|
||||||
|
tensorTypeToDataTypeMap = createTensorTypeToDataTypeMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the input tensor with {@code inputIndex}.
|
||||||
|
*
|
||||||
|
* @param inputIndex The index of the desired input tensor.
|
||||||
|
* @throws IllegalArgumentException if the inputIndex specified is invalid.
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
Tensor getInputTensor(int inputIndex) {
|
||||||
|
SupportPreconditions.checkArgument(
|
||||||
|
inputIndex >= 0 && inputIndex < inputTensors.size(),
|
||||||
|
"The inputIndex specified is invalid.");
|
||||||
|
return inputTensors.get(inputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getInputTensorCount() {
|
||||||
|
return inputTensors.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets shape of the input tensor with {@code inputIndex}.
|
||||||
|
*
|
||||||
|
* @param inputIndex The index of the desired intput tensor.
|
||||||
|
*/
|
||||||
|
int[] getInputTensorShape(int inputIndex) {
|
||||||
|
Tensor tensor = getInputTensor(inputIndex);
|
||||||
|
return getShape(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets {@link DataType} of the input tensor with {@code inputIndex}.
|
||||||
|
*
|
||||||
|
* @param inputIndex The index of the desired intput tensor.
|
||||||
|
*/
|
||||||
|
DataType getInputTensorType(int inputIndex) {
|
||||||
|
Tensor tensor = getInputTensor(inputIndex);
|
||||||
|
return getDataType(tensor.type());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the metadata FlatBuffer from the model FlatBuffer. */
|
||||||
|
@Nullable
|
||||||
|
ByteBuffer getMetadataBuffer() {
|
||||||
|
// Some models may not have metadata, and this is allowed.
|
||||||
|
if (model.metadataLength() == 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < model.metadataLength(); i++) {
|
||||||
|
Metadata meta = model.metadata(i);
|
||||||
|
if (METADATA_FIELD_NAME.equals(meta.name())) {
|
||||||
|
long bufferIndex = meta.buffer();
|
||||||
|
Buffer metadataBuf = model.buffers((int) bufferIndex);
|
||||||
|
return metadataBuf.dataAsByteBuffer();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the output tensor with {@code outputIndex}.
|
||||||
|
*
|
||||||
|
* @param outputIndex The index of the desired outtput tensor.
|
||||||
|
* @throws IllegalArgumentException if the outputIndex specified is invalid.
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
Tensor getOutputTensor(int outputIndex) {
|
||||||
|
SupportPreconditions.checkArgument(
|
||||||
|
outputIndex >= 0 && outputIndex < outputTensors.size(),
|
||||||
|
"The outputIndex specified is invalid.");
|
||||||
|
return outputTensors.get(outputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getOutputTensorCount() {
|
||||||
|
return outputTensors.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets shape of the output tensor with {@code outputIndex}.
|
||||||
|
*
|
||||||
|
* @param outputIndex The index of the desired outtput tensor.
|
||||||
|
*/
|
||||||
|
int[] getOutputTensorShape(int outputIndex) {
|
||||||
|
Tensor tensor = getOutputTensor(outputIndex);
|
||||||
|
return getShape(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets {@link DataType} of the output tensor {@code outputIndex}.
|
||||||
|
*
|
||||||
|
* @param outputIndex The index of the desired outtput tensor.
|
||||||
|
*/
|
||||||
|
DataType getOutputTensorType(int outputIndex) {
|
||||||
|
Tensor tensor = getOutputTensor(outputIndex);
|
||||||
|
return getDataType(tensor.type());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<Byte, DataType> createTensorTypeToDataTypeMap() {
|
||||||
|
Map<Byte, DataType> map = new HashMap<>();
|
||||||
|
map.put(TensorType.FLOAT32, DataType.FLOAT32);
|
||||||
|
map.put(TensorType.INT32, DataType.INT32);
|
||||||
|
map.put(TensorType.UINT8, DataType.UINT8);
|
||||||
|
map.put(TensorType.INT64, DataType.INT64);
|
||||||
|
map.put(TensorType.STRING, DataType.STRING);
|
||||||
|
return Collections.unmodifiableMap(map);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the quantization parameters of a tensor.
|
||||||
|
*
|
||||||
|
* <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not
|
||||||
|
* quantized, the values of scale and zero_point are both 0.
|
||||||
|
*
|
||||||
|
* @param tensor The tensor whoes quantization parameters is desired.
|
||||||
|
* @throws NullPointerException if the tensor is null.
|
||||||
|
* @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link
|
||||||
|
* QuantizationParameters} are not single values.
|
||||||
|
*/
|
||||||
|
QuantizationParams getQuantizationParams(Tensor tensor) {
|
||||||
|
SupportPreconditions.checkNotNull(tensor, "Tensor cannot be null.");
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
int zeroPoint;
|
||||||
|
QuantizationParameters quantization = tensor.quantization();
|
||||||
|
|
||||||
|
// Tensors that are not quantized do not have quantization parameters, which can be null when
|
||||||
|
// being extracted from the flatbuffer.
|
||||||
|
if (quantization == null) {
|
||||||
|
scale = 0.0f;
|
||||||
|
zeroPoint = 0;
|
||||||
|
return new QuantizationParams(scale, zeroPoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tensors that are not quantized do not have quantization parameters.
|
||||||
|
// quantization.scaleLength() and quantization.zeroPointLength() may both return 0.
|
||||||
|
SupportPreconditions.checkArgument(
|
||||||
|
quantization.scaleLength() <= 1,
|
||||||
|
"Input and output tensors do not support per-channel quantization.");
|
||||||
|
SupportPreconditions.checkArgument(
|
||||||
|
quantization.zeroPointLength() <= 1,
|
||||||
|
"Input and output tensors do not support per-channel quantization.");
|
||||||
|
|
||||||
|
// For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will
|
||||||
|
// both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++
|
||||||
|
// runtime.
|
||||||
|
scale = quantization.scale(0);
|
||||||
|
// zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it
|
||||||
|
// consistent with the C++ runtime.
|
||||||
|
zeroPoint = (int) quantization.zeroPoint(0);
|
||||||
|
|
||||||
|
return new QuantizationParams(scale, zeroPoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transforms from TensorType in TFlite FlatBuffer to {@link DataType} in Java.
|
||||||
|
*
|
||||||
|
* @param tensorType The tensor type to be converted.
|
||||||
|
* @throws IllegalArgumentException if the tensor type is not supported.
|
||||||
|
*/
|
||||||
|
private DataType getDataType(byte tensorType) {
|
||||||
|
SupportPreconditions.checkArgument(
|
||||||
|
tensorTypeToDataTypeMap.containsKey(tensorType),
|
||||||
|
String.format("Tensor type %d is not supported.", tensorType));
|
||||||
|
return tensorTypeToDataTypeMap.get(tensorType);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the shape of a tensor.
|
||||||
|
*
|
||||||
|
* @param tensor The tensor whoes shape is desired.
|
||||||
|
* @throws NullPointerException if the tensor is null.
|
||||||
|
*/
|
||||||
|
private static int[] getShape(Tensor tensor) {
|
||||||
|
SupportPreconditions.checkNotNull(tensor, "Tensor cannot be null.");
|
||||||
|
int shapeDim = tensor.shapeLength();
|
||||||
|
int[] tensorShape = new int[shapeDim];
|
||||||
|
for (int i = 0; i < shapeDim; i++) {
|
||||||
|
tensorShape[i] = tensor.shape(i);
|
||||||
|
}
|
||||||
|
return tensorShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets input tensors from a model. */
|
||||||
|
private static List<Tensor> getInputTensors(Model model) {
|
||||||
|
// TFLite only support one subgraph currently.
|
||||||
|
SubGraph subgraph = model.subgraphs(0);
|
||||||
|
int tensorNum = subgraph.inputsLength();
|
||||||
|
ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum);
|
||||||
|
for (int i = 0; i < tensorNum; i++) {
|
||||||
|
inputTensors.add(subgraph.tensors(subgraph.inputs(i)));
|
||||||
|
}
|
||||||
|
return Collections.unmodifiableList(inputTensors);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets output tensors from a model. */
|
||||||
|
private static List<Tensor> getOutputTensors(Model model) {
|
||||||
|
// TFLite only support one subgraph currently.
|
||||||
|
SubGraph subgraph = model.subgraphs(0);
|
||||||
|
int tensorNum = subgraph.outputsLength();
|
||||||
|
ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum);
|
||||||
|
for (int i = 0; i < tensorNum; i++) {
|
||||||
|
outputTensors.add(subgraph.tensors(subgraph.outputs(i)));
|
||||||
|
}
|
||||||
|
return Collections.unmodifiableList(outputTensors);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,114 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.lite.support.metadata;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import org.checkerframework.checker.nullness.qual.Nullable;
|
||||||
|
import org.tensorflow.lite.support.common.SupportPreconditions;
|
||||||
|
import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
|
||||||
|
import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata;
|
||||||
|
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
|
||||||
|
|
||||||
|
/** Extracts model metadata information out of TFLite metadata FlatBuffer. */
|
||||||
|
final class ModelMetadataInfo {
|
||||||
|
/** Metadata array of input tensors. */
|
||||||
|
private final List</* @Nullable */ TensorMetadata> inputsMetadata;
|
||||||
|
|
||||||
|
/** Metadata array of output tensors. */
|
||||||
|
private final List</* @Nullable */ TensorMetadata> outputsMetadata;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}.
|
||||||
|
*
|
||||||
|
* @param buffer The TFLite metadata FlatBuffer.
|
||||||
|
* @throws NullPointerException if {@code buffer} is null.
|
||||||
|
* @throws IllegalArgumentException if the metadata does not contain any subgraph metadata.
|
||||||
|
*/
|
||||||
|
ModelMetadataInfo(ByteBuffer buffer) {
|
||||||
|
SupportPreconditions.checkNotNull(buffer, "Metadata flatbuffer cannot be null.");
|
||||||
|
|
||||||
|
ModelMetadata modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer);
|
||||||
|
SupportPreconditions.checkArgument(
|
||||||
|
modelMetadata.subgraphMetadataLength() > 0,
|
||||||
|
"The metadata flatbuffer does not contain any subgraph metadata.");
|
||||||
|
|
||||||
|
inputsMetadata = getInputsMetadata(modelMetadata);
|
||||||
|
outputsMetadata = getOutputsMetadata(modelMetadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the count of input tensors with metadata in the metadata FlatBuffer. */
|
||||||
|
int getInputTensorCount() {
|
||||||
|
return inputsMetadata.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the metadata for the input tensor specified by {@code inputIndex}.
|
||||||
|
*
|
||||||
|
* @param inputIndex The index of the desired intput tensor.
|
||||||
|
* @throws IllegalArgumentException if the inputIndex specified is invalid.
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
TensorMetadata getInputTensorMetadata(int inputIndex) {
|
||||||
|
SupportPreconditions.checkArgument(
|
||||||
|
inputIndex >= 0 && inputIndex < inputsMetadata.size(),
|
||||||
|
"The inputIndex specified is invalid.");
|
||||||
|
return inputsMetadata.get(inputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the count of output tensors with metadata in the metadata FlatBuffer. */
|
||||||
|
int getOutputTensorCount() {
|
||||||
|
return outputsMetadata.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the metadata for the output tensor specified by {@code outputIndex}.
|
||||||
|
*
|
||||||
|
* @param outputIndex The index of the desired output tensor.
|
||||||
|
* @throws IllegalArgumentException if the outputIndex specified is invalid.
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
TensorMetadata getOutputTensorMetadata(int outputIndex) {
|
||||||
|
SupportPreconditions.checkArgument(
|
||||||
|
outputIndex >= 0 && outputIndex < outputsMetadata.size(),
|
||||||
|
"The outputIndex specified is invalid.");
|
||||||
|
return outputsMetadata.get(outputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets metadata for all input tensors. */
|
||||||
|
private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) {
|
||||||
|
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
|
||||||
|
int tensorNum = subgraphMetadata.inputTensorMetadataLength();
|
||||||
|
ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum);
|
||||||
|
for (int i = 0; i < tensorNum; i++) {
|
||||||
|
inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i));
|
||||||
|
}
|
||||||
|
return Collections.unmodifiableList(inputsMetadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets metadata for all output tensors. */
|
||||||
|
private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) {
|
||||||
|
SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0);
|
||||||
|
int tensorNum = subgraphMetadata.outputTensorMetadataLength();
|
||||||
|
ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum);
|
||||||
|
for (int i = 0; i < tensorNum; i++) {
|
||||||
|
outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i));
|
||||||
|
}
|
||||||
|
return Collections.unmodifiableList(outputsMetadata);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,107 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.lite.support.metadata;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.channels.Channel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A byte channel that maintains a current <i>position</i> and allows the position to be changed.
|
||||||
|
* {@link SeekableByteChannelCompat} is compatible with {@link
|
||||||
|
* java.nio.channels.SeekableByteChannel}.
|
||||||
|
*
|
||||||
|
* <p>{@link java.nio.channels.SeekableByteChannel} is not available in Android API 23 and under.
|
||||||
|
* Therefore, {@link SeekableByteChannelCompat} is introduced here to make the interfaces used in
|
||||||
|
* the MetadtaExtractor library consistent with the common used Java libraries.
|
||||||
|
*/
|
||||||
|
interface SeekableByteChannelCompat extends Channel {
|
||||||
|
/**
|
||||||
|
* Reads a sequence of bytes from this channel into the given buffer.
|
||||||
|
*
|
||||||
|
* @param dst The buffer into which bytes are to be transferred
|
||||||
|
* @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached
|
||||||
|
* end-of-stream
|
||||||
|
* @throws NonReadableChannelException If this channel was not opened for reading
|
||||||
|
* @throws ClosedChannelException If this channel is closed
|
||||||
|
* @throws AsynchronousCloseException If another thread closes this channel while the read
|
||||||
|
* operation is in progress
|
||||||
|
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
|
||||||
|
* read operation is in progress, thereby closing the channel and setting the current thread's
|
||||||
|
* interrupt status
|
||||||
|
* @throws IOException If some other I/O error occurs
|
||||||
|
*/
|
||||||
|
int read(ByteBuffer dst) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes a sequence of bytes to this channel from the given buffer.
|
||||||
|
*
|
||||||
|
* @param src The buffer from which bytes are to be retrieved
|
||||||
|
* @return The number of bytes written, possibly zero
|
||||||
|
* @throws NonWritableChannelException If this channel was not opened for writing
|
||||||
|
* @throws ClosedChannelException If this channel is closed
|
||||||
|
* @throws AsynchronousCloseException If another thread closes this channel while the write
|
||||||
|
* operation is in progress
|
||||||
|
* @throws ClosedByInterruptException If another thread interrupts the current thread while the
|
||||||
|
* write operation is in progress, thereby closing the channel and setting the current
|
||||||
|
* thread's interrupt status
|
||||||
|
* @throws IOException If some other I/O error occurs
|
||||||
|
*/
|
||||||
|
int write(ByteBuffer src) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns this channel's position.
|
||||||
|
*
|
||||||
|
* @return This channel's position, a non-negative integer counting the number of bytes from the
|
||||||
|
* beginning of the entity to the current position
|
||||||
|
* @throws ClosedChannelException If this channel is closed
|
||||||
|
* @throws IOException If some other I/O error occurs
|
||||||
|
*/
|
||||||
|
long position() throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets this channel's position.
|
||||||
|
*
|
||||||
|
* @param newPosition The new position, a non-negative integer counting the number of bytes from
|
||||||
|
* the beginning of the entity
|
||||||
|
* @return This channel
|
||||||
|
* @throws ClosedChannelException If this channel is closed
|
||||||
|
* @throws IllegalArgumentException If the new position is negative
|
||||||
|
* @throws IOException If some other I/O error occurs
|
||||||
|
*/
|
||||||
|
SeekableByteChannelCompat position(long newPosition) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the current size of entity to which this channel is connected.
|
||||||
|
*
|
||||||
|
* @return The current size, measured in bytes
|
||||||
|
* @throws ClosedChannelException If this channel is closed
|
||||||
|
* @throws IOException If some other I/O error occurs
|
||||||
|
*/
|
||||||
|
long size() throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Truncates the entity, to which this channel is connected, to the given size.
|
||||||
|
*
|
||||||
|
* @param size The new size, a non-negative byte count
|
||||||
|
* @return This channel
|
||||||
|
* @throws NonWritableChannelException If this channel was not opened for writing
|
||||||
|
* @throws ClosedChannelException If this channel is closed
|
||||||
|
* @throws IllegalArgumentException If the new size is negative
|
||||||
|
* @throws IOException If some other I/O error occurs
|
||||||
|
*/
|
||||||
|
SeekableByteChannelCompat truncate(long size) throws IOException;
|
||||||
|
}
|
@ -0,0 +1,427 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.lite.support.metadata;
|
||||||
|
|
||||||
|
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
|
||||||
|
import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
|
||||||
|
|
||||||
|
import java.io.Closeable;
|
||||||
|
import java.io.EOFException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.nio.charset.Charset;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.zip.ZipException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads uncompressed files from the TFLite model, a zip file.
|
||||||
|
*
|
||||||
|
* <p>TODO(b/150237111): add a link to the webpage of MetadataPopulator once it's available.
|
||||||
|
*
|
||||||
|
* <p>A TFLite model file becomes a zip file when it contains associated files. The associated files
|
||||||
|
* can be packed to a TFLite model file using the MetadataPopulator. The associated files are not
|
||||||
|
* compressed when being added to the model file.
|
||||||
|
*
|
||||||
|
* <p>{@link ZipFile} does not support Zip64 format, because TFLite models are much smaller than the
|
||||||
|
* size limit for Zip64, which is 4GB.
|
||||||
|
*/
|
||||||
|
final class ZipFile implements Closeable {
|
||||||
|
/** Maps String to list of ZipEntrys, name -> actual entries. */
|
||||||
|
private final Map<String, List<ZipEntry>> nameMap;
|
||||||
|
|
||||||
|
/** The actual data source. */
|
||||||
|
private final ByteBufferChannel archive;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link
|
||||||
|
* ZipFile} does not synchronized over the buffer that is passed into it.
|
||||||
|
*
|
||||||
|
* @param channel the archive
|
||||||
|
* @throws IOException if an error occurs while creating this {@link ZipFile}
|
||||||
|
* @throws ZipException if the channel is not a zip archive
|
||||||
|
* @throws NullPointerException if the archive is null
|
||||||
|
*/
|
||||||
|
public static ZipFile createFrom(ByteBufferChannel channel) throws IOException {
|
||||||
|
checkNotNull(channel);
|
||||||
|
ZipParser zipParser = new ZipParser(channel);
|
||||||
|
Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries();
|
||||||
|
return new ZipFile(channel, nameMap);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
archive.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Exposes the raw stream of the archive entry.
|
||||||
|
*
|
||||||
|
* <p>Since the associated files will not be compressed when being packed to the zip file, the raw
|
||||||
|
* stream represents the non-compressed files.
|
||||||
|
*
|
||||||
|
* <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple
|
||||||
|
* threads concurrently reading from the returned {@link InputStream}, it must be synchronized
|
||||||
|
* externally.
|
||||||
|
*
|
||||||
|
* @param name name of the entry to get the stream for
|
||||||
|
* @return the raw input stream containing data
|
||||||
|
* @throws IllegalArgumentException if the specified file does not exist in the zip file
|
||||||
|
*/
|
||||||
|
public InputStream getRawInputStream(String name) {
|
||||||
|
checkArgument(
|
||||||
|
nameMap.containsKey(name),
|
||||||
|
String.format("The file, %s, does not exist in the zip file.", name));
|
||||||
|
|
||||||
|
List<ZipEntry> entriesWithTheSameName = nameMap.get(name);
|
||||||
|
ZipEntry entry = entriesWithTheSameName.get(0);
|
||||||
|
long start = entry.getDataOffset();
|
||||||
|
long remaining = entry.getSize();
|
||||||
|
return new BoundedInputStream(archive, start, remaining);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) {
|
||||||
|
archive = channel;
|
||||||
|
this.nameMap = nameMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Parses a Zip archive and gets the information for each {@link ZipEntry}. */
|
||||||
|
private static class ZipParser {
|
||||||
|
private final ByteBufferChannel archive;
|
||||||
|
|
||||||
|
// Cached buffers that will only be used locally in the class to reduce garbage collection.
|
||||||
|
private final ByteBuffer longBuffer =
|
||||||
|
ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
private final ByteBuffer intBuffer =
|
||||||
|
ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
private final ByteBuffer shortBuffer =
|
||||||
|
ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
|
||||||
|
private ZipParser(ByteBufferChannel archive) {
|
||||||
|
this.archive = archive;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses the underlying {@code archive} and returns the information as a list of {@link
|
||||||
|
* ZipEntry}.
|
||||||
|
*/
|
||||||
|
private Map<String, List<ZipEntry>> parseEntries() throws IOException {
|
||||||
|
List<ZipEntry> entries = parseCentralDirectory();
|
||||||
|
return parseLocalFileHeaderData(entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the current position contains a central file header signature, {@link
|
||||||
|
* ZipConstants#CENSIG}.
|
||||||
|
*/
|
||||||
|
private boolean foundCentralFileheaderSignature() {
|
||||||
|
long signature = (long) getInt();
|
||||||
|
return signature == ZipConstants.CENSIG;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the value as a Java int from two bytes starting at the current position of the archive.
|
||||||
|
*/
|
||||||
|
private int getShort() {
|
||||||
|
shortBuffer.rewind();
|
||||||
|
archive.read(shortBuffer);
|
||||||
|
shortBuffer.flip();
|
||||||
|
return (int) shortBuffer.getShort();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the value as a Java long from four bytes starting at the current position of the
|
||||||
|
* archive.
|
||||||
|
*/
|
||||||
|
private int getInt() {
|
||||||
|
intBuffer.rewind();
|
||||||
|
archive.read(intBuffer);
|
||||||
|
intBuffer.flip();
|
||||||
|
return intBuffer.getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the value as a Java long from four bytes starting at the current position of the
|
||||||
|
* archive.
|
||||||
|
*/
|
||||||
|
private long getLong() {
|
||||||
|
longBuffer.rewind();
|
||||||
|
archive.read(longBuffer);
|
||||||
|
longBuffer.flip();
|
||||||
|
return longBuffer.getLong();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Positions the archive at the start of the central directory.
|
||||||
|
*
|
||||||
|
* <p>First, it searches for the signature of the "end of central directory record", {@link
|
||||||
|
* ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory
|
||||||
|
* record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG}
|
||||||
|
* should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file.
|
||||||
|
*
|
||||||
|
* <p>Then, parse the "end of central dir record" and position the archive at the start of the
|
||||||
|
* central directory.
|
||||||
|
*/
|
||||||
|
private void locateCentralDirectory() throws IOException {
|
||||||
|
if (archive.size() < ZipConstants.ENDHDR) {
|
||||||
|
throw new ZipException("The archive is not a ZIP archive.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Positions the archive at the start of the "end of central directory record".
|
||||||
|
long offsetRecord = archive.size() - ZipConstants.ENDHDR;
|
||||||
|
archive.position(offsetRecord);
|
||||||
|
|
||||||
|
// Checks for the signature, {@link ZipConstants#ENDSIG}.
|
||||||
|
long endSig = getLong();
|
||||||
|
if (endSig != ZipConstants.ENDSIG) {
|
||||||
|
throw new ZipException("The archive is not a ZIP archive.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Positions the archive at the “offset of central directory”.
|
||||||
|
skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB);
|
||||||
|
// Gets the offset to central directory
|
||||||
|
long offsetDirectory = getInt();
|
||||||
|
// Goes to the central directory.
|
||||||
|
archive.position(offsetDirectory);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads the central directory of the given archive and populates the internal tables with
|
||||||
|
* {@link ZipEntry} instances.
|
||||||
|
*/
|
||||||
|
private List<ZipEntry> parseCentralDirectory() throws IOException {
|
||||||
|
/** List of entries in the order they appear inside the central directory. */
|
||||||
|
List<ZipEntry> entries = new ArrayList<>();
|
||||||
|
locateCentralDirectory();
|
||||||
|
|
||||||
|
while (foundCentralFileheaderSignature()) {
|
||||||
|
ZipEntry entry = parseCentralDirectoryEntry();
|
||||||
|
entries.add(entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
return entries;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to
|
||||||
|
* the global maps.
|
||||||
|
*/
|
||||||
|
private ZipEntry parseCentralDirectoryEntry() throws IOException {
|
||||||
|
// Positions the archive at the "compressed size" and read the value.
|
||||||
|
skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM);
|
||||||
|
long compressSize = getInt();
|
||||||
|
|
||||||
|
// Positions the archive at the "filename length" and read the value.
|
||||||
|
skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN);
|
||||||
|
int fileNameLen = getShort();
|
||||||
|
|
||||||
|
// Reads the extra field length and the comment length.
|
||||||
|
int extraLen = getShort();
|
||||||
|
int commentLen = getShort();
|
||||||
|
|
||||||
|
// Positions the archive at the "local file header offset" and read the value.
|
||||||
|
skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK);
|
||||||
|
long localHeaderOffset = getInt();
|
||||||
|
|
||||||
|
// Reads the file name.
|
||||||
|
byte[] fileNameBuf = new byte[fileNameLen];
|
||||||
|
archive.read(ByteBuffer.wrap(fileNameBuf));
|
||||||
|
String fileName = new String(fileNameBuf, Charset.forName("UTF-8"));
|
||||||
|
|
||||||
|
// Skips the extra field and the comment.
|
||||||
|
skipBytes(extraLen + commentLen);
|
||||||
|
|
||||||
|
ZipEntry entry = new ZipEntry();
|
||||||
|
entry.setSize(compressSize);
|
||||||
|
entry.setLocalHeaderOffset(localHeaderOffset);
|
||||||
|
entry.setName(fileName);
|
||||||
|
|
||||||
|
return entry;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Walks through all recorded entries and records the offsets for the entry data. */
|
||||||
|
private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) {
|
||||||
|
/** Maps String to list of ZipEntrys, name -> actual entries. */
|
||||||
|
Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
for (ZipEntry entry : entries) {
|
||||||
|
long offset = entry.getLocalHeaderOffset();
|
||||||
|
archive.position(offset + ZipConstants.LOCNAM);
|
||||||
|
|
||||||
|
// Gets the data offset of this entry.
|
||||||
|
int fileNameLen = getShort();
|
||||||
|
int extraFieldLen = getShort();
|
||||||
|
long dataOffset =
|
||||||
|
offset
|
||||||
|
+ ZipConstants.LOCEXT
|
||||||
|
+ ZipConstants.SHORT_BYTE_SIZE
|
||||||
|
+ fileNameLen
|
||||||
|
+ extraFieldLen;
|
||||||
|
entry.setDataOffset(dataOffset);
|
||||||
|
|
||||||
|
// Puts the entry into the nameMap.
|
||||||
|
String name = entry.getName();
|
||||||
|
List<ZipEntry> entriesWithTheSameName;
|
||||||
|
if (nameMap.containsKey(name)) {
|
||||||
|
entriesWithTheSameName = nameMap.get(name);
|
||||||
|
} else {
|
||||||
|
entriesWithTheSameName = new ArrayList<>();
|
||||||
|
nameMap.put(name, entriesWithTheSameName);
|
||||||
|
}
|
||||||
|
entriesWithTheSameName.add(entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
return nameMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Skips the given number of bytes or throws an EOFException if skipping failed. */
|
||||||
|
private void skipBytes(int count) throws IOException {
|
||||||
|
long currentPosition = archive.position();
|
||||||
|
long newPosition = currentPosition + count;
|
||||||
|
if (newPosition > archive.size()) {
|
||||||
|
throw new EOFException();
|
||||||
|
}
|
||||||
|
archive.position(newPosition);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Stores the data offset and the size of an entry in the archive. */
|
||||||
|
private static class ZipEntry {
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
private long dataOffset = -1;
|
||||||
|
private long size = -1;
|
||||||
|
private long localHeaderOffset = -1;
|
||||||
|
|
||||||
|
public long getSize() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getDataOffset() {
|
||||||
|
return dataOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getLocalHeaderOffset() {
|
||||||
|
return localHeaderOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setSize(long size) {
|
||||||
|
this.size = size;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setDataOffset(long dataOffset) {
|
||||||
|
this.dataOffset = dataOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setName(String name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setLocalHeaderOffset(long localHeaderOffset) {
|
||||||
|
this.localHeaderOffset = localHeaderOffset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Various constants for this {@link ZipFile}.
|
||||||
|
*
|
||||||
|
* <p>Referenced from {@link java.util.zip.ZipConstants}.
|
||||||
|
*/
|
||||||
|
private static class ZipConstants {
|
||||||
|
/** length of Java short in bytes. */
|
||||||
|
static final int SHORT_BYTE_SIZE = Short.SIZE / 8;
|
||||||
|
|
||||||
|
/** length of Java int in bytes. */
|
||||||
|
static final int INT_BYTE_SIZE = Integer.SIZE / 8;
|
||||||
|
|
||||||
|
/** length of Java long in bytes. */
|
||||||
|
static final int LONG_BYTE_SIZE = Long.SIZE / 8;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Header signatures
|
||||||
|
*/
|
||||||
|
static final long LOCSIG = 0x04034b50L; // "PK\003\004"
|
||||||
|
static final long EXTSIG = 0x08074b50L; // "PK\007\008"
|
||||||
|
static final long CENSIG = 0x02014b50L; // "PK\001\002"
|
||||||
|
static final long ENDSIG = 0x06054b50L; // "PK\005\006"
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Header sizes in bytes (including signatures)
|
||||||
|
*/
|
||||||
|
static final int LOCHDR = 30; // LOC header size
|
||||||
|
static final int EXTHDR = 16; // EXT header size
|
||||||
|
static final int CENHDR = 46; // CEN header size
|
||||||
|
static final int ENDHDR = 22; // END header size
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Local file (LOC) header field offsets
|
||||||
|
*/
|
||||||
|
static final int LOCVER = 4; // version needed to extract
|
||||||
|
static final int LOCFLG = 6; // general purpose bit flag
|
||||||
|
static final int LOCHOW = 8; // compression method
|
||||||
|
static final int LOCTIM = 10; // modification time
|
||||||
|
static final int LOCCRC = 14; // uncompressed file crc-32 value
|
||||||
|
static final int LOCSIZ = 18; // compressed size
|
||||||
|
static final int LOCLEN = 22; // uncompressed size
|
||||||
|
static final int LOCNAM = 26; // filename length
|
||||||
|
static final int LOCEXT = 28; // extra field length
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Extra local (EXT) header field offsets
|
||||||
|
*/
|
||||||
|
static final int EXTCRC = 4; // uncompressed file crc-32 value
|
||||||
|
static final int EXTSIZ = 8; // compressed size
|
||||||
|
static final int EXTLEN = 12; // uncompressed size
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Central directory (CEN) header field offsets
|
||||||
|
*/
|
||||||
|
static final int CENVEM = 4; // version made by
|
||||||
|
static final int CENVER = 6; // version needed to extract
|
||||||
|
static final int CENFLG = 8; // encrypt, decrypt flags
|
||||||
|
static final int CENHOW = 10; // compression method
|
||||||
|
static final int CENTIM = 12; // modification time
|
||||||
|
static final int CENCRC = 16; // uncompressed file crc-32 value
|
||||||
|
static final int CENSIZ = 20; // compressed size
|
||||||
|
static final int CENLEN = 24; // uncompressed size
|
||||||
|
static final int CENNAM = 28; // filename length
|
||||||
|
static final int CENEXT = 30; // extra field length
|
||||||
|
static final int CENCOM = 32; // comment length
|
||||||
|
static final int CENDSK = 34; // disk number start
|
||||||
|
static final int CENATT = 36; // internal file attributes
|
||||||
|
static final int CENATX = 38; // external file attributes
|
||||||
|
static final int CENOFF = 42; // LOC header offset
|
||||||
|
|
||||||
|
/*
|
||||||
|
* End of central directory (END) header field offsets
|
||||||
|
*/
|
||||||
|
static final int ENDSUB = 8; // number of entries on this disk
|
||||||
|
static final int ENDTOT = 10; // total number of entries
|
||||||
|
static final int ENDSIZ = 12; // central directory size in bytes
|
||||||
|
static final int ENDOFF = 16; // offset of first CEN header
|
||||||
|
static final int ENDCOM = 20; // zip file comment length
|
||||||
|
|
||||||
|
private ZipConstants() {}
|
||||||
|
}
|
||||||
|
}
|
542
tensorflow/lite/experimental/support/metadata/metadata.py
Normal file
542
tensorflow/lite/experimental/support/metadata/metadata.py
Normal file
@ -0,0 +1,542 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TensorFlow Lite metadata tools."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import warnings
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
from flatbuffers.python import flatbuffers
|
||||||
|
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
|
||||||
|
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
|
||||||
|
from tensorflow.python.platform import resource_loader
|
||||||
|
|
||||||
|
_FLATC_BINARY_PATH = resource_loader.get_path_to_datafile(
|
||||||
|
"../../../../../external/flatbuffers/flatc")
|
||||||
|
_FLATC_TFLITE_METADATA_SCHEMA_FILE = resource_loader.get_path_to_datafile(
|
||||||
|
"metadata_schema.fbs")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(b/141467403): add delete method for associated files.
|
||||||
|
class MetadataPopulator(object):
|
||||||
|
"""Packs metadata and associated files into TensorFlow Lite model file.
|
||||||
|
|
||||||
|
MetadataPopulator can be used to populate metadata and model associated files
|
||||||
|
into a model file or a model buffer (in bytearray). It can also help to
|
||||||
|
inspect list of files that have been packed into the model or are supposed to
|
||||||
|
be packed into the model.
|
||||||
|
|
||||||
|
The metadata file (or buffer) should be generated based on the metadata
|
||||||
|
schema:
|
||||||
|
third_party/tensorflow/lite/schema/metadata_schema.fbs
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
Populate matadata and label file into an image classifier model.
|
||||||
|
|
||||||
|
First, based on metadata_schema.fbs, generate the metadata for this image
|
||||||
|
classifer model using Flatbuffers API. Attach the label file onto the ouput
|
||||||
|
tensor (the tensor of probabilities) in the metadata.
|
||||||
|
|
||||||
|
Then, pack the metadata and lable file into the model as follows.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Populating a metadata file (or a metadta buffer) and associated files to
|
||||||
|
a model file:
|
||||||
|
populator = MetadataPopulator.with_model_file(model_file)
|
||||||
|
# For metadata buffer (bytearray read from the metadata file), use:
|
||||||
|
# populator.load_metadata_buffer(metadata_buf)
|
||||||
|
populator.load_metadata_file(metadata_file)
|
||||||
|
populator.load_associated_files([label.txt])
|
||||||
|
populator.populate()
|
||||||
|
|
||||||
|
# Populating a metadata file (or a metadta buffer) and associated files to
|
||||||
|
a model buffer:
|
||||||
|
populator = MetadataPopulator.with_model_buffer(model_buf)
|
||||||
|
populator.load_metadata_file(metadata_file)
|
||||||
|
populator.load_associated_files([label.txt])
|
||||||
|
populator.populate()
|
||||||
|
# Writing the updated model buffer into a file.
|
||||||
|
updated_model_buf = populator.get_model_buffer()
|
||||||
|
with open("updated_model.tflite", "wb") as f:
|
||||||
|
f.write(updated_model_buf)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# As Zip API is used to concatenate associated files after tflite model file,
|
||||||
|
# the populating operation is developed based on a model file. For in-memory
|
||||||
|
# model buffer, we create a tempfile to serve the populating operation.
|
||||||
|
# Creating the deleting such a tempfile is handled by the class,
|
||||||
|
# _MetadataPopulatorWithBuffer.
|
||||||
|
|
||||||
|
METADATA_FIELD_NAME = "TFLITE_METADATA"
|
||||||
|
TFLITE_FILE_IDENTIFIER = b"TFL3"
|
||||||
|
METADATA_FILE_IDENTIFIER = b"M001"
|
||||||
|
|
||||||
|
def __init__(self, model_file):
|
||||||
|
"""Constructor for MetadataPopulator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_file: valid path to a TensorFlow Lite model file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IOError: File not found.
|
||||||
|
"""
|
||||||
|
_assert_exist(model_file)
|
||||||
|
self._model_file = model_file
|
||||||
|
self._metadata_buf = None
|
||||||
|
self._associated_files = set()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def with_model_file(cls, model_file):
|
||||||
|
"""Creates a MetadataPopulator object that populates data to a model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_file: valid path to a TensorFlow Lite model file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MetadataPopulator object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IOError: File not found.
|
||||||
|
"""
|
||||||
|
return cls(model_file)
|
||||||
|
|
||||||
|
# TODO(b/141468993): investigate if type check can be applied to model_buf for
|
||||||
|
# FB.
|
||||||
|
@classmethod
|
||||||
|
def with_model_buffer(cls, model_buf):
|
||||||
|
"""Creates a MetadataPopulator object that populates data to a model buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
|
||||||
|
"""
|
||||||
|
return _MetadataPopulatorWithBuffer(model_buf)
|
||||||
|
|
||||||
|
def get_model_buffer(self):
|
||||||
|
"""Gets the buffer of the model with packed metadata and associated files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model buffer (in bytearray).
|
||||||
|
"""
|
||||||
|
with open(self._model_file, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def get_packed_associated_file_list(self):
|
||||||
|
"""Gets a list of associated files packed to the model file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of packed associated files.
|
||||||
|
"""
|
||||||
|
if not zipfile.is_zipfile(self._model_file):
|
||||||
|
return []
|
||||||
|
|
||||||
|
with zipfile.ZipFile(self._model_file, "r") as zf:
|
||||||
|
return zf.namelist()
|
||||||
|
|
||||||
|
def get_recorded_associated_file_list(self):
|
||||||
|
"""Gets a list of associated files recorded in metadata of the model file.
|
||||||
|
|
||||||
|
Associated files may be attached to a model, a subgraph, or an input/output
|
||||||
|
tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of recorded associated files.
|
||||||
|
"""
|
||||||
|
recorded_files = []
|
||||||
|
|
||||||
|
if not self._metadata_buf:
|
||||||
|
return recorded_files
|
||||||
|
|
||||||
|
metadata = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
|
||||||
|
self._metadata_buf, 0)
|
||||||
|
|
||||||
|
# Add associated files attached to ModelMetadata
|
||||||
|
self._get_associated_files_from_metadata_struct(metadata, recorded_files)
|
||||||
|
|
||||||
|
# Add associated files attached to each SubgraphMetadata
|
||||||
|
for j in range(metadata.SubgraphMetadataLength()):
|
||||||
|
subgraph = metadata.SubgraphMetadata(j)
|
||||||
|
self._get_associated_files_from_metadata_struct(subgraph, recorded_files)
|
||||||
|
|
||||||
|
# Add associated files attached to each input tensor
|
||||||
|
for k in range(subgraph.InputTensorMetadataLength()):
|
||||||
|
tensor = subgraph.InputTensorMetadata(k)
|
||||||
|
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
|
||||||
|
|
||||||
|
# Add associated files attached to each output tensor
|
||||||
|
for k in range(subgraph.OutputTensorMetadataLength()):
|
||||||
|
tensor = subgraph.OutputTensorMetadata(k)
|
||||||
|
self._get_associated_files_from_metadata_struct(tensor, recorded_files)
|
||||||
|
|
||||||
|
return recorded_files
|
||||||
|
|
||||||
|
def load_associated_files(self, associated_files):
|
||||||
|
"""Loads associated files that to be concatenated after the model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
associated_files: list of file paths.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IOError:
|
||||||
|
File not found.
|
||||||
|
"""
|
||||||
|
for af in associated_files:
|
||||||
|
_assert_exist(af)
|
||||||
|
self._associated_files.add(af)
|
||||||
|
|
||||||
|
def load_metadata_buffer(self, metadata_buf):
|
||||||
|
"""Loads the metadata buffer (in bytearray) to be populated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata_buf: metadata buffer (in bytearray) to be populated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
The metadata to be populated is empty.
|
||||||
|
"""
|
||||||
|
if not metadata_buf:
|
||||||
|
raise ValueError("The metadata to be populated is empty.")
|
||||||
|
|
||||||
|
self._metadata_buf = metadata_buf
|
||||||
|
|
||||||
|
def load_metadata_file(self, metadata_file):
|
||||||
|
"""Loads the metadata file to be populated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata_file: path to the metadata file to be populated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IOError:
|
||||||
|
File not found.
|
||||||
|
"""
|
||||||
|
_assert_exist(metadata_file)
|
||||||
|
with open(metadata_file, "rb") as f:
|
||||||
|
metadata_buf = f.read()
|
||||||
|
self.load_metadata_buffer(bytearray(metadata_buf))
|
||||||
|
|
||||||
|
def populate(self):
|
||||||
|
"""Populates loaded metadata and associated files into the model file."""
|
||||||
|
self._assert_validate()
|
||||||
|
self._populate_metadata_buffer()
|
||||||
|
self._populate_associated_files()
|
||||||
|
|
||||||
|
def _assert_validate(self):
|
||||||
|
"""Validates the metadata and associated files to be populated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
File is recorded in the metadata, but is not going to be populated.
|
||||||
|
File has already been packed.
|
||||||
|
"""
|
||||||
|
# Gets files that are recorded in metadata.
|
||||||
|
recorded_files = self.get_recorded_associated_file_list()
|
||||||
|
|
||||||
|
# Gets files that have been packed to self._model_file.
|
||||||
|
packed_files = self.get_packed_associated_file_list()
|
||||||
|
|
||||||
|
# Gets the file name of those associated files to be populated.
|
||||||
|
to_be_populated_files = []
|
||||||
|
for af in self._associated_files:
|
||||||
|
to_be_populated_files.append(os.path.basename(af))
|
||||||
|
|
||||||
|
# Checks all files recorded in the metadata will be populated.
|
||||||
|
for rf in recorded_files:
|
||||||
|
if rf not in to_be_populated_files and rf not in packed_files:
|
||||||
|
raise ValueError("File, '{0}', is recorded in the metadata, but has "
|
||||||
|
"not been loaded into the populator.".format(rf))
|
||||||
|
|
||||||
|
for f in to_be_populated_files:
|
||||||
|
if f in packed_files:
|
||||||
|
raise ValueError("File, '{0}', has already been packed.".format(f))
|
||||||
|
|
||||||
|
if f not in recorded_files:
|
||||||
|
warnings.warn(
|
||||||
|
"File, '{0}', does not exsit in the metadata. But packing it to "
|
||||||
|
"tflite model is still allowed.".format(f))
|
||||||
|
|
||||||
|
def _copy_archived_files(self, src_zip, dst_zip, file_list):
|
||||||
|
"""Copy archieved files in file_list from src_zip ro dst_zip."""
|
||||||
|
|
||||||
|
if not zipfile.is_zipfile(src_zip):
|
||||||
|
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
|
||||||
|
|
||||||
|
with zipfile.ZipFile(src_zip,
|
||||||
|
"r") as src_zf, zipfile.ZipFile(dst_zip,
|
||||||
|
"a") as dst_zf:
|
||||||
|
src_list = src_zf.namelist()
|
||||||
|
for f in file_list:
|
||||||
|
if f not in src_list:
|
||||||
|
raise ValueError(
|
||||||
|
"File, '{0}', does not exist in the zipfile, {1}.".format(
|
||||||
|
f, src_zip))
|
||||||
|
file_buffer = src_zf.read(f)
|
||||||
|
dst_zf.writestr(f, file_buffer)
|
||||||
|
|
||||||
|
def _get_associated_files_from_metadata_struct(self, file_holder, file_list):
|
||||||
|
for j in range(file_holder.AssociatedFilesLength()):
|
||||||
|
file_list.append(file_holder.AssociatedFiles(j).Name().decode("utf-8"))
|
||||||
|
|
||||||
|
def _populate_associated_files(self):
|
||||||
|
"""Concatenates associated files after TensorFlow Lite model file.
|
||||||
|
|
||||||
|
If the MetadataPopulator object is created using the method,
|
||||||
|
with_model_file(model_file), the model file will be updated.
|
||||||
|
"""
|
||||||
|
# Opens up the model file in "appending" mode.
|
||||||
|
# If self._model_file already has pack files, zipfile will concatenate
|
||||||
|
# addition files after self._model_file. For example, suppose we have
|
||||||
|
# self._model_file = old_tflite_file | label1.txt | label2.txt
|
||||||
|
# Then after trigger populate() to add label3.txt, self._model_file becomes
|
||||||
|
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
|
||||||
|
with zipfile.ZipFile(self._model_file, "a") as zf:
|
||||||
|
for af in self._associated_files:
|
||||||
|
filename = os.path.basename(af)
|
||||||
|
zf.write(af, filename)
|
||||||
|
|
||||||
|
def _populate_metadata_buffer(self):
|
||||||
|
"""Populates the metadata buffer (in bytearray) into the model file.
|
||||||
|
|
||||||
|
Inserts metadata_buf into the metadata field of schema.Model. If the
|
||||||
|
MetadataPopulator object is created using the method,
|
||||||
|
with_model_file(model_file), the model file will be updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(self._model_file, "rb") as f:
|
||||||
|
model_buf = f.read()
|
||||||
|
|
||||||
|
model = _schema_fb.ModelT.InitFromObj(
|
||||||
|
_schema_fb.Model.GetRootAsModel(model_buf, 0))
|
||||||
|
buffer_field = _schema_fb.BufferT()
|
||||||
|
buffer_field.data = self._metadata_buf
|
||||||
|
|
||||||
|
is_populated = False
|
||||||
|
if not model.metadata:
|
||||||
|
model.metadata = []
|
||||||
|
else:
|
||||||
|
# Check if metadata has already been populated.
|
||||||
|
for meta in model.metadata:
|
||||||
|
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
|
||||||
|
is_populated = True
|
||||||
|
model.buffers[meta.buffer] = buffer_field
|
||||||
|
|
||||||
|
if not is_populated:
|
||||||
|
if not model.buffers:
|
||||||
|
model.buffers = []
|
||||||
|
model.buffers.append(buffer_field)
|
||||||
|
# Creates a new metadata field.
|
||||||
|
metadata_field = _schema_fb.MetadataT()
|
||||||
|
metadata_field.name = self.METADATA_FIELD_NAME
|
||||||
|
metadata_field.buffer = len(model.buffers) - 1
|
||||||
|
model.metadata.append(metadata_field)
|
||||||
|
|
||||||
|
# Packs model back to a flatbuffer binaray file.
|
||||||
|
b = flatbuffers.Builder(0)
|
||||||
|
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
|
||||||
|
model_buf = b.Output()
|
||||||
|
|
||||||
|
# Saves the updated model buffer to model file.
|
||||||
|
# Gets files that have been packed to self._model_file.
|
||||||
|
packed_files = self.get_packed_associated_file_list()
|
||||||
|
if packed_files:
|
||||||
|
# Writes the updated model buffer and associated files into a new model
|
||||||
|
# file. Then overwrites the original model file.
|
||||||
|
with tempfile.NamedTemporaryFile() as temp:
|
||||||
|
new_file = temp.name
|
||||||
|
with open(new_file, "wb") as f:
|
||||||
|
f.write(model_buf)
|
||||||
|
self._copy_archived_files(self._model_file, new_file, packed_files)
|
||||||
|
shutil.copy(new_file, self._model_file)
|
||||||
|
os.remove(new_file)
|
||||||
|
else:
|
||||||
|
with open(self._model_file, "wb") as f:
|
||||||
|
f.write(model_buf)
|
||||||
|
|
||||||
|
|
||||||
|
class _MetadataPopulatorWithBuffer(MetadataPopulator):
|
||||||
|
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
|
||||||
|
|
||||||
|
This class is used to populate metadata into a in-memory model buffer. As we
|
||||||
|
use Zip API to concatenate associated files after tflite model file, the
|
||||||
|
populating operation is developed based on a model file. For in-memory model
|
||||||
|
buffer, we create a tempfile to serve the populating operation. This class is
|
||||||
|
then used to generate this tempfile, and delete the file when the
|
||||||
|
MetadataPopulator object is deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_buf):
|
||||||
|
"""Constructor for _MetadataPopulatorWithBuffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_buf: TensorFlow Lite model buffer in bytearray.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: model_buf is empty.
|
||||||
|
"""
|
||||||
|
if not model_buf:
|
||||||
|
raise ValueError("model_buf cannot be empty.")
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile() as temp:
|
||||||
|
model_file = temp.name
|
||||||
|
|
||||||
|
with open(model_file, "wb") as f:
|
||||||
|
f.write(model_buf)
|
||||||
|
|
||||||
|
MetadataPopulator.__init__(self, model_file)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Destructor of _MetadataPopulatorWithBuffer.
|
||||||
|
|
||||||
|
Deletes the tempfile.
|
||||||
|
"""
|
||||||
|
if os.path.exists(self._model_file):
|
||||||
|
os.remove(self._model_file)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataDisplayer(object):
|
||||||
|
"""Displays metadata and associated file info in human-readable format."""
|
||||||
|
|
||||||
|
def __init__(self, model_file, metadata_file, associated_file_list):
|
||||||
|
"""Constructor for MetadataDisplayer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_file: valid path to the model file.
|
||||||
|
metadata_file: valid path to the metadata file.
|
||||||
|
associated_file_list: list of associate files in the model file.
|
||||||
|
"""
|
||||||
|
self._model_file = model_file
|
||||||
|
self._metadata_file = metadata_file
|
||||||
|
self._associated_file_list = associated_file_list
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def with_model_file(cls, model_file):
|
||||||
|
"""Creates a MetadataDisplayer object for the model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_file: valid path to a TensorFlow Lite model file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MetadataDisplayer object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IOError: File not found.
|
||||||
|
ValueError: The model does not have metadata.
|
||||||
|
"""
|
||||||
|
_assert_exist(model_file)
|
||||||
|
metadata_file = cls._save_temporary_metadata_file(model_file)
|
||||||
|
associated_file_list = cls._parse_packed_associted_file_list(model_file)
|
||||||
|
return cls(model_file, metadata_file, associated_file_list)
|
||||||
|
|
||||||
|
def export_metadata_json_file(self, export_dir):
|
||||||
|
"""Converts the metadata into a json file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: the directory that the json file will be exported to. The json
|
||||||
|
file will be named after the model file, but with ".json" as extension.
|
||||||
|
"""
|
||||||
|
subprocess.check_call([
|
||||||
|
_FLATC_BINARY_PATH, "-o", export_dir, "--json",
|
||||||
|
_FLATC_TFLITE_METADATA_SCHEMA_FILE, "--", self._metadata_file,
|
||||||
|
"--strict-json"
|
||||||
|
])
|
||||||
|
temp_name = os.path.join(
|
||||||
|
export_dir,
|
||||||
|
os.path.splitext(os.path.basename(self._metadata_file))[0] + ".json")
|
||||||
|
expected_name = os.path.join(
|
||||||
|
export_dir,
|
||||||
|
os.path.splitext(os.path.basename(self._model_file))[0] + ".json")
|
||||||
|
os.rename(temp_name, expected_name)
|
||||||
|
|
||||||
|
def get_packed_associated_file_list(self):
|
||||||
|
"""Returns a list of associated files that are packed in the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A name list of associated files.
|
||||||
|
"""
|
||||||
|
return copy.deepcopy(self._associated_file_list)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save_temporary_metadata_file(model_file):
|
||||||
|
"""Saves the metadata in the model file to a temporary file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_file: valid path to the model file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the metadata temporary file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: The model does not have metadata.
|
||||||
|
"""
|
||||||
|
with open(model_file, "rb") as f:
|
||||||
|
model_buf = f.read()
|
||||||
|
|
||||||
|
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
|
||||||
|
|
||||||
|
# Gets metadata from the model file.
|
||||||
|
for i in range(tflite_model.MetadataLength()):
|
||||||
|
meta = tflite_model.Metadata(i)
|
||||||
|
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
|
||||||
|
buffer_index = meta.Buffer()
|
||||||
|
metadata = tflite_model.Buffers(buffer_index)
|
||||||
|
metadata_buf = metadata.DataAsNumpy().tobytes()
|
||||||
|
# Creates a temporary file to store the metadata.
|
||||||
|
with tempfile.NamedTemporaryFile() as temp:
|
||||||
|
metadata_file = temp.name
|
||||||
|
# Saves the metadata into the temporary file.
|
||||||
|
with open(metadata_file, "wb") as f:
|
||||||
|
f.write(metadata_buf)
|
||||||
|
return metadata_file
|
||||||
|
|
||||||
|
raise ValueError("The model does not have metadata.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_packed_associted_file_list(model_file):
|
||||||
|
"""Gets a list of associated files packed to the model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_file: valid path to the model file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of packed associated files.
|
||||||
|
"""
|
||||||
|
if not zipfile.is_zipfile(model_file):
|
||||||
|
return []
|
||||||
|
|
||||||
|
with zipfile.ZipFile(model_file, "r") as zf:
|
||||||
|
return zf.namelist()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Destructor of MetadataDisplayer.
|
||||||
|
|
||||||
|
Deletes the tempfile.
|
||||||
|
"""
|
||||||
|
if os.path.exists(self._metadata_file):
|
||||||
|
os.remove(self._metadata_file)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_exist(filename):
|
||||||
|
"""Checks if a file exists."""
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
raise IOError("File, '{0}', does not exist.".format(filename))
|
@ -0,0 +1,499 @@
|
|||||||
|
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
namespace tflite;
|
||||||
|
|
||||||
|
// TFLite metadata contains both human readable and machine readable information
|
||||||
|
// about what the model does and how to use the model. It can be used as a
|
||||||
|
// README file, which elaborates the details of the model, each input/ouput
|
||||||
|
// tensor, and each associated file.
|
||||||
|
//
|
||||||
|
// An important use case of TFLite metadata is the TFLite codegen tool, which
|
||||||
|
// automatically generates the model interface based on the properties of the
|
||||||
|
// model and the tensors. The model interface provides high-level APIs to
|
||||||
|
// interact with the model, such as preprocessing the input data and running
|
||||||
|
// inferences.
|
||||||
|
//
|
||||||
|
// Entries marked with "<Codegen usage>" are used in TFLite codegen tool to
|
||||||
|
// generate the model interface. It is recommended to fill in at least those
|
||||||
|
// enties to boost the codegen performance.
|
||||||
|
|
||||||
|
// This corresponds to the schema version.
|
||||||
|
file_identifier "M001";
|
||||||
|
// File extension of any written files.
|
||||||
|
file_extension "tflitemeta";
|
||||||
|
|
||||||
|
enum AssociatedFileType : byte {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
// Files such as readme.txt
|
||||||
|
DESCRIPTIONS = 1,
|
||||||
|
|
||||||
|
// Contains labels that annotate certain axis of the tensor. For example,
|
||||||
|
// the label file in image classification. Those labels annotate the
|
||||||
|
// the output tensor, such that each value in the output tensor is the
|
||||||
|
// probability of that corresponding category specified by the label.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// If an output tensor has an associated file as TENSOR_AXIS_LABELS, return
|
||||||
|
// the output as a mapping between the labels and probability in the model
|
||||||
|
// interface.
|
||||||
|
// If multiple files of the same type are present, the first one is used by
|
||||||
|
// default; additional ones are to be distinguished from one another by their
|
||||||
|
// specified locale.
|
||||||
|
TENSOR_AXIS_LABELS = 2,
|
||||||
|
|
||||||
|
// Contains labels that tensor values correspond to. For example, in
|
||||||
|
// the object detection model, one of the output tensors is the detected
|
||||||
|
// classes. And each value in the tensor refers to the index of label in the
|
||||||
|
// category label file.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// If an output tensor has an associated file as TENSOR_VALUE_LABELS, convert
|
||||||
|
// the tensor values into labels, and return a list of string as the output.
|
||||||
|
// If multiple files of the same type are present, the first one is used by
|
||||||
|
// default; additional ones are to be distinguished from one another by their
|
||||||
|
// specified locale.
|
||||||
|
TENSOR_VALUE_LABELS = 3,
|
||||||
|
|
||||||
|
// Contains sigmoid-based score calibration parameters, formatted as CSV.
|
||||||
|
// Lines contain for each index of an output tensor the scale, slope, offset
|
||||||
|
// and min_score parameters to be used for sigmoid fitting (in this order and
|
||||||
|
// in `strtof`-compatible [1] format).
|
||||||
|
// A line may be left empty to default calibrated scores for this index to
|
||||||
|
// default_score. See documentation for ScoreCalibrationOptions for details.
|
||||||
|
//
|
||||||
|
// [1]: https://en.cppreference.com/w/c/string/byte/strtof
|
||||||
|
TENSOR_AXIS_SCORE_CALIBRATION = 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
table AssociatedFile {
|
||||||
|
// Name of this file. Need to be exact the same as the name of the actual file
|
||||||
|
// packed into the TFLite model as a zip file.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Locates to the actual file in the TFLite model.
|
||||||
|
name:string;
|
||||||
|
|
||||||
|
// A description of what the file is.
|
||||||
|
description:string;
|
||||||
|
|
||||||
|
// Type of the associated file. There may be special pre/post processing for
|
||||||
|
// some types. For example in image classification, a label file of the output
|
||||||
|
// will be used to convert object index into string.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Determines how to process the corresponding tensor.
|
||||||
|
type:AssociatedFileType;
|
||||||
|
|
||||||
|
// An optional locale for this associated file (if applicable). It is
|
||||||
|
// recommended to use an ISO 639-1 letter code (e.g. "en" for English),
|
||||||
|
// optionally completed by a two letter region code (e.g. "en-US" for US
|
||||||
|
// English and "en-CA" for Canadian English).
|
||||||
|
// Leverage this in order to specify e.g multiple label files translated in
|
||||||
|
// different languages.
|
||||||
|
locale:string;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The basic content type for all tensors.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Input feature tensors:
|
||||||
|
// 1. Generates the method to load data from a TensorBuffer.
|
||||||
|
// 2. Creates the preprocessing logic. The default processing pipeline is:
|
||||||
|
// [NormalizeOp, QuantizeOp].
|
||||||
|
// Output feature tensors:
|
||||||
|
// 1. Generates the method to return the output data to a TensorBuffer.
|
||||||
|
// 2. Creates the post-processing logic. The default processing pipeline is:
|
||||||
|
// [DeQuantizeOp].
|
||||||
|
table FeatureProperties {
|
||||||
|
}
|
||||||
|
|
||||||
|
// The type of color space of an image.
|
||||||
|
enum ColorSpaceType : byte {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
RGB = 1,
|
||||||
|
GRAYSCALE = 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
table ImageSize {
|
||||||
|
width:uint;
|
||||||
|
height:uint;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The properties for image tensors.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Input image tensors:
|
||||||
|
// 1. Generates the method to load an image from a TensorImage.
|
||||||
|
// 2. Creates the preprocessing logic. The default processing pipeline is:
|
||||||
|
// [ResizeOp, NormalizeOp, QuantizeOp].
|
||||||
|
// Output image tensors:
|
||||||
|
// 1. Generates the method to return the output data to a TensorImage.
|
||||||
|
// 2. Creates the post-processing logic. The default processing pipeline is:
|
||||||
|
// [DeQuantizeOp].
|
||||||
|
table ImageProperties {
|
||||||
|
// The color space of the image.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Determines how to convert the color space of a given image from users.
|
||||||
|
color_space:ColorSpaceType;
|
||||||
|
|
||||||
|
// Indicates the default value of image width and height if the tensor shape
|
||||||
|
// is dynamic. For fixed-size tensor, this size will be consistent with the
|
||||||
|
// expected size.
|
||||||
|
default_size:ImageSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The properties for tensors representing bounding boxes.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Input image tensors: NA.
|
||||||
|
// Output image tensors: parses the values into a data stucture that represents
|
||||||
|
// bounding boxes. For example, in the generated wrapper for Android, it returns
|
||||||
|
// the output as android.graphics.Rect objects.
|
||||||
|
enum BoundingBoxType : byte {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
// Represents the bounding box by using the combination of boundaries,
|
||||||
|
// {left, top, right, bottom}.
|
||||||
|
// The default order is {left, top, right, bottom}. Other orders can be
|
||||||
|
// indicated by BoundingBoxProperties.index.
|
||||||
|
BOUNDARIES = 1,
|
||||||
|
|
||||||
|
// Represents the bounding box by using the upper_left corner, width and
|
||||||
|
// height.
|
||||||
|
// The default order is {upper_left_x, upper_left_y, width, height}. Other
|
||||||
|
// orders can be indicated by BoundingBoxProperties.index.
|
||||||
|
UPPER_LEFT = 2,
|
||||||
|
|
||||||
|
// Represents the bounding box by using the center of the box, width and
|
||||||
|
// height. The default order is {center_x, center_y, width, height}. Other
|
||||||
|
// orders can be indicated by BoundingBoxProperties.index.
|
||||||
|
CENTER = 3,
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
enum CoordinateType : byte {
|
||||||
|
// The coordinates are float values from 0 to 1.
|
||||||
|
RATIO = 0,
|
||||||
|
// The coordinates are integers.
|
||||||
|
PIXEL = 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
table BoundingBoxProperties {
|
||||||
|
// Denotes the order of the elements defined in each bounding box type. An
|
||||||
|
// empty index array represent the defualt order of each bounding box type.
|
||||||
|
// For example, to denote the default order of BOUNDARIES, {left, top, right,
|
||||||
|
// bottom}, the index should be {0, 1, 2, 3}. To denote the order {left,
|
||||||
|
// right, top, bottom}, the order should be {0, 2, 1, 3}.
|
||||||
|
//
|
||||||
|
// The index array can be applied to all bounding box types to adjust the
|
||||||
|
// order of their corresponding underlying elements.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Indicates how to parse the bounding box values.
|
||||||
|
index:[uint];
|
||||||
|
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Indicates how to parse the bounding box values.
|
||||||
|
type:BoundingBoxType;
|
||||||
|
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Indicates how to convert the bounding box back to the original image in
|
||||||
|
// pixels.
|
||||||
|
coordinate_type:CoordinateType;
|
||||||
|
}
|
||||||
|
|
||||||
|
union ContentProperties {
|
||||||
|
FeatureProperties,
|
||||||
|
ImageProperties,
|
||||||
|
BoundingBoxProperties,
|
||||||
|
}
|
||||||
|
|
||||||
|
table ValueRange {
|
||||||
|
min:int;
|
||||||
|
max:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table Content {
|
||||||
|
// The properties that the content may have, indicating the type of the
|
||||||
|
// Content.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Indicates how to process the tensor.
|
||||||
|
content_properties:ContentProperties;
|
||||||
|
|
||||||
|
// The range of dimensions that the content corresponds to. A NULL
|
||||||
|
// "range" indicates that the content uses up all dimensions,
|
||||||
|
// except the batch axis if applied.
|
||||||
|
//
|
||||||
|
// Here are all the possible situations of how a tensor is composed.
|
||||||
|
// Case 1: The tensor is a single object, such as an image.
|
||||||
|
// For example, the input of an image classifier
|
||||||
|
// (https://www.tensorflow.org/lite/models/image_classification/overview),
|
||||||
|
// a tensor of shape [1, 224, 224, 3]. Dimensions 1 to 3 correspond to the
|
||||||
|
// image. Since dimension 0 is a batch axis, which can be ignored,
|
||||||
|
// "range" can be left as NULL.
|
||||||
|
//
|
||||||
|
// Case 2: The tensor contains multiple instances of the same object.
|
||||||
|
// For example, the output tensor of detected bounding boxes of an object
|
||||||
|
// detection model
|
||||||
|
// (https://www.tensorflow.org/lite/models/object_detection/overview).
|
||||||
|
// The tensor shape is [1, 10, 4]. Here is the what the three dimensions
|
||||||
|
// represent for:
|
||||||
|
// dimension 0: the batch axis.
|
||||||
|
// dimension 1: the 10 objects detected with the highest confidence.
|
||||||
|
// dimension 2: the bounding boxes of the 10 detected objects.
|
||||||
|
// The tensor is essentially 10 bounding boxes. In this case,
|
||||||
|
// "range" should be {min=2; max=2;}.
|
||||||
|
// Another example is the pose estimation model
|
||||||
|
// (https://www.tensorflow.org/lite/models/pose_estimation/overview).
|
||||||
|
// The output tensor of heatmaps is in the shape of [1, 9, 9, 17].
|
||||||
|
// Here is the what the four dimensions represent for:
|
||||||
|
// dimension 0: the batch axis.
|
||||||
|
// dimension 1/2: the heatmap image.
|
||||||
|
// dimension 3: 17 body parts of a person.
|
||||||
|
// Even though the last axis is body part, the real content of this tensor is
|
||||||
|
// the heatmap. "range" should be [min=2; max=3].
|
||||||
|
//
|
||||||
|
// Case 3: The tensor contains multiple different objects. (Not supported by
|
||||||
|
// Content at this point).
|
||||||
|
// Sometimes a tensor may contain multiple different objects, thus different
|
||||||
|
// contents. It is very common for regression models. For example, a model
|
||||||
|
// to predict the fuel efficiency
|
||||||
|
// (https://www.tensorflow.org/tutorials/keras/regression).
|
||||||
|
// The input tensor has shape [1, 9], consisting of 9 features, such as
|
||||||
|
// "Cylinders", "Displacement", "Weight", etc. In this case, dimension 1
|
||||||
|
// contains 9 different contents. However, since these sub-dimension objects
|
||||||
|
// barely need to be specifically processed, their contents are not recorded
|
||||||
|
// in the metadata. Through, the name of each dimension can be set through
|
||||||
|
// TensorMetadata.dimension_names.
|
||||||
|
//
|
||||||
|
// Note that if it is not case 3, a tensor can only have one content type.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Case 1: return a processed single object of certain content type.
|
||||||
|
// Case 2: return a list of processed objects of certain content type. The
|
||||||
|
// generated model interface have API to random access those objects from
|
||||||
|
// the output.
|
||||||
|
range:ValueRange;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parameters that are used when normalizing the tensor.
|
||||||
|
table NormalizationOptions{
|
||||||
|
// mean and std are normalization parameters. Tensor values are normailzed
|
||||||
|
// per-channelly by,
|
||||||
|
// (x - mean) / std.
|
||||||
|
// For example, a float MobileNet model will have
|
||||||
|
// mean = 127.5f and std = 127.5f.
|
||||||
|
// A quantized MobileNet model will have
|
||||||
|
// mean = 0.0f and std = 1.0f.
|
||||||
|
// If there is only one value in mean or std, we'll propogate the value to
|
||||||
|
// all channels.
|
||||||
|
|
||||||
|
// Per-channel mean of the possible values used in normalization.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Apply normalization to input tensors accordingly.
|
||||||
|
mean:[float];
|
||||||
|
|
||||||
|
// Per-channel standard dev. of the possible values used in normalization.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Apply normalization to input tensors accordingly.
|
||||||
|
std:[float];
|
||||||
|
}
|
||||||
|
|
||||||
|
// The different possible score transforms to apply to uncalibrated scores
|
||||||
|
// before applying score calibration.
|
||||||
|
enum ScoreTransformationType : byte {
|
||||||
|
// Identity function: g(x) = x.
|
||||||
|
IDENTITY = 0,
|
||||||
|
// Log function: g(x) = log(x).
|
||||||
|
LOG = 1,
|
||||||
|
// Inverse logistic function: g(x) = log(x) - log(1-x).
|
||||||
|
INVERSE_LOGISTIC = 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options to perform score calibration on an output tensor through sigmoid
|
||||||
|
// functions. One of the main purposes of score calibration is to make scores
|
||||||
|
// across classes comparable, so that a common threshold can be used for all
|
||||||
|
// output classes. This is meant for models producing class predictions as
|
||||||
|
// output, e.g. image classification or detection models.
|
||||||
|
//
|
||||||
|
// For each index in the output tensor, this applies:
|
||||||
|
// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score`,
|
||||||
|
// * `f(x) = default_score` otherwise or if no scale, slope, offset and
|
||||||
|
// min_score have been specified.
|
||||||
|
// Where:
|
||||||
|
// * scale, slope, offset and min_score are index-specific parameters
|
||||||
|
// * g(x) is an index-independent transform among those defined in
|
||||||
|
// ScoreTransformationType
|
||||||
|
// * default_score is an index-independent parameter.
|
||||||
|
// An AssociatedFile with type TANSOR_AXIS_SCORE_CALIBRATION specifying the
|
||||||
|
// index-specific parameters must be associated with the corresponding
|
||||||
|
// TensorMetadata for score calibration be applied.
|
||||||
|
table ScoreCalibrationOptions {
|
||||||
|
// The function to use for transforming the uncalibrated score before
|
||||||
|
// applying score calibration.
|
||||||
|
score_transformation:ScoreTransformationType;
|
||||||
|
|
||||||
|
// The default calibrated score to apply if the uncalibrated score is
|
||||||
|
// below min_score or if no parameters were specified for a given index.
|
||||||
|
default_score:float;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs thresholding on output tensor values, in order to filter out
|
||||||
|
// low-confidence results.
|
||||||
|
table ScoreThresholdingOptions {
|
||||||
|
// The recommended global threshold below which results are considered
|
||||||
|
// low-confidence and should be filtered out.
|
||||||
|
global_score_threshold:float;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options that are used when processing the tensor.
|
||||||
|
union ProcessUnitOptions {
|
||||||
|
NormalizationOptions,
|
||||||
|
ScoreCalibrationOptions,
|
||||||
|
ScoreThresholdingOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
// A process unit that is used to process the tensor out-of-graph.
|
||||||
|
table ProcessUnit {
|
||||||
|
options:ProcessUnitOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Statistics to describe a tensor.
|
||||||
|
table Stats {
|
||||||
|
// Max and min are not currently used in tflite.support codegen. They mainly
|
||||||
|
// serve as references for users to better understand the model. They can also
|
||||||
|
// be used to validate model pre/post processing results.
|
||||||
|
// If there is only one value in max or min, we'll propogate the value to
|
||||||
|
// all channels.
|
||||||
|
|
||||||
|
// Per-channel maximum value of the tensor.
|
||||||
|
max:[float];
|
||||||
|
|
||||||
|
// Per-channel minimum value of the tensor.
|
||||||
|
min:[float];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detailed information of an input or output tensor.
|
||||||
|
table TensorMetadata {
|
||||||
|
// Name of the tensor.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// The name of this tensor in the generated model interface.
|
||||||
|
name:string;
|
||||||
|
|
||||||
|
// A description of the tensor.
|
||||||
|
description:string;
|
||||||
|
|
||||||
|
// A list of names of the dimensions in this tentor. The length of
|
||||||
|
// dimension_names need to match the number of dimensions in this tensor.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// The name of each dimension in the generated model interface. See "Case 2"
|
||||||
|
// in the comments of Content.range.
|
||||||
|
dimension_names:[string];
|
||||||
|
|
||||||
|
// The content that represents this tensor.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Determines how to process this tensor. See each item in ContentProperties
|
||||||
|
// for the default process units that will be applied to the tensor.
|
||||||
|
content:Content;
|
||||||
|
|
||||||
|
// The process units that are used to process the tensor out-of-graph.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Contains the parameters of the default processing pipeline for each content
|
||||||
|
// type, such as the normalization parameters in all content types. See the
|
||||||
|
// items under ContentProperties for the details of the default processing
|
||||||
|
// pipeline.
|
||||||
|
process_units:[ProcessUnit];
|
||||||
|
|
||||||
|
// The statistics of the tensor values.
|
||||||
|
stats:Stats;
|
||||||
|
|
||||||
|
// A list of associated files of this tensor.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Contains processing parameters of this tensor, such as normalization.
|
||||||
|
associated_files:[AssociatedFile];
|
||||||
|
}
|
||||||
|
|
||||||
|
table SubGraphMetadata {
|
||||||
|
// Name of the subgraph.
|
||||||
|
//
|
||||||
|
// Note that, since TFLite only support one subgraph at this moment, the
|
||||||
|
// Codegen tool will use the name in ModelMetadata in the generated model
|
||||||
|
// interface.
|
||||||
|
name:string;
|
||||||
|
|
||||||
|
// A description explains details about what the subgraph does.
|
||||||
|
description:string;
|
||||||
|
|
||||||
|
// Metadata of all input tensors used in this subgraph.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Determines how to process the inputs.
|
||||||
|
input_tensor_metadata:[TensorMetadata];
|
||||||
|
|
||||||
|
// Metadata of all output tensors used in this subgraph.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Determines how to process the outputs.
|
||||||
|
output_tensor_metadata:[TensorMetadata];
|
||||||
|
|
||||||
|
// A list of associated files of this subgraph.
|
||||||
|
associated_files:[AssociatedFile];
|
||||||
|
}
|
||||||
|
|
||||||
|
table ModelMetadata {
|
||||||
|
// Name of the model.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// The name of the model in the generated model interface.
|
||||||
|
name:string;
|
||||||
|
|
||||||
|
// Model description in schema.
|
||||||
|
description:string;
|
||||||
|
|
||||||
|
// Version of the model that specified by model creators.
|
||||||
|
version:string;
|
||||||
|
|
||||||
|
// Noted that, the minimum required TFLite runtime version that the model is
|
||||||
|
// compatible with, has already been added as a metadata entry in tflite
|
||||||
|
// schema. We'll decide later if we want to move it here, and keep it with
|
||||||
|
// other metadata entries.
|
||||||
|
|
||||||
|
// Metadata of all the subgraphs of the model. The 0th is assumed to be the
|
||||||
|
// main subgraph.
|
||||||
|
//
|
||||||
|
// <Codegen usage>:
|
||||||
|
// Determines how to process the inputs and outputs.
|
||||||
|
subgraph_metadata:[SubGraphMetadata];
|
||||||
|
|
||||||
|
// The person who creates this model.
|
||||||
|
author:string;
|
||||||
|
|
||||||
|
// Licenses that may apply to this model.
|
||||||
|
license:string;
|
||||||
|
|
||||||
|
// A list of associated files of this model.
|
||||||
|
associated_files:[AssociatedFile];
|
||||||
|
}
|
||||||
|
|
||||||
|
root_type ModelMetadata;
|
381
tensorflow/lite/experimental/support/metadata/metadata_test.py
Normal file
381
tensorflow/lite/experimental/support/metadata/metadata_test.py
Normal file
@ -0,0 +1,381 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.lite.experimental.support.metadata.metadata."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from flatbuffers.python import flatbuffers
|
||||||
|
from tensorflow.lite.experimental.support.metadata import metadata as _metadata
|
||||||
|
from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb
|
||||||
|
from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.platform import resource_loader
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(MetadataTest, self).setUp()
|
||||||
|
self._invalid_model_buf = None
|
||||||
|
self._invalid_file = "not_existed_file"
|
||||||
|
self._empty_model_buf = self._create_empty_model_buf()
|
||||||
|
self._empty_model_file = self.create_tempfile().full_path
|
||||||
|
with open(self._empty_model_file, "wb") as f:
|
||||||
|
f.write(self._empty_model_buf)
|
||||||
|
self._model_file = self._create_model_file_with_metadata_and_buf_fields()
|
||||||
|
self._metadata_file = self._create_metadata_file()
|
||||||
|
self._file1 = self.create_tempfile("file1").full_path
|
||||||
|
self._file2 = self.create_tempfile("file2").full_path
|
||||||
|
self._file3 = self.create_tempfile("file3").full_path
|
||||||
|
|
||||||
|
def _create_empty_model_buf(self):
|
||||||
|
model = _schema_fb.ModelT()
|
||||||
|
model_builder = flatbuffers.Builder(0)
|
||||||
|
model_builder.Finish(
|
||||||
|
model.Pack(model_builder),
|
||||||
|
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||||
|
return model_builder.Output()
|
||||||
|
|
||||||
|
def _create_model_file_with_metadata_and_buf_fields(self):
|
||||||
|
metadata_field = _schema_fb.MetadataT()
|
||||||
|
metadata_field.name = "meta"
|
||||||
|
buffer_field = _schema_fb.BufferT()
|
||||||
|
model = _schema_fb.ModelT()
|
||||||
|
model.metadata = [metadata_field, metadata_field]
|
||||||
|
model.buffers = [buffer_field, buffer_field, buffer_field]
|
||||||
|
model_builder = flatbuffers.Builder(0)
|
||||||
|
model_builder.Finish(
|
||||||
|
model.Pack(model_builder),
|
||||||
|
_metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
|
||||||
|
|
||||||
|
mnodel_file = self.create_tempfile().full_path
|
||||||
|
with open(mnodel_file, "wb") as f:
|
||||||
|
f.write(model_builder.Output())
|
||||||
|
|
||||||
|
return mnodel_file
|
||||||
|
|
||||||
|
def _create_metadata_file(self):
|
||||||
|
associated_file1 = _metadata_fb.AssociatedFileT()
|
||||||
|
associated_file1.name = b"file1"
|
||||||
|
associated_file2 = _metadata_fb.AssociatedFileT()
|
||||||
|
associated_file2.name = b"file2"
|
||||||
|
self.expected_recorded_files = [
|
||||||
|
six.ensure_str(associated_file1.name),
|
||||||
|
six.ensure_str(associated_file2.name)
|
||||||
|
]
|
||||||
|
|
||||||
|
output_meta = _metadata_fb.TensorMetadataT()
|
||||||
|
output_meta.associatedFiles = [associated_file2]
|
||||||
|
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||||
|
subgraph.outputTensorMetadata = [output_meta]
|
||||||
|
|
||||||
|
model_meta = _metadata_fb.ModelMetadataT()
|
||||||
|
model_meta.name = "Mobilenet_quantized"
|
||||||
|
model_meta.associatedFiles = [associated_file1]
|
||||||
|
model_meta.subgraphMetadata = [subgraph]
|
||||||
|
b = flatbuffers.Builder(0)
|
||||||
|
b.Finish(
|
||||||
|
model_meta.Pack(b),
|
||||||
|
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||||
|
|
||||||
|
metadata_file = self.create_tempfile().full_path
|
||||||
|
with open(metadata_file, "wb") as f:
|
||||||
|
f.write(b.Output())
|
||||||
|
return metadata_file
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataPopulatorTest(MetadataTest):
|
||||||
|
|
||||||
|
def testToValidModelFile(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_file(
|
||||||
|
self._empty_model_file)
|
||||||
|
self.assertIsInstance(populator, _metadata.MetadataPopulator)
|
||||||
|
|
||||||
|
def testToInvalidModelFile(self):
|
||||||
|
with self.assertRaises(IOError) as error:
|
||||||
|
_metadata.MetadataPopulator.with_model_file(self._invalid_file)
|
||||||
|
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
|
def testToValidModelBuffer(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||||
|
self._empty_model_buf)
|
||||||
|
self.assertIsInstance(populator, _metadata.MetadataPopulator)
|
||||||
|
|
||||||
|
def testToInvalidModelBuffer(self):
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
_metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
|
||||||
|
self.assertEqual("model_buf cannot be empty.", str(error.exception))
|
||||||
|
|
||||||
|
def testSinglePopulateAssociatedFile(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||||
|
self._empty_model_buf)
|
||||||
|
populator.load_associated_files([self._file1])
|
||||||
|
populator.populate()
|
||||||
|
|
||||||
|
packed_files = populator.get_packed_associated_file_list()
|
||||||
|
expected_packed_files = [os.path.basename(self._file1)]
|
||||||
|
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||||
|
|
||||||
|
def testRepeatedPopulateAssociatedFile(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_file(
|
||||||
|
self._empty_model_file)
|
||||||
|
populator.load_associated_files([self._file1, self._file2])
|
||||||
|
# Loads file2 multiple times.
|
||||||
|
populator.load_associated_files([self._file2])
|
||||||
|
populator.populate()
|
||||||
|
|
||||||
|
packed_files = populator.get_packed_associated_file_list()
|
||||||
|
expected_packed_files = [
|
||||||
|
os.path.basename(self._file1),
|
||||||
|
os.path.basename(self._file2)
|
||||||
|
]
|
||||||
|
self.assertEqual(len(packed_files), 2)
|
||||||
|
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||||
|
|
||||||
|
# Check if the model buffer read from file is the same as that read from
|
||||||
|
# get_model_buffer().
|
||||||
|
with open(self._empty_model_file, "rb") as f:
|
||||||
|
model_buf_from_file = f.read()
|
||||||
|
model_buf_from_getter = populator.get_model_buffer()
|
||||||
|
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||||
|
|
||||||
|
def testPopulateInvalidAssociatedFile(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||||
|
self._empty_model_buf)
|
||||||
|
with self.assertRaises(IOError) as error:
|
||||||
|
populator.load_associated_files([self._invalid_file])
|
||||||
|
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
|
def testPopulatePackedAssociatedFile(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||||
|
self._empty_model_buf)
|
||||||
|
populator.load_associated_files([self._file1])
|
||||||
|
populator.populate()
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
populator.load_associated_files([self._file1])
|
||||||
|
populator.populate()
|
||||||
|
self.assertEqual(
|
||||||
|
"File, '{0}', has already been packed.".format(
|
||||||
|
os.path.basename(self._file1)), str(error.exception))
|
||||||
|
|
||||||
|
def testGetPackedAssociatedFileList(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||||
|
self._empty_model_buf)
|
||||||
|
packed_files = populator.get_packed_associated_file_list()
|
||||||
|
self.assertEqual(packed_files, [])
|
||||||
|
|
||||||
|
def testPopulateMetadataFileToEmptyModelFile(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_file(
|
||||||
|
self._empty_model_file)
|
||||||
|
populator.load_metadata_file(self._metadata_file)
|
||||||
|
populator.load_associated_files([self._file1, self._file2])
|
||||||
|
populator.populate()
|
||||||
|
|
||||||
|
with open(self._empty_model_file, "rb") as f:
|
||||||
|
model_buf_from_file = f.read()
|
||||||
|
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
|
||||||
|
metadata_field = model.Metadata(0)
|
||||||
|
self.assertEqual(
|
||||||
|
six.ensure_str(metadata_field.Name()),
|
||||||
|
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
|
||||||
|
|
||||||
|
buffer_index = metadata_field.Buffer()
|
||||||
|
buffer_data = model.Buffers(buffer_index)
|
||||||
|
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||||
|
metadata_buf = metadata_buf_np.tobytes()
|
||||||
|
with open(self._metadata_file, "rb") as f:
|
||||||
|
expected_metadata_buf = bytearray(f.read())
|
||||||
|
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||||
|
|
||||||
|
recorded_files = populator.get_recorded_associated_file_list()
|
||||||
|
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||||
|
|
||||||
|
# Up to now, we've proved the correctness of the model buffer that read from
|
||||||
|
# file. Then we'll test if get_model_buffer() gives the same model buffer.
|
||||||
|
model_buf_from_getter = populator.get_model_buffer()
|
||||||
|
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||||
|
|
||||||
|
def testPopulateMetadataFileWithoutAssociatedFiles(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_file(
|
||||||
|
self._empty_model_file)
|
||||||
|
populator.load_metadata_file(self._metadata_file)
|
||||||
|
populator.load_associated_files([self._file1])
|
||||||
|
# Suppose to populate self._file2, because it is recorded in the metadta.
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
populator.populate()
|
||||||
|
self.assertEqual(("File, '{0}', is recorded in the metadata, but has "
|
||||||
|
"not been loaded into the populator.").format(
|
||||||
|
os.path.basename(self._file2)), str(error.exception))
|
||||||
|
|
||||||
|
def _assert_golden_metadata(self, model_file):
|
||||||
|
with open(model_file, "rb") as f:
|
||||||
|
model_buf_from_file = f.read()
|
||||||
|
model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
|
||||||
|
# There are two elements in model.Metadata array before the population.
|
||||||
|
# Metadata should be packed to the third element in the array.
|
||||||
|
metadata_field = model.Metadata(2)
|
||||||
|
self.assertEqual(
|
||||||
|
six.ensure_str(metadata_field.Name()),
|
||||||
|
six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))
|
||||||
|
|
||||||
|
buffer_index = metadata_field.Buffer()
|
||||||
|
buffer_data = model.Buffers(buffer_index)
|
||||||
|
metadata_buf_np = buffer_data.DataAsNumpy()
|
||||||
|
metadata_buf = metadata_buf_np.tobytes()
|
||||||
|
with open(self._metadata_file, "rb") as f:
|
||||||
|
expected_metadata_buf = bytearray(f.read())
|
||||||
|
self.assertEqual(metadata_buf, expected_metadata_buf)
|
||||||
|
|
||||||
|
def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self):
|
||||||
|
# First, creates a dummy metadata. Populates it and the associated files
|
||||||
|
# into the model.
|
||||||
|
model_meta = _metadata_fb.ModelMetadataT()
|
||||||
|
model_meta.name = "Mobilenet_quantized"
|
||||||
|
b = flatbuffers.Builder(0)
|
||||||
|
b.Finish(
|
||||||
|
model_meta.Pack(b),
|
||||||
|
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||||
|
metadata_buf = b.Output()
|
||||||
|
|
||||||
|
populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||||
|
populator1.load_metadata_buffer(metadata_buf)
|
||||||
|
populator1.load_associated_files([self._file1, self._file2])
|
||||||
|
populator1.populate()
|
||||||
|
|
||||||
|
# Then, populates the metadata again.
|
||||||
|
populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||||
|
populator2.load_metadata_file(self._metadata_file)
|
||||||
|
populator2.populate()
|
||||||
|
|
||||||
|
# Tests if the metadata is populated correctly.
|
||||||
|
self._assert_golden_metadata(self._model_file)
|
||||||
|
|
||||||
|
def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
|
||||||
|
populator.load_metadata_file(self._metadata_file)
|
||||||
|
populator.load_associated_files([self._file1, self._file2])
|
||||||
|
populator.populate()
|
||||||
|
|
||||||
|
# Tests if the metadata is populated correctly.
|
||||||
|
self._assert_golden_metadata(self._model_file)
|
||||||
|
|
||||||
|
recorded_files = populator.get_recorded_associated_file_list()
|
||||||
|
self.assertEqual(set(recorded_files), set(self.expected_recorded_files))
|
||||||
|
|
||||||
|
# Up to now, we've proved the correctness of the model buffer that read from
|
||||||
|
# file. Then we'll test if get_model_buffer() gives the same model buffer.
|
||||||
|
with open(self._model_file, "rb") as f:
|
||||||
|
model_buf_from_file = f.read()
|
||||||
|
model_buf_from_getter = populator.get_model_buffer()
|
||||||
|
self.assertEqual(model_buf_from_file, model_buf_from_getter)
|
||||||
|
|
||||||
|
def testPopulateInvalidMetadataFile(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||||
|
self._empty_model_buf)
|
||||||
|
with self.assertRaises(IOError) as error:
|
||||||
|
populator.load_metadata_file(self._invalid_file)
|
||||||
|
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
|
def testPopulateInvalidMetadataBuffer(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||||
|
self._empty_model_buf)
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
populator.load_metadata_buffer([])
|
||||||
|
self.assertEqual("The metadata to be populated is empty.",
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
|
def testGetModelBufferBeforePopulatingData(self):
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_buffer(
|
||||||
|
self._empty_model_buf)
|
||||||
|
model_buf = populator.get_model_buffer()
|
||||||
|
expected_model_buf = self._empty_model_buf
|
||||||
|
self.assertEqual(model_buf, expected_model_buf)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataDisplayerTest(MetadataTest):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(MetadataDisplayerTest, self).setUp()
|
||||||
|
self._model_file = self._create_model_with_metadata_and_associated_files()
|
||||||
|
|
||||||
|
def _create_model_with_metadata_and_associated_files(self):
|
||||||
|
model_buf = self._create_empty_model_buf()
|
||||||
|
model_file = self.create_tempfile().full_path
|
||||||
|
with open(model_file, "wb") as f:
|
||||||
|
f.write(model_buf)
|
||||||
|
|
||||||
|
populator = _metadata.MetadataPopulator.with_model_file(model_file)
|
||||||
|
populator.load_metadata_file(self._metadata_file)
|
||||||
|
populator.load_associated_files([self._file1, self._file2])
|
||||||
|
populator.populate()
|
||||||
|
return model_file
|
||||||
|
|
||||||
|
def test_load_model_file_invalidModelFile_throwsException(self):
|
||||||
|
with self.assertRaises(IOError) as error:
|
||||||
|
_metadata.MetadataDisplayer.with_model_file(self._invalid_file)
|
||||||
|
self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
|
||||||
|
str(error.exception))
|
||||||
|
|
||||||
|
def test_load_model_file_modelWithoutMetadata_throwsException(self):
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
_metadata.MetadataDisplayer.with_model_file(self._empty_model_file)
|
||||||
|
self.assertEqual("The model does not have metadata.", str(error.exception))
|
||||||
|
|
||||||
|
def test_load_model_file_modelWithMetadata(self):
|
||||||
|
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||||
|
self.assertIsInstance(displayer, _metadata.MetadataDisplayer)
|
||||||
|
|
||||||
|
def test_export_metadata_json_file_modelWithMetadata(self):
|
||||||
|
export_dir = self.create_tempdir().full_path
|
||||||
|
|
||||||
|
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||||
|
displayer.export_metadata_json_file(export_dir)
|
||||||
|
|
||||||
|
# Verifies the generated json file.
|
||||||
|
golden_json_file_path = resource_loader.get_path_to_datafile(
|
||||||
|
"testdata/golden_json.json")
|
||||||
|
json_file_path = os.path.join(
|
||||||
|
export_dir,
|
||||||
|
os.path.splitext(os.path.basename(self._model_file))[0] + ".json")
|
||||||
|
with open(json_file_path, "r") as json_file, open(golden_json_file_path,
|
||||||
|
"r") as golden_json_file:
|
||||||
|
json_contents = json_file.read()
|
||||||
|
golden_json_contents = golden_json_file.read()
|
||||||
|
self.assertEqual(json_contents, golden_json_contents)
|
||||||
|
|
||||||
|
def test_get_packed_associated_file_list_modelWithMetadata(self):
|
||||||
|
displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file)
|
||||||
|
packed_files = displayer.get_packed_associated_file_list()
|
||||||
|
|
||||||
|
expected_packed_files = [
|
||||||
|
os.path.basename(self._file1),
|
||||||
|
os.path.basename(self._file2)
|
||||||
|
]
|
||||||
|
self.assertEqual(len(packed_files), 2)
|
||||||
|
self.assertEqual(set(packed_files), set(expected_packed_files))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
21
tensorflow/lite/experimental/support/metadata/testdata/golden_json.json
vendored
Normal file
21
tensorflow/lite/experimental/support/metadata/testdata/golden_json.json
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"name": "Mobilenet_quantized",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "file2"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "file1"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
@ -206,4 +206,12 @@ cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_binary(
|
||||||
|
name = "zip_files",
|
||||||
|
srcs = ["zip_files.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = ["@absl_py//absl:app"],
|
||||||
|
)
|
||||||
|
|
||||||
tflite_portable_test_suite()
|
tflite_portable_test_suite()
|
||||||
|
41
tensorflow/lite/tools/zip_files.py
Normal file
41
tensorflow/lite/tools/zip_files.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
# Lint as: python3
|
||||||
|
"""Creates a zip package of the files passed in."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_string("export_zip_path", None, "Path to zip file.")
|
||||||
|
flags.DEFINE_string("file_directory", None, "Path to the files to be zipped.")
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
with zipfile.ZipFile(FLAGS.export_zip_path, mode="w") as zf:
|
||||||
|
for root, _, files in os.walk(FLAGS.file_directory):
|
||||||
|
for f in files:
|
||||||
|
if f.endswith(".java"):
|
||||||
|
zf.write(os.path.join(root, f))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(main)
|
21
third_party/flatbuffers/BUILD.bazel
vendored
21
third_party/flatbuffers/BUILD.bazel
vendored
@ -1,3 +1,7 @@
|
|||||||
|
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
exports_files(["LICENSE.txt"])
|
exports_files(["LICENSE.txt"])
|
||||||
@ -116,3 +120,20 @@ py_library(
|
|||||||
srcs = [":runtime_py_srcs"],
|
srcs = [":runtime_py_srcs"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "runtime_java_srcs",
|
||||||
|
srcs = glob(["java/com/google/flatbuffers/**/*.java"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
java_library(
|
||||||
|
name = "runtime_java",
|
||||||
|
srcs = [":runtime_java_srcs"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "runtime_android",
|
||||||
|
srcs = [":runtime_java_srcs"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
187
third_party/flatbuffers/build_defs.bzl
vendored
187
third_party/flatbuffers/build_defs.bzl
vendored
@ -1,6 +1,15 @@
|
|||||||
"""BUILD rules for generating flatbuffer files."""
|
"""BUILD rules for generating flatbuffer files."""
|
||||||
|
|
||||||
|
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||||
|
|
||||||
flatc_path = "@flatbuffers//:flatc"
|
flatc_path = "@flatbuffers//:flatc"
|
||||||
|
zip_files = "//tensorflow/lite/tools:zip_files"
|
||||||
|
|
||||||
|
DEFAULT_INCLUDE_PATHS = [
|
||||||
|
"./",
|
||||||
|
"$(GENDIR)",
|
||||||
|
"$(BINDIR)",
|
||||||
|
]
|
||||||
|
|
||||||
DEFAULT_FLATC_ARGS = [
|
DEFAULT_FLATC_ARGS = [
|
||||||
"--no-union-value-namespacing",
|
"--no-union-value-namespacing",
|
||||||
@ -422,3 +431,181 @@ def flatbuffer_py_library(
|
|||||||
"@flatbuffers//:runtime_py",
|
"@flatbuffers//:runtime_py",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def flatbuffer_java_library(
|
||||||
|
name,
|
||||||
|
srcs,
|
||||||
|
custom_package = "",
|
||||||
|
package_prefix = "",
|
||||||
|
include_paths = DEFAULT_INCLUDE_PATHS,
|
||||||
|
flatc_args = DEFAULT_FLATC_ARGS,
|
||||||
|
visibility = None):
|
||||||
|
"""A java library with the generated reader/writers for the given flatbuffer definitions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Rule name. (required)
|
||||||
|
srcs: List of source .fbs files including all includes. (required)
|
||||||
|
custom_package: Package name of generated Java files. If not specified
|
||||||
|
namespace in the schema files will be used. (optional)
|
||||||
|
package_prefix: like custom_package, but prefixes to the existing
|
||||||
|
namespace. (optional)
|
||||||
|
include_paths: List of paths that includes files can be found in. (optional)
|
||||||
|
flatc_args: List of additional arguments to pass to flatc. (optional)
|
||||||
|
visibility: Visibility setting for the java_library rule. (optional)
|
||||||
|
"""
|
||||||
|
out_srcjar = "java_%s_all.srcjar" % name
|
||||||
|
flatbuffer_java_srcjar(
|
||||||
|
name = "%s_srcjar" % name,
|
||||||
|
srcs = srcs,
|
||||||
|
out = out_srcjar,
|
||||||
|
custom_package = custom_package,
|
||||||
|
flatc_args = flatc_args,
|
||||||
|
include_paths = include_paths,
|
||||||
|
package_prefix = package_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
native.filegroup(
|
||||||
|
name = "%s.srcjar" % name,
|
||||||
|
srcs = [out_srcjar],
|
||||||
|
)
|
||||||
|
|
||||||
|
native.java_library(
|
||||||
|
name = name,
|
||||||
|
srcs = [out_srcjar],
|
||||||
|
deps = [
|
||||||
|
"@flatbuffers//:runtime_java",
|
||||||
|
],
|
||||||
|
visibility = visibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
def flatbuffer_java_srcjar(
|
||||||
|
name,
|
||||||
|
srcs,
|
||||||
|
out,
|
||||||
|
custom_package = "",
|
||||||
|
package_prefix = "",
|
||||||
|
include_paths = DEFAULT_INCLUDE_PATHS,
|
||||||
|
flatc_args = DEFAULT_FLATC_ARGS):
|
||||||
|
"""Generate flatbuffer Java source files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Rule name. (required)
|
||||||
|
srcs: List of source .fbs files including all includes. (required)
|
||||||
|
out: Output file name. (required)
|
||||||
|
custom_package: Package name of generated Java files. If not specified
|
||||||
|
namespace in the schema files will be used. (optional)
|
||||||
|
package_prefix: like custom_package, but prefixes to the existing
|
||||||
|
namespace. (optional)
|
||||||
|
include_paths: List of paths that includes files can be found in. (optional)
|
||||||
|
flatc_args: List of additional arguments to pass to flatc. (optional)
|
||||||
|
"""
|
||||||
|
command_fmt = """set -e
|
||||||
|
tmpdir=$(@D)
|
||||||
|
schemas=$$tmpdir/schemas
|
||||||
|
java_root=$$tmpdir/java
|
||||||
|
rm -rf $$schemas
|
||||||
|
rm -rf $$java_root
|
||||||
|
mkdir -p $$schemas
|
||||||
|
mkdir -p $$java_root
|
||||||
|
|
||||||
|
for src in $(SRCS); do
|
||||||
|
dest=$$schemas/$$src
|
||||||
|
rm -rf $$(dirname $$dest)
|
||||||
|
mkdir -p $$(dirname $$dest)
|
||||||
|
if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then
|
||||||
|
cp -f $$src $$dest
|
||||||
|
else
|
||||||
|
if [ -z "{package_prefix}" ]; then
|
||||||
|
sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest
|
||||||
|
else
|
||||||
|
sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
flatc_arg_I="-I $$tmpdir/schemas"
|
||||||
|
for include_path in {include_paths}; do
|
||||||
|
flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path"
|
||||||
|
done
|
||||||
|
|
||||||
|
flatc_additional_args=
|
||||||
|
for arg in {flatc_args}; do
|
||||||
|
flatc_additional_args="$$flatc_additional_args $$arg"
|
||||||
|
done
|
||||||
|
|
||||||
|
for src in $(SRCS); do
|
||||||
|
$(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src
|
||||||
|
done
|
||||||
|
|
||||||
|
$(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root
|
||||||
|
"""
|
||||||
|
genrule_cmd = command_fmt.format(
|
||||||
|
package_name = native.package_name(),
|
||||||
|
custom_package = custom_package,
|
||||||
|
package_prefix = package_prefix,
|
||||||
|
flatc_path = flatc_path,
|
||||||
|
zip_files = zip_files,
|
||||||
|
include_paths = " ".join(include_paths),
|
||||||
|
flatc_args = " ".join(flatc_args),
|
||||||
|
)
|
||||||
|
|
||||||
|
native.genrule(
|
||||||
|
name = name,
|
||||||
|
srcs = srcs,
|
||||||
|
outs = [out],
|
||||||
|
tools = [flatc_path, zip_files],
|
||||||
|
cmd = genrule_cmd,
|
||||||
|
)
|
||||||
|
|
||||||
|
def flatbuffer_android_library(
|
||||||
|
name,
|
||||||
|
srcs,
|
||||||
|
custom_package = "",
|
||||||
|
package_prefix = "",
|
||||||
|
javacopts = None,
|
||||||
|
include_paths = DEFAULT_INCLUDE_PATHS,
|
||||||
|
flatc_args = DEFAULT_FLATC_ARGS,
|
||||||
|
visibility = None):
|
||||||
|
"""An android_library with the generated reader/writers for the given flatbuffer definitions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Rule name. (required)
|
||||||
|
srcs: List of source .fbs files including all includes. (required)
|
||||||
|
custom_package: Package name of generated Java files. If not specified
|
||||||
|
namespace in the schema files will be used. (optional)
|
||||||
|
package_prefix: like custom_package, but prefixes to the existing
|
||||||
|
namespace. (optional)
|
||||||
|
javacopts: List of options to pass to javac.
|
||||||
|
include_paths: List of paths that includes files can be found in. (optional)
|
||||||
|
flatc_args: List of additional arguments to pass to flatc. (optional)
|
||||||
|
visibility: Visibility setting for the android_library rule. (optional)
|
||||||
|
"""
|
||||||
|
out_srcjar = "android_%s_all.srcjar" % name
|
||||||
|
flatbuffer_java_srcjar(
|
||||||
|
name = "%s_srcjar" % name,
|
||||||
|
srcs = srcs,
|
||||||
|
out = out_srcjar,
|
||||||
|
custom_package = custom_package,
|
||||||
|
flatc_args = flatc_args,
|
||||||
|
include_paths = include_paths,
|
||||||
|
package_prefix = package_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
native.filegroup(
|
||||||
|
name = "%s.srcjar" % name,
|
||||||
|
srcs = [out_srcjar],
|
||||||
|
)
|
||||||
|
|
||||||
|
# To support org.checkerframework.dataflow.qual.Pure.
|
||||||
|
checkerframework_annotations = [
|
||||||
|
"@org_checkerframework_qual",
|
||||||
|
] if "--java-checkerframework" in flatc_args else []
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = name,
|
||||||
|
srcs = [out_srcjar],
|
||||||
|
visibility = visibility,
|
||||||
|
deps = [
|
||||||
|
"@flatbuffers//:runtime_android",
|
||||||
|
] + checkerframework_annotations,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user