diff --git a/tensorflow/lite/experimental/support/metadata/BUILD b/tensorflow/lite/experimental/support/metadata/BUILD new file mode 100644 index 00000000000..7478aca3b57 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/BUILD @@ -0,0 +1,87 @@ +load("//tensorflow:tensorflow.bzl", "py_test") +load("@flatbuffers//:build_defs.bzl", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library", "flatbuffer_py_library") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["metadata_schema.fbs"]) + +flatbuffer_py_library( + name = "schema_py", + srcs = ["//tensorflow/lite/schema:schema.fbs"], +) + +# Generic schema for inference on device. +flatbuffer_android_library( + name = "schema_fbs_android", + srcs = ["//tensorflow/lite/schema:schema.fbs"], + custom_package = "org.tensorflow.lite.schema", +) + +flatbuffer_java_library( + name = "schema_fbs_java", + srcs = ["//tensorflow/lite/schema:schema.fbs"], + custom_package = "org.tensorflow.lite.schema", +) + +# Generic schema for model metadata. +flatbuffer_cc_library( + name = "metadata_schema_cc", + srcs = ["metadata_schema.fbs"], +) + +flatbuffer_py_library( + name = "metadata_schema_py", + srcs = ["metadata_schema.fbs"], +) + +flatbuffer_java_library( + name = "metadata_schema_java", + srcs = ["metadata_schema.fbs"], + custom_package = "org.tensorflow.lite.support.metadata.schema", +) + +flatbuffer_android_library( + name = "metadata_schema_fbs_android", + srcs = ["metadata_schema.fbs"], + custom_package = "org.tensorflow.lite.support.metadata.schema", +) + +py_library( + name = "metadata", + srcs = ["metadata.py"], + data = [ + "//tensorflow/lite/experimental/support/metadata:metadata_schema.fbs", + "@flatbuffers//:flatc", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":metadata_schema_py", + ":schema_py", + "//tensorflow/python:platform", + "@flatbuffers//:runtime_py", + ], +) + +py_test( + name = "metadata_test", + srcs = ["metadata_test.py"], + data = ["testdata/golden_json.json"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":metadata", + ":metadata_schema_py", + ":schema_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "@flatbuffers//:runtime_py", + "@six_archive//:six", + ], +) diff --git a/tensorflow/lite/experimental/support/metadata/java/AndroidManifest.xml b/tensorflow/lite/experimental/support/metadata/java/AndroidManifest.xml new file mode 100644 index 00000000000..b2e22628db6 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/java/AndroidManifest.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.support"> + <uses-sdk android:minSdkVersion="19" /> +</manifest> + diff --git a/tensorflow/lite/experimental/support/metadata/java/BUILD b/tensorflow/lite/experimental/support/metadata/java/BUILD new file mode 100644 index 00000000000..fb7a9cd9c65 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/java/BUILD @@ -0,0 +1,36 @@ +# Description: +# TensorFlow Lite Support API in Java for metadata. + +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +android_library( + name = "tensorflow-lite-support-metadata", + srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]), + manifest = "AndroidManifest.xml", + deps = [ + "//tensorflow/lite/experimental/support/java:tensorflow-lite-support-precondition-lib-android", + "//tensorflow/lite/experimental/support/metadata:metadata_schema_fbs_android", + "//tensorflow/lite/experimental/support/metadata:schema_fbs_android", + "//tensorflow/lite/java:tensorflowlite", + "@org_checkerframework_qual", + ], +) + +java_library( + name = "tensorflow-lite-support-metadata-lib", + srcs = glob(["src/java/org/tensorflow/lite/support/metadata/**/*.java"]), + javacopts = JAVACOPTS, + deps = [ + "//tensorflow/lite/experimental/support/java:tensorflow-lite-support-precondition", + "//tensorflow/lite/experimental/support/metadata:metadata_schema_java", + "//tensorflow/lite/experimental/support/metadata:schema_fbs_java", + "//tensorflow/lite/java:tensorflowlitelib", + "@org_checkerframework_qual", + ], +) diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java new file mode 100644 index 00000000000..1dd504cec52 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/BoundedInputStream.java @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.metadata; + +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.SupportPreconditions.checkElementIndex; +import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +/** + * An {@link InputStream} that wraps a section of a {@link SeekableByteChannelCompat}. + * + * <p><b>WARNING:</b> Similar as {@link InputStream}, instances of an {@link BoundedInputStream} are + * <b>not</b> thread-safe. If multiple threads concurrently reading from the same {@link + * BoundedInputStream}, it must be synchronized externally. Also, if multiple instances of {@link + * BoundedInputStream} are created on the same {@link SeekableByteChannelCompat}, it must be + * synchronized as well. + */ +final class BoundedInputStream extends InputStream { + private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1); + private final long end; // The valid data for the stream is between [start, end). + private long position; + private final SeekableByteChannelCompat channel; + + /** + * Creates a {@link BoundedInputStream} with a {@link SeekableByteChannelCompat}. + * + * @param channel the {@link SeekableByteChannelCompat} that backs up this {@link + * BoundedInputStream} + * @param start the starting position of this {@link BoundedInputStream} in the given {@link + * SeekableByteChannelCompat} + * @param remaining the length of this {@link BoundedInputStream} + * @throws IllegalArgumentException if {@code start} or {@code remaining} is negative + */ + BoundedInputStream(SeekableByteChannelCompat channel, long start, long remaining) { + checkArgument( + remaining >= 0 && start >= 0, + String.format("Invalid length of stream at offset=%d, length=%d", start, remaining)); + + end = start + remaining; + this.channel = channel; + position = start; + } + + @Override + public int available() throws IOException { + return (int) (Math.min(end, channel.size()) - position); + } + + @Override + public int read() throws IOException { + if (position >= end) { + return -1; + } + + singleByteBuffer.rewind(); + int count = read(position, singleByteBuffer); + if (count < 0) { + return count; + } + + position++; + return singleByteBuffer.get() & 0xff; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + checkNotNull(b); + checkElementIndex(off, b.length, "The start offset"); + checkElementIndex(len, b.length - off + 1, "The maximumn number of bytes to read"); + + if (len == 0) { + return 0; + } + + if (len > end - position) { + if (position >= end) { + return -1; + } + len = (int) (end - position); + } + + ByteBuffer buf = ByteBuffer.wrap(b, off, len); + int count = read(position, buf); + if (count > 0) { + position += count; + } + return count; + } + + private int read(long position, ByteBuffer buf) throws IOException { + int count; + synchronized (channel) { + channel.position(position); + count = channel.read(buf); + } + buf.flip(); + return count; + } +} diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java new file mode 100644 index 00000000000..9df816c7ff5 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ByteBufferChannel.java @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.metadata; + +import static java.lang.Math.min; +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull; + +import java.nio.ByteBuffer; +import java.nio.channels.NonWritableChannelException; + +/** Implements the {@link SeekableByteChannelCompat} on top of {@link ByteBuffer}. */ +final class ByteBufferChannel implements SeekableByteChannelCompat { + + /** The ByteBuffer that holds the data. */ + private final ByteBuffer buffer; + + /** + * Creates a {@link ByteBufferChannel} that wraps a {@link ByteBuffer}. + * + * @param buffer the {@link ByteBuffer} that backs this {@link ByteBufferChannel} + * @throws NullPointerException if {@code buffer} is null + */ + public ByteBufferChannel(ByteBuffer buffer) { + checkNotNull(buffer, "The ByteBuffer cannot be null."); + this.buffer = buffer; + } + + @Override + public void close() {} + + @Override + public boolean isOpen() { + return true; + } + + @Override + public long position() { + return buffer.position(); + } + + /** + * Sets this channel's position. + * + * @param newPosition the new position, a non-negative integer counting the number of bytes from + * the beginning of the entity + * @return this channel + * @throws IllegalArgumentException if the new position is negative, or greater than the size of + * the underlying {@link ByteBuffer}, or greater than Integer.MAX_VALUE + */ + @Override + public synchronized ByteBufferChannel position(long newPosition) { + checkArgument( + (newPosition >= 0 && newPosition <= Integer.MAX_VALUE), + "The new position should be non-negative and be less than Integer.MAX_VALUE."); + buffer.position((int) newPosition); + return this; + } + + /** + * {@inheritDoc} + * + * <p>Bytes are read starting at this channel's current position, and then the position is updated + * with the number of bytes actually read. Otherwise this method behaves exactly as specified in + * the {@link ReadableByteChannel} interface. + */ + @Override + public synchronized int read(ByteBuffer dst) { + if (buffer.remaining() == 0) { + return -1; + } + + int count = min(dst.remaining(), buffer.remaining()); + if (count > 0) { + ByteBuffer tempBuffer = buffer.slice(); + tempBuffer.order(buffer.order()).limit(count); + dst.put(tempBuffer); + buffer.position(buffer.position() + count); + } + return count; + } + + @Override + public long size() { + return buffer.limit(); + } + + @Override + public synchronized ByteBufferChannel truncate(long size) { + checkArgument( + (size >= 0 && size <= Integer.MAX_VALUE), + "The new size should be non-negative and be less than Integer.MAX_VALUE."); + + if (size < buffer.limit()) { + buffer.limit((int) size); + if (buffer.position() > size) { + buffer.position((int) size); + } + } + return this; + } + + @Override + public synchronized int write(ByteBuffer src) { + if (buffer.isReadOnly()) { + throw new NonWritableChannelException(); + } + + int count = min(src.remaining(), buffer.remaining()); + if (count > 0) { + ByteBuffer tempBuffer = src.slice(); + tempBuffer.order(buffer.order()).limit(count); + buffer.put(tempBuffer); + } + return count; + } +} diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java new file mode 100644 index 00000000000..f22b914f269 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java @@ -0,0 +1,247 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.metadata; + +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.zip.ZipException; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.Tensor.QuantizationParams; +import org.tensorflow.lite.schema.Tensor; +import org.tensorflow.lite.support.metadata.schema.TensorMetadata; + +/** + * Loads metadata from TFLite Model FlatBuffer. + * + * <p>TFLite Model FlatBuffer can be generated using the <a + * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite + * Model schema file.</a> + * + * <p>Some models contain a TFLite Metadata Flatbuffer, which records more information about what + * the model does and how to interprete the model. TFLite Metadata Flatbuffer can be generated using + * the <a + * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/metadata_schema.fbs">TFLite + * Metadata schema file.</a> + * + * <p>It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking methods + * that read from TFLite metadata will cause runtime errors. + * + * <p>Similarly, it is allowed to pass in a model FlatBuffer without associated files. However, + * invoking methods that read the associated files will cause runtime errors. + * + * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports a + * single subgraph so far. See the <a + * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction + * of how to specify subgraph during convertion for more information.</a> Therefore, {@link + * MetadataExtractor} omits subgraph index as an input in its methods. + */ +public class MetadataExtractor { + /** The helper class to load metadata from TFLite model FlatBuffer. */ + private final ModelInfo modelInfo; + + /** The helper class to load metadata from TFLite metadata FlatBuffer. */ + @Nullable private final ModelMetadataInfo metadataInfo; + + /** The handler to load associated files through zip. */ + @Nullable private final ZipFile zipFile; + + /** + * Creates a {@link MetadataExtractor} with TFLite model FlatBuffer. + * + * @param buffer the TFLite model FlatBuffer + * @throws IllegalArgumentException if the number of input or output tensors in the model does not + * match that in the metadata + * @throws IOException if an error occurs while reading the model as a Zip file + */ + public MetadataExtractor(ByteBuffer buffer) throws IOException { + modelInfo = new ModelInfo(buffer); + ByteBuffer metadataBuffer = modelInfo.getMetadataBuffer(); + if (metadataBuffer != null) { + metadataInfo = new ModelMetadataInfo(metadataBuffer); + checkArgument( + modelInfo.getInputTensorCount() == metadataInfo.getInputTensorCount(), + String.format( + "The number of input tensors in the model is %d. The number of input tensors that" + + " recorded in the metadata is %d. These two values does not match.", + modelInfo.getInputTensorCount(), metadataInfo.getInputTensorCount())); + checkArgument( + modelInfo.getOutputTensorCount() == metadataInfo.getOutputTensorCount(), + String.format( + "The number of output tensors in the model is %d. The number of output tensors that" + + " recorded in the metadata is %d. These two values does not match.", + modelInfo.getOutputTensorCount(), metadataInfo.getOutputTensorCount())); + } else { + // It is allowed to pass in a model FlatBuffer without TFLite metadata. However, invoking + // methods that read from TFLite metadata will cause runtime errors. + metadataInfo = null; + } + + zipFile = createZipFile(buffer); + } + + /** + * Gets the packed associated file with the specified {@code fileName}. + * + * @param fileName the name of the associated file + * @return the raw input stream containing specified file + * @throws IllegalStateException if the model is not a zip file + * @throws IllegalArgumentException if the specified file does not exist in the model + */ + public InputStream getAssociatedFile(String fileName) { + assertZipFile(); + return zipFile.getRawInputStream(fileName); + } + + /** Gets the count of input tensors in the model. */ + public int getInputTensorCount() { + return modelInfo.getInputTensorCount(); + } + + /** + * Gets the metadata for the input tensor specified by {@code inputIndex}. + * + * @param inputIndex the index of the desired input tensor + * @throws IllegalStateException if this model does not contain model metadata + */ + @Nullable + public TensorMetadata getInputTensorMetadata(int inputIndex) { + assertMetadataInfo(); + return metadataInfo.getInputTensorMetadata(inputIndex); + } + + /** + * Gets the quantization parameters for the input tensor specified by {@code inputIndex}. + * + * @param inputIndex the index of the desired input tensor + */ + public QuantizationParams getInputTensorQuantizationParams(int inputIndex) { + Tensor tensor = modelInfo.getInputTensor(inputIndex); + return modelInfo.getQuantizationParams(tensor); + } + + /** + * Gets the shape of the input tensor with {@code inputIndex}. + * + * @param inputIndex the index of the desired input tensor + */ + public int[] getInputTensorShape(int inputIndex) { + return modelInfo.getInputTensorShape(inputIndex); + } + + /** + * Gets the {@link DataType} of the input tensor with {@code inputIndex}. + * + * @param inputIndex the index of the desired input tensor + */ + public DataType getInputTensorType(int inputIndex) { + return modelInfo.getInputTensorType(inputIndex); + } + + /** Gets the count of output tensors in the model. */ + public int getOutputTensorCount() { + return modelInfo.getOutputTensorCount(); + } + + /** + * Gets the metadata for the output tensor specified by {@code outputIndex}. + * + * @param outputIndex the index of the desired output tensor + * @throws IllegalStateException if this model does not contain model metadata + */ + @Nullable + public TensorMetadata getOutputTensorMetadata(int outputIndex) { + assertMetadataInfo(); + return metadataInfo.getOutputTensorMetadata(outputIndex); + } + + /** + * Gets the quantization parameters for the output tensor specified by {@code outputIndex}. + * + * @param outputIndex the index of the desired output tensor + */ + public QuantizationParams getOutputTensorQuantizationParams(int outputIndex) { + Tensor tensor = modelInfo.getOutputTensor(outputIndex); + return modelInfo.getQuantizationParams(tensor); + } + + /** + * Gets the shape of the output tensor with {@code outputIndex}. + * + * @param outputIndex the index of the desired output tensor + */ + public int[] getOutputTensorShape(int outputIndex) { + return modelInfo.getOutputTensorShape(outputIndex); + } + + /** + * Gets the {@link DataType} of the output tensor with {@code outputIndex}. + * + * @param outputIndex the index of the desired output tensor + */ + public DataType getOutputTensorType(int outputIndex) { + return modelInfo.getOutputTensorType(outputIndex); + } + + /** + * Asserts if {@link metdadataInfo} is not initialized. Some models may not have metadata and this + * is allowed. However, invoking methods that reads the metadata is not allowed. + * + * @throws IllegalStateException if this model does not contain model metadata + */ + private void assertMetadataInfo() { + if (metadataInfo == null) { + throw new IllegalStateException("This model does not contain model metadata."); + } + } + + /** + * Asserts if {@link #zipFile} is not initialized. Some models may not have associated files, thus + * are not Zip files. This is allowed. However, invoking methods that reads those associated files + * is not allowed. + * + * @throws IllegalStateException if this model is not a Zip file + */ + private void assertZipFile() { + if (zipFile == null) { + throw new IllegalStateException( + "This model does not contain associated files, and is not a Zip file."); + } + } + + /** + * Creates a Zip file handler to read the associated files. If the model is not a zip file, i.e. + * it does not have associated files, return a null handler. + * + * @param buffer the TFLite model FlatBuffer + * @throws IOException if an error occurs while reading the model as a Zip file + */ + @Nullable + private static ZipFile createZipFile(ByteBuffer buffer) throws IOException { + try { + // Creates the handler to hold the associated files through the Zip. + ByteBufferChannel byteBufferChannel = new ByteBufferChannel(buffer); + return ZipFile.createFrom(byteBufferChannel); + } catch (ZipException e) { + // Some models may not have associate files. Therefore, Those models are not zip files. + // However, invoking methods that read associated files later will lead into errors. + return null; + } + } +} diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java new file mode 100644 index 00000000000..f767bf8afe8 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java @@ -0,0 +1,281 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.metadata; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.Tensor.QuantizationParams; +import org.tensorflow.lite.schema.Buffer; +import org.tensorflow.lite.schema.Metadata; +import org.tensorflow.lite.schema.Model; +import org.tensorflow.lite.schema.QuantizationParameters; +import org.tensorflow.lite.schema.SubGraph; +import org.tensorflow.lite.schema.Tensor; +import org.tensorflow.lite.schema.TensorType; +import org.tensorflow.lite.support.common.SupportPreconditions; + +/** Extracts model information out of TFLite model FLatBuffer. */ +final class ModelInfo { + /** The model that is loaded from TFLite model FlatBuffer. */ + private final Model model; + + /** A list of input tensors. */ + private final List</* @Nullable */ Tensor> inputTensors; + + /** A list of output tensors. */ + private final List</* @Nullable */ Tensor> outputTensors; + + /** Identifier of the TFLite model metadata in the Metadata array. */ + static final String METADATA_FIELD_NAME = "TFLITE_METADATA"; + + /** Maps from TensorType in TFlite FlatBuffer to {@link DataType} in Java. */ + private final Map<Byte, DataType> tensorTypeToDataTypeMap; + + /** + * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}. + * + * <p>Though TFLite model FlatBuffer supports multiple subgraphs, TFLite Interpreter only supports + * single subgraph so far. See the <a + * href="https://www.tensorflow.org/lite/convert/cmdline_examples#specifying_subgraphs">instruction + * of how to specify subgraph during convertion for more information.</a> Therefore, all methods + * in {@link ModelInfo} retrieves metadata of the first subgrpah as default. + * + * @param buffer The TFLite model FlatBuffer. + * @throws NullPointerException if {@code buffer} is null. + * @throws IllegalArgumentException if the model does not contain any subgraph. + */ + ModelInfo(ByteBuffer buffer) { + SupportPreconditions.checkNotNull(buffer, "Model flatbuffer cannot be null."); + + model = Model.getRootAsModel(buffer); + SupportPreconditions.checkArgument( + model.subgraphsLength() > 0, "The model does not contain any subgraph."); + + inputTensors = getInputTensors(model); + outputTensors = getOutputTensors(model); + tensorTypeToDataTypeMap = createTensorTypeToDataTypeMap(); + } + + /** + * Gets the input tensor with {@code inputIndex}. + * + * @param inputIndex The index of the desired input tensor. + * @throws IllegalArgumentException if the inputIndex specified is invalid. + */ + @Nullable + Tensor getInputTensor(int inputIndex) { + SupportPreconditions.checkArgument( + inputIndex >= 0 && inputIndex < inputTensors.size(), + "The inputIndex specified is invalid."); + return inputTensors.get(inputIndex); + } + + int getInputTensorCount() { + return inputTensors.size(); + } + + /** + * Gets shape of the input tensor with {@code inputIndex}. + * + * @param inputIndex The index of the desired intput tensor. + */ + int[] getInputTensorShape(int inputIndex) { + Tensor tensor = getInputTensor(inputIndex); + return getShape(tensor); + } + + /** + * Gets {@link DataType} of the input tensor with {@code inputIndex}. + * + * @param inputIndex The index of the desired intput tensor. + */ + DataType getInputTensorType(int inputIndex) { + Tensor tensor = getInputTensor(inputIndex); + return getDataType(tensor.type()); + } + + /** Gets the metadata FlatBuffer from the model FlatBuffer. */ + @Nullable + ByteBuffer getMetadataBuffer() { + // Some models may not have metadata, and this is allowed. + if (model.metadataLength() == 0) { + return null; + } + + for (int i = 0; i < model.metadataLength(); i++) { + Metadata meta = model.metadata(i); + if (METADATA_FIELD_NAME.equals(meta.name())) { + long bufferIndex = meta.buffer(); + Buffer metadataBuf = model.buffers((int) bufferIndex); + return metadataBuf.dataAsByteBuffer(); + } + } + return null; + } + + /** + * Gets the output tensor with {@code outputIndex}. + * + * @param outputIndex The index of the desired outtput tensor. + * @throws IllegalArgumentException if the outputIndex specified is invalid. + */ + @Nullable + Tensor getOutputTensor(int outputIndex) { + SupportPreconditions.checkArgument( + outputIndex >= 0 && outputIndex < outputTensors.size(), + "The outputIndex specified is invalid."); + return outputTensors.get(outputIndex); + } + + int getOutputTensorCount() { + return outputTensors.size(); + } + + /** + * Gets shape of the output tensor with {@code outputIndex}. + * + * @param outputIndex The index of the desired outtput tensor. + */ + int[] getOutputTensorShape(int outputIndex) { + Tensor tensor = getOutputTensor(outputIndex); + return getShape(tensor); + } + + /** + * Gets {@link DataType} of the output tensor {@code outputIndex}. + * + * @param outputIndex The index of the desired outtput tensor. + */ + DataType getOutputTensorType(int outputIndex) { + Tensor tensor = getOutputTensor(outputIndex); + return getDataType(tensor.type()); + } + + private static Map<Byte, DataType> createTensorTypeToDataTypeMap() { + Map<Byte, DataType> map = new HashMap<>(); + map.put(TensorType.FLOAT32, DataType.FLOAT32); + map.put(TensorType.INT32, DataType.INT32); + map.put(TensorType.UINT8, DataType.UINT8); + map.put(TensorType.INT64, DataType.INT64); + map.put(TensorType.STRING, DataType.STRING); + return Collections.unmodifiableMap(map); + } + + /** + * Gets the quantization parameters of a tensor. + * + * <p>Only quantized tensors have valid {@code QuantizationParameters}. For tensor that are not + * quantized, the values of scale and zero_point are both 0. + * + * @param tensor The tensor whoes quantization parameters is desired. + * @throws NullPointerException if the tensor is null. + * @throws IllegalArgumentException if {@code scale} and {@code zeroPoint} of the tensor's {@link + * QuantizationParameters} are not single values. + */ + QuantizationParams getQuantizationParams(Tensor tensor) { + SupportPreconditions.checkNotNull(tensor, "Tensor cannot be null."); + + float scale; + int zeroPoint; + QuantizationParameters quantization = tensor.quantization(); + + // Tensors that are not quantized do not have quantization parameters, which can be null when + // being extracted from the flatbuffer. + if (quantization == null) { + scale = 0.0f; + zeroPoint = 0; + return new QuantizationParams(scale, zeroPoint); + } + + // Tensors that are not quantized do not have quantization parameters. + // quantization.scaleLength() and quantization.zeroPointLength() may both return 0. + SupportPreconditions.checkArgument( + quantization.scaleLength() <= 1, + "Input and output tensors do not support per-channel quantization."); + SupportPreconditions.checkArgument( + quantization.zeroPointLength() <= 1, + "Input and output tensors do not support per-channel quantization."); + + // For tensors that are not quantized, quantization.scale(0) and quantization.zeroPoint(0) will + // both be the default value in flatbuffer, 0. This behavior is consistent with the TFlite C++ + // runtime. + scale = quantization.scale(0); + // zeroPoint is a long value in the schema, but an integer in the C++ runtime. Here we keep it + // consistent with the C++ runtime. + zeroPoint = (int) quantization.zeroPoint(0); + + return new QuantizationParams(scale, zeroPoint); + } + + /** + * Transforms from TensorType in TFlite FlatBuffer to {@link DataType} in Java. + * + * @param tensorType The tensor type to be converted. + * @throws IllegalArgumentException if the tensor type is not supported. + */ + private DataType getDataType(byte tensorType) { + SupportPreconditions.checkArgument( + tensorTypeToDataTypeMap.containsKey(tensorType), + String.format("Tensor type %d is not supported.", tensorType)); + return tensorTypeToDataTypeMap.get(tensorType); + } + + /** + * Gets the shape of a tensor. + * + * @param tensor The tensor whoes shape is desired. + * @throws NullPointerException if the tensor is null. + */ + private static int[] getShape(Tensor tensor) { + SupportPreconditions.checkNotNull(tensor, "Tensor cannot be null."); + int shapeDim = tensor.shapeLength(); + int[] tensorShape = new int[shapeDim]; + for (int i = 0; i < shapeDim; i++) { + tensorShape[i] = tensor.shape(i); + } + return tensorShape; + } + + /** Gets input tensors from a model. */ + private static List<Tensor> getInputTensors(Model model) { + // TFLite only support one subgraph currently. + SubGraph subgraph = model.subgraphs(0); + int tensorNum = subgraph.inputsLength(); + ArrayList<Tensor> inputTensors = new ArrayList<>(tensorNum); + for (int i = 0; i < tensorNum; i++) { + inputTensors.add(subgraph.tensors(subgraph.inputs(i))); + } + return Collections.unmodifiableList(inputTensors); + } + + /** Gets output tensors from a model. */ + private static List<Tensor> getOutputTensors(Model model) { + // TFLite only support one subgraph currently. + SubGraph subgraph = model.subgraphs(0); + int tensorNum = subgraph.outputsLength(); + ArrayList<Tensor> outputTensors = new ArrayList<>(tensorNum); + for (int i = 0; i < tensorNum; i++) { + outputTensors.add(subgraph.tensors(subgraph.outputs(i))); + } + return Collections.unmodifiableList(outputTensors); + } +} diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java new file mode 100644 index 00000000000..a99150c7fd9 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelMetadataInfo.java @@ -0,0 +1,114 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.metadata; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.metadata.schema.ModelMetadata; +import org.tensorflow.lite.support.metadata.schema.SubGraphMetadata; +import org.tensorflow.lite.support.metadata.schema.TensorMetadata; + +/** Extracts model metadata information out of TFLite metadata FlatBuffer. */ +final class ModelMetadataInfo { + /** Metadata array of input tensors. */ + private final List</* @Nullable */ TensorMetadata> inputsMetadata; + + /** Metadata array of output tensors. */ + private final List</* @Nullable */ TensorMetadata> outputsMetadata; + + /** + * Creates a {@link ModelMetadataInfo} with the metadata FlatBuffer, {@code buffer}. + * + * @param buffer The TFLite metadata FlatBuffer. + * @throws NullPointerException if {@code buffer} is null. + * @throws IllegalArgumentException if the metadata does not contain any subgraph metadata. + */ + ModelMetadataInfo(ByteBuffer buffer) { + SupportPreconditions.checkNotNull(buffer, "Metadata flatbuffer cannot be null."); + + ModelMetadata modelMetadata = ModelMetadata.getRootAsModelMetadata(buffer); + SupportPreconditions.checkArgument( + modelMetadata.subgraphMetadataLength() > 0, + "The metadata flatbuffer does not contain any subgraph metadata."); + + inputsMetadata = getInputsMetadata(modelMetadata); + outputsMetadata = getOutputsMetadata(modelMetadata); + } + + /** Gets the count of input tensors with metadata in the metadata FlatBuffer. */ + int getInputTensorCount() { + return inputsMetadata.size(); + } + + /** + * Gets the metadata for the input tensor specified by {@code inputIndex}. + * + * @param inputIndex The index of the desired intput tensor. + * @throws IllegalArgumentException if the inputIndex specified is invalid. + */ + @Nullable + TensorMetadata getInputTensorMetadata(int inputIndex) { + SupportPreconditions.checkArgument( + inputIndex >= 0 && inputIndex < inputsMetadata.size(), + "The inputIndex specified is invalid."); + return inputsMetadata.get(inputIndex); + } + + /** Gets the count of output tensors with metadata in the metadata FlatBuffer. */ + int getOutputTensorCount() { + return outputsMetadata.size(); + } + + /** + * Gets the metadata for the output tensor specified by {@code outputIndex}. + * + * @param outputIndex The index of the desired output tensor. + * @throws IllegalArgumentException if the outputIndex specified is invalid. + */ + @Nullable + TensorMetadata getOutputTensorMetadata(int outputIndex) { + SupportPreconditions.checkArgument( + outputIndex >= 0 && outputIndex < outputsMetadata.size(), + "The outputIndex specified is invalid."); + return outputsMetadata.get(outputIndex); + } + + /** Gets metadata for all input tensors. */ + private static List<TensorMetadata> getInputsMetadata(ModelMetadata modelMetadata) { + SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0); + int tensorNum = subgraphMetadata.inputTensorMetadataLength(); + ArrayList<TensorMetadata> inputsMetadata = new ArrayList<>(tensorNum); + for (int i = 0; i < tensorNum; i++) { + inputsMetadata.add(subgraphMetadata.inputTensorMetadata(i)); + } + return Collections.unmodifiableList(inputsMetadata); + } + + /** Gets metadata for all output tensors. */ + private static List<TensorMetadata> getOutputsMetadata(ModelMetadata modelMetadata) { + SubGraphMetadata subgraphMetadata = modelMetadata.subgraphMetadata(0); + int tensorNum = subgraphMetadata.outputTensorMetadataLength(); + ArrayList<TensorMetadata> outputsMetadata = new ArrayList<>(tensorNum); + for (int i = 0; i < tensorNum; i++) { + outputsMetadata.add(subgraphMetadata.outputTensorMetadata(i)); + } + return Collections.unmodifiableList(outputsMetadata); + } +} diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java new file mode 100644 index 00000000000..c655786755b --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/SeekableByteChannelCompat.java @@ -0,0 +1,107 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.metadata; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channel; + +/** + * A byte channel that maintains a current <i>position</i> and allows the position to be changed. + * {@link SeekableByteChannelCompat} is compatible with {@link + * java.nio.channels.SeekableByteChannel}. + * + * <p>{@link java.nio.channels.SeekableByteChannel} is not available in Android API 23 and under. + * Therefore, {@link SeekableByteChannelCompat} is introduced here to make the interfaces used in + * the MetadtaExtractor library consistent with the common used Java libraries. + */ +interface SeekableByteChannelCompat extends Channel { + /** + * Reads a sequence of bytes from this channel into the given buffer. + * + * @param dst The buffer into which bytes are to be transferred + * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached + * end-of-stream + * @throws NonReadableChannelException If this channel was not opened for reading + * @throws ClosedChannelException If this channel is closed + * @throws AsynchronousCloseException If another thread closes this channel while the read + * operation is in progress + * @throws ClosedByInterruptException If another thread interrupts the current thread while the + * read operation is in progress, thereby closing the channel and setting the current thread's + * interrupt status + * @throws IOException If some other I/O error occurs + */ + int read(ByteBuffer dst) throws IOException; + + /** + * Writes a sequence of bytes to this channel from the given buffer. + * + * @param src The buffer from which bytes are to be retrieved + * @return The number of bytes written, possibly zero + * @throws NonWritableChannelException If this channel was not opened for writing + * @throws ClosedChannelException If this channel is closed + * @throws AsynchronousCloseException If another thread closes this channel while the write + * operation is in progress + * @throws ClosedByInterruptException If another thread interrupts the current thread while the + * write operation is in progress, thereby closing the channel and setting the current + * thread's interrupt status + * @throws IOException If some other I/O error occurs + */ + int write(ByteBuffer src) throws IOException; + + /** + * Returns this channel's position. + * + * @return This channel's position, a non-negative integer counting the number of bytes from the + * beginning of the entity to the current position + * @throws ClosedChannelException If this channel is closed + * @throws IOException If some other I/O error occurs + */ + long position() throws IOException; + + /** + * Sets this channel's position. + * + * @param newPosition The new position, a non-negative integer counting the number of bytes from + * the beginning of the entity + * @return This channel + * @throws ClosedChannelException If this channel is closed + * @throws IllegalArgumentException If the new position is negative + * @throws IOException If some other I/O error occurs + */ + SeekableByteChannelCompat position(long newPosition) throws IOException; + + /** + * Returns the current size of entity to which this channel is connected. + * + * @return The current size, measured in bytes + * @throws ClosedChannelException If this channel is closed + * @throws IOException If some other I/O error occurs + */ + long size() throws IOException; + + /** + * Truncates the entity, to which this channel is connected, to the given size. + * + * @param size The new size, a non-negative byte count + * @return This channel + * @throws NonWritableChannelException If this channel was not opened for writing + * @throws ClosedChannelException If this channel is closed + * @throws IllegalArgumentException If the new size is negative + * @throws IOException If some other I/O error occurs + */ + SeekableByteChannelCompat truncate(long size) throws IOException; +} diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java new file mode 100644 index 00000000000..b58e5ab29c1 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ZipFile.java @@ -0,0 +1,427 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.metadata; + +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.zip.ZipException; + +/** + * Reads uncompressed files from the TFLite model, a zip file. + * + * <p>TODO(b/150237111): add a link to the webpage of MetadataPopulator once it's available. + * + * <p>A TFLite model file becomes a zip file when it contains associated files. The associated files + * can be packed to a TFLite model file using the MetadataPopulator. The associated files are not + * compressed when being added to the model file. + * + * <p>{@link ZipFile} does not support Zip64 format, because TFLite models are much smaller than the + * size limit for Zip64, which is 4GB. + */ +final class ZipFile implements Closeable { + /** Maps String to list of ZipEntrys, name -> actual entries. */ + private final Map<String, List<ZipEntry>> nameMap; + + /** The actual data source. */ + private final ByteBufferChannel archive; + + /** + * Opens the given {@link ByteBufferChannel} for reading, assuming "UTF8" for file names. {@link + * ZipFile} does not synchronized over the buffer that is passed into it. + * + * @param channel the archive + * @throws IOException if an error occurs while creating this {@link ZipFile} + * @throws ZipException if the channel is not a zip archive + * @throws NullPointerException if the archive is null + */ + public static ZipFile createFrom(ByteBufferChannel channel) throws IOException { + checkNotNull(channel); + ZipParser zipParser = new ZipParser(channel); + Map<String, List<ZipEntry>> nameMap = zipParser.parseEntries(); + return new ZipFile(channel, nameMap); + } + + @Override + public void close() { + archive.close(); + } + + /** + * Exposes the raw stream of the archive entry. + * + * <p>Since the associated files will not be compressed when being packed to the zip file, the raw + * stream represents the non-compressed files. + * + * <p><b>WARNING:</b> The returned {@link InputStream}, is <b>not</b> thread-safe. If multiple + * threads concurrently reading from the returned {@link InputStream}, it must be synchronized + * externally. + * + * @param name name of the entry to get the stream for + * @return the raw input stream containing data + * @throws IllegalArgumentException if the specified file does not exist in the zip file + */ + public InputStream getRawInputStream(String name) { + checkArgument( + nameMap.containsKey(name), + String.format("The file, %s, does not exist in the zip file.", name)); + + List<ZipEntry> entriesWithTheSameName = nameMap.get(name); + ZipEntry entry = entriesWithTheSameName.get(0); + long start = entry.getDataOffset(); + long remaining = entry.getSize(); + return new BoundedInputStream(archive, start, remaining); + } + + private ZipFile(ByteBufferChannel channel, Map<String, List<ZipEntry>> nameMap) { + archive = channel; + this.nameMap = nameMap; + } + + /* Parses a Zip archive and gets the information for each {@link ZipEntry}. */ + private static class ZipParser { + private final ByteBufferChannel archive; + + // Cached buffers that will only be used locally in the class to reduce garbage collection. + private final ByteBuffer longBuffer = + ByteBuffer.allocate(ZipConstants.LONG_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN); + private final ByteBuffer intBuffer = + ByteBuffer.allocate(ZipConstants.INT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN); + private final ByteBuffer shortBuffer = + ByteBuffer.allocate(ZipConstants.SHORT_BYTE_SIZE).order(ByteOrder.LITTLE_ENDIAN); + + private ZipParser(ByteBufferChannel archive) { + this.archive = archive; + } + + /** + * Parses the underlying {@code archive} and returns the information as a list of {@link + * ZipEntry}. + */ + private Map<String, List<ZipEntry>> parseEntries() throws IOException { + List<ZipEntry> entries = parseCentralDirectory(); + return parseLocalFileHeaderData(entries); + } + + /** + * Checks if the current position contains a central file header signature, {@link + * ZipConstants#CENSIG}. + */ + private boolean foundCentralFileheaderSignature() { + long signature = (long) getInt(); + return signature == ZipConstants.CENSIG; + } + + /** + * Gets the value as a Java int from two bytes starting at the current position of the archive. + */ + private int getShort() { + shortBuffer.rewind(); + archive.read(shortBuffer); + shortBuffer.flip(); + return (int) shortBuffer.getShort(); + } + + /** + * Gets the value as a Java long from four bytes starting at the current position of the + * archive. + */ + private int getInt() { + intBuffer.rewind(); + archive.read(intBuffer); + intBuffer.flip(); + return intBuffer.getInt(); + } + + /** + * Gets the value as a Java long from four bytes starting at the current position of the + * archive. + */ + private long getLong() { + longBuffer.rewind(); + archive.read(longBuffer); + longBuffer.flip(); + return longBuffer.getLong(); + } + + /** + * Positions the archive at the start of the central directory. + * + * <p>First, it searches for the signature of the "end of central directory record", {@link + * ZipConstants#ENDSIG}. Position the stream at the start of the "end of central directory + * record". The zip file are created without archive comments, thus {@link ZipConstants#ENDSIG} + * should appear exactly at {@link ZipConstants#ENDHDR} from the end of the zip file. + * + * <p>Then, parse the "end of central dir record" and position the archive at the start of the + * central directory. + */ + private void locateCentralDirectory() throws IOException { + if (archive.size() < ZipConstants.ENDHDR) { + throw new ZipException("The archive is not a ZIP archive."); + } + + // Positions the archive at the start of the "end of central directory record". + long offsetRecord = archive.size() - ZipConstants.ENDHDR; + archive.position(offsetRecord); + + // Checks for the signature, {@link ZipConstants#ENDSIG}. + long endSig = getLong(); + if (endSig != ZipConstants.ENDSIG) { + throw new ZipException("The archive is not a ZIP archive."); + } + + // Positions the archive at the “offset of central directory”. + skipBytes(ZipConstants.ENDOFF - ZipConstants.ENDSUB); + // Gets the offset to central directory + long offsetDirectory = getInt(); + // Goes to the central directory. + archive.position(offsetDirectory); + } + + /** + * Reads the central directory of the given archive and populates the internal tables with + * {@link ZipEntry} instances. + */ + private List<ZipEntry> parseCentralDirectory() throws IOException { + /** List of entries in the order they appear inside the central directory. */ + List<ZipEntry> entries = new ArrayList<>(); + locateCentralDirectory(); + + while (foundCentralFileheaderSignature()) { + ZipEntry entry = parseCentralDirectoryEntry(); + entries.add(entry); + } + + return entries; + } + + /** + * Reads an individual entry of the central directory, creats an ZipEntry from it and adds it to + * the global maps. + */ + private ZipEntry parseCentralDirectoryEntry() throws IOException { + // Positions the archive at the "compressed size" and read the value. + skipBytes(ZipConstants.CENSIZ - ZipConstants.CENVEM); + long compressSize = getInt(); + + // Positions the archive at the "filename length" and read the value. + skipBytes(ZipConstants.CENNAM - ZipConstants.CENLEN); + int fileNameLen = getShort(); + + // Reads the extra field length and the comment length. + int extraLen = getShort(); + int commentLen = getShort(); + + // Positions the archive at the "local file header offset" and read the value. + skipBytes(ZipConstants.CENOFF - ZipConstants.CENDSK); + long localHeaderOffset = getInt(); + + // Reads the file name. + byte[] fileNameBuf = new byte[fileNameLen]; + archive.read(ByteBuffer.wrap(fileNameBuf)); + String fileName = new String(fileNameBuf, Charset.forName("UTF-8")); + + // Skips the extra field and the comment. + skipBytes(extraLen + commentLen); + + ZipEntry entry = new ZipEntry(); + entry.setSize(compressSize); + entry.setLocalHeaderOffset(localHeaderOffset); + entry.setName(fileName); + + return entry; + } + + /** Walks through all recorded entries and records the offsets for the entry data. */ + private Map<String, List<ZipEntry>> parseLocalFileHeaderData(List<ZipEntry> entries) { + /** Maps String to list of ZipEntrys, name -> actual entries. */ + Map<String, List<ZipEntry>> nameMap = new LinkedHashMap<>(); + + for (ZipEntry entry : entries) { + long offset = entry.getLocalHeaderOffset(); + archive.position(offset + ZipConstants.LOCNAM); + + // Gets the data offset of this entry. + int fileNameLen = getShort(); + int extraFieldLen = getShort(); + long dataOffset = + offset + + ZipConstants.LOCEXT + + ZipConstants.SHORT_BYTE_SIZE + + fileNameLen + + extraFieldLen; + entry.setDataOffset(dataOffset); + + // Puts the entry into the nameMap. + String name = entry.getName(); + List<ZipEntry> entriesWithTheSameName; + if (nameMap.containsKey(name)) { + entriesWithTheSameName = nameMap.get(name); + } else { + entriesWithTheSameName = new ArrayList<>(); + nameMap.put(name, entriesWithTheSameName); + } + entriesWithTheSameName.add(entry); + } + + return nameMap; + } + + /** Skips the given number of bytes or throws an EOFException if skipping failed. */ + private void skipBytes(int count) throws IOException { + long currentPosition = archive.position(); + long newPosition = currentPosition + count; + if (newPosition > archive.size()) { + throw new EOFException(); + } + archive.position(newPosition); + } + } + + /** Stores the data offset and the size of an entry in the archive. */ + private static class ZipEntry { + + private String name; + private long dataOffset = -1; + private long size = -1; + private long localHeaderOffset = -1; + + public long getSize() { + return size; + } + + public long getDataOffset() { + return dataOffset; + } + + public String getName() { + return name; + } + + public long getLocalHeaderOffset() { + return localHeaderOffset; + } + + public void setSize(long size) { + this.size = size; + } + + public void setDataOffset(long dataOffset) { + this.dataOffset = dataOffset; + } + + public void setName(String name) { + this.name = name; + } + + public void setLocalHeaderOffset(long localHeaderOffset) { + this.localHeaderOffset = localHeaderOffset; + } + } + + /** + * Various constants for this {@link ZipFile}. + * + * <p>Referenced from {@link java.util.zip.ZipConstants}. + */ + private static class ZipConstants { + /** length of Java short in bytes. */ + static final int SHORT_BYTE_SIZE = Short.SIZE / 8; + + /** length of Java int in bytes. */ + static final int INT_BYTE_SIZE = Integer.SIZE / 8; + + /** length of Java long in bytes. */ + static final int LONG_BYTE_SIZE = Long.SIZE / 8; + + /* + * Header signatures + */ + static final long LOCSIG = 0x04034b50L; // "PK\003\004" + static final long EXTSIG = 0x08074b50L; // "PK\007\008" + static final long CENSIG = 0x02014b50L; // "PK\001\002" + static final long ENDSIG = 0x06054b50L; // "PK\005\006" + + /* + * Header sizes in bytes (including signatures) + */ + static final int LOCHDR = 30; // LOC header size + static final int EXTHDR = 16; // EXT header size + static final int CENHDR = 46; // CEN header size + static final int ENDHDR = 22; // END header size + + /* + * Local file (LOC) header field offsets + */ + static final int LOCVER = 4; // version needed to extract + static final int LOCFLG = 6; // general purpose bit flag + static final int LOCHOW = 8; // compression method + static final int LOCTIM = 10; // modification time + static final int LOCCRC = 14; // uncompressed file crc-32 value + static final int LOCSIZ = 18; // compressed size + static final int LOCLEN = 22; // uncompressed size + static final int LOCNAM = 26; // filename length + static final int LOCEXT = 28; // extra field length + + /* + * Extra local (EXT) header field offsets + */ + static final int EXTCRC = 4; // uncompressed file crc-32 value + static final int EXTSIZ = 8; // compressed size + static final int EXTLEN = 12; // uncompressed size + + /* + * Central directory (CEN) header field offsets + */ + static final int CENVEM = 4; // version made by + static final int CENVER = 6; // version needed to extract + static final int CENFLG = 8; // encrypt, decrypt flags + static final int CENHOW = 10; // compression method + static final int CENTIM = 12; // modification time + static final int CENCRC = 16; // uncompressed file crc-32 value + static final int CENSIZ = 20; // compressed size + static final int CENLEN = 24; // uncompressed size + static final int CENNAM = 28; // filename length + static final int CENEXT = 30; // extra field length + static final int CENCOM = 32; // comment length + static final int CENDSK = 34; // disk number start + static final int CENATT = 36; // internal file attributes + static final int CENATX = 38; // external file attributes + static final int CENOFF = 42; // LOC header offset + + /* + * End of central directory (END) header field offsets + */ + static final int ENDSUB = 8; // number of entries on this disk + static final int ENDTOT = 10; // total number of entries + static final int ENDSIZ = 12; // central directory size in bytes + static final int ENDOFF = 16; // offset of first CEN header + static final int ENDCOM = 20; // zip file comment length + + private ZipConstants() {} + } +} diff --git a/tensorflow/lite/experimental/support/metadata/metadata.py b/tensorflow/lite/experimental/support/metadata/metadata.py new file mode 100644 index 00000000000..042e2b222c7 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/metadata.py @@ -0,0 +1,542 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite metadata tools.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import os +import shutil +import subprocess +import tempfile +import warnings +import zipfile + +from flatbuffers.python import flatbuffers +from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb +from tensorflow.python.platform import resource_loader + +_FLATC_BINARY_PATH = resource_loader.get_path_to_datafile( + "../../../../../external/flatbuffers/flatc") +_FLATC_TFLITE_METADATA_SCHEMA_FILE = resource_loader.get_path_to_datafile( + "metadata_schema.fbs") + + +# TODO(b/141467403): add delete method for associated files. +class MetadataPopulator(object): + """Packs metadata and associated files into TensorFlow Lite model file. + + MetadataPopulator can be used to populate metadata and model associated files + into a model file or a model buffer (in bytearray). It can also help to + inspect list of files that have been packed into the model or are supposed to + be packed into the model. + + The metadata file (or buffer) should be generated based on the metadata + schema: + third_party/tensorflow/lite/schema/metadata_schema.fbs + + Example usage: + Populate matadata and label file into an image classifier model. + + First, based on metadata_schema.fbs, generate the metadata for this image + classifer model using Flatbuffers API. Attach the label file onto the ouput + tensor (the tensor of probabilities) in the metadata. + + Then, pack the metadata and lable file into the model as follows. + + ```python + # Populating a metadata file (or a metadta buffer) and associated files to + a model file: + populator = MetadataPopulator.with_model_file(model_file) + # For metadata buffer (bytearray read from the metadata file), use: + # populator.load_metadata_buffer(metadata_buf) + populator.load_metadata_file(metadata_file) + populator.load_associated_files([label.txt]) + populator.populate() + + # Populating a metadata file (or a metadta buffer) and associated files to + a model buffer: + populator = MetadataPopulator.with_model_buffer(model_buf) + populator.load_metadata_file(metadata_file) + populator.load_associated_files([label.txt]) + populator.populate() + # Writing the updated model buffer into a file. + updated_model_buf = populator.get_model_buffer() + with open("updated_model.tflite", "wb") as f: + f.write(updated_model_buf) + ``` + """ + # As Zip API is used to concatenate associated files after tflite model file, + # the populating operation is developed based on a model file. For in-memory + # model buffer, we create a tempfile to serve the populating operation. + # Creating the deleting such a tempfile is handled by the class, + # _MetadataPopulatorWithBuffer. + + METADATA_FIELD_NAME = "TFLITE_METADATA" + TFLITE_FILE_IDENTIFIER = b"TFL3" + METADATA_FILE_IDENTIFIER = b"M001" + + def __init__(self, model_file): + """Constructor for MetadataPopulator. + + Args: + model_file: valid path to a TensorFlow Lite model file. + + Raises: + IOError: File not found. + """ + _assert_exist(model_file) + self._model_file = model_file + self._metadata_buf = None + self._associated_files = set() + + @classmethod + def with_model_file(cls, model_file): + """Creates a MetadataPopulator object that populates data to a model file. + + Args: + model_file: valid path to a TensorFlow Lite model file. + + Returns: + MetadataPopulator object. + + Raises: + IOError: File not found. + """ + return cls(model_file) + + # TODO(b/141468993): investigate if type check can be applied to model_buf for + # FB. + @classmethod + def with_model_buffer(cls, model_buf): + """Creates a MetadataPopulator object that populates data to a model buffer. + + Args: + model_buf: TensorFlow Lite model buffer in bytearray. + + Returns: + A MetadataPopulator(_MetadataPopulatorWithBuffer) object. + """ + return _MetadataPopulatorWithBuffer(model_buf) + + def get_model_buffer(self): + """Gets the buffer of the model with packed metadata and associated files. + + Returns: + Model buffer (in bytearray). + """ + with open(self._model_file, "rb") as f: + return f.read() + + def get_packed_associated_file_list(self): + """Gets a list of associated files packed to the model file. + + Returns: + List of packed associated files. + """ + if not zipfile.is_zipfile(self._model_file): + return [] + + with zipfile.ZipFile(self._model_file, "r") as zf: + return zf.namelist() + + def get_recorded_associated_file_list(self): + """Gets a list of associated files recorded in metadata of the model file. + + Associated files may be attached to a model, a subgraph, or an input/output + tensor. + + Returns: + List of recorded associated files. + """ + recorded_files = [] + + if not self._metadata_buf: + return recorded_files + + metadata = _metadata_fb.ModelMetadata.GetRootAsModelMetadata( + self._metadata_buf, 0) + + # Add associated files attached to ModelMetadata + self._get_associated_files_from_metadata_struct(metadata, recorded_files) + + # Add associated files attached to each SubgraphMetadata + for j in range(metadata.SubgraphMetadataLength()): + subgraph = metadata.SubgraphMetadata(j) + self._get_associated_files_from_metadata_struct(subgraph, recorded_files) + + # Add associated files attached to each input tensor + for k in range(subgraph.InputTensorMetadataLength()): + tensor = subgraph.InputTensorMetadata(k) + self._get_associated_files_from_metadata_struct(tensor, recorded_files) + + # Add associated files attached to each output tensor + for k in range(subgraph.OutputTensorMetadataLength()): + tensor = subgraph.OutputTensorMetadata(k) + self._get_associated_files_from_metadata_struct(tensor, recorded_files) + + return recorded_files + + def load_associated_files(self, associated_files): + """Loads associated files that to be concatenated after the model file. + + Args: + associated_files: list of file paths. + + Raises: + IOError: + File not found. + """ + for af in associated_files: + _assert_exist(af) + self._associated_files.add(af) + + def load_metadata_buffer(self, metadata_buf): + """Loads the metadata buffer (in bytearray) to be populated. + + Args: + metadata_buf: metadata buffer (in bytearray) to be populated. + + Raises: + ValueError: + The metadata to be populated is empty. + """ + if not metadata_buf: + raise ValueError("The metadata to be populated is empty.") + + self._metadata_buf = metadata_buf + + def load_metadata_file(self, metadata_file): + """Loads the metadata file to be populated. + + Args: + metadata_file: path to the metadata file to be populated. + + Raises: + IOError: + File not found. + """ + _assert_exist(metadata_file) + with open(metadata_file, "rb") as f: + metadata_buf = f.read() + self.load_metadata_buffer(bytearray(metadata_buf)) + + def populate(self): + """Populates loaded metadata and associated files into the model file.""" + self._assert_validate() + self._populate_metadata_buffer() + self._populate_associated_files() + + def _assert_validate(self): + """Validates the metadata and associated files to be populated. + + Raises: + ValueError: + File is recorded in the metadata, but is not going to be populated. + File has already been packed. + """ + # Gets files that are recorded in metadata. + recorded_files = self.get_recorded_associated_file_list() + + # Gets files that have been packed to self._model_file. + packed_files = self.get_packed_associated_file_list() + + # Gets the file name of those associated files to be populated. + to_be_populated_files = [] + for af in self._associated_files: + to_be_populated_files.append(os.path.basename(af)) + + # Checks all files recorded in the metadata will be populated. + for rf in recorded_files: + if rf not in to_be_populated_files and rf not in packed_files: + raise ValueError("File, '{0}', is recorded in the metadata, but has " + "not been loaded into the populator.".format(rf)) + + for f in to_be_populated_files: + if f in packed_files: + raise ValueError("File, '{0}', has already been packed.".format(f)) + + if f not in recorded_files: + warnings.warn( + "File, '{0}', does not exsit in the metadata. But packing it to " + "tflite model is still allowed.".format(f)) + + def _copy_archived_files(self, src_zip, dst_zip, file_list): + """Copy archieved files in file_list from src_zip ro dst_zip.""" + + if not zipfile.is_zipfile(src_zip): + raise ValueError("File, '{0}', is not a zipfile.".format(src_zip)) + + with zipfile.ZipFile(src_zip, + "r") as src_zf, zipfile.ZipFile(dst_zip, + "a") as dst_zf: + src_list = src_zf.namelist() + for f in file_list: + if f not in src_list: + raise ValueError( + "File, '{0}', does not exist in the zipfile, {1}.".format( + f, src_zip)) + file_buffer = src_zf.read(f) + dst_zf.writestr(f, file_buffer) + + def _get_associated_files_from_metadata_struct(self, file_holder, file_list): + for j in range(file_holder.AssociatedFilesLength()): + file_list.append(file_holder.AssociatedFiles(j).Name().decode("utf-8")) + + def _populate_associated_files(self): + """Concatenates associated files after TensorFlow Lite model file. + + If the MetadataPopulator object is created using the method, + with_model_file(model_file), the model file will be updated. + """ + # Opens up the model file in "appending" mode. + # If self._model_file already has pack files, zipfile will concatenate + # addition files after self._model_file. For example, suppose we have + # self._model_file = old_tflite_file | label1.txt | label2.txt + # Then after trigger populate() to add label3.txt, self._model_file becomes + # self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt + with zipfile.ZipFile(self._model_file, "a") as zf: + for af in self._associated_files: + filename = os.path.basename(af) + zf.write(af, filename) + + def _populate_metadata_buffer(self): + """Populates the metadata buffer (in bytearray) into the model file. + + Inserts metadata_buf into the metadata field of schema.Model. If the + MetadataPopulator object is created using the method, + with_model_file(model_file), the model file will be updated. + """ + + with open(self._model_file, "rb") as f: + model_buf = f.read() + + model = _schema_fb.ModelT.InitFromObj( + _schema_fb.Model.GetRootAsModel(model_buf, 0)) + buffer_field = _schema_fb.BufferT() + buffer_field.data = self._metadata_buf + + is_populated = False + if not model.metadata: + model.metadata = [] + else: + # Check if metadata has already been populated. + for meta in model.metadata: + if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME: + is_populated = True + model.buffers[meta.buffer] = buffer_field + + if not is_populated: + if not model.buffers: + model.buffers = [] + model.buffers.append(buffer_field) + # Creates a new metadata field. + metadata_field = _schema_fb.MetadataT() + metadata_field.name = self.METADATA_FIELD_NAME + metadata_field.buffer = len(model.buffers) - 1 + model.metadata.append(metadata_field) + + # Packs model back to a flatbuffer binaray file. + b = flatbuffers.Builder(0) + b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER) + model_buf = b.Output() + + # Saves the updated model buffer to model file. + # Gets files that have been packed to self._model_file. + packed_files = self.get_packed_associated_file_list() + if packed_files: + # Writes the updated model buffer and associated files into a new model + # file. Then overwrites the original model file. + with tempfile.NamedTemporaryFile() as temp: + new_file = temp.name + with open(new_file, "wb") as f: + f.write(model_buf) + self._copy_archived_files(self._model_file, new_file, packed_files) + shutil.copy(new_file, self._model_file) + os.remove(new_file) + else: + with open(self._model_file, "wb") as f: + f.write(model_buf) + + +class _MetadataPopulatorWithBuffer(MetadataPopulator): + """Subclass of MetadtaPopulator that populates metadata to a model buffer. + + This class is used to populate metadata into a in-memory model buffer. As we + use Zip API to concatenate associated files after tflite model file, the + populating operation is developed based on a model file. For in-memory model + buffer, we create a tempfile to serve the populating operation. This class is + then used to generate this tempfile, and delete the file when the + MetadataPopulator object is deleted. + """ + + def __init__(self, model_buf): + """Constructor for _MetadataPopulatorWithBuffer. + + Args: + model_buf: TensorFlow Lite model buffer in bytearray. + + Raises: + ValueError: model_buf is empty. + """ + if not model_buf: + raise ValueError("model_buf cannot be empty.") + + with tempfile.NamedTemporaryFile() as temp: + model_file = temp.name + + with open(model_file, "wb") as f: + f.write(model_buf) + + MetadataPopulator.__init__(self, model_file) + + def __del__(self): + """Destructor of _MetadataPopulatorWithBuffer. + + Deletes the tempfile. + """ + if os.path.exists(self._model_file): + os.remove(self._model_file) + + +class MetadataDisplayer(object): + """Displays metadata and associated file info in human-readable format.""" + + def __init__(self, model_file, metadata_file, associated_file_list): + """Constructor for MetadataDisplayer. + + Args: + model_file: valid path to the model file. + metadata_file: valid path to the metadata file. + associated_file_list: list of associate files in the model file. + """ + self._model_file = model_file + self._metadata_file = metadata_file + self._associated_file_list = associated_file_list + + @classmethod + def with_model_file(cls, model_file): + """Creates a MetadataDisplayer object for the model file. + + Args: + model_file: valid path to a TensorFlow Lite model file. + + Returns: + MetadataDisplayer object. + + Raises: + IOError: File not found. + ValueError: The model does not have metadata. + """ + _assert_exist(model_file) + metadata_file = cls._save_temporary_metadata_file(model_file) + associated_file_list = cls._parse_packed_associted_file_list(model_file) + return cls(model_file, metadata_file, associated_file_list) + + def export_metadata_json_file(self, export_dir): + """Converts the metadata into a json file. + + Args: + export_dir: the directory that the json file will be exported to. The json + file will be named after the model file, but with ".json" as extension. + """ + subprocess.check_call([ + _FLATC_BINARY_PATH, "-o", export_dir, "--json", + _FLATC_TFLITE_METADATA_SCHEMA_FILE, "--", self._metadata_file, + "--strict-json" + ]) + temp_name = os.path.join( + export_dir, + os.path.splitext(os.path.basename(self._metadata_file))[0] + ".json") + expected_name = os.path.join( + export_dir, + os.path.splitext(os.path.basename(self._model_file))[0] + ".json") + os.rename(temp_name, expected_name) + + def get_packed_associated_file_list(self): + """Returns a list of associated files that are packed in the model. + + Returns: + A name list of associated files. + """ + return copy.deepcopy(self._associated_file_list) + + @staticmethod + def _save_temporary_metadata_file(model_file): + """Saves the metadata in the model file to a temporary file. + + Args: + model_file: valid path to the model file. + + Returns: + Path to the metadata temporary file. + + Raises: + ValueError: The model does not have metadata. + """ + with open(model_file, "rb") as f: + model_buf = f.read() + + tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0) + + # Gets metadata from the model file. + for i in range(tflite_model.MetadataLength()): + meta = tflite_model.Metadata(i) + if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME: + buffer_index = meta.Buffer() + metadata = tflite_model.Buffers(buffer_index) + metadata_buf = metadata.DataAsNumpy().tobytes() + # Creates a temporary file to store the metadata. + with tempfile.NamedTemporaryFile() as temp: + metadata_file = temp.name + # Saves the metadata into the temporary file. + with open(metadata_file, "wb") as f: + f.write(metadata_buf) + return metadata_file + + raise ValueError("The model does not have metadata.") + + @staticmethod + def _parse_packed_associted_file_list(model_file): + """Gets a list of associated files packed to the model file. + + Args: + model_file: valid path to the model file. + + Returns: + List of packed associated files. + """ + if not zipfile.is_zipfile(model_file): + return [] + + with zipfile.ZipFile(model_file, "r") as zf: + return zf.namelist() + + def __del__(self): + """Destructor of MetadataDisplayer. + + Deletes the tempfile. + """ + if os.path.exists(self._metadata_file): + os.remove(self._metadata_file) + + +def _assert_exist(filename): + """Checks if a file exists.""" + if not os.path.exists(filename): + raise IOError("File, '{0}', does not exist.".format(filename)) diff --git a/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs b/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs new file mode 100644 index 00000000000..a70dd044849 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs @@ -0,0 +1,499 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tflite; + +// TFLite metadata contains both human readable and machine readable information +// about what the model does and how to use the model. It can be used as a +// README file, which elaborates the details of the model, each input/ouput +// tensor, and each associated file. +// +// An important use case of TFLite metadata is the TFLite codegen tool, which +// automatically generates the model interface based on the properties of the +// model and the tensors. The model interface provides high-level APIs to +// interact with the model, such as preprocessing the input data and running +// inferences. +// +// Entries marked with "<Codegen usage>" are used in TFLite codegen tool to +// generate the model interface. It is recommended to fill in at least those +// enties to boost the codegen performance. + +// This corresponds to the schema version. +file_identifier "M001"; +// File extension of any written files. +file_extension "tflitemeta"; + +enum AssociatedFileType : byte { + UNKNOWN = 0, + // Files such as readme.txt + DESCRIPTIONS = 1, + + // Contains labels that annotate certain axis of the tensor. For example, + // the label file in image classification. Those labels annotate the + // the output tensor, such that each value in the output tensor is the + // probability of that corresponding category specified by the label. + // + // <Codegen usage>: + // If an output tensor has an associated file as TENSOR_AXIS_LABELS, return + // the output as a mapping between the labels and probability in the model + // interface. + // If multiple files of the same type are present, the first one is used by + // default; additional ones are to be distinguished from one another by their + // specified locale. + TENSOR_AXIS_LABELS = 2, + + // Contains labels that tensor values correspond to. For example, in + // the object detection model, one of the output tensors is the detected + // classes. And each value in the tensor refers to the index of label in the + // category label file. + // + // <Codegen usage>: + // If an output tensor has an associated file as TENSOR_VALUE_LABELS, convert + // the tensor values into labels, and return a list of string as the output. + // If multiple files of the same type are present, the first one is used by + // default; additional ones are to be distinguished from one another by their + // specified locale. + TENSOR_VALUE_LABELS = 3, + + // Contains sigmoid-based score calibration parameters, formatted as CSV. + // Lines contain for each index of an output tensor the scale, slope, offset + // and min_score parameters to be used for sigmoid fitting (in this order and + // in `strtof`-compatible [1] format). + // A line may be left empty to default calibrated scores for this index to + // default_score. See documentation for ScoreCalibrationOptions for details. + // + // [1]: https://en.cppreference.com/w/c/string/byte/strtof + TENSOR_AXIS_SCORE_CALIBRATION = 4, +} + +table AssociatedFile { + // Name of this file. Need to be exact the same as the name of the actual file + // packed into the TFLite model as a zip file. + // + // <Codegen usage>: + // Locates to the actual file in the TFLite model. + name:string; + + // A description of what the file is. + description:string; + + // Type of the associated file. There may be special pre/post processing for + // some types. For example in image classification, a label file of the output + // will be used to convert object index into string. + // + // <Codegen usage>: + // Determines how to process the corresponding tensor. + type:AssociatedFileType; + + // An optional locale for this associated file (if applicable). It is + // recommended to use an ISO 639-1 letter code (e.g. "en" for English), + // optionally completed by a two letter region code (e.g. "en-US" for US + // English and "en-CA" for Canadian English). + // Leverage this in order to specify e.g multiple label files translated in + // different languages. + locale:string; +} + +// The basic content type for all tensors. +// +// <Codegen usage>: +// Input feature tensors: +// 1. Generates the method to load data from a TensorBuffer. +// 2. Creates the preprocessing logic. The default processing pipeline is: +// [NormalizeOp, QuantizeOp]. +// Output feature tensors: +// 1. Generates the method to return the output data to a TensorBuffer. +// 2. Creates the post-processing logic. The default processing pipeline is: +// [DeQuantizeOp]. +table FeatureProperties { +} + +// The type of color space of an image. +enum ColorSpaceType : byte { + UNKNOWN = 0, + RGB = 1, + GRAYSCALE = 2, +} + +table ImageSize { + width:uint; + height:uint; +} + +// The properties for image tensors. +// +// <Codegen usage>: +// Input image tensors: +// 1. Generates the method to load an image from a TensorImage. +// 2. Creates the preprocessing logic. The default processing pipeline is: +// [ResizeOp, NormalizeOp, QuantizeOp]. +// Output image tensors: +// 1. Generates the method to return the output data to a TensorImage. +// 2. Creates the post-processing logic. The default processing pipeline is: +// [DeQuantizeOp]. +table ImageProperties { + // The color space of the image. + // + // <Codegen usage>: + // Determines how to convert the color space of a given image from users. + color_space:ColorSpaceType; + + // Indicates the default value of image width and height if the tensor shape + // is dynamic. For fixed-size tensor, this size will be consistent with the + // expected size. + default_size:ImageSize; +} + +// The properties for tensors representing bounding boxes. +// +// <Codegen usage>: +// Input image tensors: NA. +// Output image tensors: parses the values into a data stucture that represents +// bounding boxes. For example, in the generated wrapper for Android, it returns +// the output as android.graphics.Rect objects. +enum BoundingBoxType : byte { + UNKNOWN = 0, + // Represents the bounding box by using the combination of boundaries, + // {left, top, right, bottom}. + // The default order is {left, top, right, bottom}. Other orders can be + // indicated by BoundingBoxProperties.index. + BOUNDARIES = 1, + + // Represents the bounding box by using the upper_left corner, width and + // height. + // The default order is {upper_left_x, upper_left_y, width, height}. Other + // orders can be indicated by BoundingBoxProperties.index. + UPPER_LEFT = 2, + + // Represents the bounding box by using the center of the box, width and + // height. The default order is {center_x, center_y, width, height}. Other + // orders can be indicated by BoundingBoxProperties.index. + CENTER = 3, + +} + +enum CoordinateType : byte { + // The coordinates are float values from 0 to 1. + RATIO = 0, + // The coordinates are integers. + PIXEL = 1, +} + +table BoundingBoxProperties { + // Denotes the order of the elements defined in each bounding box type. An + // empty index array represent the defualt order of each bounding box type. + // For example, to denote the default order of BOUNDARIES, {left, top, right, + // bottom}, the index should be {0, 1, 2, 3}. To denote the order {left, + // right, top, bottom}, the order should be {0, 2, 1, 3}. + // + // The index array can be applied to all bounding box types to adjust the + // order of their corresponding underlying elements. + // + // <Codegen usage>: + // Indicates how to parse the bounding box values. + index:[uint]; + + // <Codegen usage>: + // Indicates how to parse the bounding box values. + type:BoundingBoxType; + + // <Codegen usage>: + // Indicates how to convert the bounding box back to the original image in + // pixels. + coordinate_type:CoordinateType; +} + +union ContentProperties { + FeatureProperties, + ImageProperties, + BoundingBoxProperties, +} + +table ValueRange { + min:int; + max:int; +} + +table Content { + // The properties that the content may have, indicating the type of the + // Content. + // + // <Codegen usage>: + // Indicates how to process the tensor. + content_properties:ContentProperties; + + // The range of dimensions that the content corresponds to. A NULL + // "range" indicates that the content uses up all dimensions, + // except the batch axis if applied. + // + // Here are all the possible situations of how a tensor is composed. + // Case 1: The tensor is a single object, such as an image. + // For example, the input of an image classifier + // (https://www.tensorflow.org/lite/models/image_classification/overview), + // a tensor of shape [1, 224, 224, 3]. Dimensions 1 to 3 correspond to the + // image. Since dimension 0 is a batch axis, which can be ignored, + // "range" can be left as NULL. + // + // Case 2: The tensor contains multiple instances of the same object. + // For example, the output tensor of detected bounding boxes of an object + // detection model + // (https://www.tensorflow.org/lite/models/object_detection/overview). + // The tensor shape is [1, 10, 4]. Here is the what the three dimensions + // represent for: + // dimension 0: the batch axis. + // dimension 1: the 10 objects detected with the highest confidence. + // dimension 2: the bounding boxes of the 10 detected objects. + // The tensor is essentially 10 bounding boxes. In this case, + // "range" should be {min=2; max=2;}. + // Another example is the pose estimation model + // (https://www.tensorflow.org/lite/models/pose_estimation/overview). + // The output tensor of heatmaps is in the shape of [1, 9, 9, 17]. + // Here is the what the four dimensions represent for: + // dimension 0: the batch axis. + // dimension 1/2: the heatmap image. + // dimension 3: 17 body parts of a person. + // Even though the last axis is body part, the real content of this tensor is + // the heatmap. "range" should be [min=2; max=3]. + // + // Case 3: The tensor contains multiple different objects. (Not supported by + // Content at this point). + // Sometimes a tensor may contain multiple different objects, thus different + // contents. It is very common for regression models. For example, a model + // to predict the fuel efficiency + // (https://www.tensorflow.org/tutorials/keras/regression). + // The input tensor has shape [1, 9], consisting of 9 features, such as + // "Cylinders", "Displacement", "Weight", etc. In this case, dimension 1 + // contains 9 different contents. However, since these sub-dimension objects + // barely need to be specifically processed, their contents are not recorded + // in the metadata. Through, the name of each dimension can be set through + // TensorMetadata.dimension_names. + // + // Note that if it is not case 3, a tensor can only have one content type. + // + // <Codegen usage>: + // Case 1: return a processed single object of certain content type. + // Case 2: return a list of processed objects of certain content type. The + // generated model interface have API to random access those objects from + // the output. + range:ValueRange; +} + +// Parameters that are used when normalizing the tensor. +table NormalizationOptions{ + // mean and std are normalization parameters. Tensor values are normailzed + // per-channelly by, + // (x - mean) / std. + // For example, a float MobileNet model will have + // mean = 127.5f and std = 127.5f. + // A quantized MobileNet model will have + // mean = 0.0f and std = 1.0f. + // If there is only one value in mean or std, we'll propogate the value to + // all channels. + + // Per-channel mean of the possible values used in normalization. + // + // <Codegen usage>: + // Apply normalization to input tensors accordingly. + mean:[float]; + + // Per-channel standard dev. of the possible values used in normalization. + // + // <Codegen usage>: + // Apply normalization to input tensors accordingly. + std:[float]; +} + +// The different possible score transforms to apply to uncalibrated scores +// before applying score calibration. +enum ScoreTransformationType : byte { + // Identity function: g(x) = x. + IDENTITY = 0, + // Log function: g(x) = log(x). + LOG = 1, + // Inverse logistic function: g(x) = log(x) - log(1-x). + INVERSE_LOGISTIC = 2, +} + +// Options to perform score calibration on an output tensor through sigmoid +// functions. One of the main purposes of score calibration is to make scores +// across classes comparable, so that a common threshold can be used for all +// output classes. This is meant for models producing class predictions as +// output, e.g. image classification or detection models. +// +// For each index in the output tensor, this applies: +// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score`, +// * `f(x) = default_score` otherwise or if no scale, slope, offset and +// min_score have been specified. +// Where: +// * scale, slope, offset and min_score are index-specific parameters +// * g(x) is an index-independent transform among those defined in +// ScoreTransformationType +// * default_score is an index-independent parameter. +// An AssociatedFile with type TANSOR_AXIS_SCORE_CALIBRATION specifying the +// index-specific parameters must be associated with the corresponding +// TensorMetadata for score calibration be applied. +table ScoreCalibrationOptions { + // The function to use for transforming the uncalibrated score before + // applying score calibration. + score_transformation:ScoreTransformationType; + + // The default calibrated score to apply if the uncalibrated score is + // below min_score or if no parameters were specified for a given index. + default_score:float; +} + +// Performs thresholding on output tensor values, in order to filter out +// low-confidence results. +table ScoreThresholdingOptions { + // The recommended global threshold below which results are considered + // low-confidence and should be filtered out. + global_score_threshold:float; +} + +// Options that are used when processing the tensor. +union ProcessUnitOptions { + NormalizationOptions, + ScoreCalibrationOptions, + ScoreThresholdingOptions, +} + +// A process unit that is used to process the tensor out-of-graph. +table ProcessUnit { + options:ProcessUnitOptions; +} + + +// Statistics to describe a tensor. +table Stats { + // Max and min are not currently used in tflite.support codegen. They mainly + // serve as references for users to better understand the model. They can also + // be used to validate model pre/post processing results. + // If there is only one value in max or min, we'll propogate the value to + // all channels. + + // Per-channel maximum value of the tensor. + max:[float]; + + // Per-channel minimum value of the tensor. + min:[float]; +} + +// Detailed information of an input or output tensor. +table TensorMetadata { + // Name of the tensor. + // + // <Codegen usage>: + // The name of this tensor in the generated model interface. + name:string; + + // A description of the tensor. + description:string; + + // A list of names of the dimensions in this tentor. The length of + // dimension_names need to match the number of dimensions in this tensor. + // + // <Codegen usage>: + // The name of each dimension in the generated model interface. See "Case 2" + // in the comments of Content.range. + dimension_names:[string]; + + // The content that represents this tensor. + // + // <Codegen usage>: + // Determines how to process this tensor. See each item in ContentProperties + // for the default process units that will be applied to the tensor. + content:Content; + + // The process units that are used to process the tensor out-of-graph. + // + // <Codegen usage>: + // Contains the parameters of the default processing pipeline for each content + // type, such as the normalization parameters in all content types. See the + // items under ContentProperties for the details of the default processing + // pipeline. + process_units:[ProcessUnit]; + + // The statistics of the tensor values. + stats:Stats; + + // A list of associated files of this tensor. + // + // <Codegen usage>: + // Contains processing parameters of this tensor, such as normalization. + associated_files:[AssociatedFile]; +} + +table SubGraphMetadata { + // Name of the subgraph. + // + // Note that, since TFLite only support one subgraph at this moment, the + // Codegen tool will use the name in ModelMetadata in the generated model + // interface. + name:string; + + // A description explains details about what the subgraph does. + description:string; + + // Metadata of all input tensors used in this subgraph. + // + // <Codegen usage>: + // Determines how to process the inputs. + input_tensor_metadata:[TensorMetadata]; + + // Metadata of all output tensors used in this subgraph. + // + // <Codegen usage>: + // Determines how to process the outputs. + output_tensor_metadata:[TensorMetadata]; + + // A list of associated files of this subgraph. + associated_files:[AssociatedFile]; +} + +table ModelMetadata { + // Name of the model. + // + // <Codegen usage>: + // The name of the model in the generated model interface. + name:string; + + // Model description in schema. + description:string; + + // Version of the model that specified by model creators. + version:string; + + // Noted that, the minimum required TFLite runtime version that the model is + // compatible with, has already been added as a metadata entry in tflite + // schema. We'll decide later if we want to move it here, and keep it with + // other metadata entries. + + // Metadata of all the subgraphs of the model. The 0th is assumed to be the + // main subgraph. + // + // <Codegen usage>: + // Determines how to process the inputs and outputs. + subgraph_metadata:[SubGraphMetadata]; + + // The person who creates this model. + author:string; + + // Licenses that may apply to this model. + license:string; + + // A list of associated files of this model. + associated_files:[AssociatedFile]; +} + +root_type ModelMetadata; diff --git a/tensorflow/lite/experimental/support/metadata/metadata_test.py b/tensorflow/lite/experimental/support/metadata/metadata_test.py new file mode 100644 index 00000000000..00ee23f0b41 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/metadata_test.py @@ -0,0 +1,381 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.lite.experimental.support.metadata.metadata.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import six + +from flatbuffers.python import flatbuffers +from tensorflow.lite.experimental.support.metadata import metadata as _metadata +from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb +from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb +from tensorflow.python.framework import test_util +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test + + +class MetadataTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(MetadataTest, self).setUp() + self._invalid_model_buf = None + self._invalid_file = "not_existed_file" + self._empty_model_buf = self._create_empty_model_buf() + self._empty_model_file = self.create_tempfile().full_path + with open(self._empty_model_file, "wb") as f: + f.write(self._empty_model_buf) + self._model_file = self._create_model_file_with_metadata_and_buf_fields() + self._metadata_file = self._create_metadata_file() + self._file1 = self.create_tempfile("file1").full_path + self._file2 = self.create_tempfile("file2").full_path + self._file3 = self.create_tempfile("file3").full_path + + def _create_empty_model_buf(self): + model = _schema_fb.ModelT() + model_builder = flatbuffers.Builder(0) + model_builder.Finish( + model.Pack(model_builder), + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + return model_builder.Output() + + def _create_model_file_with_metadata_and_buf_fields(self): + metadata_field = _schema_fb.MetadataT() + metadata_field.name = "meta" + buffer_field = _schema_fb.BufferT() + model = _schema_fb.ModelT() + model.metadata = [metadata_field, metadata_field] + model.buffers = [buffer_field, buffer_field, buffer_field] + model_builder = flatbuffers.Builder(0) + model_builder.Finish( + model.Pack(model_builder), + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + + mnodel_file = self.create_tempfile().full_path + with open(mnodel_file, "wb") as f: + f.write(model_builder.Output()) + + return mnodel_file + + def _create_metadata_file(self): + associated_file1 = _metadata_fb.AssociatedFileT() + associated_file1.name = b"file1" + associated_file2 = _metadata_fb.AssociatedFileT() + associated_file2.name = b"file2" + self.expected_recorded_files = [ + six.ensure_str(associated_file1.name), + six.ensure_str(associated_file2.name) + ] + + output_meta = _metadata_fb.TensorMetadataT() + output_meta.associatedFiles = [associated_file2] + subgraph = _metadata_fb.SubGraphMetadataT() + subgraph.outputTensorMetadata = [output_meta] + + model_meta = _metadata_fb.ModelMetadataT() + model_meta.name = "Mobilenet_quantized" + model_meta.associatedFiles = [associated_file1] + model_meta.subgraphMetadata = [subgraph] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + + metadata_file = self.create_tempfile().full_path + with open(metadata_file, "wb") as f: + f.write(b.Output()) + return metadata_file + + +class MetadataPopulatorTest(MetadataTest): + + def testToValidModelFile(self): + populator = _metadata.MetadataPopulator.with_model_file( + self._empty_model_file) + self.assertIsInstance(populator, _metadata.MetadataPopulator) + + def testToInvalidModelFile(self): + with self.assertRaises(IOError) as error: + _metadata.MetadataPopulator.with_model_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testToValidModelBuffer(self): + populator = _metadata.MetadataPopulator.with_model_buffer( + self._empty_model_buf) + self.assertIsInstance(populator, _metadata.MetadataPopulator) + + def testToInvalidModelBuffer(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf) + self.assertEqual("model_buf cannot be empty.", str(error.exception)) + + def testSinglePopulateAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer( + self._empty_model_buf) + populator.load_associated_files([self._file1]) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [os.path.basename(self._file1)] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + def testRepeatedPopulateAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_file( + self._empty_model_file) + populator.load_associated_files([self._file1, self._file2]) + # Loads file2 multiple times. + populator.load_associated_files([self._file2]) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertEqual(len(packed_files), 2) + self.assertEqual(set(packed_files), set(expected_packed_files)) + + # Check if the model buffer read from file is the same as that read from + # get_model_buffer(). + with open(self._empty_model_file, "rb") as f: + model_buf_from_file = f.read() + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateInvalidAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer( + self._empty_model_buf) + with self.assertRaises(IOError) as error: + populator.load_associated_files([self._invalid_file]) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testPopulatePackedAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer( + self._empty_model_buf) + populator.load_associated_files([self._file1]) + populator.populate() + with self.assertRaises(ValueError) as error: + populator.load_associated_files([self._file1]) + populator.populate() + self.assertEqual( + "File, '{0}', has already been packed.".format( + os.path.basename(self._file1)), str(error.exception)) + + def testGetPackedAssociatedFileList(self): + populator = _metadata.MetadataPopulator.with_model_buffer( + self._empty_model_buf) + packed_files = populator.get_packed_associated_file_list() + self.assertEqual(packed_files, []) + + def testPopulateMetadataFileToEmptyModelFile(self): + populator = _metadata.MetadataPopulator.with_model_file( + self._empty_model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + + with open(self._empty_model_file, "rb") as f: + model_buf_from_file = f.read() + model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0) + metadata_field = model.Metadata(0) + self.assertEqual( + six.ensure_str(metadata_field.Name()), + six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME)) + + buffer_index = metadata_field.Buffer() + buffer_data = model.Buffers(buffer_index) + metadata_buf_np = buffer_data.DataAsNumpy() + metadata_buf = metadata_buf_np.tobytes() + with open(self._metadata_file, "rb") as f: + expected_metadata_buf = bytearray(f.read()) + self.assertEqual(metadata_buf, expected_metadata_buf) + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + # Up to now, we've proved the correctness of the model buffer that read from + # file. Then we'll test if get_model_buffer() gives the same model buffer. + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateMetadataFileWithoutAssociatedFiles(self): + populator = _metadata.MetadataPopulator.with_model_file( + self._empty_model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1]) + # Suppose to populate self._file2, because it is recorded in the metadta. + with self.assertRaises(ValueError) as error: + populator.populate() + self.assertEqual(("File, '{0}', is recorded in the metadata, but has " + "not been loaded into the populator.").format( + os.path.basename(self._file2)), str(error.exception)) + + def _assert_golden_metadata(self, model_file): + with open(model_file, "rb") as f: + model_buf_from_file = f.read() + model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0) + # There are two elements in model.Metadata array before the population. + # Metadata should be packed to the third element in the array. + metadata_field = model.Metadata(2) + self.assertEqual( + six.ensure_str(metadata_field.Name()), + six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME)) + + buffer_index = metadata_field.Buffer() + buffer_data = model.Buffers(buffer_index) + metadata_buf_np = buffer_data.DataAsNumpy() + metadata_buf = metadata_buf_np.tobytes() + with open(self._metadata_file, "rb") as f: + expected_metadata_buf = bytearray(f.read()) + self.assertEqual(metadata_buf, expected_metadata_buf) + + def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self): + # First, creates a dummy metadata. Populates it and the associated files + # into the model. + model_meta = _metadata_fb.ModelMetadataT() + model_meta.name = "Mobilenet_quantized" + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator1.load_metadata_buffer(metadata_buf) + populator1.load_associated_files([self._file1, self._file2]) + populator1.populate() + + # Then, populates the metadata again. + populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator2.load_metadata_file(self._metadata_file) + populator2.populate() + + # Tests if the metadata is populated correctly. + self._assert_golden_metadata(self._model_file) + + def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + + # Tests if the metadata is populated correctly. + self._assert_golden_metadata(self._model_file) + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + # Up to now, we've proved the correctness of the model buffer that read from + # file. Then we'll test if get_model_buffer() gives the same model buffer. + with open(self._model_file, "rb") as f: + model_buf_from_file = f.read() + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateInvalidMetadataFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer( + self._empty_model_buf) + with self.assertRaises(IOError) as error: + populator.load_metadata_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testPopulateInvalidMetadataBuffer(self): + populator = _metadata.MetadataPopulator.with_model_buffer( + self._empty_model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer([]) + self.assertEqual("The metadata to be populated is empty.", + str(error.exception)) + + def testGetModelBufferBeforePopulatingData(self): + populator = _metadata.MetadataPopulator.with_model_buffer( + self._empty_model_buf) + model_buf = populator.get_model_buffer() + expected_model_buf = self._empty_model_buf + self.assertEqual(model_buf, expected_model_buf) + + +class MetadataDisplayerTest(MetadataTest): + + def setUp(self): + super(MetadataDisplayerTest, self).setUp() + self._model_file = self._create_model_with_metadata_and_associated_files() + + def _create_model_with_metadata_and_associated_files(self): + model_buf = self._create_empty_model_buf() + model_file = self.create_tempfile().full_path + with open(model_file, "wb") as f: + f.write(model_buf) + + populator = _metadata.MetadataPopulator.with_model_file(model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + return model_file + + def test_load_model_file_invalidModelFile_throwsException(self): + with self.assertRaises(IOError) as error: + _metadata.MetadataDisplayer.with_model_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def test_load_model_file_modelWithoutMetadata_throwsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_file(self._empty_model_file) + self.assertEqual("The model does not have metadata.", str(error.exception)) + + def test_load_model_file_modelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file) + self.assertIsInstance(displayer, _metadata.MetadataDisplayer) + + def test_export_metadata_json_file_modelWithMetadata(self): + export_dir = self.create_tempdir().full_path + + displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file) + displayer.export_metadata_json_file(export_dir) + + # Verifies the generated json file. + golden_json_file_path = resource_loader.get_path_to_datafile( + "testdata/golden_json.json") + json_file_path = os.path.join( + export_dir, + os.path.splitext(os.path.basename(self._model_file))[0] + ".json") + with open(json_file_path, "r") as json_file, open(golden_json_file_path, + "r") as golden_json_file: + json_contents = json_file.read() + golden_json_contents = golden_json_file.read() + self.assertEqual(json_contents, golden_json_contents) + + def test_get_packed_associated_file_list_modelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file(self._model_file) + packed_files = displayer.get_packed_associated_file_list() + + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertEqual(len(packed_files), 2) + self.assertEqual(set(packed_files), set(expected_packed_files)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/lite/experimental/support/metadata/testdata/golden_json.json b/tensorflow/lite/experimental/support/metadata/testdata/golden_json.json new file mode 100644 index 00000000000..bc3001e685a --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/testdata/golden_json.json @@ -0,0 +1,21 @@ +{ + "name": "Mobilenet_quantized", + "subgraph_metadata": [ + { + "output_tensor_metadata": [ + { + "associated_files": [ + { + "name": "file2" + } + ] + } + ] + } + ], + "associated_files": [ + { + "name": "file1" + } + ] +} diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index bcad2b20305..9eacc19ed28 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -206,4 +206,12 @@ cc_test( ], ) +py_binary( + name = "zip_files", + srcs = ["zip_files.py"], + python_version = "PY3", + visibility = ["//visibility:public"], + deps = ["@absl_py//absl:app"], +) + tflite_portable_test_suite() diff --git a/tensorflow/lite/tools/zip_files.py b/tensorflow/lite/tools/zip_files.py new file mode 100644 index 00000000000..9dc662360f7 --- /dev/null +++ b/tensorflow/lite/tools/zip_files.py @@ -0,0 +1,41 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Creates a zip package of the files passed in.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import zipfile + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS +flags.DEFINE_string("export_zip_path", None, "Path to zip file.") +flags.DEFINE_string("file_directory", None, "Path to the files to be zipped.") + + +def main(_): + with zipfile.ZipFile(FLAGS.export_zip_path, mode="w") as zf: + for root, _, files in os.walk(FLAGS.file_directory): + for f in files: + if f.endswith(".java"): + zf.write(os.path.join(root, f)) + + +if __name__ == "__main__": + app.run(main) diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel index 6fb508db841..3b21a73154a 100644 --- a/third_party/flatbuffers/BUILD.bazel +++ b/third_party/flatbuffers/BUILD.bazel @@ -1,3 +1,7 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +package(default_visibility = ["//visibility:public"]) + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE.txt"]) @@ -116,3 +120,20 @@ py_library( srcs = [":runtime_py_srcs"], visibility = ["//visibility:public"], ) + +filegroup( + name = "runtime_java_srcs", + srcs = glob(["java/com/google/flatbuffers/**/*.java"]), +) + +java_library( + name = "runtime_java", + srcs = [":runtime_java_srcs"], + visibility = ["//visibility:public"], +) + +android_library( + name = "runtime_android", + srcs = [":runtime_java_srcs"], + visibility = ["//visibility:public"], +) diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl index a5e9eac654b..d07ad18630f 100644 --- a/third_party/flatbuffers/build_defs.bzl +++ b/third_party/flatbuffers/build_defs.bzl @@ -1,6 +1,15 @@ """BUILD rules for generating flatbuffer files.""" +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + flatc_path = "@flatbuffers//:flatc" +zip_files = "//tensorflow/lite/tools:zip_files" + +DEFAULT_INCLUDE_PATHS = [ + "./", + "$(GENDIR)", + "$(BINDIR)", +] DEFAULT_FLATC_ARGS = [ "--no-union-value-namespacing", @@ -422,3 +431,181 @@ def flatbuffer_py_library( "@flatbuffers//:runtime_py", ], ) + +def flatbuffer_java_library( + name, + srcs, + custom_package = "", + package_prefix = "", + include_paths = DEFAULT_INCLUDE_PATHS, + flatc_args = DEFAULT_FLATC_ARGS, + visibility = None): + """A java library with the generated reader/writers for the given flatbuffer definitions. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files including all includes. (required) + custom_package: Package name of generated Java files. If not specified + namespace in the schema files will be used. (optional) + package_prefix: like custom_package, but prefixes to the existing + namespace. (optional) + include_paths: List of paths that includes files can be found in. (optional) + flatc_args: List of additional arguments to pass to flatc. (optional) + visibility: Visibility setting for the java_library rule. (optional) + """ + out_srcjar = "java_%s_all.srcjar" % name + flatbuffer_java_srcjar( + name = "%s_srcjar" % name, + srcs = srcs, + out = out_srcjar, + custom_package = custom_package, + flatc_args = flatc_args, + include_paths = include_paths, + package_prefix = package_prefix, + ) + + native.filegroup( + name = "%s.srcjar" % name, + srcs = [out_srcjar], + ) + + native.java_library( + name = name, + srcs = [out_srcjar], + deps = [ + "@flatbuffers//:runtime_java", + ], + visibility = visibility, + ) + +def flatbuffer_java_srcjar( + name, + srcs, + out, + custom_package = "", + package_prefix = "", + include_paths = DEFAULT_INCLUDE_PATHS, + flatc_args = DEFAULT_FLATC_ARGS): + """Generate flatbuffer Java source files. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files including all includes. (required) + out: Output file name. (required) + custom_package: Package name of generated Java files. If not specified + namespace in the schema files will be used. (optional) + package_prefix: like custom_package, but prefixes to the existing + namespace. (optional) + include_paths: List of paths that includes files can be found in. (optional) + flatc_args: List of additional arguments to pass to flatc. (optional) + """ + command_fmt = """set -e + tmpdir=$(@D) + schemas=$$tmpdir/schemas + java_root=$$tmpdir/java + rm -rf $$schemas + rm -rf $$java_root + mkdir -p $$schemas + mkdir -p $$java_root + + for src in $(SRCS); do + dest=$$schemas/$$src + rm -rf $$(dirname $$dest) + mkdir -p $$(dirname $$dest) + if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then + cp -f $$src $$dest + else + if [ -z "{package_prefix}" ]; then + sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest + else + sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest + fi + fi + done + + flatc_arg_I="-I $$tmpdir/schemas" + for include_path in {include_paths}; do + flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path" + done + + flatc_additional_args= + for arg in {flatc_args}; do + flatc_additional_args="$$flatc_additional_args $$arg" + done + + for src in $(SRCS); do + $(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src + done + + $(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root + """ + genrule_cmd = command_fmt.format( + package_name = native.package_name(), + custom_package = custom_package, + package_prefix = package_prefix, + flatc_path = flatc_path, + zip_files = zip_files, + include_paths = " ".join(include_paths), + flatc_args = " ".join(flatc_args), + ) + + native.genrule( + name = name, + srcs = srcs, + outs = [out], + tools = [flatc_path, zip_files], + cmd = genrule_cmd, + ) + +def flatbuffer_android_library( + name, + srcs, + custom_package = "", + package_prefix = "", + javacopts = None, + include_paths = DEFAULT_INCLUDE_PATHS, + flatc_args = DEFAULT_FLATC_ARGS, + visibility = None): + """An android_library with the generated reader/writers for the given flatbuffer definitions. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files including all includes. (required) + custom_package: Package name of generated Java files. If not specified + namespace in the schema files will be used. (optional) + package_prefix: like custom_package, but prefixes to the existing + namespace. (optional) + javacopts: List of options to pass to javac. + include_paths: List of paths that includes files can be found in. (optional) + flatc_args: List of additional arguments to pass to flatc. (optional) + visibility: Visibility setting for the android_library rule. (optional) + """ + out_srcjar = "android_%s_all.srcjar" % name + flatbuffer_java_srcjar( + name = "%s_srcjar" % name, + srcs = srcs, + out = out_srcjar, + custom_package = custom_package, + flatc_args = flatc_args, + include_paths = include_paths, + package_prefix = package_prefix, + ) + + native.filegroup( + name = "%s.srcjar" % name, + srcs = [out_srcjar], + ) + + # To support org.checkerframework.dataflow.qual.Pure. + checkerframework_annotations = [ + "@org_checkerframework_qual", + ] if "--java-checkerframework" in flatc_args else [] + + android_library( + name = name, + srcs = [out_srcjar], + visibility = visibility, + deps = [ + "@flatbuffers//:runtime_android", + ] + checkerframework_annotations, + )