diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc index a1705e6cf78..a0f7db25210 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc @@ -259,18 +259,25 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) { } absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) { - if (dimensions->size != 4) { - return absl::InvalidArgumentError( - absl::StrCat("Expected a 4D tensor of shape 1xHxWxC but got ", - GetDimensionString(dimensions))); + if (dimensions->size == 3) { + shape->h = dimensions->data[0]; + shape->w = dimensions->data[1]; + shape->c = dimensions->data[2]; + return absl::OkStatus(); } - if (dimensions->data[0] != 1) { - return absl::UnimplementedError("Batch size is not equal to 1."); + if (dimensions->size == 4) { + if (dimensions->data[0] != 1) { + return absl::UnimplementedError("Batch size is not equal to 1."); + } + shape->h = dimensions->data[1]; + shape->w = dimensions->data[2]; + shape->c = dimensions->data[3]; + return absl::OkStatus(); } - shape->h = dimensions->data[1]; - shape->w = dimensions->data[2]; - shape->c = dimensions->data[3]; - return absl::OkStatus(); + return absl::InvalidArgumentError( + absl::StrCat("Expected a 3D tensor of shape HxWxC or a 4D tensor of " + "shape 1xHxWxC but got ", + GetDimensionString(dimensions))); } absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {