Fixed SliceOperation parser.
PiperOrigin-RevId: 276527702 Change-Id: Iec24be8dab1fba0e0a41a5c23792d3ce18dbdd93
This commit is contained in:
parent
ee06cf6e0a
commit
95a6f4571c
@ -1777,23 +1777,23 @@ class SliceOperationParser : public TFLiteOperationParser {
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.strides = BHWC(1, 1, 1, 1);
|
||||
Tensor<Linear, DataType::INT32> starts, ends;
|
||||
Tensor<Linear, DataType::INT32> starts, sizes;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
|
||||
RETURN_IF_ERROR(reader->ReadTensor(2, &ends));
|
||||
RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
|
||||
if (starts.data.size() != sizes.data.size()) {
|
||||
return InvalidArgumentError("Starts amount != sizes amount.");
|
||||
}
|
||||
if (starts.data.size() == 4) {
|
||||
attr.starts =
|
||||
BHWC(starts.data[0], starts.data[1], starts.data[2], starts.data[3]);
|
||||
attr.ends =
|
||||
BHWC(starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1],
|
||||
starts.data[2] + sizes.data[2], starts.data[3] + sizes.data[3]);
|
||||
} else if (starts.data.size() == 3) {
|
||||
attr.starts = BHWC(0, starts.data[0], starts.data[1], starts.data[2]);
|
||||
} else {
|
||||
return UnimplementedError(
|
||||
"Slicing is supported for 3 or 4 dimensional tensors only.");
|
||||
}
|
||||
if (ends.data.size() == 4) {
|
||||
attr.ends = BHWC(ends.data[0], ends.data[1], ends.data[2], ends.data[3]);
|
||||
} else if (ends.data.size() == 3) {
|
||||
attr.ends =
|
||||
BHWC(input->tensor.shape.b, ends.data[0], ends.data[1], ends.data[2]);
|
||||
BHWC(input->tensor.shape.b, starts.data[0] + sizes.data[0],
|
||||
starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]);
|
||||
} else {
|
||||
return UnimplementedError(
|
||||
"Slicing is supported for 3 or 4 dimensional tensors only.");
|
||||
|
Loading…
Reference in New Issue
Block a user