TFLite GPU: Expand how an HWC tensor is read.

https://github.com/tensorflow/tensorflow/issues/39749

PiperOrigin-RevId: 313801850
Change-Id: I6017483d960abbc67572806f943c0b41cb6b5410
This commit is contained in:
Juhyun Lee 2020-05-29 10:08:56 -07:00 committed by TensorFlower Gardener
parent c3769e5ed3
commit 8383415398
1 changed files with 17 additions and 10 deletions

View File

@ -259,11 +259,13 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
} }
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) { absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
if (dimensions->size != 4) { if (dimensions->size == 3) {
return absl::InvalidArgumentError( shape->h = dimensions->data[0];
absl::StrCat("Expected a 4D tensor of shape 1xHxWxC but got ", shape->w = dimensions->data[1];
GetDimensionString(dimensions))); shape->c = dimensions->data[2];
return absl::OkStatus();
} }
if (dimensions->size == 4) {
if (dimensions->data[0] != 1) { if (dimensions->data[0] != 1) {
return absl::UnimplementedError("Batch size is not equal to 1."); return absl::UnimplementedError("Batch size is not equal to 1.");
} }
@ -271,6 +273,11 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
shape->w = dimensions->data[2]; shape->w = dimensions->data[2];
shape->c = dimensions->data[3]; shape->c = dimensions->data[3];
return absl::OkStatus(); return absl::OkStatus();
}
return absl::InvalidArgumentError(
absl::StrCat("Expected a 3D tensor of shape HxWxC or a 4D tensor of "
"shape 1xHxWxC but got ",
GetDimensionString(dimensions)));
} }
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) { absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {