GPU delegate: Check input shapes for Slice, StrideSlice
Check the condition at IsSupported() phase instead of Parse() which makes a runtime error. PiperOrigin-RevId: 333652702 Change-Id: I5a3943ef6bca297876a6acc4649525d8d7853e12
This commit is contained in:
parent
7893e4bcc1
commit
1d535df0d1
@ -1164,8 +1164,8 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
if (tflite_node->inputs->size != 2) {
|
||||
return absl::UnimplementedError("MUL requires two input tensors.");
|
||||
}
|
||||
auto input0 = tflite::GetInput(context, tflite_node, 0);
|
||||
auto input1 = tflite::GetInput(context, tflite_node, 1);
|
||||
const TfLiteTensor* input0 = GetInput(context, tflite_node, 0);
|
||||
const TfLiteTensor* input1 = GetInput(context, tflite_node, 1);
|
||||
if (input0 == nullptr || input1 == nullptr) {
|
||||
return absl::InvalidArgumentError("At least one input tensor is null");
|
||||
}
|
||||
@ -1383,7 +1383,7 @@ class PadOperationParser : public TFLiteOperationParser {
|
||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||
const TfLiteTensor* pad_tensor = tflite::GetInput(context, tflite_node, 1);
|
||||
const TfLiteTensor* pad_tensor = GetInput(context, tflite_node, 1);
|
||||
if (pad_tensor == nullptr) {
|
||||
return absl::InvalidArgumentError("Padding tensor was null");
|
||||
}
|
||||
@ -1775,6 +1775,15 @@ class SliceOperationParser : public TFLiteOperationParser {
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
||||
if (tflite_node->inputs->size < 3) {
|
||||
return absl::UnimplementedError("SLICE requires 3 inputs.");
|
||||
}
|
||||
const TfLiteTensor* input = GetInput(context, tflite_node, 0);
|
||||
if (input->dims->size != 3 && input->dims->size != 4) {
|
||||
return absl::UnimplementedError(
|
||||
"SLICE supports for 3 or 4 dimensional tensors only.");
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@ -1823,6 +1832,7 @@ class SliceOperationParser : public TFLiteOperationParser {
|
||||
BHWC(in_shape.b, starts.data[0] + sizes.data[0],
|
||||
starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]);
|
||||
} else {
|
||||
// Error: Must be catched in IsSupported()
|
||||
return absl::UnimplementedError(
|
||||
"Slicing is supported for 3 or 4 dimensional tensors only.");
|
||||
}
|
||||
@ -1952,6 +1962,15 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
const TfLiteStridedSliceParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
|
||||
|
||||
if (tflite_node->inputs->size < 4) {
|
||||
return absl::UnimplementedError("STRIDED_SLICE requires 4 inputs.");
|
||||
}
|
||||
const TfLiteTensor* input = GetInput(context, tflite_node, 0);
|
||||
if (input->dims->size != 3 && input->dims->size != 4) {
|
||||
return absl::UnimplementedError(
|
||||
"STRIDED_SLICE supports for 3 or 4 dimensional tensors only.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@ -1971,6 +1990,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
bool read_without_batch = tmp.data.size() == 3;
|
||||
bool read_with_batch = tmp.data.size() == 4;
|
||||
if (!read_without_batch && !read_with_batch) {
|
||||
// Error: Must be catched in IsSupported()
|
||||
return absl::UnimplementedError(
|
||||
"Slicing is supported for 3 or 4 dimensional tensors only.");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user