Handle edge case in TFLite StridedSlice: the length of begin is different from the rank of input
PiperOrigin-RevId: 332241089 Change-Id: I41b1436d78db87ac78c60e2d72bc8c9016e0106e
This commit is contained in:
parent
df1688cae1
commit
65b5275b55
@ -71,17 +71,27 @@ StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) {
|
||||
op_params.stop_indices_count = op_context->dims;
|
||||
op_params.strides_count = op_context->dims;
|
||||
|
||||
for (int i = 0; i < op_context->dims; ++i) {
|
||||
op_params.start_indices[i] = GetTensorData<int32_t>(op_context->begin)[i];
|
||||
op_params.stop_indices[i] = GetTensorData<int32_t>(op_context->end)[i];
|
||||
op_params.strides[i] = GetTensorData<int32_t>(op_context->strides)[i];
|
||||
}
|
||||
|
||||
op_params.begin_mask = op_context->params->begin_mask;
|
||||
op_params.ellipsis_mask = 0;
|
||||
op_params.end_mask = op_context->params->end_mask;
|
||||
op_params.new_axis_mask = 0;
|
||||
op_params.shrink_axis_mask = op_context->params->shrink_axis_mask;
|
||||
|
||||
int begin_count = GetTensorShape(op_context->begin).Dims(0);
|
||||
for (int i = 0; i < begin_count; ++i) {
|
||||
op_params.start_indices[i] = GetTensorData<int32_t>(op_context->begin)[i];
|
||||
op_params.stop_indices[i] = GetTensorData<int32_t>(op_context->end)[i];
|
||||
op_params.strides[i] = GetTensorData<int32_t>(op_context->strides)[i];
|
||||
}
|
||||
|
||||
// If the length of begin and end smaller than number of input dims, set the
|
||||
// mask bit of begin and end for that index.
|
||||
for (int i = begin_count; i < op_context->dims; ++i) {
|
||||
op_params.start_indices[i] = op_params.stop_indices[i] = 0;
|
||||
op_params.strides[i] = 1;
|
||||
op_params.begin_mask |= (1 << i);
|
||||
op_params.end_mask |= (1 << i);
|
||||
}
|
||||
return op_params;
|
||||
}
|
||||
|
||||
@ -95,7 +105,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
RuntimeShape input_shape = GetTensorShape(op_context->input);
|
||||
|
||||
for (int idx = op_context->dims - 1; idx >= 0; --idx) {
|
||||
int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
|
||||
int32_t stride = op_params.strides[idx];
|
||||
TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
|
||||
|
||||
int32_t begin =
|
||||
|
@ -649,5 +649,28 @@ TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4}));
|
||||
}
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_SmallBegin) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0});
|
||||
m.SetEnd({1});
|
||||
m.SetStrides({1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) {
|
||||
StridedSliceOpModel<TypeParam> m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
m.SetBegin({0});
|
||||
m.SetEnd({1});
|
||||
m.SetStrides({1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -43,17 +43,17 @@ def _make_strided_slice_tests(options, test_parameters, expected_tf_failures=0):
|
||||
begin = tf.compat.v1.placeholder(
|
||||
dtype=parameters["index_type"],
|
||||
name="begin",
|
||||
shape=[len(parameters["input_shape"])])
|
||||
shape=[len(parameters["begin"])])
|
||||
end = tf.compat.v1.placeholder(
|
||||
dtype=parameters["index_type"],
|
||||
name="end",
|
||||
shape=[len(parameters["input_shape"])])
|
||||
shape=[len(parameters["end"])])
|
||||
strides = None
|
||||
if parameters["strides"] is not None:
|
||||
strides = tf.compat.v1.placeholder(
|
||||
dtype=parameters["index_type"],
|
||||
name="strides",
|
||||
shape=[len(parameters["input_shape"])])
|
||||
shape=[len(parameters["strides"])])
|
||||
tensors = [input_tensor, begin, end]
|
||||
if strides is not None:
|
||||
tensors.append(strides)
|
||||
@ -141,7 +141,7 @@ def make_strided_slice_tests(options):
|
||||
"begin_mask": [0],
|
||||
"end_mask": [0],
|
||||
"shrink_axis_mask": [1],
|
||||
"constant_indices": [True],
|
||||
"constant_indices": [True, False],
|
||||
"fully_quantize": [False],
|
||||
},
|
||||
# 2-D
|
||||
|
Loading…
Reference in New Issue
Block a user