diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index baf2a5c904a..79ab5e92be3 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -1777,23 +1777,23 @@ class SliceOperationParser : public TFLiteOperationParser { SliceAttributes attr; attr.strides = BHWC(1, 1, 1, 1); - Tensor starts, ends; + Tensor starts, sizes; RETURN_IF_ERROR(reader->ReadTensor(1, &starts)); - RETURN_IF_ERROR(reader->ReadTensor(2, &ends)); + RETURN_IF_ERROR(reader->ReadTensor(2, &sizes)); + if (starts.data.size() != sizes.data.size()) { + return InvalidArgumentError("Starts amount != sizes amount."); + } if (starts.data.size() == 4) { attr.starts = BHWC(starts.data[0], starts.data[1], starts.data[2], starts.data[3]); + attr.ends = + BHWC(starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1], + starts.data[2] + sizes.data[2], starts.data[3] + sizes.data[3]); } else if (starts.data.size() == 3) { attr.starts = BHWC(0, starts.data[0], starts.data[1], starts.data[2]); - } else { - return UnimplementedError( - "Slicing is supported for 3 or 4 dimensional tensors only."); - } - if (ends.data.size() == 4) { - attr.ends = BHWC(ends.data[0], ends.data[1], ends.data[2], ends.data[3]); - } else if (ends.data.size() == 3) { attr.ends = - BHWC(input->tensor.shape.b, ends.data[0], ends.data[1], ends.data[2]); + BHWC(input->tensor.shape.b, starts.data[0] + sizes.data[0], + starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]); } else { return UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only.");