From 68134a60241cb3b778f2b27699f98eca87bd940a Mon Sep 17 00:00:00 2001 From: Thai Nguyen <thaink@google.com> Date: Wed, 11 Nov 2020 18:58:32 -0800 Subject: [PATCH] Support string input in TFLite StridedSlice kernel PiperOrigin-RevId: 341957475 Change-Id: I96c79ba6a95b09861fe90120f3b6431f3d8e3a53 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 4 +- .../compiler/mlir/lite/tests/legalize-tf.mlir | 7 +++ tensorflow/compiler/mlir/lite/tests/ops.mlir | 6 +++ .../internal/reference/strided_slice.h | 31 +++++++++-- tensorflow/lite/kernels/register.cc | 2 +- tensorflow/lite/kernels/strided_slice.cc | 14 +++-- tensorflow/lite/kernels/strided_slice_test.cc | 53 +++++++++++++++++++ .../lite/testing/op_tests/strided_slice.py | 14 +++++ .../lite/tools/versioning/op_version.cc | 3 ++ .../lite/tools/versioning/runtime_version.cc | 1 + 10 files changed, 123 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index ae2e424ec81..a4f67c5afe9 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -3405,7 +3405,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input, TFL_I32Tensor:$begin, TFL_I32Tensor:$end, TFL_I32Tensor:$strides, @@ -3418,7 +3418,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output + TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output ); let hasOptions = 1; diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 5e36f4af802..dd8bbdb8372 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1122,6 +1122,13 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: // CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32> } +func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> { + %0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> + return %0 : tensor<1x2x2x5x!tf.string> + // CHECK-LABEL: strided_slice_with_string + // CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> +} + func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> { %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32> return %0 : tensor<?x3x5xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 3a98f6db0c4..a3aea7bd593 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1458,6 +1458,12 @@ func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>, return %0 : tensor<1x2x2x5x!tf.quint8> } +// CHECK-LABEL: testStridedSliceWithString +func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> { + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> + return %0 : tensor<1x2x2x5x!tf.string> +} + // ----- func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> { diff --git a/tensorflow/lite/kernels/internal/reference/strided_slice.h b/tensorflow/lite/kernels/internal/reference/strided_slice.h index 8b6f0c13da1..24aa798d9c9 100644 --- a/tensorflow/lite/kernels/internal/reference/strided_slice.h +++ b/tensorflow/lite/kernels/internal/reference/strided_slice.h @@ -17,18 +17,19 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/portable_tensor.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace reference_ops { + template <typename T> inline void StridedSlice(const tflite::StridedSliceParams& op_params, const RuntimeShape& unextended_input_shape, - const T* input_data, const RuntimeShape& unextended_output_shape, - T* output_data) { + SequentialTensorWriter<T>* writer) { using strided_slice::LoopCondition; using strided_slice::StartForAxis; using strided_slice::StopForAxis; @@ -57,7 +58,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params, const int start_4 = StartForAxis(params_copy, input_shape, 4); const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4); - T* out_ptr = output_data; for (int offset_0 = start_0 * input_shape.Dims(1), end_0 = stop_0 * input_shape.Dims(1), step_0 = params_copy.strides[0] * input_shape.Dims(1); @@ -81,13 +81,36 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params, for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4; !LoopCondition(offset_4, end_4, params_copy.strides[4]); offset_4 += params_copy.strides[4]) { - *out_ptr++ = input_data[offset_4]; + writer->Write(offset_4); } } } } } } + +template <typename T> +inline void StridedSlice(const tflite::StridedSliceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + SequentialTensorWriter<T> writer(input_data, output_data); + StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape, + &writer); +} + +template <typename T> +inline void StridedSlice(const tflite::StridedSliceParams& op_params, + const RuntimeShape& unextended_input_shape, + const TfLiteTensor* input, + const RuntimeShape& unextended_output_shape, + TfLiteTensor* output) { + SequentialTensorWriter<T> writer(input, output); + StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape, + &writer); +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index cd0c297a545..9aa14e579d4 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -157,7 +157,7 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version = */ 2); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(), /* min_version = */ 1, - /* max_version = */ 4); + /* max_version = */ 5); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(), /* min_version = */ 1, diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc index d10e99c1997..3f2fd580a0b 100644 --- a/tensorflow/lite/kernels/strided_slice.cc +++ b/tensorflow/lite/kernels/strided_slice.cc @@ -190,11 +190,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } StridedSliceParams op_params = BuildStridedSliceParams(&op_context); -#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ - kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \ - GetTensorData<data_type>(op_context.input), \ - GetTensorShape(op_context.output), \ - GetTensorData<data_type>(op_context.output)) +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice<data_type>( \ + op_params, GetTensorShape(op_context.input), op_context.input, \ + GetTensorShape(op_context.output), op_context.output) switch (op_context.input->type) { case kTfLiteFloat32: @@ -232,6 +231,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_STRIDED_SLICE(reference_ops, bool); } break; + case kTfLiteString: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, string); + } + break; default: TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported " diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index d66cf884474..98521b889f9 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -55,6 +55,9 @@ class StridedSliceOpModel : public SingleOpModel { void SetInput(const std::vector<input_type> data) { PopulateTensor<input_type>(input_, data); } + void SetStringInput(std::initializer_list<string> data) { + PopulateStringTensor(input_, data); + } void SetBegin(std::initializer_list<int32_t> data) { PopulateTensor<int32_t>(begin_, data); } @@ -68,6 +71,9 @@ class StridedSliceOpModel : public SingleOpModel { std::vector<input_type> GetOutput() { return ExtractVector<input_type>(output_); } + std::vector<string> GetStringOutput() { + return ExtractVector<string>(output_); + } std::vector<int> GetOutputShape() { return GetTensorShape(output_); } private: @@ -692,5 +698,52 @@ TYPED_TEST(StridedSliceOpTest, In3D_Backward) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2})); } +TEST(StridedSliceOpTest, In1D_String_NegativeBegin) { + StridedSliceOpModel<std::string> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetStringInput({"a", "b", "c", "d"}); + m.SetBegin({-3}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"b", "c"})); +} + +TEST(StridedSliceOpTest, In3D_String_BackwardSmallBegin) { + StridedSliceOpModel<std::string> m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0); + m.SetStringInput({"a", "b"}); + m.SetBegin({1}); + m.SetEnd({0}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2})); +} + +TEST(StridedSliceOpTest, In3D_String_SmallBeginWithhrinkAxis1) { + StridedSliceOpModel<std::string> m({2, 3, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetStringInput( + {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"}); + m.SetBegin({0}); + m.SetEnd({1}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetStringOutput(), + ElementsAreArray({"1", "2", "3", "4", "5", "6"})); +} + +TEST(StridedSliceOpTest, In5D_String_IdentityShrinkAxis1) { + StridedSliceOpModel<std::string> m({2, 2, 2, 1, 2}, {5}, {5}, {5}, 0, 0, 0, 0, + 1); + m.SetStringInput({"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", + "12", "13", "14", "15", "16"}); + m.SetBegin({0, 0, 0, 0, 0}); + m.SetEnd({2, 1, 2, 1, 2}); + m.SetStrides({1, 1, 1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2})); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"1", "2", "3", "4"})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/testing/op_tests/strided_slice.py b/tensorflow/lite/testing/op_tests/strided_slice.py index 3a04354c202..8668e139f34 100644 --- a/tensorflow/lite/testing/op_tests/strided_slice.py +++ b/tensorflow/lite/testing/op_tests/strided_slice.py @@ -230,6 +230,20 @@ def make_strided_slice_tests(options): "shrink_axis_mask": [0], "constant_indices": [True, False], "fully_quantize": [False], + }, + # String input. + { + "dtype": [tf.string], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0, 0, 0, 0]], + "end": [[8, 2, 2, 3]], + "strides": [[2, 1, 3, 1]], + "begin_mask": [8], + "end_mask": [3], + "shrink_axis_mask": [None, -1], + "constant_indices": [True, False], + "fully_quantize": [False], } ] _make_strided_slice_tests(options, test_parameters, expected_tf_failures=2) diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 6b9ff9c1dcf..1f84c261cdb 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -387,6 +387,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_STRIDED_SLICE: + if (op_sig.input_types.at(0) == TensorType_STRING) { + return 5; + } if (op_sig.options.single_input_op.num_dims > 4) { return 4; } diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 2e71882f469..fa0b01fc939 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -218,6 +218,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_STRIDED_SLICE, 2}, "1.14.0"}, {{BuiltinOperator_STRIDED_SLICE, 3}, "2.1.0"}, {{BuiltinOperator_STRIDED_SLICE, 4}, "2.2.0"}, + {{BuiltinOperator_STRIDED_SLICE, 5}, kPendingReleaseVersion}, {{BuiltinOperator_TOPK_V2, 1}, "1.7.0"}, {{BuiltinOperator_TOPK_V2, 2}, "1.14.0"}, {{BuiltinOperator_ARG_MAX, 1}, "1.9.0"},