GPU delegate: Check input shapes for Slice, StrideSlice

Check the condition at IsSupported() phase instead of Parse() which makes
a runtime error.

PiperOrigin-RevId: 333652702
Change-Id: I5a3943ef6bca297876a6acc4649525d8d7853e12
This commit is contained in:
Terry Heo 2020-09-24 20:03:59 -07:00 committed by TensorFlower Gardener
parent 7893e4bcc1
commit 1d535df0d1

View File

@ -1164,8 +1164,8 @@ class MulOperationParser : public TFLiteOperationParser {
if (tflite_node->inputs->size != 2) {
return absl::UnimplementedError("MUL requires two input tensors.");
}
auto input0 = tflite::GetInput(context, tflite_node, 0);
auto input1 = tflite::GetInput(context, tflite_node, 1);
const TfLiteTensor* input0 = GetInput(context, tflite_node, 0);
const TfLiteTensor* input1 = GetInput(context, tflite_node, 1);
if (input0 == nullptr || input1 == nullptr) {
return absl::InvalidArgumentError("At least one input tensor is null");
}
@ -1383,7 +1383,7 @@ class PadOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
const TfLiteTensor* pad_tensor = tflite::GetInput(context, tflite_node, 1);
const TfLiteTensor* pad_tensor = GetInput(context, tflite_node, 1);
if (pad_tensor == nullptr) {
return absl::InvalidArgumentError("Padding tensor was null");
}
@ -1775,6 +1775,15 @@ class SliceOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
if (tflite_node->inputs->size < 3) {
return absl::UnimplementedError("SLICE requires 3 inputs.");
}
const TfLiteTensor* input = GetInput(context, tflite_node, 0);
if (input->dims->size != 3 && input->dims->size != 4) {
return absl::UnimplementedError(
"SLICE supports for 3 or 4 dimensional tensors only.");
}
return absl::OkStatus();
}
@ -1823,6 +1832,7 @@ class SliceOperationParser : public TFLiteOperationParser {
BHWC(in_shape.b, starts.data[0] + sizes.data[0],
starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]);
} else {
// Error: Must be catched in IsSupported()
return absl::UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
@ -1952,6 +1962,15 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
const TfLiteStridedSliceParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
if (tflite_node->inputs->size < 4) {
return absl::UnimplementedError("STRIDED_SLICE requires 4 inputs.");
}
const TfLiteTensor* input = GetInput(context, tflite_node, 0);
if (input->dims->size != 3 && input->dims->size != 4) {
return absl::UnimplementedError(
"STRIDED_SLICE supports for 3 or 4 dimensional tensors only.");
}
return absl::OkStatus();
}
@ -1971,6 +1990,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
bool read_without_batch = tmp.data.size() == 3;
bool read_with_batch = tmp.data.size() == 4;
if (!read_without_batch && !read_with_batch) {
// Error: Must be catched in IsSupported()
return absl::UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}