Allow ByteBuffer outputs from TFLite interpreter

PiperOrigin-RevId: 203029983
This commit is contained in:
Jared Duke 2018-07-02 16:14:32 -07:00 committed by TensorFlower Gardener
parent 8a652f7979
commit eacdfdf6c0
7 changed files with 120 additions and 10 deletions
tensorflow/contrib/lite/java/src

View File

@ -135,7 +135,8 @@ public final class Interpreter implements AutoCloseable {
* including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
* input data. When {@link ByteBuffer} is used, its content should remain unchanged until
* model inference is done.
* @param output a multidimensional array of output data.
* @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive
* types including int, float, long, and byte.
*/
public void run(@NonNull Object input, @NonNull Object output) {
Object[] inputs = {input};
@ -155,8 +156,9 @@ public final class Interpreter implements AutoCloseable {
* primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
* way to pass large input data. When {@link ByteBuffer} is used, its content should remain
* unchanged until model inference is done.
* @param outputs a map mapping output indices to multidimensional arrays of output data. It only
* needs to keep entries for the outputs to be used.
* @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
* ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep
* entries for the outputs to be used.
*/
public void runForMultipleInputsOutputs(
@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {

View File

@ -15,6 +15,8 @@ limitations under the License.
package org.tensorflow.lite;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
/**
@ -29,8 +31,21 @@ final class Tensor {
return new Tensor(nativeHandle);
}
/** Reads Tensor content into an array. */
/**
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
*
* @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}.
* @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
* mismatched data types or shapes).
* @throws BufferOverflowException If {@code dst} is a ByteBuffer with insufficient space for the
* data in this tensor.
*/
<T> T copyTo(T dst) {
if (dst instanceof ByteBuffer) {
ByteBuffer dstByteBuffer = (ByteBuffer) dst;
dstByteBuffer.put(buffer());
return dst;
}
if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
throw new IllegalArgumentException(
String.format(
@ -60,6 +75,12 @@ final class Tensor {
this.shapeCopy = shape(nativeHandle);
}
private ByteBuffer buffer() {
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
}
private static native ByteBuffer buffer(long handle);
private static native int dtype(long handle);
private static native int[] shape(long handle);

View File

@ -203,6 +203,16 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
}
}
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
jclass clazz,
jlong handle) {
TfLiteTensor* tensor = convertLongToTensor(env, handle);
if (tensor == nullptr) return nullptr;
return env->NewDirectByteBuffer(static_cast<void*>(tensor->data.raw),
static_cast<jlong>(tensor->bytes));
}
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,

View File

@ -24,8 +24,17 @@ extern "C" {
#endif // __cplusplus
/*
* Class: org_tensorflow_lite_TfLiteTensor
* Method:
* Class: org_tensorflow_lite_Tensor
* Method: buffer
* Signature: (J)Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
jclass clazz,
jlong handle);
/*
* Class: org_tensorflow_lite_Tensor
* Method: dtype
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
@ -33,8 +42,8 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
jlong handle);
/*
* Class: org_tensorflow_lite_TfLiteTensor
* Method:
* Class: org_tensorflow_lite_Tensor
* Method: shape
* Signature: (J)[I
*/
JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
@ -42,8 +51,8 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
jlong handle);
/*
* Class: org_tensorflow_lite_TfLiteTensor
* Method:
* Class: org_tensorflow_lite_Tensor
* Method: readMultiDimensionalArray
* Signature: (JLjava/lang/Object;)
*/
JNIEXPORT void JNICALL

View File

@ -164,6 +164,24 @@ public final class InterpreterTest {
interpreter.close();
}
@Test
public void testRunWithByteBufferOutput() {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
interpreter.run(fourD, parsedOutput);
}
float[] outputOneD = {
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
};
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
@Test
public void testMobilenetRun() {
// Create a gray image.

View File

@ -111,6 +111,27 @@ public final class NativeInterpreterWrapperTest {
wrapper.close();
}
@Test
public void testRunWithBufferOutput() {
try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) {
float[] oneD = {1.23f, -6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
Tensor[] outputs = wrapper.run(inputs);
assertThat(outputs).hasLength(1);
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
outputs[0].copyTo(parsedOutput);
float[] outputOneD = {
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
};
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
}
@Test
public void testRunWithInputsOfSameDims() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);

View File

@ -18,6 +18,9 @@ package org.tensorflow.lite;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@ -70,6 +73,32 @@ public final class TensorTest {
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
@Test
public void testCopyToByteBuffer() {
Tensor tensor = Tensor.fromHandle(nativeHandle);
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
tensor.copyTo(parsedOutput);
assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4);
float[] outputOneD = {
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
};
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
@Test
public void testCopyToInvalidByteBuffer() {
Tensor tensor = Tensor.fromHandle(nativeHandle);
ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
try {
tensor.copyTo(parsedOutput);
fail();
} catch (BufferOverflowException e) {
// Expected.
}
}
@Test
public void testCopyToWrongType() {
Tensor tensor = Tensor.fromHandle(nativeHandle);