From 1d535df0d19611bc22a500990779855d69b9b6d8 Mon Sep 17 00:00:00 2001 From: Terry Heo Date: Thu, 24 Sep 2020 20:03:59 -0700 Subject: [PATCH] GPU delegate: Check input shapes for Slice, StrideSlice Check the condition at IsSupported() phase instead of Parse() which makes a runtime error. PiperOrigin-RevId: 333652702 Change-Id: I5a3943ef6bca297876a6acc4649525d8d7853e12 --- .../delegates/gpu/common/model_builder.cc | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index e56afa30497..85cacbe226e 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1164,8 +1164,8 @@ class MulOperationParser : public TFLiteOperationParser { if (tflite_node->inputs->size != 2) { return absl::UnimplementedError("MUL requires two input tensors."); } - auto input0 = tflite::GetInput(context, tflite_node, 0); - auto input1 = tflite::GetInput(context, tflite_node, 1); + const TfLiteTensor* input0 = GetInput(context, tflite_node, 0); + const TfLiteTensor* input1 = GetInput(context, tflite_node, 1); if (input0 == nullptr || input1 == nullptr) { return absl::InvalidArgumentError("At least one input tensor is null"); } @@ -1383,7 +1383,7 @@ class PadOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); - const TfLiteTensor* pad_tensor = tflite::GetInput(context, tflite_node, 1); + const TfLiteTensor* pad_tensor = GetInput(context, tflite_node, 1); if (pad_tensor == nullptr) { return absl::InvalidArgumentError("Padding tensor was null"); } @@ -1775,6 +1775,15 @@ class SliceOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); + if (tflite_node->inputs->size < 3) { + return absl::UnimplementedError("SLICE requires 3 inputs."); + } + const TfLiteTensor* input = GetInput(context, tflite_node, 0); + if (input->dims->size != 3 && input->dims->size != 4) { + return absl::UnimplementedError( + "SLICE supports for 3 or 4 dimensional tensors only."); + } + return absl::OkStatus(); } @@ -1823,6 +1832,7 @@ class SliceOperationParser : public TFLiteOperationParser { BHWC(in_shape.b, starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]); } else { + // Error: Must be catched in IsSupported() return absl::UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only."); } @@ -1952,6 +1962,15 @@ class StridedSliceOperationParser : public TFLiteOperationParser { const TfLiteStridedSliceParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckOptionsSupport(tf_options)); + + if (tflite_node->inputs->size < 4) { + return absl::UnimplementedError("STRIDED_SLICE requires 4 inputs."); + } + const TfLiteTensor* input = GetInput(context, tflite_node, 0); + if (input->dims->size != 3 && input->dims->size != 4) { + return absl::UnimplementedError( + "STRIDED_SLICE supports for 3 or 4 dimensional tensors only."); + } return absl::OkStatus(); } @@ -1971,6 +1990,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser { bool read_without_batch = tmp.data.size() == 3; bool read_with_batch = tmp.data.size() == 4; if (!read_without_batch && !read_with_batch) { + // Error: Must be catched in IsSupported() return absl::UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only."); }