From 9e4434c9ffb2e74dac7abe51c3bf43f9144cb20c Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Fri, 3 Apr 2020 15:10:14 -0700 Subject: [PATCH] Add `Interpreter.allocateTensors()` call in Java Normally this isn't necessary in Java, but it is useful in cases where the output tensor shape is invalidated due to a resize, and the client needs to know that shape before executing inference. BUG=152671540 PiperOrigin-RevId: 304701829 Change-Id: Idbfaa956b307d9a24482a7c0731efddcda3e15a3 --- .../java/org/tensorflow/lite/Interpreter.java | 33 +++++++++++++++++++ .../lite/NativeInterpreterWrapper.java | 21 ++++++++++-- .../org/tensorflow/lite/InterpreterTest.java | 26 +++++++++++++++ 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index 6aeb06355b4..efcdc0e4c65 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -314,6 +314,32 @@ public final class Interpreter implements AutoCloseable { wrapper.run(inputs, outputs); } + /** + * Expicitly updates allocations for all tensors, if necessary. + * + *

This will propagate shapes and memory allocations for all dependent tensors using the input + * tensor shape(s) as given. + * + *

Note: This call is *purely optional*. Tensor allocation will occur automatically during + * execution if any input tensors have been resized. This call is most useful in determining the + * shapes for any output tensors before executing the graph, e.g., + *

{@code
+   * interpreter.resizeInput(0, new int[]{1, 4, 4, 3}));
+   * interpreter.allocateTensors();
+   * FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0),numElements());
+   * // Populate inputs...
+   * FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements());
+   * interpreter.run(input, output)
+   * // Process outputs...
+   * }
+ * + * @throws IllegalStateException if the graph's tensors could not be successfully allocated. + */ + public void allocateTensors() { + checkNotClosed(); + wrapper.allocateTensors(); + } + /** * Resizes idx-th input of the native model to the given dims. * @@ -373,6 +399,13 @@ public final class Interpreter implements AutoCloseable { /** * Gets the Tensor associated with the provdied output index. * + *

Note: Output tensor details (e.g., shape) may not be fully populated until after inference + * is executed. If you need updated details *before* running inference (e.g., after resizing an + * input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to + * explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes + * that are dependent on input *values*, the output shape may not be fully determined until + * running inference. + * * @throws IllegalArgumentException if {@code outputIndex} is negtive or is not smaller than the * number of model outputs. */ diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index ca21ec5c7ea..73fe506f131 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -175,6 +175,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { /** Resizes dimensions of a specific input. */ void resizeInput(int idx, int[] dims) { if (resizeInput(interpreterHandle, errorHandle, idx, dims)) { + // Tensor allocation is deferred until either an explicit `allocateTensors()` call or + // `invoke()` avoiding redundant allocations if multiple tensors are simultaneosly resized. isMemoryAllocated = false; if (inputTensors[idx] != null) { inputTensors[idx].refreshShape(); @@ -185,6 +187,23 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native boolean resizeInput( long interpreterHandle, long errorHandle, int inputIdx, int[] dims); + /** Triggers explicit allocation of tensors. */ + void allocateTensors() { + if (isMemoryAllocated) { + return; + } + + isMemoryAllocated = true; + allocateTensors(interpreterHandle, errorHandle); + for (int i = 0; i < outputTensors.length; ++i) { + if (outputTensors[i] != null) { + outputTensors[i].refreshShape(); + } + } + } + + private static native long allocateTensors(long interpreterHandle, long errorHandle); + void setUseNNAPI(boolean useNNAPI) { useNNAPI(interpreterHandle, useNNAPI); } @@ -385,8 +404,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { // List of owned delegates that must be closed when the interpreter is closed. private final List ownedDelegates = new ArrayList<>(); - private static native long allocateTensors(long interpreterHandle, long errorHandle); - private static native boolean hasUnresolvedFlexOp(long interpreterHandle); private static native int getInputTensorIndex(long interpreterHandle, int inputIdx); diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 8b18e1764ce..b38f1ad771d 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -222,6 +222,32 @@ public final class InterpreterTest { } } + @Test + public void testAllocateTensors() { + try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) { + // Redundant allocateTensors() should have no effect. + interpreter.allocateTensors(); + + // allocateTensors() should propagate resizes. + int[] inputDims = {1}; + assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims); + interpreter.resizeInput(0, inputDims); + assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims); + interpreter.allocateTensors(); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims); + + // Additional redundant calls should have no effect. + interpreter.allocateTensors(); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims); + + // Execution should succeed as expected. + ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder()); + ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder()); + interpreter.run(input, output); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims); + } + } + @Test public void testUnknownDims() { try (Interpreter interpreter = new Interpreter(UNKNOWN_DIMS_MODEL_PATH_BUFFER)) {