Add Interpreter.allocateTensors() call in Java

Normally this isn't necessary in Java, but it is useful in cases where
the output tensor shape is invalidated due to a resize, and the client
needs to know that shape before executing inference.

BUG=152671540
PiperOrigin-RevId: 304701829
Change-Id: Idbfaa956b307d9a24482a7c0731efddcda3e15a3
This commit is contained in:
Jared Duke 2020-04-03 15:10:14 -07:00 committed by TensorFlower Gardener
parent 88ada68c6d
commit 9e4434c9ff
3 changed files with 78 additions and 2 deletions
tensorflow/lite/java/src
main/java/org/tensorflow/lite
test/java/org/tensorflow/lite

View File

@ -314,6 +314,32 @@ public final class Interpreter implements AutoCloseable {
wrapper.run(inputs, outputs);
}
/**
* Expicitly updates allocations for all tensors, if necessary.
*
* <p>This will propagate shapes and memory allocations for all dependent tensors using the input
* tensor shape(s) as given.
*
* <p>Note: This call is *purely optional*. Tensor allocation will occur automatically during
* execution if any input tensors have been resized. This call is most useful in determining the
* shapes for any output tensors before executing the graph, e.g.,
* <pre>{@code
* interpreter.resizeInput(0, new int[]{1, 4, 4, 3}));
* interpreter.allocateTensors();
* FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0),numElements());
* // Populate inputs...
* FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements());
* interpreter.run(input, output)
* // Process outputs...
* }</pre>
*
* @throws IllegalStateException if the graph's tensors could not be successfully allocated.
*/
public void allocateTensors() {
checkNotClosed();
wrapper.allocateTensors();
}
/**
* Resizes idx-th input of the native model to the given dims.
*
@ -373,6 +399,13 @@ public final class Interpreter implements AutoCloseable {
/**
* Gets the Tensor associated with the provdied output index.
*
* <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference
* is executed. If you need updated details *before* running inference (e.g., after resizing an
* input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to
* explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes
* that are dependent on input *values*, the output shape may not be fully determined until
* running inference.
*
* @throws IllegalArgumentException if {@code outputIndex} is negtive or is not smaller than the
* number of model outputs.
*/

View File

@ -175,6 +175,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
/** Resizes dimensions of a specific input. */
void resizeInput(int idx, int[] dims) {
if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
// Tensor allocation is deferred until either an explicit `allocateTensors()` call or
// `invoke()` avoiding redundant allocations if multiple tensors are simultaneosly resized.
isMemoryAllocated = false;
if (inputTensors[idx] != null) {
inputTensors[idx].refreshShape();
@ -185,6 +187,23 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native boolean resizeInput(
long interpreterHandle, long errorHandle, int inputIdx, int[] dims);
/** Triggers explicit allocation of tensors. */
void allocateTensors() {
if (isMemoryAllocated) {
return;
}
isMemoryAllocated = true;
allocateTensors(interpreterHandle, errorHandle);
for (int i = 0; i < outputTensors.length; ++i) {
if (outputTensors[i] != null) {
outputTensors[i].refreshShape();
}
}
}
private static native long allocateTensors(long interpreterHandle, long errorHandle);
void setUseNNAPI(boolean useNNAPI) {
useNNAPI(interpreterHandle, useNNAPI);
}
@ -385,8 +404,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
// List of owned delegates that must be closed when the interpreter is closed.
private final List<AutoCloseable> ownedDelegates = new ArrayList<>();
private static native long allocateTensors(long interpreterHandle, long errorHandle);
private static native boolean hasUnresolvedFlexOp(long interpreterHandle);
private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);

View File

@ -222,6 +222,32 @@ public final class InterpreterTest {
}
}
@Test
public void testAllocateTensors() {
try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
// Redundant allocateTensors() should have no effect.
interpreter.allocateTensors();
// allocateTensors() should propagate resizes.
int[] inputDims = {1};
assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims);
interpreter.resizeInput(0, inputDims);
assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims);
interpreter.allocateTensors();
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
// Additional redundant calls should have no effect.
interpreter.allocateTensors();
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
// Execution should succeed as expected.
ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
interpreter.run(input, output);
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
}
}
@Test
public void testUnknownDims() {
try (Interpreter interpreter = new Interpreter(UNKNOWN_DIMS_MODEL_PATH_BUFFER)) {