Internal change
PiperOrigin-RevId: 226412019
This commit is contained in:
parent
d5a8035691
commit
e7bac9435d
@ -234,11 +234,15 @@ public final class Interpreter implements AutoCloseable {
|
||||
* including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
|
||||
* input data for primitive types, whereas string types require using the (multi-dimensional)
|
||||
* array input path. When {@link ByteBuffer} is used, its content should remain unchanged
|
||||
* until model inference is done.
|
||||
* until model inference is done. A {@code null} value is allowed only if the caller is using
|
||||
* a {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to
|
||||
* the input {@link Tensor}.
|
||||
* @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive
|
||||
* types including int, float, long, and byte.
|
||||
* types including int, float, long, and byte. A null value is allowed only if the caller is
|
||||
* using a {@link Delegate} that allows buffer handle interop, and such a buffer has been
|
||||
* bound to the output {@link Tensor}. See also {@link Options#setAllowBufferHandleOutput()}.
|
||||
*/
|
||||
public void run(@NonNull Object input, @NonNull Object output) {
|
||||
public void run(Object input, Object output) {
|
||||
Object[] inputs = {input};
|
||||
Map<Integer, Object> outputs = new HashMap<>();
|
||||
outputs.put(0, output);
|
||||
@ -251,6 +255,10 @@ public final class Interpreter implements AutoCloseable {
|
||||
* <p>Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please
|
||||
* consider using {@link ByteBuffer} to feed primitive input data for better performance.
|
||||
*
|
||||
* <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is
|
||||
* allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and
|
||||
* such a buffer has been bound to the corresponding input or output {@link Tensor}(s).
|
||||
*
|
||||
* @param inputs an array of input data. The inputs should be in the same order as inputs of the
|
||||
* model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of
|
||||
* primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
|
||||
|
@ -93,12 +93,20 @@ public final class Tensor {
|
||||
* Copies the contents of the provided {@code src} object to the Tensor.
|
||||
*
|
||||
* <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of
|
||||
* this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size.
|
||||
* this tensor, a {@link ByteByffer} of compatible primitive type with a matching flat size, or
|
||||
* {@code null} iff the tensor has an underlying delegate buffer handle.
|
||||
*
|
||||
* @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible
|
||||
* with the tensor (for example, mismatched data types or shapes).
|
||||
*/
|
||||
void setTo(Object src) {
|
||||
if (src == null) {
|
||||
if (hasDelegateBufferHandle(nativeHandle)) {
|
||||
return;
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"Null inputs are allowed only if the Tensor is bound to a buffer handle.");
|
||||
}
|
||||
throwExceptionIfTypeIsIncompatible(src);
|
||||
if (isByteBuffer(src)) {
|
||||
ByteBuffer srcBuffer = (ByteBuffer) src;
|
||||
@ -117,11 +125,19 @@ public final class Tensor {
|
||||
/**
|
||||
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
|
||||
*
|
||||
* @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}.
|
||||
* @param dst the destination buffer, either an explicitly-typed array, a {@link ByteBuffer} or
|
||||
* {@code null} iff the tensor has an underlying delegate buffer handle.
|
||||
* @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
|
||||
* mismatched data types or shapes).
|
||||
*/
|
||||
Object copyTo(Object dst) {
|
||||
if (dst == null) {
|
||||
if (hasDelegateBufferHandle(nativeHandle)) {
|
||||
return dst;
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"Null outputs are allowed only if the Tensor is bound to a buffer handle.");
|
||||
}
|
||||
throwExceptionIfTypeIsIncompatible(dst);
|
||||
if (dst instanceof ByteBuffer) {
|
||||
ByteBuffer dstByteBuffer = (ByteBuffer) dst;
|
||||
@ -135,6 +151,9 @@ public final class Tensor {
|
||||
/** Returns the provided buffer's shape if specified and different from this Tensor's shape. */
|
||||
// TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs.
|
||||
int[] getInputShapeIfDifferent(Object input) {
|
||||
if (input == null) {
|
||||
return null;
|
||||
}
|
||||
// Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path.
|
||||
// The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}.
|
||||
if (isByteBuffer(input)) {
|
||||
@ -287,9 +306,7 @@ public final class Tensor {
|
||||
|
||||
private static native int numBytes(long handle);
|
||||
|
||||
private static native int setBufferHandle(long handle, long delegateHandle, int bufferHandle);
|
||||
|
||||
private static native int bufferHandle(long handle);
|
||||
private static native boolean hasDelegateBufferHandle(long handle);
|
||||
|
||||
private static native void readMultiDimensionalArray(long handle, Object dst);
|
||||
|
||||
|
@ -410,6 +410,17 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
|
||||
return static_cast<jint>(tensor->bytes);
|
||||
}
|
||||
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_org_tensorflow_lite_Tensor_hasDelegateBufferHandle(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle) {
|
||||
const TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
|
||||
if (tensor == nullptr) return false;
|
||||
return tensor->delegate && (tensor->buffer_handle != kTfLiteNullBufferHandle)
|
||||
? JNI_TRUE
|
||||
: JNI_FALSE;
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_index(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle) {
|
||||
|
@ -84,6 +84,16 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_lite_Tensor
|
||||
* Method: hasDelegateBufferHandle
|
||||
* Signature: (J)Z
|
||||
*/
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_org_tensorflow_lite_Tensor_hasDelegateBufferHandle(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle);
|
||||
|
||||
/*
|
||||
* Class: org_tensorflow_lite_Tensor
|
||||
* Method: readMultiDimensionalArray
|
||||
|
@ -334,6 +334,30 @@ public final class InterpreterTest {
|
||||
interpreter.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNullInputs() throws Exception {
|
||||
Interpreter interpreter = new Interpreter(MODEL_FILE);
|
||||
try {
|
||||
interpreter.run(null, new float[2][8][8][3]);
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
// Expected failure.
|
||||
}
|
||||
interpreter.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNullOutputs() throws Exception {
|
||||
Interpreter interpreter = new Interpreter(MODEL_FILE);
|
||||
try {
|
||||
interpreter.run(new float[2][8][8][3], null);
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
// Expected failure.
|
||||
}
|
||||
interpreter.close();
|
||||
}
|
||||
|
||||
/** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
|
||||
@Test
|
||||
public void testFlexModel() throws Exception {
|
||||
@ -372,6 +396,25 @@ public final class InterpreterTest {
|
||||
interpreter.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNullInputsAndOutputsWithDelegate() throws Exception {
|
||||
System.loadLibrary("tensorflowlite_test_jni");
|
||||
Delegate delegate =
|
||||
new Delegate() {
|
||||
@Override
|
||||
public long getNativeHandle() {
|
||||
return getNativeHandleForDelegate();
|
||||
}
|
||||
};
|
||||
Interpreter interpreter =
|
||||
new Interpreter(MODEL_FILE, new Interpreter.Options().addDelegate(delegate));
|
||||
// The delegate installs a custom buffer handle for all tensors, in turn allowing null to be
|
||||
// provided for the inputs/outputs (as the client can reference the buffer directly).
|
||||
interpreter.run(new float[2][8][8][3], null);
|
||||
interpreter.run(null, new float[2][8][8][3]);
|
||||
interpreter.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModifyGraphWithDelegate() throws Exception {
|
||||
System.loadLibrary("tensorflowlite_test_jni");
|
||||
|
@ -78,6 +78,16 @@ public final class TensorTest {
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCopyToNull() {
|
||||
try {
|
||||
tensor.copyTo(null);
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
// Success.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCopyToByteBuffer() {
|
||||
ByteBuffer parsedOutput =
|
||||
@ -150,6 +160,16 @@ public final class TensorTest {
|
||||
assertThat(output[0][0][0][0]).isEqualTo(3.0f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSetToNull() {
|
||||
try {
|
||||
tensor.setTo(null);
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
// Success.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSetToInvalidByteBuffer() {
|
||||
ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
|
||||
|
@ -49,8 +49,6 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate(
|
||||
.custom_name = "",
|
||||
.version = 1,
|
||||
};
|
||||
// A simple delegate which replaces all ops with a single op that outputs a
|
||||
// vector of length 1 with the value [7].
|
||||
static TfLiteDelegate delegate = {
|
||||
.data_ = nullptr,
|
||||
.Prepare = [](TfLiteContext* context,
|
||||
@ -60,6 +58,11 @@ Java_org_tensorflow_lite_InterpreterTest_getNativeHandleForDelegate(
|
||||
context->GetExecutionPlan(context, &execution_plan));
|
||||
context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||
context, registration, execution_plan, delegate);
|
||||
// Now bind delegate buffer handles for all tensors.
|
||||
for (size_t i = 0; i < context->tensors_size; ++i) {
|
||||
context->tensors[i].delegate = delegate;
|
||||
context->tensors[i].buffer_handle = static_cast<int>(i);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
},
|
||||
.CopyFromBufferHandle = nullptr,
|
||||
|
Loading…
x
Reference in New Issue
Block a user