Fixed SliceOperation parser.

PiperOrigin-RevId: 276527702
Change-Id: Iec24be8dab1fba0e0a41a5c23792d3ce18dbdd93
This commit is contained in:
A. Unique TensorFlower 2019-10-24 11:17:17 -07:00 committed by TensorFlower Gardener
parent ee06cf6e0a
commit 95a6f4571c

View File

@ -1777,23 +1777,23 @@ class SliceOperationParser : public TFLiteOperationParser {
SliceAttributes attr;
attr.strides = BHWC(1, 1, 1, 1);
Tensor<Linear, DataType::INT32> starts, ends;
Tensor<Linear, DataType::INT32> starts, sizes;
RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
RETURN_IF_ERROR(reader->ReadTensor(2, &ends));
RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
if (starts.data.size() != sizes.data.size()) {
return InvalidArgumentError("Starts amount != sizes amount.");
}
if (starts.data.size() == 4) {
attr.starts =
BHWC(starts.data[0], starts.data[1], starts.data[2], starts.data[3]);
attr.ends =
BHWC(starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1],
starts.data[2] + sizes.data[2], starts.data[3] + sizes.data[3]);
} else if (starts.data.size() == 3) {
attr.starts = BHWC(0, starts.data[0], starts.data[1], starts.data[2]);
} else {
return UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");
}
if (ends.data.size() == 4) {
attr.ends = BHWC(ends.data[0], ends.data[1], ends.data[2], ends.data[3]);
} else if (ends.data.size() == 3) {
attr.ends =
BHWC(input->tensor.shape.b, ends.data[0], ends.data[1], ends.data[2]);
BHWC(input->tensor.shape.b, starts.data[0] + sizes.data[0],
starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]);
} else {
return UnimplementedError(
"Slicing is supported for 3 or 4 dimensional tensors only.");