Internal change

PiperOrigin-RevId: 226412019
This commit is contained in:
Jared Duke 2018-12-20 16:23:58 -08:00 committed by TensorFlower Gardener
parent d5a8035691
commit e7bac9435d
7 changed files with 122 additions and 10 deletions

View File

@ -234,11 +234,15 @@ public final class Interpreter implements AutoCloseable {
* including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
* input data for primitive types, whereas string types require using the (multi-dimensional)
* array input path. When {@link ByteBuffer} is used, its content should remain unchanged
* until model inference is done.
* until model inference is done. A {@code null} value is allowed only if the caller is using
* a {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to
* the input {@link Tensor}.
* @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive
* types including int, float, long, and byte.
* types including int, float, long, and byte. A null value is allowed only if the caller is
* using a {@link Delegate} that allows buffer handle interop, and such a buffer has been
* bound to the output {@link Tensor}. See also {@link Options#setAllowBufferHandleOutput()}.
*/
public void run(@NonNull Object input, @NonNull Object output) {
public void run(Object input, Object output) {
Object[] inputs = {input};
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, output);
@ -251,6 +255,10 @@ public final class Interpreter implements AutoCloseable {
* <p>Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please
* consider using {@link ByteBuffer} to feed primitive input data for better performance.
*
* <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is
* allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and
* such a buffer has been bound to the corresponding input or output {@link Tensor}(s).
*
* @param inputs an array of input data. The inputs should be in the same order as inputs of the
* model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of
* primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred

View File

@ -93,12 +93,20 @@ public final class Tensor {
* Copies the contents of the provided {@code src} object to the Tensor.
*
* <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of
* this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size.
* this tensor, a {@link ByteByffer} of compatible primitive type with a matching flat size, or
* {@code null} iff the tensor has an underlying delegate buffer handle.
*
* @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible
* with the tensor (for example, mismatched data types or shapes).
*/
void setTo(Object src) {
if (src == null) {
if (hasDelegateBufferHandle(nativeHandle)) {
return;
}
throw new IllegalArgumentException(
"Null inputs are allowed only if the Tensor is bound to a buffer handle.");
}
throwExceptionIfTypeIsIncompatible(src);
if (isByteBuffer(src)) {
ByteBuffer srcBuffer = (ByteBuffer) src;
@ -117,11 +125,19 @@ public final class Tensor {
/**
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
*
* @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}.
* @param dst the destination buffer, either an explicitly-typed array, a {@link ByteBuffer} or
* {@code null} iff the tensor has an underlying delegate buffer handle.
* @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
* mismatched data types or shapes).
*/
Object copyTo(Object dst) {
if (dst == null) {
if (hasDelegateBufferHandle(nativeHandle)) {
return dst;
}
throw new IllegalArgumentException(
"Null outputs are allowed only if the Tensor is bound to a buffer handle.");
}
throwExceptionIfTypeIsIncompatible(dst);
if (dst instanceof ByteBuffer) {
ByteBuffer dstByteBuffer = (ByteBuffer) dst;
@ -135,6 +151,9 @@ public final class Tensor {
/** Returns the provided buffer's shape if specified and different from this Tensor's shape. */
// TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs.
int[] getInputShapeIfDifferent(Object input) {
if (input == null) {
return null;
}
// Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path.
// The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}.
if (isByteBuffer(input)) {
@ -287,9 +306,7 @@ public final class Tensor {
private static native int numBytes(long handle);
private static native int setBufferHandle(long handle, long delegateHandle, int bufferHandle);
private static native int bufferHandle(long handle);
private static native boolean hasDelegateBufferHandle(long handle);
private static native void readMultiDimensionalArray(long handle, Object dst);

View File

@ -410,6 +410,17 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
return static_cast<jint>(tensor->bytes);
}
JNIEXPORT jboolean JNICALL
Java_org_tensorflow_lite_Tensor_hasDelegateBufferHandle(JNIEnv* env,
jclass clazz,
jlong handle) {
const TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return false;
return tensor->delegate && (tensor->buffer_handle != kTfLiteNullBufferHandle)
? JNI_TRUE
: JNI_FALSE;
}
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_index(JNIEnv* env,
jclass clazz,
jlong handle) {

View File

@ -84,6 +84,16 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
jclass clazz,
jlong handle);
/*
* Class: org_tensorflow_lite_Tensor
* Method: hasDelegateBufferHandle
* Signature: (J)Z
*/
JNIEXPORT jboolean JNICALL
Java_org_tensorflow_lite_Tensor_hasDelegateBufferHandle(JNIEnv* env,
jclass clazz,
jlong handle);
/*
* Class: org_tensorflow_lite_Tensor
* Method: readMultiDimensionalArray

View File

@ -334,6 +334,30 @@ public final class InterpreterTest {
interpreter.close();
}
@Test
public void testNullInputs() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
try {
interpreter.run(null, new float[2][8][8][3]);
fail();
} catch (IllegalArgumentException e) {
// Expected failure.
}
interpreter.close();
}
@Test
public void testNullOutputs() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
try {
interpreter.run(new float[2][8][8][3], null);
fail();
} catch (IllegalArgumentException e) {
// Expected failure.
}
interpreter.close();
}
/** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
@Test
public void testFlexModel() throws Exception {
@ -372,6 +396,25 @@ public final class InterpreterTest {
interpreter.close();
}
@Test
public void testNullInputsAndOutputsWithDelegate() throws Exception {
System.loadLibrary("tensorflowlite_test_jni");
Delegate delegate =
new Delegate() {
@Override
public long getNativeHandle() {
return getNativeHandleForDelegate();
}
};
Interpreter interpreter =
new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate));
// The delegate installs a custom buffer handle for all tensors, in turn allowing null to be
// provided for the inputs/outputs (as the client can reference the buffer directly).
interpreter.run(new float[2][8][8][3], null);
interpreter.run(null, new float[2][8][8][3]);
interpreter.close();
}
@Test
public void testModifyGraphWithDelegate() throws Exception {
System.loadLibrary("tensorflowlite_test_jni");

View File

@ -78,6 +78,16 @@ public final class TensorTest {
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
@Test
public void testCopyToNull() {
try {
tensor.copyTo(null);
fail();
} catch (IllegalArgumentException e) {
// Success.
}
}
@Test
public void testCopyToByteBuffer() {
ByteBuffer parsedOutput =
@ -150,6 +160,16 @@ public final class TensorTest {
assertThat(output[0][0][0][0]).isEqualTo(3.0f);
}
@Test
public void testSetToNull() {
try {
tensor.setTo(null);
fail();
} catch (IllegalArgumentException e) {
// Success.
}
}
@Test
public void testSetToInvalidByteBuffer() {
ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());

View File

@ -49,8 +49,6 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate(
.custom_name = "",
.version = 1,
};
// A simple delegate which replaces all ops with a single op that outputs a
// vector of length 1 with the value [7].
static TfLiteDelegate delegate = {
.data_ = nullptr,
.Prepare = [](TfLiteContext* context,
@ -60,6 +58,11 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate(
context->GetExecutionPlan(context, &execution_plan));
context->ReplaceNodeSubsetsWithDelegateKernels(
context, registration, execution_plan, delegate);
// Now bind delegate buffer handles for all tensors.
for (size_t i = 0; i < context->tensors_size; ++i) {
context->tensors[i].delegate = delegate;
context->tensors[i].buffer_handle = static_cast<int>(i);
}
return kTfLiteOk;
},
.CopyFromBufferHandle = nullptr,