From a1b64bb516f8eb089d53e3ceb216d1826b8e9ecd Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Tue, 26 May 2020 22:59:59 -0700 Subject: [PATCH] Check PAD tensor shape in IsSupported() phase PiperOrigin-RevId: 313333989 Change-Id: I5a47cfaf2f5aedca919d737274e2d94c1b5825ce --- .../lite/delegates/gpu/common/model_builder.cc | 12 ++++++++++++ tensorflow/lite/kernels/kernel_util.h | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 64b335f10a5..daedc277869 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1348,6 +1348,17 @@ class PadOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + auto pad_tensor = tflite::GetInput(context, tflite_node, 1); + if (pad_tensor->dims->size != 2) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid paddings tensor dimension: expected 2 dim, got ", + pad_tensor->dims->size, " dim")); + } + if (pad_tensor->dims->data[0] != 4 || pad_tensor->dims->data[1] != 2) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid paddings tensor shape: expected 4x2, got ", + pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1])); + } return absl::OkStatus(); } @@ -1371,6 +1382,7 @@ class PadOperationParser : public TFLiteOperationParser { // 4x2 tensor with paddings. if (paddings.shape.h != 4 || paddings.shape.w != 2) { + // It shouldn't fail here since it's checked at IsSupported(). return absl::InvalidArgumentError( "Paddings tensor has unexpected shape."); } diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 5793b08616d..d6a2dac8583 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -28,7 +28,7 @@ inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; } inline int SizeOfDimension(const TfLiteTensor* t, int dim) { return t->dims->data[dim]; } -inline const TfLiteTensor* GetInput(TfLiteContext* context, +inline const TfLiteTensor* GetInput(const TfLiteContext* context, const TfLiteNode* node, int index) { return &context ->tensors[flatbuffers::EndianScalar(node->inputs->data[index])];