Check PAD tensor shape in IsSupported() phase
PiperOrigin-RevId: 313333989 Change-Id: I5a47cfaf2f5aedca919d737274e2d94c1b5825ce
This commit is contained in:
parent
ca47cbd37c
commit
a1b64bb516
@ -1348,6 +1348,17 @@ class PadOperationParser : public TFLiteOperationParser {
|
|||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
/*runtime_inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||||
|
auto pad_tensor = tflite::GetInput(context, tflite_node, 1);
|
||||||
|
if (pad_tensor->dims->size != 2) {
|
||||||
|
return absl::InvalidArgumentError(absl::StrCat(
|
||||||
|
"Invalid paddings tensor dimension: expected 2 dim, got ",
|
||||||
|
pad_tensor->dims->size, " dim"));
|
||||||
|
}
|
||||||
|
if (pad_tensor->dims->data[0] != 4 || pad_tensor->dims->data[1] != 2) {
|
||||||
|
return absl::InvalidArgumentError(absl::StrCat(
|
||||||
|
"Invalid paddings tensor shape: expected 4x2, got ",
|
||||||
|
pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1]));
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1371,6 +1382,7 @@ class PadOperationParser : public TFLiteOperationParser {
|
|||||||
|
|
||||||
// 4x2 tensor with paddings.
|
// 4x2 tensor with paddings.
|
||||||
if (paddings.shape.h != 4 || paddings.shape.w != 2) {
|
if (paddings.shape.h != 4 || paddings.shape.w != 2) {
|
||||||
|
// It shouldn't fail here since it's checked at IsSupported().
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"Paddings tensor has unexpected shape.");
|
"Paddings tensor has unexpected shape.");
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
|
|||||||
inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
|
||||||
return t->dims->data[dim];
|
return t->dims->data[dim];
|
||||||
}
|
}
|
||||||
inline const TfLiteTensor* GetInput(TfLiteContext* context,
|
inline const TfLiteTensor* GetInput(const TfLiteContext* context,
|
||||||
const TfLiteNode* node, int index) {
|
const TfLiteNode* node, int index) {
|
||||||
return &context
|
return &context
|
||||||
->tensors[flatbuffers::EndianScalar(node->inputs->data[index])];
|
->tensors[flatbuffers::EndianScalar(node->inputs->data[index])];
|
||||||
|
Loading…
x
Reference in New Issue
Block a user