Added batch support to SliceAttributes.

PiperOrigin-RevId: 274208252
This commit is contained in:
A. Unique TensorFlower 2019-10-11 11:05:18 -07:00 committed by TensorFlower Gardener
parent 0f6190f86d
commit 8c21158cd4
6 changed files with 75 additions and 61 deletions
tensorflow/lite/delegates/gpu

View File

@ -44,9 +44,9 @@ TEST_F(OpenCLOperationTest, StridedSlice) {
half(21.1f), half(21.2f), half(21.3f), half(21.4)};
SliceAttributes attr;
attr.starts = HWC(1, 0, 1);
attr.ends = HWC(2, 2, 3);
attr.strides = HWC(1, 2, 2);
attr.starts = BHWC(0, 1, 0, 1);
attr.ends = BHWC(src_tensor.shape.b, 2, 2, 3);
attr.strides = BHWC(1, 1, 2, 2);
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {

View File

@ -1806,12 +1806,18 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(
ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
}
if (attr.strides.h == 0 || attr.strides.w == 0 || attr.strides.c == 0) {
if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
attr.strides.c == 0) {
return InvalidArgumentError("stride values must be non-zero");
}
if (attr.strides.h < 0 || attr.strides.w < 0 || attr.strides.c < 0) {
if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
attr.strides.c < 0) {
return UnimplementedError("Reverse slices are not supported.");
}
if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
out_shape.b) {
return UnimplementedError("Output batch don't match");
}
if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
out_shape.h) {
return UnimplementedError("Output height doesn't match");
@ -1830,8 +1836,8 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
private:
Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, int ignore_h, int ignore_w,
int ignore_c, SliceAttributes* attr) {
const BHWC& input_shape, int ignore_b, int ignore_h,
int ignore_w, int ignore_c, SliceAttributes* attr) {
if (tf_options->begin_mask & ignore_h) {
attr->starts.h = 0;
}
@ -1841,6 +1847,9 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
if (tf_options->begin_mask & ignore_c) {
attr->starts.c = 0;
}
if (tf_options->begin_mask & ignore_b) {
attr->starts.b = 0;
}
if (tf_options->end_mask & ignore_h) {
attr->ends.h = input_shape.h;
@ -1851,6 +1860,9 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
if (tf_options->end_mask & ignore_c) {
attr->ends.c = input_shape.c;
}
if (tf_options->end_mask & ignore_b) {
attr->ends.b = input_shape.b;
}
return OkStatus();
}
@ -1864,29 +1876,27 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
if (attr->ends.c < 0) {
attr->ends.c = input_shape.c + attr->ends.c;
}
if (attr->ends.b < 0) {
attr->ends.b = input_shape.b + attr->ends.b;
}
return OkStatus();
}
Status ReadAttribsWithBatch(const ObjectReader* reader,
const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, SliceAttributes* attr) {
auto read_hwc = [&](int tensor_index, HWC* hwc) -> Status {
auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> Status {
Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
if (t.data[0] != 1 && t.data[0] != 0) {
return UnimplementedError(
"Slicing for BATCH channel is not supported. If you use batch it "
"should be 0 or 1");
}
*hwc = HWC(t.data[1], t.data[2], t.data[3]);
*bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
return OkStatus();
};
RETURN_IF_ERROR(read_hwc(1, &attr->starts));
RETURN_IF_ERROR(read_hwc(2, &attr->ends));
RETURN_IF_ERROR(read_hwc(3, &attr->strides));
RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 2, 4, 8, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
return OkStatus();
}
@ -1894,10 +1904,10 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape,
SliceAttributes* attr) {
auto read_hwc = [&](int tensor_index, HWC* hwc) -> Status {
auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> Status {
Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
*hwc = HWC(t.data[0], t.data[1], t.data[2]);
*bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
return OkStatus();
};
@ -1905,7 +1915,10 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(read_hwc(2, &attr->ends));
RETURN_IF_ERROR(read_hwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
attr->starts.b = 0;
attr->ends.b = input_shape.b;
attr->strides.b = 1;
return OkStatus();
}
Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {

View File

@ -322,7 +322,8 @@ BHWC CalculateOutputShape(const BHWC& input,
}
BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) {
return BHWC(input.b, StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
return BHWC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
}

View File

@ -302,11 +302,11 @@ struct ConstTensorAttributes {
// Simple slicing without advanced support for shrinking, reverse slicing etc.
struct SliceAttributes {
// Specifies start and end dimensions for slicing.
HWC starts;
HWC ends;
BHWC starts;
BHWC ends;
// Stride should be >= 1.
HWC strides;
BHWC strides;
};
// @return shape of a tensor after Slice2D operation is applied to the given

View File

@ -42,9 +42,9 @@ TEST(SliceTest, Identity) {
output.shape = BHWC(1, 1, 2, 2);
SliceAttributes attr;
attr.starts = HWC(0, 0, 0);
attr.ends = HWC(1, 2, 2);
attr.strides = HWC(1, 1, 1);
attr.starts = BHWC(0, 0, 0, 0);
attr.ends = BHWC(input.shape.b, 1, 2, 2);
attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output});
@ -65,9 +65,9 @@ TEST(SliceTest, NoStrides) {
output.shape = BHWC(1, 1, 2, 1);
SliceAttributes attr;
attr.starts = HWC(0, 0, 0);
attr.ends = HWC(1, 2, 1);
attr.strides = HWC(1, 1, 1);
attr.starts = BHWC(0, 0, 0, 0);
attr.ends = BHWC(input.shape.b, 1, 2, 1);
attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output});
@ -88,9 +88,9 @@ TEST(SliceTest, NoStridesStartOffset) {
output.shape = BHWC(1, 1, 1, 2);
SliceAttributes attr;
attr.starts = HWC(0, 1, 0);
attr.ends = HWC(1, 2, 2);
attr.strides = HWC(1, 1, 1);
attr.starts = BHWC(0, 0, 1, 0);
attr.ends = BHWC(input.shape.b, 1, 2, 2);
attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output});
@ -111,9 +111,9 @@ TEST(SliceTest, StridesByHeight) {
output.shape = BHWC(1, 2, 1, 1);
SliceAttributes attr;
attr.starts = HWC(0, 0, 0);
attr.ends = HWC(4, 1, 1);
attr.strides = HWC(2, 1, 1);
attr.starts = BHWC(0, 0, 0, 0);
attr.ends = BHWC(input.shape.b, 4, 1, 1);
attr.strides = BHWC(1, 2, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output});
@ -134,9 +134,9 @@ TEST(SliceTest, StridesByWidth) {
output.shape = BHWC(1, 1, 2, 1);
SliceAttributes attr;
attr.starts = HWC(0, 1, 0);
attr.ends = HWC(1, 4, 1);
attr.strides = HWC(1, 2, 1);
attr.starts = BHWC(0, 0, 1, 0);
attr.ends = BHWC(input.shape.b, 1, 4, 1);
attr.strides = BHWC(1, 1, 2, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output});
@ -157,9 +157,9 @@ TEST(SliceTest, StridesByChannels) {
output.shape = BHWC(1, 1, 1, 2);
SliceAttributes attr;
attr.starts = HWC(0, 0, 1);
attr.ends = HWC(1, 1, 4);
attr.strides = HWC(1, 1, 2);
attr.starts = BHWC(0, 0, 0, 1);
attr.ends = BHWC(input.shape.b, 1, 1, 4);
attr.strides = BHWC(1, 1, 1, 2);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output});

