[tf.lite] Re-land "Fix JNI memory leak when constructor exceptions are thrown"

Properly dispose of native memory handles if an exception is thrown
in the Interpreter constructor. This change landed previously but was rolled
back due to a bug in how the ByteBuffer constructor argument was stored.
That has been addressed in this change.

PiperOrigin-RevId: 351844621
Change-Id: Ibd335a4d3a4412a910d859af9f3a5b0a60829cc0
This commit is contained in:
Jared Duke 2021-01-14 11:47:08 -08:00 committed by TensorFlower Gardener
parent 00915ea833
commit 3d21db1f00

View File

@ -44,14 +44,19 @@ final class NativeInterpreterWrapper implements AutoCloseable {
this(byteBuffer, /* options= */ null);
}
NativeInterpreterWrapper(String modelPath, Interpreter.Options options) {
NativeInterpreterWrapper(final String modelPath, Interpreter.Options options) {
TensorFlowLite.init();
long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
long modelHandle = createModel(modelPath, errorHandle);
init(errorHandle, modelHandle, options);
ModelCreator modelCreator =
new ModelCreator() {
@Override
public long create(long errorHandle) {
return createModel(modelPath, errorHandle);
}
};
init(modelCreator, options);
}
NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) {
NativeInterpreterWrapper(final ByteBuffer buffer, Interpreter.Options options) {
TensorFlowLite.init();
if (buffer == null
|| (!(buffer instanceof MappedByteBuffer)
@ -61,45 +66,67 @@ final class NativeInterpreterWrapper implements AutoCloseable {
+ "ByteBuffer using ByteOrder.nativeOrder() which contains bytes of model content.");
}
this.modelByteBuffer = buffer;
long errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
long modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
init(errorHandle, modelHandle, options);
ModelCreator modelCreator =
new ModelCreator() {
@Override
public long create(long errorHandle) {
return createModelWithBuffer(buffer, errorHandle);
}
};
init(modelCreator, options);
}
private void init(long errorHandle, long modelHandle, Interpreter.Options options) {
private interface ModelCreator {
public long create(long errorHandle);
}
private void init(ModelCreator modelCreator, Interpreter.Options options) {
if (options == null) {
options = new Interpreter.Options();
}
this.errorHandle = errorHandle;
this.modelHandle = modelHandle;
this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
if (options.allowCancellation != null && options.allowCancellation) {
this.cancellationFlagHandle = createCancellationFlag(interpreterHandle);
}
this.inputTensors = new Tensor[getInputCount(interpreterHandle)];
this.outputTensors = new Tensor[getOutputCount(interpreterHandle)];
if (options.allowFp16PrecisionForFp32 != null) {
allowFp16PrecisionForFp32(
interpreterHandle, options.allowFp16PrecisionForFp32.booleanValue());
}
if (options.allowBufferHandleOutput != null) {
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
}
applyDelegates(options);
// First initialize native handles to zero. If an exception is encountered, we will dispose of
// them, and this avoids deleting an uninitialized pointer if handle creation fails.
errorHandle = 0;
modelHandle = 0;
interpreterHandle = 0;
try {
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = modelCreator.create(errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
if (options.allowCancellation != null && options.allowCancellation) {
cancellationFlagHandle = createCancellationFlag(interpreterHandle);
}
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
if (options.allowFp16PrecisionForFp32 != null) {
allowFp16PrecisionForFp32(
interpreterHandle, options.allowFp16PrecisionForFp32.booleanValue());
}
if (options.allowBufferHandleOutput != null) {
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
}
applyDelegates(options);
// Simply use "-1" to represent the default mode.
int applyXNNPACKMode = -1;
if (options.useXNNPACK != null) {
applyXNNPACKMode = options.useXNNPACK.booleanValue() ? 1 : 0;
}
// Simply use "-1" to represent the default mode.
int applyXNNPACKMode = -1;
if (options.useXNNPACK != null) {
applyXNNPACKMode = options.useXNNPACK.booleanValue() ? 1 : 0;
}
// TODO(b/171856982): uncomment the following when applying XNNPACK delegate by default is
// enabled for C++ TfLite library on Android platform.
if (applyXNNPACKMode == 1 /*|| applyXNNPACKMode == -1*/) {
useXNNPACK(interpreterHandle, errorHandle, applyXNNPACKMode, options.numThreads);
// TODO(b/171856982): uncomment the following when applying XNNPACK delegate by default is
// enabled for C++ TfLite library on Android platform.
if (applyXNNPACKMode == 1 /*|| applyXNNPACKMode == -1*/) {
useXNNPACK(interpreterHandle, errorHandle, applyXNNPACKMode, options.numThreads);
}
allocateTensors(interpreterHandle, errorHandle);
isMemoryAllocated = true;
} finally {
// If any of the native handles were created successfully, we should dispose of them if
// creation and allocation did not succeed. This avoids leaks in the event of an error.
if (!isMemoryAllocated && (errorHandle != 0 || modelHandle != 0 || interpreterHandle != 0)) {
delete(errorHandle, modelHandle, interpreterHandle);
}
}
allocateTensors(interpreterHandle, errorHandle);
this.isMemoryAllocated = true;
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */