diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java index b19ef2e3b62..bced23e6f67 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java @@ -231,6 +231,26 @@ public class TensorImage { return container.getDataType(); } + /** + * Gets the image width. + * + * @throws IllegalStateException if the TensorImage never loads data. + * @throws IllegalArgumentException if the container data is corrupted. + */ + public int getWidth() { + return container.getWidth(); + } + + /** + * Gets the image height. + * + * @throws IllegalStateException if the TensorImage never loads data. + * @throws IllegalArgumentException if the container data is corrupted. + */ + public int getHeight() { + return container.getHeight(); + } + // Requires tensor shape [h, w, 3] or [1, h, w, 3]. static void checkImageTensorShape(int[] shape) { SupportPreconditions.checkArgument( @@ -273,6 +293,41 @@ public class TensorImage { isBufferUpdated = true; } + int getWidth() { + SupportPreconditions.checkState( + isBitmapUpdated || isBufferUpdated, + "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?"); + if (isBitmapUpdated) { + return bitmapImage.getWidth(); + } + return getBufferDimensionSize(-2); + } + + int getHeight() { + SupportPreconditions.checkState( + isBitmapUpdated || isBufferUpdated, + "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?"); + if (isBitmapUpdated) { + return bitmapImage.getHeight(); + } + return getBufferDimensionSize(-3); + } + + // Internal helper method to get the size of one dimension in the shape of the `bufferImage`. + // Requires `isBufferUpdated` is true. + // Throws `IllegalArgumentException` if data is corrupted. + private int getBufferDimensionSize(int dim) { + int[] shape = bufferImage.getShape(); + // The defensive check is needed because bufferImage might be invalidly changed by user + // (a.k.a internal data is corrupted) + TensorImage.checkImageTensorShape(shape); + dim = dim % shape.length; + if (dim < 0) { + dim += shape.length; + } + return shape[dim]; + } + public DataType getDataType() { return dataType; } @@ -284,7 +339,8 @@ public class TensorImage { return bitmapImage; } if (!isBufferUpdated) { - throw new IllegalStateException("Both buffer and bitmap data are obsolete."); + throw new IllegalStateException( + "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?"); } if (bufferImage.getDataType() != DataType.UINT8) { throw new IllegalStateException( @@ -310,7 +366,8 @@ public class TensorImage { return bufferImage; } SupportPreconditions.checkArgument( - isBitmapUpdated, "Both buffer and bitmap data are obsolete."); + isBitmapUpdated, + "Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?"); int requiredFlatSize = bitmapImage.getWidth() * bitmapImage.getHeight() * 3; if (bufferImage == null || (!bufferImage.isDynamic() && bufferImage.getFlatSize() != requiredFlatSize)) { diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java index 16622a25333..fa05be363a6 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java @@ -379,13 +379,13 @@ public abstract class TensorBuffer { // Check if the new shape is the same as current shape. int newFlatSize = computeFlatSize(shape); + this.shape = shape.clone(); if (flatSize == newFlatSize) { return; } // Update to the new shape. flatSize = newFlatSize; - this.shape = shape.clone(); buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize()); buffer.order(ByteOrder.nativeOrder()); }