Allow larger Java output buffers for TFLite outputs

Allow the user to provide a larger output buffer than is necessary
when copying from an output tensor in the TFLite Java bindings.
This makes it easier to accommodate outputs that might have variable
output size using a single, pre-allocated output.

See also PR #39266.

PiperOrigin-RevId: 314454310
Change-Id: I83fd82344196831cdd240f106a588996ad87e88b
This commit is contained in:
Jared Duke 2020-06-02 20:02:17 -07:00 committed by TensorFlower Gardener
parent 5febf24c80
commit 6afec5ecbb
3 changed files with 66 additions and 26 deletions

View File

@ -185,7 +185,8 @@ public final class Tensor {
throw new IllegalArgumentException(
"Null inputs are allowed only if the Tensor is bound to a buffer handle.");
}
throwIfDataIsIncompatible(src);
throwIfTypeIsIncompatible(src);
throwIfSrcShapeIsIncompatible(src);
if (isBuffer(src)) {
setTo((Buffer) src);
} else if (src.getClass().isArray()) {
@ -247,7 +248,8 @@ public final class Tensor {
throw new IllegalArgumentException(
"Null outputs are allowed only if the Tensor is bound to a buffer handle.");
}
throwIfDataIsIncompatible(dst);
throwIfTypeIsIncompatible(dst);
throwIfDstShapeIsIncompatible(dst);
if (isBuffer(dst)) {
copyTo((Buffer) dst);
} else {
@ -387,11 +389,6 @@ public final class Tensor {
}
}
private void throwIfDataIsIncompatible(Object o) {
throwIfTypeIsIncompatible(o);
throwIfShapeIsIncompatible(o);
}
private void throwIfTypeIsIncompatible(Object o) {
// ByteBuffer payloads can map to any type, so exempt it from the check.
if (isByteBuffer(o)) {
@ -413,29 +410,58 @@ public final class Tensor {
}
}
private void throwIfShapeIsIncompatible(Object o) {
if (isBuffer(o)) {
Buffer oBuffer = (Buffer) o;
private void throwIfSrcShapeIsIncompatible(Object src) {
if (isBuffer(src)) {
Buffer srcBuffer = (Buffer) src;
int bytes = numBytes();
// Note that we allow the client to provide a ByteBuffer even for non-byte Tensors.
// In such cases, we only care that the raw byte capacity matches the tensor byte capacity.
int oBytes = isByteBuffer(o) ? oBuffer.capacity() : oBuffer.capacity() * dtype.byteSize();
if (bytes != oBytes) {
int srcBytes =
isByteBuffer(src) ? srcBuffer.capacity() : srcBuffer.capacity() * dtype.byteSize();
if (bytes != srcBytes) {
throw new IllegalArgumentException(
String.format(
"Cannot convert between a TensorFlowLite buffer with %d bytes and a "
"Cannot copy to a TensorFlowLite tensor (%s) with %d bytes from a "
+ "Java Buffer with %d bytes.",
bytes, oBytes));
name(), bytes, srcBytes));
}
return;
}
int[] oShape = computeShapeOf(o);
if (!Arrays.equals(oShape, shapeCopy)) {
int[] srcShape = computeShapeOf(src);
if (!Arrays.equals(srcShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
"Cannot copy between a TensorFlowLite tensor with shape %s and a Java object "
"Cannot copy to a TensorFlowLite tensor (%s) with shape %s from a Java object "
+ "with shape %s.",
Arrays.toString(shapeCopy), Arrays.toString(oShape)));
name(), Arrays.toString(shapeCopy), Arrays.toString(srcShape)));
}
}
private void throwIfDstShapeIsIncompatible(Object dst) {
if (isBuffer(dst)) {
Buffer dstBuffer = (Buffer) dst;
int bytes = numBytes();
// Note that we allow the client to provide a ByteBuffer even for non-byte Tensors.
// In such cases, we only care that the raw byte capacity fits the tensor byte capacity.
// This is subtly different than Buffer *inputs*, where the size should be exact.
int dstBytes =
isByteBuffer(dst) ? dstBuffer.capacity() : dstBuffer.capacity() * dtype.byteSize();
if (bytes > dstBytes) {
throw new IllegalArgumentException(
String.format(
"Cannot copy from a TensorFlowLite tensor (%s) with %d bytes to a "
+ "Java Buffer with %d bytes.",
name(), bytes, dstBytes));
}
return;
}
int[] dstShape = computeShapeOf(dst);
if (!Arrays.equals(dstShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
"Cannot copy from a TensorFlowLite tensor (%s) with shape %s to a Java object "
+ "with shape %s.",
name(), Arrays.toString(shapeCopy), Arrays.toString(dstShape)));
}
}

View File

@ -301,8 +301,8 @@ public final class NativeInterpreterWrapperTest {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot copy between a TensorFlowLite tensor with shape [2, 4, 4, 12] and "
+ "a Java object with shape [2, 4, 4, 10]");
"Cannot copy from a TensorFlowLite tensor (output_tensor) with shape [2, 4, 4, 12] "
+ "to a Java object with shape [2, 4, 4, 10]");
}
}
}
@ -365,7 +365,7 @@ public final class NativeInterpreterWrapperTest {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot convert between a TensorFlowLite buffer with 768 bytes and a "
"Cannot copy to a TensorFlowLite tensor (input) with 768 bytes from a "
+ "Java Buffer with 3072 bytes.");
}
int[] inputDims = {4, 8, 8, 3};
@ -393,7 +393,7 @@ public final class NativeInterpreterWrapperTest {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot convert between a TensorFlowLite buffer with 192 bytes and a "
"Cannot copy to a TensorFlowLite tensor (input) with 192 bytes from a "
+ "Java Buffer with 336 bytes.");
}
}
@ -494,7 +494,7 @@ public final class NativeInterpreterWrapperTest {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a "
"Cannot copy from a TensorFlowLite tensor (output) with shape [8, 7, 3] to a "
+ "Java object with shape [2, 8, 8, 3].");
}
}
@ -518,7 +518,7 @@ public final class NativeInterpreterWrapperTest {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a "
"Cannot copy from a TensorFlowLite tensor (output) with shape [2, 8, 7, 3] to a "
+ "Java object with shape [2, 8, 8, 3].");
}
}

View File

@ -117,6 +117,20 @@ public final class TensorTest {
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
@Test
public void testCopyToLargerByteBuffer() {
// Allocate a ByteBuffer that is larger than the Tensor, and ensure we can copy to it.
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(10 * 2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
tensor.copyTo(parsedOutput);
assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4);
float[] outputOneD = {
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
};
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
@Test
public void testCopyToByteBufferAsFloatBuffer() {
FloatBuffer parsedOutput =
@ -203,8 +217,8 @@ public final class TensorTest {
assertThat(e)
.hasMessageThat()
.contains(
"Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] "
+ "and a Java object with shape [1, 8, 8, 3].");
"Cannot copy from a TensorFlowLite tensor (output) with shape [2, 8, 8, 3] "
+ "to a Java object with shape [1, 8, 8, 3].");
}
}