[tf.lite] Fix issue with direct ByteBuffer inputs and dynamic graphs
In these graphs, the input tensor pointers may get "refreshed" during invocation. This refresh is fine if the original pointer came from the arena, but if it comes from something like the direct ByteBuffer raw address, the input data will be lost. Avoid this by simply using memcpy from the direct ByteBuffer. This is still quite fast, but avoids the hack where we simply inject the direct ByteBuffer address as the tensor buffer pointer. A longer term solution will formally allow providing "custom" allocated regions to tensor inputs, but until then, do the safe thing. PiperOrigin-RevId: 310643333 Change-Id: I05dfebd24617ebb1af7eb281ff9e530b01669093
This commit is contained in:
parent
bca497ef47
commit
b8d991c9b4
@ -240,6 +240,7 @@ java_test(
|
||||
data = [
|
||||
"src/testdata/add.bin",
|
||||
"src/testdata/add_unknown_dimensions.bin",
|
||||
"//tensorflow/lite:testdata/dynamic_shapes.bin",
|
||||
"//tensorflow/lite:testdata/multi_add.bin",
|
||||
"//tensorflow/lite:testdata/multi_add_flex.bin",
|
||||
],
|
||||
|
@ -196,7 +196,7 @@ public final class Tensor {
|
||||
}
|
||||
|
||||
private void setTo(Buffer src) {
|
||||
// Note that we attempt to use zero-copy optimization for direct, native-ordered buffers.
|
||||
// Note that we attempt to use a direct memcpy optimization for direct, native-ordered buffers.
|
||||
// There are no base Buffer#order() or Buffer#put() methods, so again we have to ugly cast.
|
||||
if (src instanceof ByteBuffer) {
|
||||
ByteBuffer srcBuffer = (ByteBuffer) src;
|
||||
|
@ -402,14 +402,26 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
|
||||
TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
|
||||
if (tensor == nullptr) return;
|
||||
|
||||
char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src));
|
||||
void* src_data_raw = env->GetDirectBufferAddress(src);
|
||||
if (!src_data_raw) {
|
||||
ThrowException(env, kIllegalArgumentException,
|
||||
"Input ByteBuffer is not a direct buffer");
|
||||
return;
|
||||
}
|
||||
|
||||
tensor->data.raw = src_data_raw;
|
||||
if (!tensor->data.data) {
|
||||
ThrowException(env, kIllegalArgumentException,
|
||||
"Internal error: Tensor hasn't been allocated.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Historically, we would simply overwrite the tensor buffer pointer with
|
||||
// the direct Buffer address. However, that is generally unsafe, and
|
||||
// specifically wrong if the graph happens to have dynamic shapes where
|
||||
// arena-allocated input buffers will be refreshed during invocation.
|
||||
// TODO(b/156094015): Explore whether this is actually faster than
|
||||
// using ByteBuffer.put(ByteBuffer).
|
||||
memcpy(tensor->data.data, src_data_raw, tensor->bytes);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
|
@ -40,6 +40,8 @@ public final class InterpreterTest {
|
||||
"tensorflow/lite/testdata/multi_add_flex.bin";
|
||||
private static final String UNKNOWN_DIMS_MODEL_PATH =
|
||||
"tensorflow/lite/java/src/testdata/add_unknown_dimensions.bin";
|
||||
private static final String DYNAMIC_SHAPES_MODEL_PATH =
|
||||
"tensorflow/lite/testdata/dynamic_shapes.bin";
|
||||
|
||||
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
|
||||
private static final ByteBuffer MULTIPLE_INPUTS_MODEL_BUFFER =
|
||||
@ -48,6 +50,8 @@ public final class InterpreterTest {
|
||||
TestUtils.getTestFileAsBuffer(FLEX_MODEL_PATH);
|
||||
private static final ByteBuffer UNKNOWN_DIMS_MODEL_PATH_BUFFER =
|
||||
TestUtils.getTestFileAsBuffer(UNKNOWN_DIMS_MODEL_PATH);
|
||||
private static final ByteBuffer DYNAMIC_SHAPES_MODEL_BUFFER =
|
||||
TestUtils.getTestFileAsBuffer(DYNAMIC_SHAPES_MODEL_PATH);
|
||||
|
||||
@Test
|
||||
public void testInterpreter() throws Exception {
|
||||
@ -434,7 +438,7 @@ public final class InterpreterTest {
|
||||
interpreter.close();
|
||||
}
|
||||
|
||||
/** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
|
||||
// Smoke test validating that flex model loading fails when the flex delegate is not linked.
|
||||
@Test
|
||||
public void testFlexModel() throws Exception {
|
||||
try {
|
||||
@ -573,6 +577,45 @@ public final class InterpreterTest {
|
||||
}
|
||||
}
|
||||
|
||||
private static FloatBuffer fill(FloatBuffer buffer, float value) {
|
||||
while (buffer.hasRemaining()) {
|
||||
buffer.put(value);
|
||||
}
|
||||
buffer.rewind();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
// Regression test case to ensure that graphs with dynamically computed shapes work properly.
|
||||
// Historically, direct ByteBuffer addresses would overwrite the arena-allocated tensor input
|
||||
// pointers. Normally this works fine, but for dynamic graphs, the original input tensor pointers
|
||||
// may be "restored" at invocation time by the arena allocator, resetting the direct ByteBuffer
|
||||
// address and leading to stale input data being used.
|
||||
@Test
|
||||
public void testDynamicShapesWithDirectBufferInputs() {
|
||||
try (Interpreter interpreter = new Interpreter(DYNAMIC_SHAPES_MODEL_BUFFER)) {
|
||||
ByteBuffer input0 =
|
||||
ByteBuffer.allocateDirect(8 * 42 * 1024 * 4).order(ByteOrder.nativeOrder());
|
||||
ByteBuffer input1 =
|
||||
ByteBuffer.allocateDirect(1 * 90 * 1024 * 4).order(ByteOrder.nativeOrder());
|
||||
ByteBuffer input2 = ByteBuffer.allocateDirect(1 * 4).order(ByteOrder.nativeOrder());
|
||||
Object[] inputs = {input0, input1, input2};
|
||||
|
||||
fill(input0.asFloatBuffer(), 2.0f);
|
||||
fill(input1.asFloatBuffer(), 0.5f);
|
||||
// Note that the value of this input dictates the shape of the output.
|
||||
fill(input2.asFloatBuffer(), 1.0f);
|
||||
|
||||
FloatBuffer output = FloatBuffer.allocate(8 * 1 * 1024);
|
||||
Map<Integer, Object> outputs = new HashMap<>();
|
||||
outputs.put(0, output);
|
||||
|
||||
interpreter.runForMultipleInputsOutputs(inputs, outputs);
|
||||
|
||||
FloatBuffer expected = fill(FloatBuffer.allocate(8 * 1 * 1024), 2.0f);
|
||||
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
|
||||
}
|
||||
}
|
||||
|
||||
private static native long getNativeHandleForDelegate();
|
||||
|
||||
private static native long getNativeHandleForInvalidDelegate();
|
||||
|
BIN
tensorflow/lite/testdata/dynamic_shapes.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/dynamic_shapes.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user