TFLite GPU: Expand how an HWC tensor is read.
https://github.com/tensorflow/tensorflow/issues/39749 PiperOrigin-RevId: 313801850 Change-Id: I6017483d960abbc67572806f943c0b41cb6b5410
This commit is contained in:
parent
c3769e5ed3
commit
8383415398
|
@ -259,18 +259,25 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
|
|||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Expected a 4D tensor of shape 1xHxWxC but got ",
|
||||
GetDimensionString(dimensions)));
|
||||
if (dimensions->size == 3) {
|
||||
shape->h = dimensions->data[0];
|
||||
shape->w = dimensions->data[1];
|
||||
shape->c = dimensions->data[2];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
if (dimensions->data[0] != 1) {
|
||||
return absl::UnimplementedError("Batch size is not equal to 1.");
|
||||
if (dimensions->size == 4) {
|
||||
if (dimensions->data[0] != 1) {
|
||||
return absl::UnimplementedError("Batch size is not equal to 1.");
|
||||
}
|
||||
shape->h = dimensions->data[1];
|
||||
shape->w = dimensions->data[2];
|
||||
shape->c = dimensions->data[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
shape->h = dimensions->data[1];
|
||||
shape->w = dimensions->data[2];
|
||||
shape->c = dimensions->data[3];
|
||||
return absl::OkStatus();
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Expected a 3D tensor of shape HxWxC or a 4D tensor of "
|
||||
"shape 1xHxWxC but got ",
|
||||
GetDimensionString(dimensions)));
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
|
||||
|
|
Loading…
Reference in New Issue