Added batch support to SliceAttributes.

PiperOrigin-RevId: 274208252
This commit is contained in:
A. Unique TensorFlower 2019-10-11 11:05:18 -07:00 committed by TensorFlower Gardener
parent 0f6190f86d
commit 8c21158cd4
6 changed files with 75 additions and 61 deletions

View File

@ -44,9 +44,9 @@ TEST_F(OpenCLOperationTest, StridedSlice) {
half(21.1f), half(21.2f), half(21.3f), half(21.4)}; half(21.1f), half(21.2f), half(21.3f), half(21.4)};
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(1, 0, 1); attr.starts = BHWC(0, 1, 0, 1);
attr.ends = HWC(2, 2, 3); attr.ends = BHWC(src_tensor.shape.b, 2, 2, 3);
attr.strides = HWC(1, 2, 2); attr.strides = BHWC(1, 1, 2, 2);
for (auto storage : env_.GetSupportedStorages()) { for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) { for (auto precision : env_.GetSupportedPrecisions()) {

View File

@ -1806,12 +1806,18 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR( RETURN_IF_ERROR(
ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr)); ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
} }
if (attr.strides.h == 0 || attr.strides.w == 0 || attr.strides.c == 0) { if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
attr.strides.c == 0) {
return InvalidArgumentError("stride values must be non-zero"); return InvalidArgumentError("stride values must be non-zero");
} }
if (attr.strides.h < 0 || attr.strides.w < 0 || attr.strides.c < 0) { if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
attr.strides.c < 0) {
return UnimplementedError("Reverse slices are not supported."); return UnimplementedError("Reverse slices are not supported.");
} }
if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
out_shape.b) {
return UnimplementedError("Output batch don't match");
}
if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h != if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
out_shape.h) { out_shape.h) {
return UnimplementedError("Output height doesn't match"); return UnimplementedError("Output height doesn't match");
@ -1830,8 +1836,8 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
private: private:
Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options, Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, int ignore_h, int ignore_w, const BHWC& input_shape, int ignore_b, int ignore_h,
int ignore_c, SliceAttributes* attr) { int ignore_w, int ignore_c, SliceAttributes* attr) {
if (tf_options->begin_mask & ignore_h) { if (tf_options->begin_mask & ignore_h) {
attr->starts.h = 0; attr->starts.h = 0;
} }
@ -1841,6 +1847,9 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
if (tf_options->begin_mask & ignore_c) { if (tf_options->begin_mask & ignore_c) {
attr->starts.c = 0; attr->starts.c = 0;
} }
if (tf_options->begin_mask & ignore_b) {
attr->starts.b = 0;
}
if (tf_options->end_mask & ignore_h) { if (tf_options->end_mask & ignore_h) {
attr->ends.h = input_shape.h; attr->ends.h = input_shape.h;
@ -1851,6 +1860,9 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
if (tf_options->end_mask & ignore_c) { if (tf_options->end_mask & ignore_c) {
attr->ends.c = input_shape.c; attr->ends.c = input_shape.c;
} }
if (tf_options->end_mask & ignore_b) {
attr->ends.b = input_shape.b;
}
return OkStatus(); return OkStatus();
} }
@ -1864,29 +1876,27 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
if (attr->ends.c < 0) { if (attr->ends.c < 0) {
attr->ends.c = input_shape.c + attr->ends.c; attr->ends.c = input_shape.c + attr->ends.c;
} }
if (attr->ends.b < 0) {
attr->ends.b = input_shape.b + attr->ends.b;
}
return OkStatus(); return OkStatus();
} }
Status ReadAttribsWithBatch(const ObjectReader* reader, Status ReadAttribsWithBatch(const ObjectReader* reader,
const TfLiteStridedSliceParams* tf_options, const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, SliceAttributes* attr) { const BHWC& input_shape, SliceAttributes* attr) {
auto read_hwc = [&](int tensor_index, HWC* hwc) -> Status { auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> Status {
Tensor<Linear, DataType::INT32> t; Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
if (t.data[0] != 1 && t.data[0] != 0) { *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
return UnimplementedError(
"Slicing for BATCH channel is not supported. If you use batch it "
"should be 0 or 1");
}
*hwc = HWC(t.data[1], t.data[2], t.data[3]);
return OkStatus(); return OkStatus();
}; };
RETURN_IF_ERROR(read_hwc(1, &attr->starts)); RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
RETURN_IF_ERROR(read_hwc(2, &attr->ends)); RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
RETURN_IF_ERROR(read_hwc(3, &attr->strides)); RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr)); RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 2, 4, 8, attr)); RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
return OkStatus(); return OkStatus();
} }
@ -1894,10 +1904,10 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
const TfLiteStridedSliceParams* tf_options, const TfLiteStridedSliceParams* tf_options,
const BHWC& input_shape, const BHWC& input_shape,
SliceAttributes* attr) { SliceAttributes* attr) {
auto read_hwc = [&](int tensor_index, HWC* hwc) -> Status { auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> Status {
Tensor<Linear, DataType::INT32> t; Tensor<Linear, DataType::INT32> t;
RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t)); RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
*hwc = HWC(t.data[0], t.data[1], t.data[2]); *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
return OkStatus(); return OkStatus();
}; };
@ -1905,7 +1915,10 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(read_hwc(2, &attr->ends)); RETURN_IF_ERROR(read_hwc(2, &attr->ends));
RETURN_IF_ERROR(read_hwc(3, &attr->strides)); RETURN_IF_ERROR(read_hwc(3, &attr->strides));
RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr)); RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, attr)); RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
attr->starts.b = 0;
attr->ends.b = input_shape.b;
attr->strides.b = 1;
return OkStatus(); return OkStatus();
} }
Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) { Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {

View File

@ -322,7 +322,8 @@ BHWC CalculateOutputShape(const BHWC& input,
} }
BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) { BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) {
return BHWC(input.b, StridedSize(attr.ends.h - attr.starts.h, attr.strides.h), return BHWC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
StridedSize(attr.ends.w - attr.starts.w, attr.strides.w), StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
StridedSize(attr.ends.c - attr.starts.c, attr.strides.c)); StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
} }

