TFLite GPU: Expand how an HWC tensor is read.

https://github.com/tensorflow/tensorflow/issues/39749

PiperOrigin-RevId: 313801850
Change-Id: I6017483d960abbc67572806f943c0b41cb6b5410
This commit is contained in:
Juhyun Lee 2020-05-29 10:08:56 -07:00 committed by TensorFlower Gardener
parent c3769e5ed3
commit 8383415398
1 changed files with 17 additions and 10 deletions

View File

@ -259,18 +259,25 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
}
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
if (dimensions->size != 4) {
return absl::InvalidArgumentError(
absl::StrCat("Expected a 4D tensor of shape 1xHxWxC but got ",
GetDimensionString(dimensions)));
if (dimensions->size == 3) {
shape->h = dimensions->data[0];
shape->w = dimensions->data[1];
shape->c = dimensions->data[2];
return absl::OkStatus();
}
if (dimensions->data[0] != 1) {
return absl::UnimplementedError("Batch size is not equal to 1.");
if (dimensions->size == 4) {
if (dimensions->data[0] != 1) {
return absl::UnimplementedError("Batch size is not equal to 1.");
}
shape->h = dimensions->data[1];
shape->w = dimensions->data[2];
shape->c = dimensions->data[3];
return absl::OkStatus();
}
shape->h = dimensions->data[1];
shape->w = dimensions->data[2];
shape->c = dimensions->data[3];
return absl::OkStatus();
return absl::InvalidArgumentError(
absl::StrCat("Expected a 3D tensor of shape HxWxC or a 4D tensor of "
"shape 1xHxWxC but got ",
GetDimensionString(dimensions)));
}
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {