From e7bac9435dd57740ea2b5a1a0a8bb14a1fbf093f Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Thu, 20 Dec 2018 16:23:58 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 226412019 --- .../java/org/tensorflow/lite/Interpreter.java | 14 ++++-- .../main/java/org/tensorflow/lite/Tensor.java | 27 +++++++++--- .../lite/java/src/main/native/tensor_jni.cc | 11 +++++ .../lite/java/src/main/native/tensor_jni.h | 10 +++++ .../org/tensorflow/lite/InterpreterTest.java | 43 +++++++++++++++++++ .../java/org/tensorflow/lite/TensorTest.java | 20 +++++++++ .../src/test/native/interpreter_test_jni.cc | 7 ++- 7 files changed, 122 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index 1b2d0d5aa84..5aef4fb0572 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -234,11 +234,15 @@ public final class Interpreter implements AutoCloseable { * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large * input data for primitive types, whereas string types require using the (multi-dimensional) * array input path. When {@link ByteBuffer} is used, its content should remain unchanged - * until model inference is done. + * until model inference is done. A {@code null} value is allowed only if the caller is using + * a {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to + * the input {@link Tensor}. * @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive - * types including int, float, long, and byte. + * types including int, float, long, and byte. A null value is allowed only if the caller is + * using a {@link Delegate} that allows buffer handle interop, and such a buffer has been + * bound to the output {@link Tensor}. See also {@link Options#setAllowBufferHandleOutput()}. */ - public void run(@NonNull Object input, @NonNull Object output) { + public void run(Object input, Object output) { Object[] inputs = {input}; Map outputs = new HashMap<>(); outputs.put(0, output); @@ -251,6 +255,10 @@ public final class Interpreter implements AutoCloseable { *

Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please * consider using {@link ByteBuffer} to feed primitive input data for better performance. * + *

Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is + * allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and + * such a buffer has been bound to the corresponding input or output {@link Tensor}(s). + * * @param inputs an array of input data. The inputs should be in the same order as inputs of the * model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index b56fcd772b1..725bb326ba1 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -93,12 +93,20 @@ public final class Tensor { * Copies the contents of the provided {@code src} object to the Tensor. * *

The {@code src} should either be a (multi-dimensional) array with a shape matching that of - * this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size. + * this tensor, a {@link ByteByffer} of compatible primitive type with a matching flat size, or + * {@code null} iff the tensor has an underlying delegate buffer handle. * * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible * with the tensor (for example, mismatched data types or shapes). */ void setTo(Object src) { + if (src == null) { + if (hasDelegateBufferHandle(nativeHandle)) { + return; + } + throw new IllegalArgumentException( + "Null inputs are allowed only if the Tensor is bound to a buffer handle."); + } throwExceptionIfTypeIsIncompatible(src); if (isByteBuffer(src)) { ByteBuffer srcBuffer = (ByteBuffer) src; @@ -117,11 +125,19 @@ public final class Tensor { /** * Copies the contents of the tensor to {@code dst} and returns {@code dst}. * - * @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}. + * @param dst the destination buffer, either an explicitly-typed array, a {@link ByteBuffer} or + * {@code null} iff the tensor has an underlying delegate buffer handle. * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example, * mismatched data types or shapes). */ Object copyTo(Object dst) { + if (dst == null) { + if (hasDelegateBufferHandle(nativeHandle)) { + return dst; + } + throw new IllegalArgumentException( + "Null outputs are allowed only if the Tensor is bound to a buffer handle."); + } throwExceptionIfTypeIsIncompatible(dst); if (dst instanceof ByteBuffer) { ByteBuffer dstByteBuffer = (ByteBuffer) dst; @@ -135,6 +151,9 @@ public final class Tensor { /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */ // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs. int[] getInputShapeIfDifferent(Object input) { + if (input == null) { + return null; + } // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path. // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}. if (isByteBuffer(input)) { @@ -287,9 +306,7 @@ public final class Tensor { private static native int numBytes(long handle); - private static native int setBufferHandle(long handle, long delegateHandle, int bufferHandle); - - private static native int bufferHandle(long handle); + private static native boolean hasDelegateBufferHandle(long handle); private static native void readMultiDimensionalArray(long handle, Object dst); diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc index cc81eb8d517..f07437e7f31 100644 --- a/tensorflow/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc @@ -410,6 +410,17 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, return static_cast(tensor->bytes); } +JNIEXPORT jboolean JNICALL +Java_org_tensorflow_lite_Tensor_hasDelegateBufferHandle(JNIEnv* env, + jclass clazz, + jlong handle) { + const TfLiteTensor* tensor = GetTensorFromHandle(env, handle); + if (tensor == nullptr) return false; + return tensor->delegate && (tensor->buffer_handle != kTfLiteNullBufferHandle) + ? JNI_TRUE + : JNI_FALSE; +} + JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_index(JNIEnv* env, jclass clazz, jlong handle) { diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.h b/tensorflow/lite/java/src/main/native/tensor_jni.h index 52150bf3ab3..a14f24a47d0 100644 --- a/tensorflow/lite/java/src/main/native/tensor_jni.h +++ b/tensorflow/lite/java/src/main/native/tensor_jni.h @@ -84,6 +84,16 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, jclass clazz, jlong handle); +/* + * Class: org_tensorflow_lite_Tensor + * Method: hasDelegateBufferHandle + * Signature: (J)Z + */ +JNIEXPORT jboolean JNICALL +Java_org_tensorflow_lite_Tensor_hasDelegateBufferHandle(JNIEnv* env, + jclass clazz, + jlong handle); + /* * Class: org_tensorflow_lite_Tensor * Method: readMultiDimensionalArray diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index f89062ba458..c5496e3a21e 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -334,6 +334,30 @@ public final class InterpreterTest { interpreter.close(); } + @Test + public void testNullInputs() throws Exception { + Interpreter interpreter = new Interpreter(MODEL_FILE); + try { + interpreter.run(null, new float[2][8][8][3]); + fail(); + } catch (IllegalArgumentException e) { + // Expected failure. + } + interpreter.close(); + } + + @Test + public void testNullOutputs() throws Exception { + Interpreter interpreter = new Interpreter(MODEL_FILE); + try { + interpreter.run(new float[2][8][8][3], null); + fail(); + } catch (IllegalArgumentException e) { + // Expected failure. + } + interpreter.close(); + } + /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */ @Test public void testFlexModel() throws Exception { @@ -372,6 +396,25 @@ public final class InterpreterTest { interpreter.close(); } + @Test + public void testNullInputsAndOutputsWithDelegate() throws Exception { + System.loadLibrary("tensorflowlite_test_jni"); + Delegate delegate = + new Delegate() { + @Override + public long getNativeHandle() { + return getNativeHandleForDelegate(); + } + }; + Interpreter interpreter = + new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate)); + // The delegate installs a custom buffer handle for all tensors, in turn allowing null to be + // provided for the inputs/outputs (as the client can reference the buffer directly). + interpreter.run(new float[2][8][8][3], null); + interpreter.run(null, new float[2][8][8][3]); + interpreter.close(); + } + @Test public void testModifyGraphWithDelegate() throws Exception { System.loadLibrary("tensorflowlite_test_jni"); diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index be6a706b8d4..d9b20510106 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -78,6 +78,16 @@ public final class TensorTest { assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); } + @Test + public void testCopyToNull() { + try { + tensor.copyTo(null); + fail(); + } catch (IllegalArgumentException e) { + // Success. + } + } + @Test public void testCopyToByteBuffer() { ByteBuffer parsedOutput = @@ -150,6 +160,16 @@ public final class TensorTest { assertThat(output[0][0][0][0]).isEqualTo(3.0f); } + @Test + public void testSetToNull() { + try { + tensor.setTo(null); + fail(); + } catch (IllegalArgumentException e) { + // Success. + } + } + @Test public void testSetToInvalidByteBuffer() { ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); diff --git a/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc b/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc index 000e718ba7a..f5bcc1249f0 100644 --- a/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc +++ b/tensorflow/lite/java/src/test/native/interpreter_test_jni.cc @@ -49,8 +49,6 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate( .custom_name = "", .version = 1, }; - // A simple delegate which replaces all ops with a single op that outputs a - // vector of length 1 with the value [7]. static TfLiteDelegate delegate = { .data_ = nullptr, .Prepare = [](TfLiteContext* context, @@ -60,6 +58,11 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate( context->GetExecutionPlan(context, &execution_plan)); context->ReplaceNodeSubsetsWithDelegateKernels( context, registration, execution_plan, delegate); + // Now bind delegate buffer handles for all tensors. + for (size_t i = 0; i < context->tensors_size; ++i) { + context->tensors[i].delegate = delegate; + context->tensors[i].buffer_handle = static_cast(i); + } return kTfLiteOk; }, .CopyFromBufferHandle = nullptr,