View File

@ -302,11 +302,11 @@ struct ConstTensorAttributes {
// Simple slicing without advanced support for shrinking, reverse slicing etc. // Simple slicing without advanced support for shrinking, reverse slicing etc.
struct SliceAttributes { struct SliceAttributes {
// Specifies start and end dimensions for slicing. // Specifies start and end dimensions for slicing.
HWC starts; BHWC starts;
HWC ends; BHWC ends;
// Stride should be >= 1. // Stride should be >= 1.
HWC strides; BHWC strides;
}; };
// @return shape of a tensor after Slice2D operation is applied to the given // @return shape of a tensor after Slice2D operation is applied to the given

View File

@ -42,9 +42,9 @@ TEST(SliceTest, Identity) {
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 0, 0); attr.starts = BHWC(0, 0, 0, 0);
attr.ends = HWC(1, 2, 2); attr.ends = BHWC(input.shape.b, 1, 2, 2);
attr.strides = HWC(1, 1, 1); attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output}); {output});
@ -65,9 +65,9 @@ TEST(SliceTest, NoStrides) {
output.shape = BHWC(1, 1, 2, 1); output.shape = BHWC(1, 1, 2, 1);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 0, 0); attr.starts = BHWC(0, 0, 0, 0);
attr.ends = HWC(1, 2, 1); attr.ends = BHWC(input.shape.b, 1, 2, 1);
attr.strides = HWC(1, 1, 1); attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output}); {output});
@ -88,9 +88,9 @@ TEST(SliceTest, NoStridesStartOffset) {
output.shape = BHWC(1, 1, 1, 2); output.shape = BHWC(1, 1, 1, 2);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 1, 0); attr.starts = BHWC(0, 0, 1, 0);
attr.ends = HWC(1, 2, 2); attr.ends = BHWC(input.shape.b, 1, 2, 2);
attr.strides = HWC(1, 1, 1); attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output}); {output});
@ -111,9 +111,9 @@ TEST(SliceTest, StridesByHeight) {
output.shape = BHWC(1, 2, 1, 1); output.shape = BHWC(1, 2, 1, 1);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 0, 0); attr.starts = BHWC(0, 0, 0, 0);
attr.ends = HWC(4, 1, 1); attr.ends = BHWC(input.shape.b, 4, 1, 1);
attr.strides = HWC(2, 1, 1); attr.strides = BHWC(1, 2, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output}); {output});
@ -134,9 +134,9 @@ TEST(SliceTest, StridesByWidth) {
output.shape = BHWC(1, 1, 2, 1); output.shape = BHWC(1, 1, 2, 1);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 1, 0); attr.starts = BHWC(0, 0, 1, 0);
attr.ends = HWC(1, 4, 1); attr.ends = BHWC(input.shape.b, 1, 4, 1);
attr.strides = HWC(1, 2, 1); attr.strides = BHWC(1, 1, 2, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output}); {output});
@ -157,9 +157,9 @@ TEST(SliceTest, StridesByChannels) {
output.shape = BHWC(1, 1, 1, 2); output.shape = BHWC(1, 1, 1, 2);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 0, 1); attr.starts = BHWC(0, 0, 0, 1);
attr.ends = HWC(1, 1, 4); attr.ends = BHWC(input.shape.b, 1, 1, 4);
attr.strides = HWC(1, 1, 2); attr.strides = BHWC(1, 1, 1, 2);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, SingleOpModel model({ToString(OperationType::SLICE), attr}, {input},
{output}); {output});

