From b8d991c9b4d4ba8c9156d4958c83504791ff9929 Mon Sep 17 00:00:00 2001
From: Jared Duke <jdduke@google.com>
Date: Fri, 8 May 2020 15:45:44 -0700
Subject: [PATCH] [tf.lite] Fix issue with direct ByteBuffer inputs and dynamic
 graphs

In these graphs, the input tensor pointers may get "refreshed" during
invocation. This refresh is fine if the original pointer came from the
arena, but if it comes from something like the direct ByteBuffer raw
address, the input data will be lost.

Avoid this by simply using memcpy from the direct ByteBuffer. This is
still quite fast, but avoids the hack where we simply inject the
direct ByteBuffer address as the tensor buffer pointer.

A longer term solution will formally allow providing "custom" allocated
regions to tensor inputs, but until then, do the safe thing.

PiperOrigin-RevId: 310643333
Change-Id: I05dfebd24617ebb1af7eb281ff9e530b01669093
---
 tensorflow/lite/java/BUILD                    |   1 +
 .../main/java/org/tensorflow/lite/Tensor.java |   2 +-
 .../lite/java/src/main/native/tensor_jni.cc   |  16 ++++++-
 .../org/tensorflow/lite/InterpreterTest.java  |  45 +++++++++++++++++-
 tensorflow/lite/testdata/dynamic_shapes.bin   | Bin 0 -> 5264 bytes
 5 files changed, 60 insertions(+), 4 deletions(-)
 create mode 100644 tensorflow/lite/testdata/dynamic_shapes.bin

diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD
index 49c2136ffb4..2fcb4b631be 100644
--- a/tensorflow/lite/java/BUILD
+++ b/tensorflow/lite/java/BUILD
@@ -240,6 +240,7 @@ java_test(
     data = [
         "src/testdata/add.bin",
         "src/testdata/add_unknown_dimensions.bin",
+        "//tensorflow/lite:testdata/dynamic_shapes.bin",
         "//tensorflow/lite:testdata/multi_add.bin",
         "//tensorflow/lite:testdata/multi_add_flex.bin",
     ],
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index 89a2a6a0639..cc9a6a451ac 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -196,7 +196,7 @@ public final class Tensor {
   }
 
   private void setTo(Buffer src) {
-    // Note that we attempt to use zero-copy optimization for direct, native-ordered buffers.
+    // Note that we attempt to use a direct memcpy optimization for direct, native-ordered buffers.
     // There are no base Buffer#order() or Buffer#put() methods, so again we have to ugly cast.
     if (src instanceof ByteBuffer) {
       ByteBuffer srcBuffer = (ByteBuffer) src;
diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc
index 99be71ba37d..dfa4e22162a 100644
--- a/tensorflow/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc
@@ -402,14 +402,26 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
   TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
   if (tensor == nullptr) return;
 
-  char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src));
+  void* src_data_raw = env->GetDirectBufferAddress(src);
   if (!src_data_raw) {
     ThrowException(env, kIllegalArgumentException,
                    "Input ByteBuffer is not a direct buffer");
     return;
   }
 
-  tensor->data.raw = src_data_raw;
+  if (!tensor->data.data) {
+    ThrowException(env, kIllegalArgumentException,
+                   "Internal error: Tensor hasn't been allocated.");
+    return;
+  }
+
+  // Historically, we would simply overwrite the tensor buffer pointer with
+  // the direct Buffer address. However, that is generally unsafe, and
+  // specifically wrong if the graph happens to have dynamic shapes where
+  // arena-allocated input buffers will be refreshed during invocation.
+  // TODO(b/156094015): Explore whether this is actually faster than
+  // using ByteBuffer.put(ByteBuffer).
+  memcpy(tensor->data.data, src_data_raw, tensor->bytes);
 }
 
 JNIEXPORT void JNICALL
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index cd782c7f5aa..6b6799eaad9 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -40,6 +40,8 @@ public final class InterpreterTest {
       "tensorflow/lite/testdata/multi_add_flex.bin";
   private static final String UNKNOWN_DIMS_MODEL_PATH =
       "tensorflow/lite/java/src/testdata/add_unknown_dimensions.bin";
+  private static final String DYNAMIC_SHAPES_MODEL_PATH =
+      "tensorflow/lite/testdata/dynamic_shapes.bin";
 
   private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
   private static final ByteBuffer MULTIPLE_INPUTS_MODEL_BUFFER =
@@ -48,6 +50,8 @@ public final class InterpreterTest {
       TestUtils.getTestFileAsBuffer(FLEX_MODEL_PATH);
   private static final ByteBuffer UNKNOWN_DIMS_MODEL_PATH_BUFFER =
       TestUtils.getTestFileAsBuffer(UNKNOWN_DIMS_MODEL_PATH);
+  private static final ByteBuffer DYNAMIC_SHAPES_MODEL_BUFFER =
+      TestUtils.getTestFileAsBuffer(DYNAMIC_SHAPES_MODEL_PATH);
 
   @Test
   public void testInterpreter() throws Exception {
@@ -434,7 +438,7 @@ public final class InterpreterTest {
     interpreter.close();
   }
 
-  /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
+  // Smoke test validating that flex model loading fails when the flex delegate is not linked.
   @Test
   public void testFlexModel() throws Exception {
     try {
@@ -573,6 +577,45 @@ public final class InterpreterTest {
     }
   }
 
+  private static FloatBuffer fill(FloatBuffer buffer, float value) {
+    while (buffer.hasRemaining()) {
+      buffer.put(value);
+    }
+    buffer.rewind();
+    return buffer;
+  }
+
+  // Regression test case to ensure that graphs with dynamically computed shapes work properly.
+  // Historically, direct ByteBuffer addresses would overwrite the arena-allocated tensor input
+  // pointers. Normally this works fine, but for dynamic graphs, the original input tensor pointers
+  // may be "restored" at invocation time by the arena allocator, resetting the direct ByteBuffer
+  // address and leading to stale input data being used.
+  @Test
+  public void testDynamicShapesWithDirectBufferInputs() {
+    try (Interpreter interpreter = new Interpreter(DYNAMIC_SHAPES_MODEL_BUFFER)) {
+      ByteBuffer input0 =
+          ByteBuffer.allocateDirect(8 * 42 * 1024 * 4).order(ByteOrder.nativeOrder());
+      ByteBuffer input1 =
+          ByteBuffer.allocateDirect(1 * 90 * 1024 * 4).order(ByteOrder.nativeOrder());
+      ByteBuffer input2 = ByteBuffer.allocateDirect(1 * 4).order(ByteOrder.nativeOrder());
+      Object[] inputs = {input0, input1, input2};
+
+      fill(input0.asFloatBuffer(), 2.0f);
+      fill(input1.asFloatBuffer(), 0.5f);
+      // Note that the value of this input dictates the shape of the output.
+      fill(input2.asFloatBuffer(), 1.0f);
+
+      FloatBuffer output = FloatBuffer.allocate(8 * 1 * 1024);
+      Map<Integer, Object> outputs = new HashMap<>();
+      outputs.put(0, output);
+
+      interpreter.runForMultipleInputsOutputs(inputs, outputs);
+
+      FloatBuffer expected = fill(FloatBuffer.allocate(8 * 1 * 1024), 2.0f);
+      assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder();
+    }
+  }
+
   private static native long getNativeHandleForDelegate();
 
   private static native long getNativeHandleForInvalidDelegate();
diff --git a/tensorflow/lite/testdata/dynamic_shapes.bin b/tensorflow/lite/testdata/dynamic_shapes.bin
new file mode 100644
index 0000000000000000000000000000000000000000..268d457131a9b0e664378214bd588a908228edc9
GIT binary patch
literal 5264
zcmaKwUyM~(6~?!Sg;K;3=|3FXYaK1tDFeg5Nx?X#qf#9j!Vp6UalGwZ!;N?D&^vc<
zXbfo@V~A;pF+BLtG(0qf7}5|jF{B}8Vt62AYz&zg9*FUw52nG2wIbHB$KP+CyDw+B
zw1<3i_St*w^?hrtz4tkn?hwM>Zx3$@VRh&ZOT+T8GIWJC^gcvS2%iM=;1byTiNJUW
z3*bqx8?#}s6AXdfpQ968Wa2c~0Iq#DgcD#8T<@X{c7Pk73E?Ez4VHqLRm{c50`qTy
z_rVQt9sCu%4gLbIfvaF1{1ME7OW-1y1sA}1FayqlGvG9M9ZW+(2;nUFG`Pi|onU??
ze(npQcLihj5)(K;p!4^Hun)|B3M@s>l6*eqF;8a8e288EMHsKvM;lZ1W_7$WdbrY<
ztWMO!TYt+krjyqYHm={eY5hhTt&7mZuHk3)l%AfbW8bWdt*2)d*}q*2Q}Ze~27U@i
z6QeoC%hftf?1=$k`@lT041!tW83yNxrx%<>=m3~mMn9M%ww>S-@eP90I6Xtb2f!?`
z4ud%oHVo$QZ|<G|t_d%nF$rG-UDWy#HU2d?2$q960)7b$fa?VMDzM%Q1oS!>2G-yV
zH~?0H3+S(b3+Uh3ieX19h5=$Yf&PbDG4v3_Ecyq)dGv1pGsMs>5P1S|-CzcxgWxne
zH-r8E$}(;h+{L^>7xVeEce5<TXk1?fUVHzYWzI9sZZH7Mzj;7P$hhpV{mfW};x~`b
zY>(^H@wf562X5!O7Q8zqk9_Mu30#D4kh{5-ZwP2BpIqsBkmp8a@?iN;W%SU*WOKCq
ze6xDE+;q)ptM7y#{vGa{G46YfSHDR^{f7N7eM4I{cz+!9=XrYA=Y6-bte3l5U2C@0
zr0%4La)gnIA2r9z)3HBsUxWLtocn%loS$-ji?PI~ude{g7%sK=)bmk%+Ko@uMmKVI
zZsHpzo;TrBUy9K&eQWn+@F?j0PnPZAZme?HzB8r==xgV=Q7NW)5$7zN>#3V-pvWng
zoIQYL98SXNT2hNO7yxn>*OGN->4#>cT%V+{6(UUI#^Lyee`MKX<cQ`<ZHqY_W=uOd
zy?+UOnCJh#*#AM6ou*H1YB&Rm8uV{Xk8-b1G{(!d>Wh`_?N;B&3sdFBqK>KYzD8wi
zisd&^AFbBMs?S#@ixZ~C3s1Ys)75ci9O>*U^z4~@q0zjZ7d5GU-FsQKmK@Uj-U|Mx
zKjnMtKeOx)b8}x$F?)Xud_Y~eJ?C@jzxm!WH_jabcZ>ag`VUrWwTaQ6Of+g^eZ!UV
zkB{uB)TZc9wMd#Spy`cVQ^%Qdm9eD5b>Q_ya6GTwQO9P`w1K;IQj_cBZkjme!me_2
z*Ho=j!yvJ%;m-9Zm%N_`=IfqZ=i~DBF63*`9LZZ;i)l}L_nd66McT&i{Dcc~v!>|U
z?pU!Vm*3B_WM5*;wr9W>fbm>moZAv(R;&8Wh2!mgxA<)9izgoA^6PI0KGu^%njQS#
z<qR`^Z65`;9|3Ydo1bryyTo(iR)3p&(VlNpM}PV{5BG0!?vAr<o-ujl)7F@b{TF$z
zRy%g3rNRBHzEnTs;%wtD`NU@JhcUiVF}KI5pFV`XT{DjJDh|AU<GnxM_e<9CtPcH`
zfO}@L*{F_H#zrS=?AOt4{dC1^K5AMO)?)3~;4;2AcvX9CGjBa0oi}E3ITrSFEm4aa
zyu0QO=egPC^Y2<(99OSe^kYs&;#+J>oM-G^*d8T5&-k>zlTUNG7C8S;x&LoDwm^<t
zJNg&JeN<e}=4o%WR^isMmXE=E6JBkKyz<HWFz~(mEH>O8khh6*d5!nuc#U&U{=V+u
z-5JlQRq@--tS^JTmGe@{wVKvpC#O1m_M6jhNNv=+hL-(#<!I2ehQH>*7_>`yF`kqc
zKE1+7z8*$g4lV}A<q{|1{Z+mOJNDQcoMp*Y|F(T^dTyM{&$+n%)NIXqfLe=ueU<tc
zU6ITAUHU{E%?`dtIKx~c>arHLjX80?^S?N6#PR<8yjvot+S2+r-c~Kzn_7E)y#n{~
z)_&{dxdhCi-{sEFucfg*&Jl8}!S{xBa}SDhE&g@1)tdN?fARj?>}O28_QpB>2I$y-
z51Jd}SPA4Cz*c^F9aE<<yqoLqacpaT-ozZJGigxgcgoF!mB#n`xAXgBvYE%?oU2j8
z3hI>q4)NV#zq;kMEthwBe~{N|i_fQ8eC`Ex$!FVnPnTs`ytam-{i>M5JJ<eZY|Wiq
zX10sS^`|%W%jel6Y`Y(lKEI31lePBk=dMzXv2Z=9{cqt;{`G6Svo>2c(H-O5Zv%p}
z%sa)|BUj>kEkBFmT33hudO)YYZL!w!iv@hWSMX&%-)Z@BzUxR|CE#9Pyl%EdANm<$
z&Qq<kOZ()1YkpS6`uko;zV6)figJ$pzr@$i3%=-z*UJT8#+7`T=i+|WIQhDUuPZHI
z=F01G!H+Q|Kb`SyZpYEXoVDz2eaM&M^KL)4n0wcS{_dXpEgo%guKQ#fpY=O@rW)Ii
zN%U#1KE`Lep8A}_=kE(X>1tp0^AKxc3?Ci$?RQ}}b5^sTQ`|bUM#gQwkH+eHPyWn7
z`<t!rP^B?`JLVexo@@Ejrwe$0!H@exKknJidc>^|&f?=#L6i2W9`^GXHTe!GYVy6?
z$A?X`dZ<>JysarP=Vr^Fnw--*!=fLb#oaZ<l7=DXKgBuZeDMV6&(Eb;!+iILP@=zc
zPj1{mcf4M~&&vfrSby`+^vUfxqYlp_{l(v(eevghxQNZ0g?Z-VT#LVlw%;<v7#6*M
z`f~c+(cW*WrI9Pe!c?yOov@76@Bh(RW3DCXY)!i4X&vX$-~c{f1HLz0pW4}0--{rv
zC7+LBce!P2zqTddm{B_4cG@GVRHNF~#8&>!Qd{CkF*|P>Bn{TbDQp*N=D!gB4>H*7
AG5`Po

literal 0
HcmV?d00001