[tflite] Validate segment ids for segment_sum.
Segment identifiers in segment_sum should be in a 1-D tensor of same size as the first dimension of the input. The values of the tensor should be integers from {0, 1, 2, ... k-1}, where k is the first dimension of the input. The segment identifiers must not contain jumps and must be increasing. See https://www.tensorflow.org/api_docs/python/tf/math#Segmentation as the source for these constraints. PiperOrigin-RevId: 332510942 Change-Id: I898beaba00642c918bcd4b4d4ce893ebb190d869
This commit is contained in:
parent
2369d14f9d
commit
00c7ed7ce8
tensorflow/lite/kernels
@ -34,11 +34,24 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
const TfLiteTensor* data,
|
||||
const TfLiteTensor* segment_ids,
|
||||
TfLiteTensor* output) {
|
||||
int max_index = -1;
|
||||
// Segment ids should be of same cardinality as first input dimension and they
|
||||
// should be increasing by at most 1, from 0 (e.g., [0, 0, 1, 2, 3] is valid)
|
||||
const int segment_id_size = segment_ids->dims->data[0];
|
||||
if (segment_id_size > 0) {
|
||||
max_index = segment_ids->data.i32[segment_id_size - 1];
|
||||
TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]);
|
||||
int previous_segment_id = -1;
|
||||
for (int i = 0; i < segment_id_size; i++) {
|
||||
const int current_segment_id = GetTensorData<int32_t>(segment_ids)[i];
|
||||
if (i == 0) {
|
||||
TF_LITE_ENSURE_EQ(context, current_segment_id, 0);
|
||||
} else {
|
||||
int delta = current_segment_id - previous_segment_id;
|
||||
TF_LITE_ENSURE(context, delta == 0 || delta == 1);
|
||||
}
|
||||
previous_segment_id = current_segment_id;
|
||||
}
|
||||
|
||||
const int max_index = previous_segment_id;
|
||||
|
||||
const int data_rank = NumDimensions(data);
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
|
||||
output_shape->data[0] = max_index + 1;
|
||||
|
@ -110,5 +110,37 @@ TEST(SegmentSumOpModelTest, Float32Test_ThreeDimensions) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotSorted) {
|
||||
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
|
||||
{TensorType_INT32, {3}});
|
||||
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 3, 1});
|
||||
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
|
||||
}
|
||||
|
||||
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotConsecutive) {
|
||||
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
|
||||
{TensorType_INT32, {3}});
|
||||
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 3, 5});
|
||||
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
|
||||
}
|
||||
|
||||
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNegative) {
|
||||
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
|
||||
{TensorType_INT32, {3}});
|
||||
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||
model.PopulateTensor<int32_t>(model.segment_ids(), {-1, 0, 1});
|
||||
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
|
||||
}
|
||||
|
||||
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotTheRightCardinality) {
|
||||
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
|
||||
{TensorType_INT32, {2}});
|
||||
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1});
|
||||
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user