From b8d991c9b4d4ba8c9156d4958c83504791ff9929 Mon Sep 17 00:00:00 2001 From: Jared Duke <jdduke@google.com> Date: Fri, 8 May 2020 15:45:44 -0700 Subject: [PATCH] [tf.lite] Fix issue with direct ByteBuffer inputs and dynamic graphs In these graphs, the input tensor pointers may get "refreshed" during invocation. This refresh is fine if the original pointer came from the arena, but if it comes from something like the direct ByteBuffer raw address, the input data will be lost. Avoid this by simply using memcpy from the direct ByteBuffer. This is still quite fast, but avoids the hack where we simply inject the direct ByteBuffer address as the tensor buffer pointer. A longer term solution will formally allow providing "custom" allocated regions to tensor inputs, but until then, do the safe thing. PiperOrigin-RevId: 310643333 Change-Id: I05dfebd24617ebb1af7eb281ff9e530b01669093 --- tensorflow/lite/java/BUILD | 1 + .../main/java/org/tensorflow/lite/Tensor.java | 2 +- .../lite/java/src/main/native/tensor_jni.cc | 16 ++++++- .../org/tensorflow/lite/InterpreterTest.java | 45 +++++++++++++++++- tensorflow/lite/testdata/dynamic_shapes.bin | Bin 0 -> 5264 bytes 5 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 tensorflow/lite/testdata/dynamic_shapes.bin diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index 49c2136ffb4..2fcb4b631be 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -240,6 +240,7 @@ java_test( data = [ "src/testdata/add.bin", "src/testdata/add_unknown_dimensions.bin", + "//tensorflow/lite:testdata/dynamic_shapes.bin", "//tensorflow/lite:testdata/multi_add.bin", "//tensorflow/lite:testdata/multi_add_flex.bin", ], diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 89a2a6a0639..cc9a6a451ac 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -196,7 +196,7 @@ public final class Tensor { } private void setTo(Buffer src) { - // Note that we attempt to use zero-copy optimization for direct, native-ordered buffers. + // Note that we attempt to use a direct memcpy optimization for direct, native-ordered buffers. // There are no base Buffer#order() or Buffer#put() methods, so again we have to ugly cast. if (src instanceof ByteBuffer) { ByteBuffer srcBuffer = (ByteBuffer) src; diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc index 99be71ba37d..dfa4e22162a 100644 --- a/tensorflow/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc @@ -402,14 +402,26 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( TfLiteTensor* tensor = GetTensorFromHandle(env, handle); if (tensor == nullptr) return; - char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src)); + void* src_data_raw = env->GetDirectBufferAddress(src); if (!src_data_raw) { ThrowException(env, kIllegalArgumentException, "Input ByteBuffer is not a direct buffer"); return; } - tensor->data.raw = src_data_raw; + if (!tensor->data.data) { + ThrowException(env, kIllegalArgumentException, + "Internal error: Tensor hasn't been allocated."); + return; + } + + // Historically, we would simply overwrite the tensor buffer pointer with + // the direct Buffer address. However, that is generally unsafe, and + // specifically wrong if the graph happens to have dynamic shapes where + // arena-allocated input buffers will be refreshed during invocation. + // TODO(b/156094015): Explore whether this is actually faster than + // using ByteBuffer.put(ByteBuffer). + memcpy(tensor->data.data, src_data_raw, tensor->bytes); } JNIEXPORT void JNICALL diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index cd782c7f5aa..6b6799eaad9 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -40,6 +40,8 @@ public final class InterpreterTest { "tensorflow/lite/testdata/multi_add_flex.bin"; private static final String UNKNOWN_DIMS_MODEL_PATH = "tensorflow/lite/java/src/testdata/add_unknown_dimensions.bin"; + private static final String DYNAMIC_SHAPES_MODEL_PATH = + "tensorflow/lite/testdata/dynamic_shapes.bin"; private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH); private static final ByteBuffer MULTIPLE_INPUTS_MODEL_BUFFER = @@ -48,6 +50,8 @@ public final class InterpreterTest { TestUtils.getTestFileAsBuffer(FLEX_MODEL_PATH); private static final ByteBuffer UNKNOWN_DIMS_MODEL_PATH_BUFFER = TestUtils.getTestFileAsBuffer(UNKNOWN_DIMS_MODEL_PATH); + private static final ByteBuffer DYNAMIC_SHAPES_MODEL_BUFFER = + TestUtils.getTestFileAsBuffer(DYNAMIC_SHAPES_MODEL_PATH); @Test public void testInterpreter() throws Exception { @@ -434,7 +438,7 @@ public final class InterpreterTest { interpreter.close(); } - /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */ + // Smoke test validating that flex model loading fails when the flex delegate is not linked. @Test public void testFlexModel() throws Exception { try { @@ -573,6 +577,45 @@ public final class InterpreterTest { } } + private static FloatBuffer fill(FloatBuffer buffer, float value) { + while (buffer.hasRemaining()) { + buffer.put(value); + } + buffer.rewind(); + return buffer; + } + + // Regression test case to ensure that graphs with dynamically computed shapes work properly. + // Historically, direct ByteBuffer addresses would overwrite the arena-allocated tensor input + // pointers. Normally this works fine, but for dynamic graphs, the original input tensor pointers + // may be "restored" at invocation time by the arena allocator, resetting the direct ByteBuffer + // address and leading to stale input data being used. + @Test + public void testDynamicShapesWithDirectBufferInputs() { + try (Interpreter interpreter = new Interpreter(DYNAMIC_SHAPES_MODEL_BUFFER)) { + ByteBuffer input0 = + ByteBuffer.allocateDirect(8 * 42 * 1024 * 4).order(ByteOrder.nativeOrder()); + ByteBuffer input1 = + ByteBuffer.allocateDirect(1 * 90 * 1024 * 4).order(ByteOrder.nativeOrder()); + ByteBuffer input2 = ByteBuffer.allocateDirect(1 * 4).order(ByteOrder.nativeOrder()); + Object[] inputs = {input0, input1, input2}; + + fill(input0.asFloatBuffer(), 2.0f); + fill(input1.asFloatBuffer(), 0.5f); + // Note that the value of this input dictates the shape of the output. + fill(input2.asFloatBuffer(), 1.0f); + + FloatBuffer output = FloatBuffer.allocate(8 * 1 * 1024); + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, output); + + interpreter.runForMultipleInputsOutputs(inputs, outputs); + + FloatBuffer expected = fill(FloatBuffer.allocate(8 * 1 * 1024), 2.0f); + assertThat(output.array()).usingTolerance(0.1f).containsExactly(expected.array()).inOrder(); + } + } + private static native long getNativeHandleForDelegate(); private static native long getNativeHandleForInvalidDelegate(); diff --git a/tensorflow/lite/testdata/dynamic_shapes.bin b/tensorflow/lite/testdata/dynamic_shapes.bin new file mode 100644 index 0000000000000000000000000000000000000000..268d457131a9b0e664378214bd588a908228edc9 GIT binary patch literal 5264 zcmaKwUyM~(6~?!Sg;K;3=|3FXYaK1tDFeg5Nx?X#qf#9j!Vp6UalGwZ!;N?D&^vc< zXbfo@V~A;pF+BLtG(0qf7}5|jF{B}8Vt62AYz&zg9*FUw52nG2wIbHB$KP+CyDw+B zw1<3i_St*w^?hrtz4tkn?hwM>Zx3$@VRh&ZOT+T8GIWJC^gcvS2%iM=;1byTiNJUW z3*bqx8?#}s6AXdfpQ968Wa2c~0Iq#DgcD#8T<@X{c7Pk73E?Ez4VHqLRm{c50`qTy z_rVQt9sCu%4gLbIfvaF1{1ME7OW-1y1sA}1FayqlGvG9M9ZW+(2;nUFG`Pi|onU?? ze(npQcLihj5)(K;p!4^Hun)|B3M@s>l6*eqF;8a8e288EMHsKvM;lZ1W_7$WdbrY< ztWMO!TYt+krjyqYHm={eY5hhTt&7mZuHk3)l%AfbW8bWdt*2)d*}q*2Q}Ze~27U@i z6QeoC%hftf?1=$k`@lT041!tW83yNxrx%<>=m3~mMn9M%ww>S-@eP90I6Xtb2f!?` z4ud%oHVo$QZ|<G|t_d%nF$rG-UDWy#HU2d?2$q960)7b$fa?VMDzM%Q1oS!>2G-yV zH~?0H3+S(b3+Uh3ieX19h5=$Yf&PbDG4v3_Ecyq)dGv1pGsMs>5P1S|-CzcxgWxne zH-r8E$}(;h+{L^>7xVeEce5<TXk1?fUVHzYWzI9sZZH7Mzj;7P$hhpV{mfW};x~`b zY>(^H@wf562X5!O7Q8zqk9_Mu30#D4kh{5-ZwP2BpIqsBkmp8a@?iN;W%SU*WOKCq ze6xDE+;q)ptM7y#{vGa{G46YfSHDR^{f7N7eM4I{cz+!9=XrYA=Y6-bte3l5U2C@0 zr0%4La)gnIA2r9z)3HBsUxWLtocn%loS$-ji?PI~ude{g7%sK=)bmk%+Ko@uMmKVI zZsHpzo;TrBUy9K&eQWn+@F?j0PnPZAZme?HzB8r==xgV=Q7NW)5$7zN>#3V-pvWng zoIQYL98SXNT2hNO7yxn>*OGN->4#>cT%V+{6(UUI#^Lyee`MKX<cQ`<ZHqY_W=uOd zy?+UOnCJh#*#AM6ou*H1YB&Rm8uV{Xk8-b1G{(!d>Wh`_?N;B&3sdFBqK>KYzD8wi zisd&^AFbBMs?S#@ixZ~C3s1Ys)75ci9O>*U^z4~@q0zjZ7d5GU-FsQKmK@Uj-U|Mx zKjnMtKeOx)b8}x$F?)Xud_Y~eJ?C@jzxm!WH_jabcZ>ag`VUrWwTaQ6Of+g^eZ!UV zkB{uB)TZc9wMd#Spy`cVQ^%Qdm9eD5b>Q_ya6GTwQO9P`w1K;IQj_cBZkjme!me_2 z*Ho=j!yvJ%;m-9Zm%N_`=IfqZ=i~DBF63*`9LZZ;i)l}L_nd66McT&i{Dcc~v!>|U z?pU!Vm*3B_WM5*;wr9W>fbm>moZAv(R;&8Wh2!mgxA<)9izgoA^6PI0KGu^%njQS# z<qR`^Z65`;9|3Ydo1bryyTo(iR)3p&(VlNpM}PV{5BG0!?vAr<o-ujl)7F@b{TF$z zRy%g3rNRBHzEnTs;%wtD`NU@JhcUiVF}KI5pFV`XT{DjJDh|AU<GnxM_e<9CtPcH` zfO}@L*{F_H#zrS=?AOt4{dC1^K5AMO)?)3~;4;2AcvX9CGjBa0oi}E3ITrSFEm4aa zyu0QO=egPC^Y2<(99OSe^kYs&;#+J>oM-G^*d8T5&-k>zlTUNG7C8S;x&LoDwm^<t zJNg&JeN<e}=4o%WR^isMmXE=E6JBkKyz<HWFz~(mEH>O8khh6*d5!nuc#U&U{=V+u z-5JlQRq@--tS^JTmGe@{wVKvpC#O1m_M6jhNNv=+hL-(#<!I2ehQH>*7_>`yF`kqc zKE1+7z8*$g4lV}A<q{|1{Z+mOJNDQcoMp*Y|F(T^dTyM{&$+n%)NIXqfLe=ueU<tc zU6ITAUHU{E%?`dtIKx~c>arHLjX80?^S?N6#PR<8yjvot+S2+r-c~Kzn_7E)y#n{~ z)_&{dxdhCi-{sEFucfg*&Jl8}!S{xBa}SDhE&g@1)tdN?fARj?>}O28_QpB>2I$y- z51Jd}SPA4Cz*c^F9aE<<yqoLqacpaT-ozZJGigxgcgoF!mB#n`xAXgBvYE%?oU2j8 z3hI>q4)NV#zq;kMEthwBe~{N|i_fQ8eC`Ex$!FVnPnTs`ytam-{i>M5JJ<eZY|Wiq zX10sS^`|%W%jel6Y`Y(lKEI31lePBk=dMzXv2Z=9{cqt;{`G6Svo>2c(H-O5Zv%p} z%sa)|BUj>kEkBFmT33hudO)YYZL!w!iv@hWSMX&%-)Z@BzUxR|CE#9Pyl%EdANm<$ z&Qq<kOZ()1YkpS6`uko;zV6)figJ$pzr@$i3%=-z*UJT8#+7`T=i+|WIQhDUuPZHI z=F01G!H+Q|Kb`SyZpYEXoVDz2eaM&M^KL)4n0wcS{_dXpEgo%guKQ#fpY=O@rW)Ii zN%U#1KE`Lep8A}_=kE(X>1tp0^AKxc3?Ci$?RQ}}b5^sTQ`|bUM#gQwkH+eHPyWn7 z`<t!rP^B?`JLVexo@@Ejrwe$0!H@exKknJidc>^|&f?=#L6i2W9`^GXHTe!GYVy6? z$A?X`dZ<>JysarP=Vr^Fnw--*!=fLb#oaZ<l7=DXKgBuZeDMV6&(Eb;!+iILP@=zc zPj1{mcf4M~&&vfrSby`+^vUfxqYlp_{l(v(eevghxQNZ0g?Z-VT#LVlw%;<v7#6*M z`f~c+(cW*WrI9Pe!c?yOov@76@Bh(RW3DCXY)!i4X&vX$-~c{f1HLz0pW4}0--{rv zC7+LBce!P2zqTddm{B_4cG@GVRHNF~#8&>!Qd{CkF*|P>Bn{TbDQp*N=D!gB4>H*7 AG5`Po literal 0 HcmV?d00001