Remove TFLite Java runtime dependency in the metadata java lib

PiperOrigin-RevId: 312601579
Change-Id: I57d7bfe06d36e62a6fa203c39225687861fa4580
This commit is contained in:
Lu Wang 2020-05-20 19:26:00 -07:00 committed by TensorFlower Gardener
parent bdf665b504
commit cf739c4104
4 changed files with 72 additions and 47 deletions
tensorflow/lite/experimental/support
java/src/java/org/tensorflow/lite/support/model
metadata/java
BUILD
src/java/org/tensorflow/lite/support/metadata

View File

@ -22,6 +22,7 @@ import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.Tensor;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.SupportPreconditions;
@ -218,6 +219,24 @@ public class Model {
return modelPath;
}
/**
* Gets the Tensor associated with the provdied input index.
*
* @throws IllegalStateException if the interpreter is closed.
*/
public Tensor getInputTensor(int inputIndex) {
return interpreter.getInputTensor(inputIndex);
}
/**
* Gets the Tensor associated with the provdied output index.
*
* @throws IllegalStateException if the interpreter is closed.
*/
public Tensor getOutputTensor(int outputIndex) {
return interpreter.getOutputTensor(outputIndex);
}
/**
* Returns the output shape. Useful if output shape is only determined when graph is created.
*

View File

@ -16,7 +16,6 @@ android_library(
deps = [
"//tensorflow/lite/experimental/support/metadata:metadata_schema_fbs_android",
"//tensorflow/lite/experimental/support/metadata:schema_fbs_android",
"//tensorflow/lite/java:tensorflowlite_java",
"@org_checkerframework_qual",
],
)
@ -32,7 +31,6 @@ java_library(
deps = [
"//tensorflow/lite/experimental/support/metadata:metadata_schema_java",
"//tensorflow/lite/experimental/support/metadata:schema_fbs_java",
"//tensorflow/lite/java:tensorflowlite_javalib",
"@org_checkerframework_qual",
],
)

View File

@ -22,8 +22,6 @@ import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.zip.ZipException;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Tensor.QuantizationParams;
import org.tensorflow.lite.schema.Tensor;
import org.tensorflow.lite.support.metadata.schema.ModelMetadata;
import org.tensorflow.lite.support.metadata.schema.TensorMetadata;
@ -111,6 +109,48 @@ public class MetadataExtractor {
zipFile = createZipFile(buffer);
}
/**
* Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the
* <a
* href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite
* Model schema file.</a>
*
* <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and
* {@code zero_point} are both single values instead of arrays.
*
* <p>For tensor that are not quantized, the values of scale and zero_point are both 0.
*
* <p>Given a quantized value q, the corresponding float value f should be: <br>
* f = scale * (q - zero_point) <br>
*/
public static class QuantizationParams {
/** The scale value used in quantization. */
private final float scale;
/** The zero point value used in quantization. */
private final int zeroPoint;
/**
* Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}.
*
* @param scale The scale value used in quantization.
* @param zeroPoint The zero point value used in quantization.
*/
public QuantizationParams(final float scale, final int zeroPoint) {
this.scale = scale;
this.zeroPoint = zeroPoint;
}
/** Returns the scale value. */
public float getScale() {
return scale;
}
/** Returns the zero point value. */
public int getZeroPoint() {
return zeroPoint;
}
}
/** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */
public boolean hasMetadata() {
return metadataInfo != null;
@ -166,11 +206,11 @@ public class MetadataExtractor {
}
/**
* Gets the {@link DataType} of the input tensor with {@code inputIndex}.
* Gets the {@link TensorType} of the input tensor with {@code inputIndex}.
*
* @param inputIndex the index of the desired input tensor
*/
public DataType getInputTensorType(int inputIndex) {
public byte getInputTensorType(int inputIndex) {
return modelInfo.getInputTensorType(inputIndex);
}
@ -221,11 +261,11 @@ public class MetadataExtractor {
}
/**
* Gets the {@link DataType} of the output tensor with {@code outputIndex}.
* Gets the {@link TensorType} of the output tensor with {@code outputIndex}.
*
* @param outputIndex the index of the desired output tensor
*/
public DataType getOutputTensorType(int outputIndex) {
public byte getOutputTensorType(int outputIndex) {
return modelInfo.getOutputTensorType(outputIndex);
}

View File

@ -21,12 +21,8 @@ import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Tensor.QuantizationParams;
import org.tensorflow.lite.schema.Buffer;
import org.tensorflow.lite.schema.Metadata;
import org.tensorflow.lite.schema.Model;
@ -34,6 +30,7 @@ import org.tensorflow.lite.schema.QuantizationParameters;
import org.tensorflow.lite.schema.SubGraph;
import org.tensorflow.lite.schema.Tensor;
import org.tensorflow.lite.schema.TensorType;
import org.tensorflow.lite.support.metadata.MetadataExtractor.QuantizationParams;
/** Extracts model information out of TFLite model FLatBuffer. */
final class ModelInfo {
@ -49,9 +46,6 @@ final class ModelInfo {
/** Identifier of the TFLite model metadata in the Metadata array. */
static final String METADATA_FIELD_NAME = "TFLITE_METADATA";
/** Maps from TensorType in TFlite FlatBuffer to {@link DataType} in Java. */
private final Map<Byte, DataType> tensorTypeToDataTypeMap;
/**
* Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}.
*
@ -74,7 +68,6 @@ final class ModelInfo {
inputTensors = getInputTensors(model);
outputTensors = getOutputTensors(model);
tensorTypeToDataTypeMap = createTensorTypeToDataTypeMap();
}
/**
@ -106,13 +99,12 @@ final class ModelInfo {
}
/**
* Gets {@link DataType} of the input tensor with {@code inputIndex}.
* Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}.
*
* @param inputIndex The index of the desired intput tensor.
*/
DataType getInputTensorType(int inputIndex) {
Tensor tensor = getInputTensor(inputIndex);
return getDataType(tensor.type());
byte getInputTensorType(int inputIndex) {
return getInputTensor(inputIndex).type();
}
/** Gets the metadata FlatBuffer from the model FlatBuffer. */
@ -163,13 +155,12 @@ final class ModelInfo {
}
/**
* Gets {@link DataType} of the output tensor {@code outputIndex}.
* Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}.
*
* @param outputIndex The index of the desired outtput tensor.
*/
DataType getOutputTensorType(int outputIndex) {
Tensor tensor = getOutputTensor(outputIndex);
return getDataType(tensor.type());
byte getOutputTensorType(int outputIndex) {
return getOutputTensor(outputIndex).type();
}
/**
@ -233,29 +224,6 @@ final class ModelInfo {
+ " flatbuffer.");
}
private static Map<Byte, DataType> createTensorTypeToDataTypeMap() {
Map<Byte, DataType> map = new HashMap<>();
map.put(TensorType.FLOAT32, DataType.FLOAT32);
map.put(TensorType.INT32, DataType.INT32);
map.put(TensorType.UINT8, DataType.UINT8);
map.put(TensorType.INT64, DataType.INT64);
map.put(TensorType.STRING, DataType.STRING);
return Collections.unmodifiableMap(map);
}
/**
* Transforms from TensorType in TFlite FlatBuffer to {@link DataType} in Java.
*
* @param tensorType The tensor type to be converted.
* @throws IllegalArgumentException if the tensor type is not supported.
*/
private DataType getDataType(byte tensorType) {
checkArgument(
tensorTypeToDataTypeMap.containsKey(tensorType),
String.format("Tensor type %d is not supported.", tensorType));
return tensorTypeToDataTypeMap.get(tensorType);
}
/**
* Gets the shape of a tensor.
*