Provide more verbose error message in SetAllDimensions()
Error messages will include actual tensor dimension information. PiperOrigin-RevId: 313529606 Change-Id: I88631ea2ebba796fe266a5d0ea3ea73e4e3ad3ed
This commit is contained in:
parent
8404e3e296
commit
ad74ce7491
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include <fp16.h>
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context.h"
|
||||
@ -220,14 +221,18 @@ absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const std::string GetDimensionString(const TfLiteIntArray* dimensions) {
|
||||
return absl::StrJoin(TfLiteIntArrayView(dimensions), "x");
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) {
|
||||
if (dimensions->size < 0) {
|
||||
return absl::InvalidArgumentError("Invalid Scalar dimensions");
|
||||
}
|
||||
for (int i = 0; i < dimensions->size; ++i) {
|
||||
if (dimensions->data[i] != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Dimension can not be reduced to scalar.");
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
GetDimensionString(dimensions), " cannot be reduced to scalar."));
|
||||
}
|
||||
}
|
||||
shape->v = 1;
|
||||
@ -240,8 +245,8 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
|
||||
}
|
||||
for (int i = 0; i < dimensions->size - 1; ++i) {
|
||||
if (dimensions->data[i] != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Dimension can not be reduced to linear.");
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
GetDimensionString(dimensions), " cannot be reduced to linear."));
|
||||
}
|
||||
}
|
||||
shape->v = dimensions->data[dimensions->size - 1];
|
||||
@ -250,7 +255,9 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError("Dimensions are not HWC");
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Expected a 4D tensor of shape 1xHxWxC but got ",
|
||||
GetDimensionString(dimensions)));
|
||||
}
|
||||
if (dimensions->data[0] != 1) {
|
||||
return absl::UnimplementedError("Batch size is not equal to 1.");
|
||||
@ -263,7 +270,9 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
|
||||
if (dimensions->size != 2) {
|
||||
return absl::InvalidArgumentError("Dimensions are not HW");
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Expected a 2D tensor of shape HxW but got ",
|
||||
GetDimensionString(dimensions)));
|
||||
}
|
||||
shape->h = dimensions->data[0];
|
||||
shape->w = dimensions->data[1];
|
||||
@ -273,7 +282,8 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Dimensions are not OHWI: ", dimensions->size));
|
||||
absl::StrCat("Expected a 4D tensor of shape OxHxWxI but got ",
|
||||
GetDimensionString(dimensions)));
|
||||
}
|
||||
shape->o = dimensions->data[0];
|
||||
shape->h = dimensions->data[1];
|
||||
@ -284,7 +294,9 @@ absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) {
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError("Dimensions are not BHWC");
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Expected a 4D tensor of shape BxHxWxC but got ",
|
||||
GetDimensionString(dimensions)));
|
||||
}
|
||||
shape->b = dimensions->data[0];
|
||||
shape->h = dimensions->data[1];
|
||||
|
Loading…
Reference in New Issue
Block a user