Provide more verbose error message in SetAllDimensions()

Error messages will include actual tensor dimension information.

PiperOrigin-RevId: 313529606
Change-Id: I88631ea2ebba796fe266a5d0ea3ea73e4e3ad3ed
This commit is contained in:
Terry Heo 2020-05-27 22:47:03 -07:00 committed by TensorFlower Gardener
parent 8404e3e296
commit ad74ce7491

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <fp16.h> #include <fp16.h>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h" #include "tensorflow/lite/context.h"
@ -220,14 +221,18 @@ absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
return absl::OkStatus(); return absl::OkStatus();
} }
const std::string GetDimensionString(const TfLiteIntArray* dimensions) {
return absl::StrJoin(TfLiteIntArrayView(dimensions), "x");
}
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) { absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) {
if (dimensions->size < 0) { if (dimensions->size < 0) {
return absl::InvalidArgumentError("Invalid Scalar dimensions"); return absl::InvalidArgumentError("Invalid Scalar dimensions");
} }
for (int i = 0; i < dimensions->size; ++i) { for (int i = 0; i < dimensions->size; ++i) {
if (dimensions->data[i] != 1) { if (dimensions->data[i] != 1) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(absl::StrCat(
"Dimension can not be reduced to scalar."); GetDimensionString(dimensions), " cannot be reduced to scalar."));
} }
} }
shape->v = 1; shape->v = 1;
@ -240,8 +245,8 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
} }
for (int i = 0; i < dimensions->size - 1; ++i) { for (int i = 0; i < dimensions->size - 1; ++i) {
if (dimensions->data[i] != 1) { if (dimensions->data[i] != 1) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(absl::StrCat(
"Dimension can not be reduced to linear."); GetDimensionString(dimensions), " cannot be reduced to linear."));
} }
} }
shape->v = dimensions->data[dimensions->size - 1]; shape->v = dimensions->data[dimensions->size - 1];
@ -250,7 +255,9 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) { absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
if (dimensions->size != 4) { if (dimensions->size != 4) {
return absl::InvalidArgumentError("Dimensions are not HWC"); return absl::InvalidArgumentError(
absl::StrCat("Expected a 4D tensor of shape 1xHxWxC but got ",
GetDimensionString(dimensions)));
} }
if (dimensions->data[0] != 1) { if (dimensions->data[0] != 1) {
return absl::UnimplementedError("Batch size is not equal to 1."); return absl::UnimplementedError("Batch size is not equal to 1.");
@ -263,7 +270,9 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) { absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
if (dimensions->size != 2) { if (dimensions->size != 2) {
return absl::InvalidArgumentError("Dimensions are not HW"); return absl::InvalidArgumentError(
absl::StrCat("Expected a 2D tensor of shape HxW but got ",
GetDimensionString(dimensions)));
} }
shape->h = dimensions->data[0]; shape->h = dimensions->data[0];
shape->w = dimensions->data[1]; shape->w = dimensions->data[1];
@ -273,7 +282,8 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) { absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) {
if (dimensions->size != 4) { if (dimensions->size != 4) {
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Dimensions are not OHWI: ", dimensions->size)); absl::StrCat("Expected a 4D tensor of shape OxHxWxI but got ",
GetDimensionString(dimensions)));
} }
shape->o = dimensions->data[0]; shape->o = dimensions->data[0];
shape->h = dimensions->data[1]; shape->h = dimensions->data[1];
@ -284,7 +294,9 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) {
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) { absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) {
if (dimensions->size != 4) { if (dimensions->size != 4) {
return absl::InvalidArgumentError("Dimensions are not BHWC"); return absl::InvalidArgumentError(
absl::StrCat("Expected a 4D tensor of shape BxHxWxC but got ",
GetDimensionString(dimensions)));
} }
shape->b = dimensions->data[0]; shape->b = dimensions->data[0];
shape->h = dimensions->data[1]; shape->h = dimensions->data[1];