[tfls.image] Provide TensorImage#getWidth and TensorImage#getHeight as they are not trivial.

PiperOrigin-RevId: 311060642
Change-Id: Ie6d8043ffe82cb6276cb919d9d799c0740ef29c0
This commit is contained in:
Xunkai Zhang 2020-05-11 22:10:49 -07:00 committed by TensorFlower Gardener
parent bec0b38233
commit 4926e23ba4
2 changed files with 60 additions and 3 deletions

View File

@ -231,6 +231,26 @@ public class TensorImage {
return container.getDataType(); return container.getDataType();
} }
/**
* Gets the image width.
*
* @throws IllegalStateException if the TensorImage never loads data.
* @throws IllegalArgumentException if the container data is corrupted.
*/
public int getWidth() {
return container.getWidth();
}
/**
* Gets the image height.
*
* @throws IllegalStateException if the TensorImage never loads data.
* @throws IllegalArgumentException if the container data is corrupted.
*/
public int getHeight() {
return container.getHeight();
}
// Requires tensor shape [h, w, 3] or [1, h, w, 3]. // Requires tensor shape [h, w, 3] or [1, h, w, 3].
static void checkImageTensorShape(int[] shape) { static void checkImageTensorShape(int[] shape) {
SupportPreconditions.checkArgument( SupportPreconditions.checkArgument(
@ -273,6 +293,41 @@ public class TensorImage {
isBufferUpdated = true; isBufferUpdated = true;
} }
int getWidth() {
SupportPreconditions.checkState(
isBitmapUpdated || isBufferUpdated,
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
if (isBitmapUpdated) {
return bitmapImage.getWidth();
}
return getBufferDimensionSize(-2);
}
int getHeight() {
SupportPreconditions.checkState(
isBitmapUpdated || isBufferUpdated,
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
if (isBitmapUpdated) {
return bitmapImage.getHeight();
}
return getBufferDimensionSize(-3);
}
// Internal helper method to get the size of one dimension in the shape of the `bufferImage`.
// Requires `isBufferUpdated` is true.
// Throws `IllegalArgumentException` if data is corrupted.
private int getBufferDimensionSize(int dim) {
int[] shape = bufferImage.getShape();
// The defensive check is needed because bufferImage might be invalidly changed by user
// (a.k.a internal data is corrupted)
TensorImage.checkImageTensorShape(shape);
dim = dim % shape.length;
if (dim < 0) {
dim += shape.length;
}
return shape[dim];
}
public DataType getDataType() { public DataType getDataType() {
return dataType; return dataType;
} }
@ -284,7 +339,8 @@ public class TensorImage {
return bitmapImage; return bitmapImage;
} }
if (!isBufferUpdated) { if (!isBufferUpdated) {
throw new IllegalStateException("Both buffer and bitmap data are obsolete."); throw new IllegalStateException(
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
} }
if (bufferImage.getDataType() != DataType.UINT8) { if (bufferImage.getDataType() != DataType.UINT8) {
throw new IllegalStateException( throw new IllegalStateException(
@ -310,7 +366,8 @@ public class TensorImage {
return bufferImage; return bufferImage;
} }
SupportPreconditions.checkArgument( SupportPreconditions.checkArgument(
isBitmapUpdated, "Both buffer and bitmap data are obsolete."); isBitmapUpdated,
"Both buffer and bitmap data are obsolete. Forgot to call TensorImage#load?");
int requiredFlatSize = bitmapImage.getWidth() * bitmapImage.getHeight() * 3; int requiredFlatSize = bitmapImage.getWidth() * bitmapImage.getHeight() * 3;
if (bufferImage == null if (bufferImage == null
|| (!bufferImage.isDynamic() && bufferImage.getFlatSize() != requiredFlatSize)) { || (!bufferImage.isDynamic() && bufferImage.getFlatSize() != requiredFlatSize)) {

View File

@ -379,13 +379,13 @@ public abstract class TensorBuffer {
// Check if the new shape is the same as current shape. // Check if the new shape is the same as current shape.
int newFlatSize = computeFlatSize(shape); int newFlatSize = computeFlatSize(shape);
this.shape = shape.clone();
if (flatSize == newFlatSize) { if (flatSize == newFlatSize) {
return; return;
} }
// Update to the new shape. // Update to the new shape.
flatSize = newFlatSize; flatSize = newFlatSize;
this.shape = shape.clone();
buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize()); buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
buffer.order(ByteOrder.nativeOrder()); buffer.order(ByteOrder.nativeOrder());
} }