Added batch support to SliceAttributes.
PiperOrigin-RevId: 274208252
This commit is contained in:
parent
0f6190f86d
commit
8c21158cd4
tensorflow/lite/delegates/gpu
cl/kernels
common
gl/kernels
metal/kernels
@ -44,9 +44,9 @@ TEST_F(OpenCLOperationTest, StridedSlice) {
|
||||
half(21.1f), half(21.2f), half(21.3f), half(21.4)};
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(1, 0, 1);
|
||||
attr.ends = HWC(2, 2, 3);
|
||||
attr.strides = HWC(1, 2, 2);
|
||||
attr.starts = BHWC(0, 1, 0, 1);
|
||||
attr.ends = BHWC(src_tensor.shape.b, 2, 2, 3);
|
||||
attr.strides = BHWC(1, 1, 2, 2);
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
|
@ -1806,12 +1806,18 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
RETURN_IF_ERROR(
|
||||
ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
|
||||
}
|
||||
if (attr.strides.h == 0 || attr.strides.w == 0 || attr.strides.c == 0) {
|
||||
if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
|
||||
attr.strides.c == 0) {
|
||||
return InvalidArgumentError("stride values must be non-zero");
|
||||
}
|
||||
if (attr.strides.h < 0 || attr.strides.w < 0 || attr.strides.c < 0) {
|
||||
if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
|
||||
attr.strides.c < 0) {
|
||||
return UnimplementedError("Reverse slices are not supported.");
|
||||
}
|
||||
if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
|
||||
out_shape.b) {
|
||||
return UnimplementedError("Output batch don't match");
|
||||
}
|
||||
if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
|
||||
out_shape.h) {
|
||||
return UnimplementedError("Output height doesn't match");
|
||||
@ -1830,8 +1836,8 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
|
||||
private:
|
||||
Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
|
||||
const BHWC& input_shape, int ignore_h, int ignore_w,
|
||||
int ignore_c, SliceAttributes* attr) {
|
||||
const BHWC& input_shape, int ignore_b, int ignore_h,
|
||||
int ignore_w, int ignore_c, SliceAttributes* attr) {
|
||||
if (tf_options->begin_mask & ignore_h) {
|
||||
attr->starts.h = 0;
|
||||
}
|
||||
@ -1841,6 +1847,9 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
if (tf_options->begin_mask & ignore_c) {
|
||||
attr->starts.c = 0;
|
||||
}
|
||||
if (tf_options->begin_mask & ignore_b) {
|
||||
attr->starts.b = 0;
|
||||
}
|
||||
|
||||
if (tf_options->end_mask & ignore_h) {
|
||||
attr->ends.h = input_shape.h;
|
||||
@ -1851,6 +1860,9 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
if (tf_options->end_mask & ignore_c) {
|
||||
attr->ends.c = input_shape.c;
|
||||
}
|
||||
if (tf_options->end_mask & ignore_b) {
|
||||
attr->ends.b = input_shape.b;
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
@ -1864,29 +1876,27 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
if (attr->ends.c < 0) {
|
||||
attr->ends.c = input_shape.c + attr->ends.c;
|
||||
}
|
||||
if (attr->ends.b < 0) {
|
||||
attr->ends.b = input_shape.b + attr->ends.b;
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status ReadAttribsWithBatch(const ObjectReader* reader,
|
||||
const TfLiteStridedSliceParams* tf_options,
|
||||
const BHWC& input_shape, SliceAttributes* attr) {
|
||||
auto read_hwc = [&](int tensor_index, HWC* hwc) -> Status {
|
||||
auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> Status {
|
||||
Tensor<Linear, DataType::INT32> t;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
|
||||
if (t.data[0] != 1 && t.data[0] != 0) {
|
||||
return UnimplementedError(
|
||||
"Slicing for BATCH channel is not supported. If you use batch it "
|
||||
"should be 0 or 1");
|
||||
}
|
||||
*hwc = HWC(t.data[1], t.data[2], t.data[3]);
|
||||
*bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
|
||||
return OkStatus();
|
||||
};
|
||||
|
||||
RETURN_IF_ERROR(read_hwc(1, &attr->starts));
|
||||
RETURN_IF_ERROR(read_hwc(2, &attr->ends));
|
||||
RETURN_IF_ERROR(read_hwc(3, &attr->strides));
|
||||
RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
|
||||
RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
|
||||
RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
|
||||
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
|
||||
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 2, 4, 8, attr));
|
||||
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
@ -1894,10 +1904,10 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
const TfLiteStridedSliceParams* tf_options,
|
||||
const BHWC& input_shape,
|
||||
SliceAttributes* attr) {
|
||||
auto read_hwc = [&](int tensor_index, HWC* hwc) -> Status {
|
||||
auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> Status {
|
||||
Tensor<Linear, DataType::INT32> t;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
|
||||
*hwc = HWC(t.data[0], t.data[1], t.data[2]);
|
||||
*bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
|
||||
return OkStatus();
|
||||
};
|
||||
|
||||
@ -1905,7 +1915,10 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
||||
RETURN_IF_ERROR(read_hwc(2, &attr->ends));
|
||||
RETURN_IF_ERROR(read_hwc(3, &attr->strides));
|
||||
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
|
||||
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, attr));
|
||||
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
|
||||
attr->starts.b = 0;
|
||||
attr->ends.b = input_shape.b;
|
||||
attr->strides.b = 1;
|
||||
return OkStatus();
|
||||
}
|
||||
Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
|
||||
|
@ -322,7 +322,8 @@ BHWC CalculateOutputShape(const BHWC& input,
|
||||
}
|
||||
|
||||
BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) {
|
||||
return BHWC(input.b, StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
|
||||
return BHWC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
|
||||
StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
|
||||
StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
|
||||
StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
|
||||
}
|
||||
|
@ -302,11 +302,11 @@ struct ConstTensorAttributes {
|
||||
// Simple slicing without advanced support for shrinking, reverse slicing etc.
|
||||
struct SliceAttributes {
|
||||
// Specifies start and end dimensions for slicing.
|
||||
HWC starts;
|
||||
HWC ends;
|
||||
BHWC starts;
|
||||
BHWC ends;
|
||||
|
||||
// Stride should be >= 1.
|
||||
HWC strides;
|
||||
BHWC strides;
|
||||
};
|
||||
|
||||
// @return shape of a tensor after Slice2D operation is applied to the given
|
||||
|
@ -42,9 +42,9 @@ TEST(SliceTest, Identity) {
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 0, 0);
|
||||
attr.ends = HWC(1, 2, 2);
|
||||
attr.strides = HWC(1, 1, 1);
|
||||
attr.starts = BHWC(0, 0, 0, 0);
|
||||
attr.ends = BHWC(input.shape.b, 1, 2, 2);
|
||||
attr.strides = BHWC(1, 1, 1, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
|
||||
{output});
|
||||
@ -65,9 +65,9 @@ TEST(SliceTest, NoStrides) {
|
||||
output.shape = BHWC(1, 1, 2, 1);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 0, 0);
|
||||
attr.ends = HWC(1, 2, 1);
|
||||
attr.strides = HWC(1, 1, 1);
|
||||
attr.starts = BHWC(0, 0, 0, 0);
|
||||
attr.ends = BHWC(input.shape.b, 1, 2, 1);
|
||||
attr.strides = BHWC(1, 1, 1, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
|
||||
{output});
|
||||
@ -88,9 +88,9 @@ TEST(SliceTest, NoStridesStartOffset) {
|
||||
output.shape = BHWC(1, 1, 1, 2);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 1, 0);
|
||||
attr.ends = HWC(1, 2, 2);
|
||||
attr.strides = HWC(1, 1, 1);
|
||||
attr.starts = BHWC(0, 0, 1, 0);
|
||||
attr.ends = BHWC(input.shape.b, 1, 2, 2);
|
||||
attr.strides = BHWC(1, 1, 1, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
|
||||
{output});
|
||||
@ -111,9 +111,9 @@ TEST(SliceTest, StridesByHeight) {
|
||||
output.shape = BHWC(1, 2, 1, 1);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 0, 0);
|
||||
attr.ends = HWC(4, 1, 1);
|
||||
attr.strides = HWC(2, 1, 1);
|
||||
attr.starts = BHWC(0, 0, 0, 0);
|
||||
attr.ends = BHWC(input.shape.b, 4, 1, 1);
|
||||
attr.strides = BHWC(1, 2, 1, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
|
||||
{output});
|
||||
@ -134,9 +134,9 @@ TEST(SliceTest, StridesByWidth) {
|
||||
output.shape = BHWC(1, 1, 2, 1);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 1, 0);
|
||||
attr.ends = HWC(1, 4, 1);
|
||||
attr.strides = HWC(1, 2, 1);
|
||||
attr.starts = BHWC(0, 0, 1, 0);
|
||||
attr.ends = BHWC(input.shape.b, 1, 4, 1);
|
||||
attr.strides = BHWC(1, 1, 2, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
|
||||
{output});
|
||||
@ -157,9 +157,9 @@ TEST(SliceTest, StridesByChannels) {
|
||||
output.shape = BHWC(1, 1, 1, 2);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 0, 1);
|
||||
attr.ends = HWC(1, 1, 4);
|
||||
attr.strides = HWC(1, 1, 2);
|
||||
attr.starts = BHWC(0, 0, 0, 1);
|
||||
attr.ends = BHWC(input.shape.b, 1, 1, 4);
|
||||
attr.strides = BHWC(1, 1, 1, 2);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
|
||||
{output});
|
||||
|
@ -57,9 +57,9 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 0, 0);
|
||||
attr.ends = HWC(1, 2, 2);
|
||||
attr.strides = HWC(1, 1, 1);
|
||||
attr.starts = BHWC(0, 0, 0, 0);
|
||||
attr.ends = BHWC(input.shape.b, 1, 2, 2);
|
||||
attr.strides = BHWC(1, 1, 1, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
@ -81,9 +81,9 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.shape = BHWC(1, 1, 2, 1);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 0, 0);
|
||||
attr.ends = HWC(1, 2, 1);
|
||||
attr.strides = HWC(1, 1, 1);
|
||||
attr.starts = BHWC(0, 0, 0, 0);
|
||||
attr.ends = BHWC(input.shape.b, 1, 2, 1);
|
||||
attr.strides = BHWC(1, 1, 1, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
@ -105,9 +105,9 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.shape = BHWC(1, 1, 1, 2);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 1, 0);
|
||||
attr.ends = HWC(1, 2, 2);
|
||||
attr.strides = HWC(1, 1, 1);
|
||||
attr.starts = BHWC(0, 0, 1, 0);
|
||||
attr.ends = BHWC(input.shape.b, 1, 2, 2);
|
||||
attr.strides = BHWC(1, 1, 1, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
@ -129,9 +129,9 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.shape = BHWC(1, 2, 1, 1);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 0, 0);
|
||||
attr.ends = HWC(4, 1, 1);
|
||||
attr.strides = HWC(2, 1, 1);
|
||||
attr.starts = BHWC(0, 0, 0, 0);
|
||||
attr.ends = BHWC(input.shape.b, 4, 1, 1);
|
||||
attr.strides = BHWC(1, 2, 1, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
@ -153,9 +153,9 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.shape = BHWC(1, 1, 2, 1);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 1, 0);
|
||||
attr.ends = HWC(1, 4, 1);
|
||||
attr.strides = HWC(1, 2, 1);
|
||||
attr.starts = BHWC(0, 0, 1, 0);
|
||||
attr.ends = BHWC(input.shape.b, 1, 4, 1);
|
||||
attr.strides = BHWC(1, 1, 2, 1);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
@ -177,9 +177,9 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.shape = BHWC(1, 1, 1, 2);
|
||||
|
||||
SliceAttributes attr;
|
||||
attr.starts = HWC(0, 0, 1);
|
||||
attr.ends = HWC(1, 1, 4);
|
||||
attr.strides = HWC(1, 1, 2);
|
||||
attr.starts = BHWC(0, 0, 0, 1);
|
||||
attr.ends = BHWC(input.shape.b, 1, 1, 4);
|
||||
attr.strides = BHWC(1, 1, 1, 2);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
|
Loading…
Reference in New Issue
Block a user