From 838341539843b1b972f9a8bbb3afa2d2288a6c63 Mon Sep 17 00:00:00 2001 From: Juhyun Lee Date: Fri, 29 May 2020 10:08:56 -0700 Subject: [PATCH] TFLite GPU: Expand how an HWC tensor is read. https://github.com/tensorflow/tensorflow/issues/39749 PiperOrigin-RevId: 313801850 Change-Id: I6017483d960abbc67572806f943c0b41cb6b5410 --- .../gpu/common/model_builder_helper.cc | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc index a1705e6cf78..a0f7db25210 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc @@ -259,18 +259,25 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) { } absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) { - if (dimensions->size != 4) { - return absl::InvalidArgumentError( - absl::StrCat("Expected a 4D tensor of shape 1xHxWxC but got ", - GetDimensionString(dimensions))); + if (dimensions->size == 3) { + shape->h = dimensions->data[0]; + shape->w = dimensions->data[1]; + shape->c = dimensions->data[2]; + return absl::OkStatus(); } - if (dimensions->data[0] != 1) { - return absl::UnimplementedError("Batch size is not equal to 1."); + if (dimensions->size == 4) { + if (dimensions->data[0] != 1) { + return absl::UnimplementedError("Batch size is not equal to 1."); + } + shape->h = dimensions->data[1]; + shape->w = dimensions->data[2]; + shape->c = dimensions->data[3]; + return absl::OkStatus(); } - shape->h = dimensions->data[1]; - shape->w = dimensions->data[2]; - shape->c = dimensions->data[3]; - return absl::OkStatus(); + return absl::InvalidArgumentError( + absl::StrCat("Expected a 3D tensor of shape HxWxC or a 4D tensor of " + "shape 1xHxWxC but got ", + GetDimensionString(dimensions))); } absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {