Allow ByteBuffer outputs from TFLite interpreter
PiperOrigin-RevId: 203029983
This commit is contained in:
parent
8a652f7979
commit
eacdfdf6c0
@ -135,7 +135,8 @@ public final class Interpreter implements AutoCloseable {
|
|||||||
* including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
|
* including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
|
||||||
* input data. When {@link ByteBuffer} is used, its content should remain unchanged until
|
* input data. When {@link ByteBuffer} is used, its content should remain unchanged until
|
||||||
* model inference is done.
|
* model inference is done.
|
||||||
* @param output a multidimensional array of output data.
|
* @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive
|
||||||
|
* types including int, float, long, and byte.
|
||||||
*/
|
*/
|
||||||
public void run(@NonNull Object input, @NonNull Object output) {
|
public void run(@NonNull Object input, @NonNull Object output) {
|
||||||
Object[] inputs = {input};
|
Object[] inputs = {input};
|
||||||
@ -155,8 +156,9 @@ public final class Interpreter implements AutoCloseable {
|
|||||||
* primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
|
* primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
|
||||||
* way to pass large input data. When {@link ByteBuffer} is used, its content should remain
|
* way to pass large input data. When {@link ByteBuffer} is used, its content should remain
|
||||||
* unchanged until model inference is done.
|
* unchanged until model inference is done.
|
||||||
* @param outputs a map mapping output indices to multidimensional arrays of output data. It only
|
* @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
|
||||||
* needs to keep entries for the outputs to be used.
|
* ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep
|
||||||
|
* entries for the outputs to be used.
|
||||||
*/
|
*/
|
||||||
public void runForMultipleInputsOutputs(
|
public void runForMultipleInputsOutputs(
|
||||||
@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
|
@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
package org.tensorflow.lite;
|
package org.tensorflow.lite;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -29,8 +31,21 @@ final class Tensor {
|
|||||||
return new Tensor(nativeHandle);
|
return new Tensor(nativeHandle);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Reads Tensor content into an array. */
|
/**
|
||||||
|
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
|
||||||
|
*
|
||||||
|
* @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}.
|
||||||
|
* @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
|
||||||
|
* mismatched data types or shapes).
|
||||||
|
* @throws BufferOverflowException If {@code dst} is a ByteBuffer with insufficient space for the
|
||||||
|
* data in this tensor.
|
||||||
|
*/
|
||||||
<T> T copyTo(T dst) {
|
<T> T copyTo(T dst) {
|
||||||
|
if (dst instanceof ByteBuffer) {
|
||||||
|
ByteBuffer dstByteBuffer = (ByteBuffer) dst;
|
||||||
|
dstByteBuffer.put(buffer());
|
||||||
|
return dst;
|
||||||
|
}
|
||||||
if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
|
if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
String.format(
|
String.format(
|
||||||
@ -60,6 +75,12 @@ final class Tensor {
|
|||||||
this.shapeCopy = shape(nativeHandle);
|
this.shapeCopy = shape(nativeHandle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private ByteBuffer buffer() {
|
||||||
|
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static native ByteBuffer buffer(long handle);
|
||||||
|
|
||||||
private static native int dtype(long handle);
|
private static native int dtype(long handle);
|
||||||
|
|
||||||
private static native int[] shape(long handle);
|
private static native int[] shape(long handle);
|
||||||
|
@ -203,6 +203,16 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
|
||||||
|
jclass clazz,
|
||||||
|
jlong handle) {
|
||||||
|
TfLiteTensor* tensor = convertLongToTensor(env, handle);
|
||||||
|
if (tensor == nullptr) return nullptr;
|
||||||
|
|
||||||
|
return env->NewDirectByteBuffer(static_cast<void*>(tensor->data.raw),
|
||||||
|
static_cast<jlong>(tensor->bytes));
|
||||||
|
}
|
||||||
|
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
|
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
|
||||||
jclass clazz,
|
jclass clazz,
|
||||||
|
@ -24,8 +24,17 @@ extern "C" {
|
|||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_tensorflow_lite_TfLiteTensor
|
* Class: org_tensorflow_lite_Tensor
|
||||||
* Method:
|
* Method: buffer
|
||||||
|
* Signature: (J)Ljava/nio/ByteBuffer;
|
||||||
|
*/
|
||||||
|
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
|
||||||
|
jclass clazz,
|
||||||
|
jlong handle);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: org_tensorflow_lite_Tensor
|
||||||
|
* Method: dtype
|
||||||
* Signature: (J)I
|
* Signature: (J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
|
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
|
||||||
@ -33,8 +42,8 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
|
|||||||
jlong handle);
|
jlong handle);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_tensorflow_lite_TfLiteTensor
|
* Class: org_tensorflow_lite_Tensor
|
||||||
* Method:
|
* Method: shape
|
||||||
* Signature: (J)[I
|
* Signature: (J)[I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
|
JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
|
||||||
@ -42,8 +51,8 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
|
|||||||
jlong handle);
|
jlong handle);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_tensorflow_lite_TfLiteTensor
|
* Class: org_tensorflow_lite_Tensor
|
||||||
* Method:
|
* Method: readMultiDimensionalArray
|
||||||
* Signature: (JLjava/lang/Object;)
|
* Signature: (JLjava/lang/Object;)
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
|
@ -164,6 +164,24 @@ public final class InterpreterTest {
|
|||||||
interpreter.close();
|
interpreter.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRunWithByteBufferOutput() {
|
||||||
|
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||||
|
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||||
|
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||||
|
float[][][][] fourD = {threeD, threeD};
|
||||||
|
ByteBuffer parsedOutput =
|
||||||
|
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
|
||||||
|
try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
|
||||||
|
interpreter.run(fourD, parsedOutput);
|
||||||
|
}
|
||||||
|
float[] outputOneD = {
|
||||||
|
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
|
||||||
|
};
|
||||||
|
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||||
|
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMobilenetRun() {
|
public void testMobilenetRun() {
|
||||||
// Create a gray image.
|
// Create a gray image.
|
||||||
|
@ -111,6 +111,27 @@ public final class NativeInterpreterWrapperTest {
|
|||||||
wrapper.close();
|
wrapper.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRunWithBufferOutput() {
|
||||||
|
try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) {
|
||||||
|
float[] oneD = {1.23f, -6.54f, 7.81f};
|
||||||
|
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||||
|
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||||
|
float[][][][] fourD = {threeD, threeD};
|
||||||
|
Object[] inputs = {fourD};
|
||||||
|
Tensor[] outputs = wrapper.run(inputs);
|
||||||
|
assertThat(outputs).hasLength(1);
|
||||||
|
ByteBuffer parsedOutput =
|
||||||
|
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
|
||||||
|
outputs[0].copyTo(parsedOutput);
|
||||||
|
float[] outputOneD = {
|
||||||
|
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
|
||||||
|
};
|
||||||
|
float[] expected = {3.69f, -19.62f, 23.43f};
|
||||||
|
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRunWithInputsOfSameDims() {
|
public void testRunWithInputsOfSameDims() {
|
||||||
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
|
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
|
||||||
|
@ -18,6 +18,9 @@ package org.tensorflow.lite;
|
|||||||
import static com.google.common.truth.Truth.assertThat;
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
|
import java.nio.BufferOverflowException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
@ -70,6 +73,32 @@ public final class TensorTest {
|
|||||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCopyToByteBuffer() {
|
||||||
|
Tensor tensor = Tensor.fromHandle(nativeHandle);
|
||||||
|
ByteBuffer parsedOutput =
|
||||||
|
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
|
||||||
|
tensor.copyTo(parsedOutput);
|
||||||
|
assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4);
|
||||||
|
float[] outputOneD = {
|
||||||
|
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
|
||||||
|
};
|
||||||
|
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||||
|
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCopyToInvalidByteBuffer() {
|
||||||
|
Tensor tensor = Tensor.fromHandle(nativeHandle);
|
||||||
|
ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
|
||||||
|
try {
|
||||||
|
tensor.copyTo(parsedOutput);
|
||||||
|
fail();
|
||||||
|
} catch (BufferOverflowException e) {
|
||||||
|
// Expected.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCopyToWrongType() {
|
public void testCopyToWrongType() {
|
||||||
Tensor tensor = Tensor.fromHandle(nativeHandle);
|
Tensor tensor = Tensor.fromHandle(nativeHandle);
|
||||||
|
Loading…
Reference in New Issue
Block a user