View File

@ -57,9 +57,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 2, 2);
SliceAttributes attr;
attr.starts = HWC(0, 0, 0);
attr.ends = HWC(1, 2, 2);
attr.strides = HWC(1, 1, 1);
attr.starts = BHWC(0, 0, 0, 0);
attr.ends = BHWC(input.shape.b, 1, 2, 2);
attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -81,9 +81,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 2, 1);
SliceAttributes attr;
attr.starts = HWC(0, 0, 0);
attr.ends = HWC(1, 2, 1);
attr.strides = HWC(1, 1, 1);
attr.starts = BHWC(0, 0, 0, 0);
attr.ends = BHWC(input.shape.b, 1, 2, 1);
attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -105,9 +105,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 1, 2);
SliceAttributes attr;
attr.starts = HWC(0, 1, 0);
attr.ends = HWC(1, 2, 2);
attr.strides = HWC(1, 1, 1);
attr.starts = BHWC(0, 0, 1, 0);
attr.ends = BHWC(input.shape.b, 1, 2, 2);
attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -129,9 +129,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 2, 1, 1);
SliceAttributes attr;
attr.starts = HWC(0, 0, 0);
attr.ends = HWC(4, 1, 1);
attr.strides = HWC(2, 1, 1);
attr.starts = BHWC(0, 0, 0, 0);
attr.ends = BHWC(input.shape.b, 4, 1, 1);
attr.strides = BHWC(1, 2, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -153,9 +153,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 2, 1);
SliceAttributes attr;
attr.starts = HWC(0, 1, 0);
attr.ends = HWC(1, 4, 1);
attr.strides = HWC(1, 2, 1);
attr.starts = BHWC(0, 0, 1, 0);
attr.ends = BHWC(input.shape.b, 1, 4, 1);
attr.strides = BHWC(1, 1, 2, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -177,9 +177,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 1, 2);
SliceAttributes attr;
attr.starts = HWC(0, 0, 1);
attr.ends = HWC(1, 1, 4);
attr.strides = HWC(1, 1, 2);
attr.starts = BHWC(0, 0, 0, 1);
attr.ends = BHWC(input.shape.b, 1, 1, 4);
attr.strides = BHWC(1, 1, 1, 2);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));