diff --git a/tensorflow/lite/kernels/internal/reference/strided_slice.h b/tensorflow/lite/kernels/internal/reference/strided_slice.h index 921c49ea77b..ba6d4c22554 100644 --- a/tensorflow/lite/kernels/internal/reference/strided_slice.h +++ b/tensorflow/lite/kernels/internal/reference/strided_slice.h @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/types.h" - namespace tflite { namespace reference_ops { @@ -28,47 +27,60 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params, const T* input_data, const RuntimeShape& unextended_output_shape, T* output_data) { + using strided_slice::LoopCondition; + using strided_slice::StartForAxis; + using strided_slice::StopForAxis; // Note that the output_shape is not used herein. tflite::StridedSliceParams params_copy = op_params; - TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5); const RuntimeShape input_shape = - RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape::ExtendedShape(5, unextended_input_shape); const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); + RuntimeShape::ExtendedShape(5, unextended_output_shape); - // Reverse and pad to 4 dimensions because that is what the runtime code - // requires (ie. all shapes must be 4D and are given backwards). - strided_slice::StridedSlicePadIndices(¶ms_copy, 4); + // Reverse and pad to 5 dimensions because that is what the runtime code + // requires (ie. all shapes must be 5D and are given backwards). + strided_slice::StridedSlicePadIndices(¶ms_copy, 5); - const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0); - const int stop_b = - strided_slice::StopForAxis(params_copy, input_shape, 0, start_b); - const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1); - const int stop_h = - strided_slice::StopForAxis(params_copy, input_shape, 1, start_h); - const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2); - const int stop_w = - strided_slice::StopForAxis(params_copy, input_shape, 2, start_w); - const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3); - const int stop_d = - strided_slice::StopForAxis(params_copy, input_shape, 3, start_d); + const int start_0 = StartForAxis(params_copy, input_shape, 0); + const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0); + const int start_1 = StartForAxis(params_copy, input_shape, 1); + const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1); + const int start_2 = StartForAxis(params_copy, input_shape, 2); + const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2); + const int start_3 = StartForAxis(params_copy, input_shape, 3); + const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3); + const int start_4 = StartForAxis(params_copy, input_shape, 4); + const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4); T* out_ptr = output_data; - for (int in_b = start_b; - !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]); - in_b += params_copy.strides[0]) { - for (int in_h = start_h; - !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]); - in_h += params_copy.strides[1]) { - for (int in_w = start_w; - !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]); - in_w += params_copy.strides[2]) { - for (int in_d = start_d; !strided_slice::LoopCondition( - in_d, stop_d, params_copy.strides[3]); - in_d += params_copy.strides[3]) { - *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)]; + for (int offset_0 = start_0 * input_shape.Dims(1), + end_0 = stop_0 * input_shape.Dims(1), + step_0 = params_copy.strides[0] * input_shape.Dims(1); + !LoopCondition(offset_0, end_0, params_copy.strides[0]); + offset_0 += step_0) { + for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2), + end_1 = (offset_0 + stop_1) * input_shape.Dims(2), + step_1 = params_copy.strides[1] * input_shape.Dims(2); + !LoopCondition(offset_1, end_1, params_copy.strides[1]); + offset_1 += step_1) { + for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3), + end_2 = (offset_1 + stop_2) * input_shape.Dims(3), + step_2 = params_copy.strides[2] * input_shape.Dims(3); + !LoopCondition(offset_2, end_2, params_copy.strides[2]); + offset_2 += step_2) { + for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4), + end_3 = (offset_2 + stop_3) * input_shape.Dims(4), + step_3 = params_copy.strides[3] * input_shape.Dims(4); + !LoopCondition(offset_3, end_3, params_copy.strides[3]); + offset_3 += step_3) { + for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4; + !LoopCondition(offset_4, end_4, params_copy.strides[4]); + offset_4 += params_copy.strides[4]) { + *out_ptr++ = input_data[offset_4]; + } } } } diff --git a/tensorflow/lite/kernels/internal/strided_slice_logic.h b/tensorflow/lite/kernels/internal/strided_slice_logic.h index 3022ac7b8e9..12dd33d3296 100644 --- a/tensorflow/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/lite/kernels/internal/strided_slice_logic.h @@ -35,7 +35,7 @@ inline int Clamp(const int v, const int lo, const int hi) { inline void StridedSlicePadIndices(tflite::StridedSliceParams* p, int dim_count) { // Add indices and mask bits to fully include extra dimensions - TFLITE_CHECK_LE(dim_count, 4); + TFLITE_CHECK_LE(dim_count, 5); TFLITE_CHECK_GE(dim_count, p->start_indices_count); TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count); TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count); diff --git a/tensorflow/lite/kernels/internal/types.h b/tensorflow/lite/kernels/internal/types.h index e96e50209bb..f4d2b5c0dad 100644 --- a/tensorflow/lite/kernels/internal/types.h +++ b/tensorflow/lite/kernels/internal/types.h @@ -128,9 +128,9 @@ struct Dims { class RuntimeShape { public: - // Shapes with dimensions up to 4 are stored directly in the structure, while + // Shapes with dimensions up to 5 are stored directly in the structure, while // larger shapes are separately allocated. - static constexpr int kMaxSmallSize = 4; + static constexpr int kMaxSmallSize = 5; RuntimeShape& operator=(RuntimeShape const&) = delete; @@ -207,8 +207,8 @@ class RuntimeShape { inline const int32* DimsData() const { return size_ > kMaxSmallSize ? dims_pointer_ : dims_; } - // The caller must ensure that the shape is no bigger than 4-D. - inline const int32* DimsDataUpTo4D() const { return dims_; } + // The caller must ensure that the shape is no bigger than 5-D. + inline const int32* DimsDataUpTo5D() const { return dims_; } inline void Resize(int dimensions_count) { if (size_ > kMaxSmallSize) { @@ -378,7 +378,7 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims, inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) { TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4); - const int* dims_data = reinterpret_cast(shape.DimsDataUpTo4D()); + const int* dims_data = reinterpret_cast(shape.DimsDataUpTo5D()); TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]); TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]); TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]); @@ -1049,11 +1049,11 @@ struct SqueezeParams { struct StridedSliceParams { int8 start_indices_count; - int32 start_indices[4]; + int32 start_indices[5]; int8 stop_indices_count; - int32 stop_indices[4]; + int32 stop_indices[5]; int8 strides_count; - int32 strides[4]; + int32 strides[5]; int16 begin_mask; int16 ellipsis_mask; diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index e8eebd81025..51534375e5f 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -156,7 +156,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(), /* min_version */ 1, - /* max_version */ 3); + /* max_version */ 4); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(), /* min_version */ 1, diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc index ba39b016624..e2ca812d193 100644 --- a/tensorflow/lite/kernels/strided_slice.cc +++ b/tensorflow/lite/kernels/strided_slice.cc @@ -142,8 +142,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); - TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, - "StridedSlice op only supports 1D-4D input arrays."); + TF_LITE_ENSURE_MSG(context, op_context.dims <= 5, + "StridedSlice op only supports 1D-5D input arrays."); // TODO(b/138098220): Remove when bug is resolved. // Currently, working on using the compiler to cannonize strided_slice, diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index 83093a09eed..8db98dba0e9 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -82,9 +82,9 @@ TYPED_TEST_SUITE(StridedSliceOpTest, DataTypes); #ifdef GTEST_HAS_DEATH_TEST TYPED_TEST(StridedSliceOpTest, UnsupportedInputSize) { - EXPECT_DEATH(StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, - 0, 0, 0, 0), - "StridedSlice op only supports 1D-4D input arrays."); + EXPECT_DEATH(StridedSliceOpModel({2, 2, 2, 2, 2, 2}, {5}, {5}, {5}, + 0, 0, 0, 0, 0), + "StridedSlice op only supports 1D-5D input arrays."); } TYPED_TEST(StridedSliceOpTest, UnssupportedArgs) { @@ -612,5 +612,29 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } + +TYPED_TEST(StridedSliceOpTest, In5D_Identity) { + StridedSliceOpModel m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0, + 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBegin({0, 0, 0, 0, 0}); + m.SetEnd({2, 1, 2, 1, 2}); + m.SetStrides({1, 1, 1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 9, 10, 11, 12})); +} + +TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) { + StridedSliceOpModel m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0, + 1); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBegin({0, 0, 0, 0, 0}); + m.SetEnd({2, 1, 2, 1, 2}); + m.SetStrides({1, 1, 1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); +} } // namespace } // namespace tflite diff --git a/tensorflow/lite/testing/op_tests/strided_slice_np_style.py b/tensorflow/lite/testing/op_tests/strided_slice_np_style.py index 95f7acabdf7..45f2e4b867a 100644 --- a/tensorflow/lite/testing/op_tests/strided_slice_np_style.py +++ b/tensorflow/lite/testing/op_tests/strided_slice_np_style.py @@ -68,6 +68,18 @@ def make_strided_slice_np_style_tests(options): [slice(1, 11, 3), Ellipsis, slice(3, 7, 2)]], }, + # Ellipsis 5d. + { + "dtype": [tf.float32], + "shape": [[11, 21, 15, 7, 9]], + "spec": [[ + slice(3, 7, 2), + slice(None), + slice(None), + slice(None), + slice(None) + ], [Ellipsis, slice(3, 7, 2)]], + }, # All combinations. { "dtype": [tf.float32], diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 3238d8ef032..0c310d15020 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -1165,6 +1165,15 @@ class StridedSlice op->new_axis_mask = options.new_axis_mask(); op->shrink_axis_mask = options.shrink_axis_mask(); } + + int GetVersion(const OperatorSignature& op_signature) const override { + const auto& ss_op = + static_cast(*op_signature.op); + ::tflite::OpSignature op_sig = + GetVersioningOpSig(builtin_op(), op_signature); + op_sig.options.strided_slice.num_dims = ss_op.start_indices.size(); + return ::tflite::GetBuiltinOperatorVersion(op_sig); + } }; class TopK_V2 : public BuiltinOperator 4) { + return 4; + } // If the op takes bool input, it is version 3. if (op_sig.input_types.at(0) == TensorType_BOOL) { return 3; @@ -431,6 +434,11 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, resize_bilinear_option->half_pixel_centers(); } } break; + // TODO(b/150176627): Add tests for GetOpSignature. + case BuiltinOperator_STRIDED_SLICE: { + op_sig.options.strided_slice.num_dims = + subgraph->tensors()->Get(op->inputs()->Get(0))->shape()->size(); + } break; default: break; diff --git a/tensorflow/lite/tools/versioning/op_version.h b/tensorflow/lite/tools/versioning/op_version.h index 7fbc5a056e5..364d1a299cc 100644 --- a/tensorflow/lite/tools/versioning/op_version.h +++ b/tensorflow/lite/tools/versioning/op_version.h @@ -49,6 +49,9 @@ typedef struct { struct { bool half_pixel_centers; } resize_bilinear; + struct { + int32_t num_dims; + } strided_slice; } options; } OpSignature;