View File

@ -57,9 +57,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 0, 0); attr.starts = BHWC(0, 0, 0, 0);
attr.ends = HWC(1, 2, 2); attr.ends = BHWC(input.shape.b, 1, 2, 2);
attr.strides = HWC(1, 1, 1); attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -81,9 +81,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 2, 1); output.shape = BHWC(1, 1, 2, 1);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 0, 0); attr.starts = BHWC(0, 0, 0, 0);
attr.ends = HWC(1, 2, 1); attr.ends = BHWC(input.shape.b, 1, 2, 1);
attr.strides = HWC(1, 1, 1); attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -105,9 +105,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 1, 2); output.shape = BHWC(1, 1, 1, 2);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 1, 0); attr.starts = BHWC(0, 0, 1, 0);
attr.ends = HWC(1, 2, 2); attr.ends = BHWC(input.shape.b, 1, 2, 2);
attr.strides = HWC(1, 1, 1); attr.strides = BHWC(1, 1, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -129,9 +129,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 2, 1, 1); output.shape = BHWC(1, 2, 1, 1);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 0, 0); attr.starts = BHWC(0, 0, 0, 0);
attr.ends = HWC(4, 1, 1); attr.ends = BHWC(input.shape.b, 4, 1, 1);
attr.strides = HWC(2, 1, 1); attr.strides = BHWC(1, 2, 1, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -153,9 +153,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 2, 1); output.shape = BHWC(1, 1, 2, 1);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 1, 0); attr.starts = BHWC(0, 0, 1, 0);
attr.ends = HWC(1, 4, 1); attr.ends = BHWC(input.shape.b, 1, 4, 1);
attr.strides = HWC(1, 2, 1); attr.strides = BHWC(1, 1, 2, 1);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
@ -177,9 +177,9 @@ using ::tflite::gpu::metal::SingleOpModel;
output.shape = BHWC(1, 1, 1, 2); output.shape = BHWC(1, 1, 1, 2);
SliceAttributes attr; SliceAttributes attr;
attr.starts = HWC(0, 0, 1); attr.starts = BHWC(0, 0, 0, 1);
attr.ends = HWC(1, 1, 4); attr.ends = BHWC(input.shape.b, 1, 1, 4);
attr.strides = HWC(1, 1, 2); attr.strides = BHWC(1, 1, 1, 2);
SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::SLICE), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));