Provide more verbose error message in SetAllDimensions()

Error messages will include actual tensor dimension information.

PiperOrigin-RevId: 313529606
Change-Id: I88631ea2ebba796fe266a5d0ea3ea73e4e3ad3ed
This commit is contained in:
Terry Heo 2020-05-27 22:47:03 -07:00 committed by TensorFlower Gardener
parent 8404e3e296
commit ad74ce7491

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <fp16.h>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
@ -220,14 +221,18 @@ absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
return absl::OkStatus();
}
const std::string GetDimensionString(const TfLiteIntArray* dimensions) {
return absl::StrJoin(TfLiteIntArrayView(dimensions), "x");
}
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) {
if (dimensions->size < 0) {
return absl::InvalidArgumentError("Invalid Scalar dimensions");
}
for (int i = 0; i < dimensions->size; ++i) {
if (dimensions->data[i] != 1) {
return absl::InvalidArgumentError(
"Dimension can not be reduced to scalar.");
return absl::InvalidArgumentError(absl::StrCat(
GetDimensionString(dimensions), " cannot be reduced to scalar."));
}
}
shape->v = 1;
@ -240,8 +245,8 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
}
for (int i = 0; i < dimensions->size - 1; ++i) {
if (dimensions->data[i] != 1) {
return absl::InvalidArgumentError(
"Dimension can not be reduced to linear.");
return absl::InvalidArgumentError(absl::StrCat(
GetDimensionString(dimensions), " cannot be reduced to linear."));
}
}
shape->v = dimensions->data[dimensions->size - 1];
@ -250,7 +255,9 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
if (dimensions->size != 4) {
return absl::InvalidArgumentError("Dimensions are not HWC");
return absl::InvalidArgumentError(
absl::StrCat("Expected a 4D tensor of shape 1xHxWxC but got ",
GetDimensionString(dimensions)));
}
if (dimensions->data[0] != 1) {
return absl::UnimplementedError("Batch size is not equal to 1.");
@ -263,7 +270,9 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
if (dimensions->size != 2) {
return absl::InvalidArgumentError("Dimensions are not HW");
return absl::InvalidArgumentError(
absl::StrCat("Expected a 2D tensor of shape HxW but got ",
GetDimensionString(dimensions)));
}
shape->h = dimensions->data[0];
shape->w = dimensions->data[1];
@ -273,7 +282,8 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) {
if (dimensions->size != 4) {
return absl::InvalidArgumentError(
absl::StrCat("Dimensions are not OHWI: ", dimensions->size));
absl::StrCat("Expected a 4D tensor of shape OxHxWxI but got ",
GetDimensionString(dimensions)));
}
shape->o = dimensions->data[0];
shape->h = dimensions->data[1];
@ -284,7 +294,9 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) {
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) {
if (dimensions->size != 4) {
return absl::InvalidArgumentError("Dimensions are not BHWC");
return absl::InvalidArgumentError(
absl::StrCat("Expected a 4D tensor of shape BxHxWxC but got ",
GetDimensionString(dimensions)));
}
shape->b = dimensions->data[0];
shape->h = dimensions->data[1];