Allow ByteBuffer outputs from TFLite interpreter
PiperOrigin-RevId: 203029983
This commit is contained in:
parent
8a652f7979
commit
eacdfdf6c0
tensorflow/contrib/lite/java/src
main
test/java/org/tensorflow/lite
@ -135,7 +135,8 @@ public final class Interpreter implements AutoCloseable {
|
||||
* including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
|
||||
* input data. When {@link ByteBuffer} is used, its content should remain unchanged until
|
||||
* model inference is done.
|
||||
* @param output a multidimensional array of output data.
|
||||
* @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive
|
||||
* types including int, float, long, and byte.
|
||||
*/
|
||||
public void run(@NonNull Object input, @NonNull Object output) {
|
||||
Object[] inputs = {input};
|
||||
@ -155,8 +156,9 @@ public final class Interpreter implements AutoCloseable {
|
||||
* primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
|
||||
* way to pass large input data. When {@link ByteBuffer} is used, its content should remain
|
||||
* unchanged until model inference is done.
|
||||
* @param outputs a map mapping output indices to multidimensional arrays of output data. It only
|
||||
* needs to keep entries for the outputs to be used.
|
||||
* @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
|
||||
* ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep
|
||||
* entries for the outputs to be used.
|
||||
*/
|
||||
public void runForMultipleInputsOutputs(
|
||||
@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
package org.tensorflow.lite;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
@ -29,8 +31,21 @@ final class Tensor {
|
||||
return new Tensor(nativeHandle);
|
||||
}
|
||||
|
||||
/** Reads Tensor content into an array. */
|
||||
/**
|
||||
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
|
||||
*
|
||||
* @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}.
|
||||
* @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
|
||||
* mismatched data types or shapes).
|
||||
* @throws BufferOverflowException If {@code dst} is a ByteBuffer with insufficient space for the
|
||||
* data in this tensor.
|
||||
*/
|
||||
<T> T copyTo(T dst) {
|
||||
if (dst instanceof ByteBuffer) {
|
||||
ByteBuffer dstByteBuffer = (ByteBuffer) dst;
|
||||
dstByteBuffer.put(buffer());
|
||||
return dst;
|
||||
}
|
||||
if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
@ -60,6 +75,12 @@ final class Tensor {
|
||||
this.shapeCopy = shape(nativeHandle);
|
||||
}
|
||||
|
||||
private ByteBuffer buffer() {
|
||||
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
|
||||
}
|
||||
|
||||
private static native ByteBuffer buffer(long handle);
|
||||
|
||||
private static native int dtype(long handle);
|
||||
|
||||
private static native int[] shape(long handle);
|
||||
|
@ -203,6 +203,16 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
|
||||
}
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle) {
|
||||
TfLiteTensor* tensor = convertLongToTensor(env, handle);
|
||||
if (tensor == nullptr) return nullptr;
|
||||
|
||||
return env->NewDirectByteBuffer(static_cast<void*>(tensor->data.raw),
|
||||
static_cast<jlong>(tensor->bytes));
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
|
||||
jclass clazz,
|
||||
|
@ -24,8 +24,17 @@ extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_lite_TfLiteTensor
|
||||
* Method:
|
||||
* Class: org_tensorflow_lite_Tensor
|
||||
* Method: buffer
|
||||
* Signature: (J)Ljava/nio/ByteBuffer;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_lite_Tensor
|
||||
* Method: dtype
|
||||
* Signature: (J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
|
||||
@ -33,8 +42,8 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
|
||||
jlong handle);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_lite_TfLiteTensor
|
||||
* Method:
|
||||
* Class: org_tensorflow_lite_Tensor
|
||||
* Method: shape
|
||||
* Signature: (J)[I
|
||||
*/
|
||||
JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
|
||||
@ -42,8 +51,8 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
|
||||
jlong handle);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_lite_TfLiteTensor
|
||||
* Method:
|
||||
* Class: org_tensorflow_lite_Tensor
|
||||
* Method: readMultiDimensionalArray
|
||||
* Signature: (JLjava/lang/Object;)
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
|
@ -164,6 +164,24 @@ public final class InterpreterTest {
|
||||
interpreter.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRunWithByteBufferOutput() {
|
||||
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||
float[][][][] fourD = {threeD, threeD};
|
||||
ByteBuffer parsedOutput =
|
||||
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
|
||||
try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
|
||||
interpreter.run(fourD, parsedOutput);
|
||||
}
|
||||
float[] outputOneD = {
|
||||
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
|
||||
};
|
||||
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobilenetRun() {
|
||||
// Create a gray image.
|
||||
|
@ -111,6 +111,27 @@ public final class NativeInterpreterWrapperTest {
|
||||
wrapper.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRunWithBufferOutput() {
|
||||
try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) {
|
||||
float[] oneD = {1.23f, -6.54f, 7.81f};
|
||||
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||
float[][][][] fourD = {threeD, threeD};
|
||||
Object[] inputs = {fourD};
|
||||
Tensor[] outputs = wrapper.run(inputs);
|
||||
assertThat(outputs).hasLength(1);
|
||||
ByteBuffer parsedOutput =
|
||||
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
|
||||
outputs[0].copyTo(parsedOutput);
|
||||
float[] outputOneD = {
|
||||
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
|
||||
};
|
||||
float[] expected = {3.69f, -19.62f, 23.43f};
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRunWithInputsOfSameDims() {
|
||||
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
|
||||
|
@ -18,6 +18,9 @@ package org.tensorflow.lite;
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import java.nio.BufferOverflowException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
@ -70,6 +73,32 @@ public final class TensorTest {
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCopyToByteBuffer() {
|
||||
Tensor tensor = Tensor.fromHandle(nativeHandle);
|
||||
ByteBuffer parsedOutput =
|
||||
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
|
||||
tensor.copyTo(parsedOutput);
|
||||
assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4);
|
||||
float[] outputOneD = {
|
||||
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
|
||||
};
|
||||
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCopyToInvalidByteBuffer() {
|
||||
Tensor tensor = Tensor.fromHandle(nativeHandle);
|
||||
ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
|
||||
try {
|
||||
tensor.copyTo(parsedOutput);
|
||||
fail();
|
||||
} catch (BufferOverflowException e) {
|
||||
// Expected.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCopyToWrongType() {
|
||||
Tensor tensor = Tensor.fromHandle(nativeHandle);
|
||||
|
Loading…
Reference in New Issue
Block a user