[tf.lite] Fix issue with direct ByteBuffer inputs and dynamic graphs

In these graphs, the input tensor pointers may get "refreshed" during
invocation. This refresh is fine if the original pointer came from the
arena, but if it comes from something like the direct ByteBuffer raw
address, the input data will be lost.

Avoid this by simply using memcpy from the direct ByteBuffer. This is
still quite fast, but avoids the hack where we simply inject the
direct ByteBuffer address as the tensor buffer pointer.

A longer term solution will formally allow providing "custom" allocated
regions to tensor inputs, but until then, do the safe thing.

PiperOrigin-RevId: 310643333
Change-Id: I05dfebd24617ebb1af7eb281ff9e530b01669093
This commit is contained in:
Jared Duke 2020-05-08 15:45:44 -07:00 committed by TensorFlower Gardener
parent bca497ef47
commit b8d991c9b4
5 changed files with 60 additions and 4 deletions

View File

@ -240,6 +240,7 @@ java_test(
data = [
"src/testdata/add.bin",
"src/testdata/add_unknown_dimensions.bin",
"//tensorflow/lite:testdata/dynamic_shapes.bin",
"//tensorflow/lite:testdata/multi_add.bin",
"//tensorflow/lite:testdata/multi_add_flex.bin",
],

View File

@ -196,7 +196,7 @@ public final class Tensor {
}
private void setTo(Buffer src) {
// Note that we attempt to use zero-copy optimization for direct, native-ordered buffers.
// Note that we attempt to use a direct memcpy optimization for direct, native-ordered buffers.
// There are no base Buffer#order() or Buffer#put() methods, so again we have to ugly cast.
if (src instanceof ByteBuffer) {
ByteBuffer srcBuffer = (ByteBuffer) src;

View File

@ -402,14 +402,26 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return;
char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src));
void* src_data_raw = env->GetDirectBufferAddress(src);
if (!src_data_raw) {
ThrowException(env, kIllegalArgumentException,
"Input ByteBuffer is not a direct buffer");
return;
}
tensor->data.raw = src_data_raw;
if (!tensor->data.data) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Tensor hasn't been allocated.");
return;
}
// Historically, we would simply overwrite the tensor buffer pointer with
// the direct Buffer address. However, that is generally unsafe, and
// specifically wrong if the graph happens to have dynamic shapes where
// arena-allocated input buffers will be refreshed during invocation.
// TODO(b/156094015): Explore whether this is actually faster than
// using ByteBuffer.put(ByteBuffer).
memcpy(tensor->data.data, src_data_raw, tensor->bytes);
}
JNIEXPORT void JNICALL

View File

@ -40,6 +40,8 @@ public final class InterpreterTest {
"tensorflow/lite/testdata/multi_add_flex.bin";
private static final String UNKNOWN_DIMS_MODEL_PATH =
"tensorflow/lite/java/src/testdata/add_unknown_dimensions.bin";
private static final String DYNAMIC_SHAPES_MODEL_PATH =
"tensorflow/lite/testdata/dynamic_shapes.bin";
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
private static final ByteBuffer MULTIPLE_INPUTS_MODEL_BUFFER =
@ -48,6 +50,8 @@ public final class InterpreterTest {
TestUtils.getTestFileAsBuffer(FLEX_MODEL_PATH);
private static final ByteBuffer UNKNOWN_DIMS_MODEL_PATH_BUFFER =
TestUtils.getTestFileAsBuffer(UNKNOWN_DIMS_MODEL_PATH);
private static final ByteBuffer DYNAMIC_SHAPES_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(DYNAMIC_SHAPES_MODEL_PATH);
@Test
public void testInterpreter() throws Exception {
@ -434,7 +438,7 @@ public final class InterpreterTest {
interpreter.close();
}
/** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
// Smoke test validating that flex model loading fails when the flex delegate is not linked.
@Test
public void testFlexModel() throws Exception {
try {
@ -573,6 +577,45 @@ public final class InterpreterTest {
}
}
private static FloatBuffer fill(FloatBuffer buffer, float value) {
while (buffer.hasRemaining()) {
buffer.put(value);
}
buffer.rewind();
return buffer;
}
// Regression test case to ensure that graphs with dynamically computed shapes work properly.
// Historically, direct ByteBuffer addresses would overwrite the arena-allocated tensor input
// pointers. Normally this works fine, but for dynamic graphs, the original input tensor pointers
// may be "restored" at invocation time by the arena allocator, resetting the direct ByteBuffer
// address and leading to stale input data being used.
@Test
public void testDynamicShapesWithDirectBufferInputs() {
try (Interpreter interpreter = new Interpreter(DYNAMIC_SHAPES_MODEL_BUFFER)) {
ByteBuffer input0 =
ByteBuffer.allocateDirect(8 * 42 * 1024 * 4).order(ByteOrder.nativeOrder());
ByteBuffer input1 =
ByteBuffer.allocateDirect(1 * 90 * 1024 * 4).order(ByteOrder.nativeOrder());
ByteBuffer input2 = ByteBuffer.allocateDirect(1 * 4).order(ByteOrder.nativeOrder());
Object[] inputs = {input0, input1, input2};
fill(input0.asFloatBuffer(), 2.0f);
fill(input1.asFloatBuffer(), 0.5f);
// Note that the value of this input dictates the shape of the output.
fill(input2.asFloatBuffer(), 1.0f);
FloatBuffer output = FloatBuffer.allocate(8 * 1 * 1024);
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, output);
interpreter.runForMultipleInputsOutputs(inputs, outputs);
FloatBuffer expected = fill(FloatBuffer.allocate(8 * 1 * 1024), 2.0f);
assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
}
}
private static native long getNativeHandleForDelegate();
private static native long getNativeHandleForInvalidDelegate();

Binary file not shown.