Check PAD tensor shape in IsSupported() phase

PiperOrigin-RevId: 313333989
Change-Id: I5a47cfaf2f5aedca919d737274e2d94c1b5825ce
This commit is contained in:
Terry Heo 2020-05-26 22:59:59 -07:00 committed by TensorFlower Gardener
parent ca47cbd37c
commit a1b64bb516
2 changed files with 13 additions and 1 deletions
tensorflow/lite
delegates/gpu/common
kernels

View File

@ -1348,6 +1348,17 @@ class PadOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
auto pad_tensor = tflite::GetInput(context, tflite_node, 1);
if (pad_tensor->dims->size != 2) {
return absl::InvalidArgumentError(absl::StrCat(
"Invalid paddings tensor dimension: expected 2 dim, got ",
pad_tensor->dims->size, " dim"));
}
if (pad_tensor->dims->data[0] != 4 || pad_tensor->dims->data[1] != 2) {
return absl::InvalidArgumentError(absl::StrCat(
"Invalid paddings tensor shape: expected 4x2, got ",
pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1]));
}
return absl::OkStatus();
}
@ -1371,6 +1382,7 @@ class PadOperationParser : public TFLiteOperationParser {
// 4x2 tensor with paddings.
if (paddings.shape.h != 4 || paddings.shape.w != 2) {
// It shouldn't fail here since it's checked at IsSupported().
return absl::InvalidArgumentError(
"Paddings tensor has unexpected shape.");
}

View File

@ -28,7 +28,7 @@ inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; }
inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
return t->dims->data[dim];
}
inline const TfLiteTensor* GetInput(TfLiteContext* context,
inline const TfLiteTensor* GetInput(const TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context
->tensors[flatbuffers::EndianScalar(node->inputs->data[index])];