Add Interpreter.allocateTensors()
call in Java
Normally this isn't necessary in Java, but it is useful in cases where the output tensor shape is invalidated due to a resize, and the client needs to know that shape before executing inference. BUG=152671540 PiperOrigin-RevId: 304701829 Change-Id: Idbfaa956b307d9a24482a7c0731efddcda3e15a3
This commit is contained in:
parent
88ada68c6d
commit
9e4434c9ff
tensorflow/lite/java/src
main/java/org/tensorflow/lite
test/java/org/tensorflow/lite
@ -314,6 +314,32 @@ public final class Interpreter implements AutoCloseable {
|
||||
wrapper.run(inputs, outputs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Expicitly updates allocations for all tensors, if necessary.
|
||||
*
|
||||
* <p>This will propagate shapes and memory allocations for all dependent tensors using the input
|
||||
* tensor shape(s) as given.
|
||||
*
|
||||
* <p>Note: This call is *purely optional*. Tensor allocation will occur automatically during
|
||||
* execution if any input tensors have been resized. This call is most useful in determining the
|
||||
* shapes for any output tensors before executing the graph, e.g.,
|
||||
* <pre>{@code
|
||||
* interpreter.resizeInput(0, new int[]{1, 4, 4, 3}));
|
||||
* interpreter.allocateTensors();
|
||||
* FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0),numElements());
|
||||
* // Populate inputs...
|
||||
* FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements());
|
||||
* interpreter.run(input, output)
|
||||
* // Process outputs...
|
||||
* }</pre>
|
||||
*
|
||||
* @throws IllegalStateException if the graph's tensors could not be successfully allocated.
|
||||
*/
|
||||
public void allocateTensors() {
|
||||
checkNotClosed();
|
||||
wrapper.allocateTensors();
|
||||
}
|
||||
|
||||
/**
|
||||
* Resizes idx-th input of the native model to the given dims.
|
||||
*
|
||||
@ -373,6 +399,13 @@ public final class Interpreter implements AutoCloseable {
|
||||
/**
|
||||
* Gets the Tensor associated with the provdied output index.
|
||||
*
|
||||
* <p>Note: Output tensor details (e.g., shape) may not be fully populated until after inference
|
||||
* is executed. If you need updated details *before* running inference (e.g., after resizing an
|
||||
* input tensor, which may invalidate output tensor shapes), use {@link #allocateTensors()} to
|
||||
* explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes
|
||||
* that are dependent on input *values*, the output shape may not be fully determined until
|
||||
* running inference.
|
||||
*
|
||||
* @throws IllegalArgumentException if {@code outputIndex} is negtive or is not smaller than the
|
||||
* number of model outputs.
|
||||
*/
|
||||
|
@ -175,6 +175,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
/** Resizes dimensions of a specific input. */
|
||||
void resizeInput(int idx, int[] dims) {
|
||||
if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
|
||||
// Tensor allocation is deferred until either an explicit `allocateTensors()` call or
|
||||
// `invoke()` avoiding redundant allocations if multiple tensors are simultaneosly resized.
|
||||
isMemoryAllocated = false;
|
||||
if (inputTensors[idx] != null) {
|
||||
inputTensors[idx].refreshShape();
|
||||
@ -185,6 +187,23 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
private static native boolean resizeInput(
|
||||
long interpreterHandle, long errorHandle, int inputIdx, int[] dims);
|
||||
|
||||
/** Triggers explicit allocation of tensors. */
|
||||
void allocateTensors() {
|
||||
if (isMemoryAllocated) {
|
||||
return;
|
||||
}
|
||||
|
||||
isMemoryAllocated = true;
|
||||
allocateTensors(interpreterHandle, errorHandle);
|
||||
for (int i = 0; i < outputTensors.length; ++i) {
|
||||
if (outputTensors[i] != null) {
|
||||
outputTensors[i].refreshShape();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static native long allocateTensors(long interpreterHandle, long errorHandle);
|
||||
|
||||
void setUseNNAPI(boolean useNNAPI) {
|
||||
useNNAPI(interpreterHandle, useNNAPI);
|
||||
}
|
||||
@ -385,8 +404,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
// List of owned delegates that must be closed when the interpreter is closed.
|
||||
private final List<AutoCloseable> ownedDelegates = new ArrayList<>();
|
||||
|
||||
private static native long allocateTensors(long interpreterHandle, long errorHandle);
|
||||
|
||||
private static native boolean hasUnresolvedFlexOp(long interpreterHandle);
|
||||
|
||||
private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
|
||||
|
@ -222,6 +222,32 @@ public final class InterpreterTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAllocateTensors() {
|
||||
try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
|
||||
// Redundant allocateTensors() should have no effect.
|
||||
interpreter.allocateTensors();
|
||||
|
||||
// allocateTensors() should propagate resizes.
|
||||
int[] inputDims = {1};
|
||||
assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims);
|
||||
interpreter.resizeInput(0, inputDims);
|
||||
assertThat(interpreter.getOutputTensor(0).shape()).isNotEqualTo(inputDims);
|
||||
interpreter.allocateTensors();
|
||||
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
|
||||
|
||||
// Additional redundant calls should have no effect.
|
||||
interpreter.allocateTensors();
|
||||
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
|
||||
|
||||
// Execution should succeed as expected.
|
||||
ByteBuffer input = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
|
||||
ByteBuffer output = ByteBuffer.allocateDirect(4).order(ByteOrder.nativeOrder());
|
||||
interpreter.run(input, output);
|
||||
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(inputDims);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnknownDims() {
|
||||
try (Interpreter interpreter = new Interpreter(UNKNOWN_DIMS_MODEL_PATH_BUFFER)) {
|
||||
|
Loading…
Reference in New Issue
Block a user